diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index f18aef61771..7375dd4a2c4 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -94,7 +94,7 @@ steps: - "/fsx/hf_cache:/fsx/hf_cache" - label: "Diffusion Parallelism Test" - timeout_in_minutes: 20 + timeout_in_minutes: 25 depends_on: image-build commands: - pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py @@ -116,7 +116,7 @@ steps: timeout_in_minutes: 20 depends_on: image-build commands: - - pytest -s -v tests/diffusion/test_gpu_worker.py + - pytest -s -v tests/diffusion/test_gpu_diffusion_worker.py agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU plugins: diff --git a/.buildkite/scripts/simple_test.sh b/.buildkite/scripts/simple_test.sh index 33248d99cde..55ac27cec9f 100755 --- a/.buildkite/scripts/simple_test.sh +++ b/.buildkite/scripts/simple_test.sh @@ -52,3 +52,4 @@ VENV_PYTHON="${VENV_DIR}/bin/python" "${VENV_PYTHON}" -m pytest -v -s tests/entrypoints/ "${VENV_PYTHON}" -m pytest -v -s tests/diffusion/cache/ "${VENV_PYTHON}" -m pytest -v -s tests/model_executor/models/qwen2_5_omni/test_audio_length.py +"${VENV_PYTHON}" -m pytest -v -s tests/worker/ diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index bcece8c495d..86d65f15bcf 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -54,7 +54,7 @@ steps: commands: - export MIOPEN_DEBUG_CONV_DIRECT=0 - export MIOPEN_DEBUG_CONV_GEMM=0 - - pytest -s -v tests/diffusion/test_gpu_worker.py + - pytest -s -v tests/diffusion/test_gpu_diffusion_worker.py - label: "Omni Model Test Qwen2-5-Omni" timeout_in_minutes: 15 diff --git a/README.md b/README.md index 829d683c3c3..1a8ad999050 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,13 @@ Easy, fast, and cheap omni-modality model serving for everyone

-| Documentation | User Forum | Developer Slack | +| Documentation | User Forum | Developer Slack | WeChat |

--- *Latest News* πŸ”₯ - +- [2026/01] We released [0.14.0rc1](https://github.com/vllm-project/vllm-omni/releases/tag/v0.14.0rc2). - [2026/01] We released [0.12.0rc1](https://github.com/vllm-project/vllm-omni/releases/tag/v0.12.0rc1) - a major RC milestone focused on maturing the diffusion stack, strengthening OpenAI-compatible serving, expanding omni-model coverage, and improving stability across platforms (GPU/NPU/ROCm), please check our latest [design](https://docs.google.com/presentation/d/1qv4qMW1rKAqDREMXiUDLIgqqHQe7TDPj/edit?usp=sharing&ouid=110473603432222024453&rtpof=true&sd=true). - [2025/11] vLLM community officially released [vllm-project/vllm-omni](https://github.com/vllm-project/vllm-omni) in order to support omni-modality models serving. @@ -70,6 +70,10 @@ Please check out [Contributing to vLLM-Omni](https://vllm-omni.readthedocs.io/en ## Join the Community Feel free to ask questions, provide feedbacks and discuss with fellow users of vLLM-Omni in `#sig-omni` slack channel at [slack.vllm.ai](https://slack.vllm.ai) or vLLM user forum at [discuss.vllm.ai](https://discuss.vllm.ai). +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=vllm-project/vllm-omni&type=date&legend=top-left)](https://www.star-history.com/#vllm-project/vllm-omni&type=date&legend=top-left) + ## License Apache License 2.0, as found in the [LICENSE](./LICENSE) file. diff --git a/collect_env.py b/collect_env.py index 71cec0c4a87..8b09379e1a3 100644 --- a/collect_env.py +++ b/collect_env.py @@ -57,6 +57,7 @@ "cpu_info", "rocm_version", # vllm specific field "vllm_version", # vllm specific field + "vllm_omni_version", # vllm-omni specific field "vllm_build_flags", # vllm specific field "gpu_topo", # vllm specific field "env_vars", @@ -289,6 +290,31 @@ def get_vllm_version(): return __version__ +def get_vllm_omni_version(run_lambda): + try: + import vllm_omni + from vllm_omni import __version__, __version_tuple__ + + version_str = __version_tuple__[-1] + if isinstance(version_str, str) and version_str.startswith("g"): + if "." in version_str: + git_sha = version_str.split(".")[0][1:] + date = version_str.split(".")[-1][1:] + return f"{__version__} (git sha: {git_sha}, date: {date})" + else: + git_sha = version_str[1:] + return f"{__version__} (git sha: {git_sha})" + + package_dir = os.path.dirname(os.path.abspath(vllm_omni.__file__)) + git_sha = run_and_read_all(run_lambda, f"git -C {package_dir} rev-parse --short HEAD") + if git_sha: + return f"{__version__} (git sha: {git_sha})" + + return __version__ + except ImportError: + return "N/A (vllm_omni not installed)" + + def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. return "CUDA Archs: {}; ROCm: {}".format( @@ -524,6 +550,7 @@ def get_version_or_na(cfg, prefix): rocm_version = get_rocm_version(run_lambda) vllm_version = get_vllm_version() + vllm_omni_version = get_vllm_omni_version(run_lambda) vllm_build_flags = summarize_vllm_build_flags() gpu_topo = get_gpu_topo(run_lambda) @@ -555,6 +582,7 @@ def get_version_or_na(cfg, prefix): cpu_info=get_cpu_info(run_lambda), rocm_version=rocm_version, vllm_version=vllm_version, + vllm_omni_version=vllm_omni_version, vllm_build_flags=vllm_build_flags, gpu_topo=gpu_topo, env_vars=get_env_vars(), @@ -621,6 +649,7 @@ def get_version_or_na(cfg, prefix): ============================== ROCM Version : {rocm_version} vLLM Version : {vllm_version} +vLLM-Omni Version : {vllm_omni_version} vLLM Build Flags: {vllm_build_flags} GPU Topology: diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci index 5e1d00a5f88..55599df29ed 100644 --- a/docker/Dockerfile.ci +++ b/docker/Dockerfile.ci @@ -1,11 +1,17 @@ ARG VLLM_BASE_IMAGE=vllm/vllm-openai -ARG VLLM_BASE_TAG=v0.12.0 +ARG VLLM_BASE_TAG=v0.14.0rc2 FROM ${VLLM_BASE_IMAGE}:${VLLM_BASE_TAG} ARG APP_DIR=/workspace/vllm-omni WORKDIR ${APP_DIR} COPY . . +# Install system dependencies +RUN apt-get update && \ + apt-get install -y ffmpeg && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + # Install vllm-omni into the same uv-managed Python environment used by the base image. RUN uv pip install --python "$(python3 -c 'import sys; print(sys.executable)')" --no-cache-dir ".[dev]" diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 7fabb9c3c68..80f709deae0 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -2,7 +2,7 @@ ARG BASE_IMAGE=rocm/vllm-dev:nightly_main_20251205 FROM ${BASE_IMAGE} ARG COMMON_WORKDIR=/app -ARG VLLM_VERSION=v0.12.0 +ARG VLLM_VERSION=v0.14.0rc2 ARG PYTORCH_ROCM_ARCH="gfx942;gfx950" WORKDIR ${COMMON_WORKDIR} diff --git a/docs/.nav.yml b/docs/.nav.yml index abe4865af33..be930637175 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -3,7 +3,7 @@ nav: - User Guide: - Getting Started: - getting_started/quickstart.md - - getting_started/installation + - getting_started/installation/* - Serving: - OpenAI-Compatible API: - Image Generation: serving/image_generation_api.md @@ -26,17 +26,16 @@ nav: - Configuration: - configuration/README.md - configuration/* - - Diffusion Acceleration: - - Overview: user_guide/diffusion_acceleration.md - - Acceleration Methods: - - TeaCache: user_guide/acceleration/teacache.md - - Cache-DiT: user_guide/acceleration/cache_dit_acceleration.md - - Parallelism Acceleration: user_guide/acceleration/parallelism_acceleration.md - Models: - models/supported_models.md - Features: - Sleep Mode: features/sleep_mode.md - - CPU Offloading for Diffusion Model: features/cpu_offload_diffusion.md + - Diffusion Features: + - Overview: user_guide/diffusion_acceleration.md + - TeaCache: user_guide/diffusion/teacache.md + - Cache-DiT: user_guide/diffusion/cache_dit_acceleration.md + - Parallelism Acceleration: user_guide/diffusion/parallelism_acceleration.md + - CPU Offloading: user_guide/diffusion/cpu_offload_diffusion.md - Developer Guide: - General: - contributing/README.md @@ -47,7 +46,6 @@ nav: - contributing/model/adding_omni_model.md - contributing/model/adding_diffusion_model.md - CI: contributing/ci - - Tests: contributing/tests - Design Documents: - design/index.md - design/architecture_overview.md diff --git a/docs/api/README.md b/docs/api/README.md index 4fa85cdc663..a1f07011118 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -82,6 +82,7 @@ Model execution components. - [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_code_predictor_mtp.Qwen3OmniMoeTalkerCodePredictor][] - [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_talker.Qwen3OmniMoeModel][] - [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_talker.Qwen3OmniMoeTalkerForConditionalGeneration][] +- [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_talker.Qwen3OmniMoeTalkerSharedExpertWrapper][] - [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker.Qwen3MoeLLMForCausalLM][] - [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker.Qwen3MoeLLMModel][] - [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker.Qwen3OmniMoeConditionalGenerationMixin][] @@ -102,8 +103,9 @@ Configuration classes. Worker classes and model runners for distributed inference. -- [vllm_omni.diffusion.worker.gpu_worker.GPUWorker][] -- [vllm_omni.diffusion.worker.gpu_worker.WorkerProc][] +- [vllm_omni.diffusion.worker.gpu_diffusion_model_runner.GPUDiffusionModelRunner][] +- [vllm_omni.diffusion.worker.gpu_diffusion_worker.GPUDiffusionWorker][] +- [vllm_omni.diffusion.worker.gpu_diffusion_worker.WorkerProc][] - [vllm_omni.diffusion.worker.npu.npu_worker.NPUWorker][] - [vllm_omni.diffusion.worker.npu.npu_worker.NPUWorkerProc][] - [vllm_omni.worker.gpu_ar_model_runner.ExecuteModelState][] diff --git a/docs/assets/WeChat.jpg b/docs/assets/WeChat.jpg new file mode 100644 index 00000000000..5a63afde85c Binary files /dev/null and b/docs/assets/WeChat.jpg differ diff --git a/docs/configuration/README.md b/docs/configuration/README.md index 40439d51121..02440b95dce 100644 --- a/docs/configuration/README.md +++ b/docs/configuration/README.md @@ -2,7 +2,7 @@ This section lists the most common options for running vLLM-Omni. -For options within a vLLM Engine. Please refer to [vLLM Configuration](https://docs.vllm.ai/en/v0.12.0/configuration/index.html) +For options within a vLLM Engine. Please refer to [vLLM Configuration](https://docs.vllm.ai/en/v0.14.0/configuration/index.html) Currently, the main options are maintained by stage configs for each model. @@ -16,6 +16,6 @@ For introduction, please check [Introduction for stage config](./stage_configs.m ## Optimization Features -- **[TeaCache Configuration](../user_guide/acceleration/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss -- **[Cache-DiT Configuration](../user_guide/acceleration/cache_dit_acceleration.md)** - Enable Cache-DiT as cache acceleration backends for DiT models -- **[Parallelism Configuration](../user_guide/acceleration/parallelism_acceleration.md)** - Enable parallelism (e.g., sequence parallelism) for for DiT models +- **[TeaCache Configuration](../user_guide/diffusion/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss +- **[Cache-DiT Configuration](../user_guide/diffusion/cache_dit_acceleration.md)** - Enable Cache-DiT as cache acceleration backends for DiT models +- **[Parallelism Configuration](../user_guide/diffusion/parallelism_acceleration.md)** - Enable parallelism (e.g., sequence parallelism) for for DiT models diff --git a/docs/contributing/ci/tests_markers.md b/docs/contributing/ci/tests_markers.md new file mode 100644 index 00000000000..bf56914f8da --- /dev/null +++ b/docs/contributing/ci/tests_markers.md @@ -0,0 +1,160 @@ +# Markers for Tests + +By adding markers before test functions, tests can later be executed uniformly by simply declaring the corresponding marker type. + +## Current Markers +Defined in `pyproject.toml`: + +| Marker | Description | +| ------------------ | ------------------------------------------------------- | +| `core_model` | Core model tests (run in each PR) | +| `diffusion` | Diffusion model tests | +| `omni` | Omni model tests | +| `cache` | Cache backend tests | +| `parallel` | Parallelism/distributed tests | +| `cpu` | Tests that run on CPU | +| `gpu` | Tests that run on GPU (auto-added) | +| `cuda` | Tests that run on CUDA (auto-added) | +| `rocm` | Tests that run on AMD/ROCm (auto-added) | +| `npu` | Tests that run on NPU/Ascend (auto-added) | +| `H100` | Tests that require H100 GPU | +| `L4` | Tests that require L4 GPU | +| `MI325` | Tests that require MI325 GPU (AMD/ROCm) | +| `A2` | Tests that require A2 NPU | +| `A3` | Tests that require A3 NPU | +| `distributed_cuda` | Tests that require multi cards on CUDA platform | +| `distributed_rocm` | Tests that require multi cards on ROCm platform | +| `distributed_npu` | Tests that require multi cards on NPU platform | +| `skipif_cuda` | Skip if the num of CUDA cards is less than the required | +| `skipif_rocm` | Skip if the num of ROCm cards is less than the required | +| `skipif_npu` | Skip if the num of NPU cards is less than the required | +| `slow` | Slow tests (may skip in quick CI) | +| `benchmark` | Benchmark tests | + +For those markers shown as auto-added, they will be added by the `@hardware_test` decorator. + +### Example usage for markers + +```python +from tests.utils import hardware_test + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards=2, +) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_video_to_audio() + ... +``` +### Decorator: `@hardware_test` + +This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/utils.py` performs the following actions: + +1. **Applies platform and resource markers** + Adds the appropriate pytest markers for each specified hardware platform (e.g., `cuda`, `rocm`, `npu`) and resource type (e.g., `L4`, `H100`, `MI325`, `A2`, `A3`). + ``` + @pytest.mark.cuda + @pytest.mark.L4 + ``` +2. **Handles multi-card (distributed) scenarios** + For tests requiring multiple cards, it automatically adds distributed markers such as `distributed_cuda`, `distributed_rocm`, or `distributed_npu`. + ``` + @pytest.mark.distributed_cuda(num_cards=num_cards) + ``` +3. **Supports flexible card requirements** + Accepts `num_cards` as either a single integer for all platforms or as a dictionary with per-platform values. If not specified, defaults to 1 card per platform. + +4. **Integrates resource validation** + On CUDA, adds a skip marker (`skipif_cuda`) if the system does not have the required number of devices. + Support for `skipif_rocm` and `skipif_npu` will be implemented later. + + +5. **Runs each test in a new process** + Automatically wraps the distributed test with a decorator (`@create_new_process_for_each_test`) to ensure isolation and compatibility with multi-process hardware backends. + +6. **Works with pytest filtering** + Allows tests to be filtered and selected at runtime using standard pytest marker expressions (e.g., `-m "distributed_cuda and L4"`). + +#### Example usage for decorator +- Single call for multiple platforms: + ```python + @hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards={"cuda": 2, "rocm": 2, "npu": 2}, + ) + ``` + or + ```python + @hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards=2, + ) + ``` +- `res` must be a dict; supported resources: CUDA (L4/H100), ROCm (MI325), NPU (A2/A3) +- `num_cards` can be int (all platforms) or dict (per platform); defaults to 1 when missing +- `hardware_test` automatically applies `@create_new_process_for_each_test` for distributed tests. +- Distributed markers (`distributed_cuda`, `distributed_rocm`, `distributed_npu`) are auto-added for multi-card cases +- Filtering examples: + - CUDA only: `pytest -m "distributed_cuda and L4"` + - ROCm only: `pytest -m "distributed_rocm and MI325"` + - NPU only: `pytest -m "distributed_npu"` + +## Add Support for a New Platform + +If you want to add support for a new platform (e.g., "tpu" for a new accelerator), follow these steps: + +1. **Extend the marker list in your pytest config** so that platform/resource markers are defined: + ```toml + # In pyproject.toml or pytest.ini + [tool.pytest.ini_options] + markers = [ + # ... existing markers ... + "tpu: Tests that require TPU device", + "TPU_V3: Tests that require TPU v3 hardware", + "distributed_tpu: Tests that require multiple TPU devices", + ] + ``` +2. **Implement a marker construction function for your platform** in `vllm-omni/tests/utils.py`: + ```python + # In vllm-omni/tests/utils.py + + def tpu_marks(*, res: str, num_cards: int): + test_platform = pytest.mark.tpu + if res == "TPU_V3": + test_resource = pytest.mark.TPU_V3 + else: + raise ValueError( + f"Invalid TPU resource type: {res}. Supported: TPU_V3") + + if num_cards == 1: + return [test_platform, test_resource] + else: + test_distributed = pytest.mark.distributed_tpu(num_cards=num_cards) + # Optionally: add skipif_tpu when implemented + return [test_platform, test_resource, test_distributed] + ``` +3. **Update `hardware_test` to recognize your new platform**: + In the relevant place (see the `hardware_test` implementation), add: + ```python + if platform == "tpu": + marks = tpu_marks(res=resource, num_cards=cards) + ``` +4. **(Recommended) Add a test using your new markers**: + ```python + @hardware_test( + res={"tpu": "TPU_V3"}, + num_cards=2, + ) + def test_my_tpu_feature(): + ... + ``` + +**Summary**: +- Add pytest markers for your new platform/resources +- Implement a marker function (`xxx_marks`) +- Plug into `hardware_test` +- You're done: tests decorated with `@hardware_test` using your platform now automatically get the correct markers, distribution, and isolation! + +See code in `vllm-omni/tests/utils.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`). diff --git a/docs/contributing/tests/tests_style.md b/docs/contributing/ci/tests_style.md similarity index 94% rename from docs/contributing/tests/tests_style.md rename to docs/contributing/ci/tests_style.md index c88e17dee34..65c2b044346 100644 --- a/docs/contributing/tests/tests_style.md +++ b/docs/contributing/ci/tests_style.md @@ -139,7 +139,7 @@ vllm_omni/ tests/ 4. **Documentation**: Add docstrings to all test functions 5. **Environment variables**: Set uniformly in `conftest.py` or at the top of files 6. **Type annotations**: Add type annotations to all test function parameters -7. **Resources**, Using pytest tag to specify the computation resources the test required. +7. **Pytest Markers**: Add necessary markers like `@pytest.mark.core_model` and use `@hardware_test` to declare hardware requirements (check detailed in [Markers for Tests](../ci/tests_markers.md)). ### Template #### E2E - Online serving @@ -155,6 +155,7 @@ from pathlib import Path import pytest import openai +from tests.utils import hardware_test # Optional: set process start method for workers os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -184,6 +185,12 @@ def base64_encoded_video() -> str: def dummy_messages_from_video_data(video_data_url: str, content_text: str) -> str: xxx +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards={"cuda": 2, "rocm": 2, "npu": 4}, +) @pytest.mark.parametrize("omni_server", test_params, indirect=True) def test_video_to_audio( client: openai.OpenAI, @@ -226,6 +233,7 @@ from pathlib import Path import pytest from vllm.assets.video import VideoAsset +from tests.utils import hardware_test from ..multi_stages.conftest import OmniRunner # Optional: set process start method for workers @@ -239,7 +247,12 @@ test_params = [(model, stage_config) for model in models for stage_config in sta # function name: test_{input_modality}_to_{output_modality} # modality candidate: text, image, audio, video, mixed_modalities -@pytest.mark.gpu_mem_high # requires high-memory GPU node +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards=2, +) @pytest.mark.parametrize("test_config", test_params) def test_video_to_audio(omni_runner: type[OmniRunner], model: str) -> None: """Offline inference: video input, audio output.""" @@ -263,4 +276,5 @@ def test_video_to_audio(omni_runner: type[OmniRunner], model: str) -> None: 1. The file is saved in an appropriate place and the file name is clear. 2. The coding style follows the requirements outlined above. -3. For e2e model test, please ensure the test is configured under the `./buildkite/` folder. +3. **All test functions have appropriate pytest markers** +4. For tests that need run in CI, please ensure the test is configured under the `./buildkite/` folder. diff --git a/docs/contributing/model/adding_diffusion_model.md b/docs/contributing/model/adding_diffusion_model.md index 70fdc6a0817..7eb56d5f5bc 100644 --- a/docs/contributing/model/adding_diffusion_model.md +++ b/docs/contributing/model/adding_diffusion_model.md @@ -140,7 +140,7 @@ Key point for writing the example: + Save or display the generated results so users can validate the integration. # Testing -For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../tests/tests_style.md). +For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../ci/tests_style.md). ## Adding a Model Recipe diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md index 2a91a305091..81499118623 100644 --- a/docs/contributing/model/adding_omni_model.md +++ b/docs/contributing/model/adding_omni_model.md @@ -572,7 +572,7 @@ def talker2code2wav( ## Testing -For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../tests/tests_style.md). +For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../ci/tests_style.md). ## Adding a Model Recipe diff --git a/docs/design/architecture_overview.md b/docs/design/architecture_overview.md index ea7eff4397d..6793895cd46 100644 --- a/docs/design/architecture_overview.md +++ b/docs/design/architecture_overview.md @@ -64,13 +64,13 @@ According to analysis for current popular open-source models, most of them have ## Key Components -| Component | Description | -|-----------|-------------| -| **OmniRouter** | provide an intelligent router for Omni-modality requests dispatch | -| **EntryPoints** | define the APIs for offline/online serving (APIServer, Omni/AsyncOmni) and provide the OmniStage abstraction for different AR/DiT stages | -| **AR** | adapted for omni-modality models while inheriting efficient features from vLLM, such as cache management | -| **Diffusion** | natively implemented and optimized using acceleration components | -| **OmniConnector** | supports fully disaggregation based on E/P/D/G (Encoding/Processing/Decoding/Generation) disaggregation across stages | +| Component | Description | +| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------- | +| **OmniRouter** | provide an intelligent router for Omni-modality requests dispatch | +| **EntryPoints** | define the APIs for offline/online serving (APIServer, Omni/AsyncOmni) and provide the OmniStage abstraction for different AR/DiT stages | +| **AR** | adapted for omni-modality models while inheriting efficient features from vLLM, such as cache management | +| **Diffusion** | natively implemented and optimized using acceleration components | +| **OmniConnector** | supports fully disaggregation based on E/P/D/G (Encoding/Processing/Decoding/Generation) disaggregation across stages | Disaggregated stages are managed through configuration, such as in the Qwen3-Omni example, where stages like Thinker, Talker, and Code2wav are defined as separate OmniStage instances with specific resources and input/output type. @@ -192,4 +192,4 @@ curl -sS -X POST http://localhost:8091/v1/chat/completions \ } ``` -For more usages, please refer to [examples](../user_guide/examples/). +For more usages, please refer to [examples](https://github.com/vllm-project/vllm-omni/tree/main/examples). diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 540b1852947..c073c152ee9 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -58,11 +58,11 @@ If you want to check, modify or debug with source code of vLLM, install the libr ```bash git clone https://github.com/vllm-project/vllm.git cd vllm -git checkout v0.12.0 +git checkout v0.14.0rc2 ``` Set up environment variables to get pre-built wheels. If there are internet problems, just download the whl file manually. And set `VLLM_PRECOMPILED_WHEEL_LOCATION` as your local absolute path of whl file. ```bash -export VLLM_PRECOMPILED_WHEEL_LOCATION=https://github.com/vllm-project/vllm/releases/download/v0.12.0/vllm-0.12.0-cp38-abi3-manylinux_2_31_x86_64.whl +export VLLM_PRECOMPILED_WHEEL_LOCATION=https://github.com/vllm-project/vllm/releases/download/v0.14.0/vllm-0.14.0rc2-cp38-abi3-manylinux_2_31_x86_64.whl ``` Install vllm with command below (If you have no existing PyTorch). ```bash @@ -93,7 +93,7 @@ docker run --runtime nvidia --gpus 2 \ --env "HF_TOKEN=$HF_TOKEN" \ -p 8091:8091 \ --ipc=host \ - vllm/vllm-omni:v0.12.0rc1 \ + vllm/vllm-omni:v0.14.0rc2 \ --model Qwen/Qwen3-Omni-30B-A3B-Instruct --port 8091 ``` diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index 1fa751e2508..638c689d4be 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -68,7 +68,7 @@ docker run -it \ -v :/app/model \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=$HF_TOKEN" \ - vllm/vllm-omni-rocm:v0.12.0rc1 \ + vllm/vllm-omni-rocm:v0.14.0rc2 \ vllm serve --model Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 ``` @@ -86,7 +86,7 @@ docker run -it \ -v :/app/model \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=$HF_TOKEN" \ - vllm/vllm-omni-rocm:v0.12.0rc1 \ + vllm/vllm-omni-rocm:v0.14.0rc2 \ bash ``` diff --git a/docs/getting_started/installation/npu/npu.inc.md b/docs/getting_started/installation/npu/npu.inc.md index 9c36c2be626..5714449a70a 100644 --- a/docs/getting_started/installation/npu/npu.inc.md +++ b/docs/getting_started/installation/npu/npu.inc.md @@ -13,10 +13,10 @@ export DEVICE0=/dev/davinci0 export DEVICE1=/dev/davinci1 # Update the vllm-ascend image # Atlas A2: -# export IMAGE=quay.io/ascend/vllm-ascend:v0.12.0rc1 +# export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0rc2 # Atlas A3: -# export IMAGE=quay.io/ascend/vllm-ascend:v0.12.0rc1-a3 -export IMAGE=quay.io/ascend/vllm-ascend:v0.12.0rc1 +# export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0rc2-a3 +export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0rc2 docker run --rm \ --name vllm-omni-npu \ --shm-size=1g \ @@ -42,7 +42,7 @@ source ~/.bashrc # Inside the container, install vLLM-Omni from source cd /vllm-workspace -git clone -b v0.12.0rc1 https://github.com/vllm-project/vllm-omni.git +git clone -b v0.14.0rc2 https://github.com/vllm-project/vllm-omni.git cd vllm-omni pip install -v -e . export VLLM_WORKER_MULTIPROC_METHOD=spawn diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 1b993152837..cd70019de32 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -32,6 +32,7 @@ th { |`LongcatImagePipeline` | LongCat-Image | `meituan-longcat/LongCat-Image` | |`LongCatImageEditPipeline` | LongCat-Image-Edit | `meituan-longcat/LongCat-Image-Edit` | |`StableDiffusion3Pipeline` | Stable-Diffusion-3 | `stabilityai/stable-diffusion-3.5-medium` | +|`Flux2KleinPipeline` | FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B`, `black-forest-labs/FLUX.2-klein-9B` | |`StableAudioPipeline` | Stable-Audio-Open | `stabilityai/stable-audio-open-1.0` | diff --git a/docs/user_guide/acceleration/cache_dit_acceleration.md b/docs/user_guide/diffusion/cache_dit_acceleration.md similarity index 100% rename from docs/user_guide/acceleration/cache_dit_acceleration.md rename to docs/user_guide/diffusion/cache_dit_acceleration.md diff --git a/docs/features/cpu_offload_diffusion.md b/docs/user_guide/diffusion/cpu_offload_diffusion.md similarity index 93% rename from docs/features/cpu_offload_diffusion.md rename to docs/user_guide/diffusion/cpu_offload_diffusion.md index aaa4243a3a2..533b6b3b964 100644 --- a/docs/features/cpu_offload_diffusion.md +++ b/docs/user_guide/diffusion/cpu_offload_diffusion.md @@ -23,7 +23,7 @@ if __name__ == "__main__": m = Omni(model="Qwen/Qwen-Image",enable_cpu_offload=True) ``` -- **CLI**: pass `--dit-cpu-offload` to the diffusion service entrypoint. +- **CLI**: pass `--enable-cpu-offload` to the diffusion service entrypoint. ## Known Limitations - Cold start latency increases for over one minute for some models(e.g., Qwen-Image) diff --git a/docs/user_guide/acceleration/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md similarity index 95% rename from docs/user_guide/acceleration/parallelism_acceleration.md rename to docs/user_guide/diffusion/parallelism_acceleration.md index dfacd2183ff..324301158d8 100644 --- a/docs/user_guide/acceleration/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -23,13 +23,23 @@ The following table shows which models are currently supported by parallelism me | **LongCat-Image** | `meituan-longcat/LongCat-Image` | βœ… | βœ… | ❌ | ❌ | | **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | βœ… | βœ… | ❌ | ❌ | | **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | -| **Qwen-Image** | `Qwen/Qwen-Image` | βœ… | βœ… | βœ… | ❌ | +| **Qwen-Image** | `Qwen/Qwen-Image` | βœ… | βœ… | βœ… | βœ… | | **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | βœ… | βœ… | βœ… | ❌ | | **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | βœ… | βœ… | βœ… | ❌ | | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | βœ… | βœ… | βœ… | ❌ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ❌ | ❌ | βœ… (TP=2 only) | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | + +!!! note "TP Limitations for Diffusion Models" + We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. + + - Good news: The text_encoder typically has minimal impact on overall inference performance. + - Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste. + + We are actively refactoring this design to address this. For details and progress, please refer to [Issue #771](https://github.com/vllm-project/vllm-omni/issues/771). + + !!! note "Why Z-Image is TP=2 only" Z-Image Turbo is currently limited to `tensor_parallel_size` of **1 or 2** due to model shape divisibility constraints. For example, the model has `n_heads=30` and a final projection out dimension of `64`, so valid TP sizes must divide both 30 and 64; the only common divisors are **1 and 2**. diff --git a/docs/user_guide/acceleration/teacache.md b/docs/user_guide/diffusion/teacache.md similarity index 100% rename from docs/user_guide/acceleration/teacache.md rename to docs/user_guide/diffusion/teacache.md diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 8f78ae32e50..0184a8fcd39 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -6,8 +6,8 @@ vLLM-Omni supports various cache acceleration methods to speed up diffusion mode vLLM-Omni currently supports two main cache acceleration backends: -1. **[TeaCache](acceleration/teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar -2. **[Cache-DiT](acceleration/cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques: +1. **[TeaCache](diffusion/teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar +2. **[Cache-DiT](diffusion/cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques: - **DBCache** (Dual Block Cache): Caches intermediate transformer block outputs based on residual differences - **TaylorSeer**: Uses Taylor expansion-based forecasting for faster inference - **SCM** (Step Computation Masking): Selectively computes steps based on adaptive masking @@ -16,11 +16,11 @@ Both methods can provide significant speedups (typically **1.5x-2.0x**) while ma vLLM-Omni also supports parallelism methods for diffusion models, including: -1. [Ulysses-SP](acceleration/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. +1. [Ulysses-SP](diffusion/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. -2. [Ring-Attention](acceleration/parallelism_acceleration.md#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded. +2. [Ring-Attention](diffusion/parallelism_acceleration.md#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded. -3. [CFG-Parallel](acceleration/parallelism_acceleration.md#cfg-parallel) - runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step. +3. [CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel) - runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step. ## Quick Comparison @@ -49,6 +49,7 @@ The following table shows which models are currently supported by each accelerat | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | βœ… | βœ… | βœ… | βœ… | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | βœ… |❌ | ❌ | ❌ | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | βœ… | ❌ | ❌ | ❌ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ❌ | βœ… | ❌ | ❌ | ❌ | ### VideoGen @@ -197,7 +198,7 @@ outputs = omni.generate(prompt="turn this cat to a dog", For detailed information on each acceleration method: -- **[TeaCache Guide](acceleration/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices -- **[Cache-DiT Acceleration Guide](acceleration/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters -- **[Sequence Parallelism](acceleration/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. -- **[CFG-Parallel](acceleration/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks. +- **[TeaCache Guide](diffusion/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices +- **[Cache-DiT Acceleration Guide](diffusion/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[Sequence Parallelism](diffusion/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. +- **[CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks. diff --git a/docs/user_guide/examples/online_serving/gradio_demo.md b/docs/user_guide/examples/online_serving/gradio_demo.md deleted file mode 100644 index 38278d9cf5a..00000000000 --- a/docs/user_guide/examples/online_serving/gradio_demo.md +++ /dev/null @@ -1,7 +0,0 @@ -# Gradio Demo - -Source . - -``````py ---8<-- "examples/online_serving/gradio_demo.py" -`````` diff --git a/docs/user_guide/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.md b/docs/user_guide/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.md deleted file mode 100644 index ca3fa8306b3..00000000000 --- a/docs/user_guide/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.md +++ /dev/null @@ -1,7 +0,0 @@ -# OpenAI Chat Completion Client For Multimodal Generation - -Source . - -``````py ---8<-- "examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py" -`````` diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 8f4dbeef98e..c31d098252b 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -55,7 +55,15 @@ --prompt "Edit description" \ --cfg_parallel_size 2 \ --num_inference_steps 50 \ - --cfg_scale 4.0 \ + --cfg_scale 4.0 + +Usage (disable torch.compile): + python image_edit.py \ + --image input.png \ + --prompt "Edit description" \ + --enforce_eager \ + --num_inference_steps 50 \ + --cfg_scale 4.0 For more options, run: python image_edit.py --help @@ -173,6 +181,12 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ring sequence parallelism.", ) + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", + ) parser.add_argument("--layers", type=int, default=4, help="Number of layers to decompose the input image into.") parser.add_argument( "--resolution", @@ -260,6 +274,11 @@ def parse_args() -> argparse.Namespace: choices=[1, 2], help="Number of GPUs used for classifier free guidance parallel size.", ) + parser.add_argument( + "--enforce_eager", + action="store_true", + help="Disable torch.compile and force eager execution.", + ) return parser.parse_args() @@ -288,7 +307,10 @@ def main(): vae_use_slicing = is_npu() vae_use_tiling = is_npu() parallel_config = DiffusionParallelConfig( - ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, cfg_parallel_size=args.cfg_parallel_size + ulysses_degree=args.ulysses_degree, + ring_degree=args.ring_degree, + cfg_parallel_size=args.cfg_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, ) # Configure cache based on backend type @@ -321,6 +343,7 @@ def main(): cache_backend=args.cache_backend, cache_config=cache_config, parallel_config=parallel_config, + enforce_eager=args.enforce_eager, ) print("Pipeline loaded") @@ -337,7 +360,7 @@ def main(): else: print(f" Input image size: {input_image.size}") print( - f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" + f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}" ) print(f"{'=' * 60}\n") diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py index d1e0264cd7a..22f2161c22b 100644 --- a/examples/offline_inference/qwen2_5_omni/end2end.py +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -278,9 +278,9 @@ def get_audio_query(question: str = None, audio_path: str | None = None, samplin query_map = { - "mixed_modalities": get_mixed_modalities_query, + "use_mixed_modalities": get_mixed_modalities_query, "use_audio_in_video": get_use_audio_in_video_query, - "multi_audios": get_multi_audios_query, + "use_multi_audios": get_multi_audios_query, "use_image": get_image_query, "use_video": get_video_query, "use_audio": get_audio_query, @@ -434,7 +434,7 @@ def parse_args(): "--query-type", "-q", type=str, - default="mixed_modalities", + default="use_mixed_modalities", choices=query_map.keys(), help="Query type.", ) diff --git a/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh index 26cbf0172c5..c8e4cd2cbf3 100644 --- a/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh +++ b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh @@ -1,2 +1,2 @@ python end2end.py --output-wav output_audio \ - --query-type mixed_modalities + --query-type use_mixed_modalities diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 9a1324305cf..3cd8918208e 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -12,8 +12,8 @@ import librosa import numpy as np import soundfile as sf -from PIL import Image import vllm +from PIL import Image from vllm import SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset @@ -226,19 +226,76 @@ def get_multi_audios_query() -> QueryResult: ) +# def get_use_audio_in_video_query(video_path: str | None = None) -> QueryResult: +# question = ( +# "Describe the content of the video in details, then convert what the " +# "baby say into text." +# ) +# prompt = ( +# f"<|im_start|>system\n{default_system}<|im_end|>\n" +# "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>" +# f"{question}<|im_end|>\n" +# f"<|im_start|>assistant\n" +# ) +# if video_path: +# if not os.path.exists(video_path): +# raise FileNotFoundError(f"Video file not found: {video_path}") +# video_frames = video_to_ndarrays(video_path, num_frames=16) +# else: +# video_frames = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays +# audio = extract_video_audio(video_path, sampling_rate=16000) +# return QueryResult( +# inputs={ +# "prompt": prompt, +# "multi_modal_data": { +# "video": video_frames, +# "audio": audio, +# }, +# "mm_processor_kwargs": { +# "use_audio_in_video": True, +# }, +# }, +# limit_mm_per_prompt={"audio": 1, "video": 1}, +# ) +def get_use_audio_in_video_query() -> QueryResult: + question = "Describe the content of the video in details, then convert what the baby say into text." + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + asset = VideoAsset(name="baby_reading", num_frames=16) + audio = asset.get_audio(sampling_rate=16000) + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "video": asset.np_ndarrays, + "audio": audio, + }, + "mm_processor_kwargs": { + "use_audio_in_video": True, + }, + }, + limit_mm_per_prompt={"audio": 1, "video": 1}, + ) + + query_map = { "text": get_text_query, "use_audio": get_audio_query, "use_image": get_image_query, "use_video": get_video_query, - "multi_audios": get_multi_audios_query, - "mixed_modalities": get_mixed_modalities_query, + "use_multi_audios": get_multi_audios_query, + "use_mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, } def main(args): model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct" - print(f"="*20,"\n",f"vllm version: {vllm.__version__}","\n","="*20) + print("=" * 20, "\n", f"vllm version: {vllm.__version__}", "\n", "=" * 20) # Get paths from args video_path = getattr(args, "video_path", None) @@ -261,6 +318,10 @@ def main(args): num_frames=getattr(args, "num_frames", 16), sampling_rate=getattr(args, "sampling_rate", 16000), ) + elif args.query_type == "multi_audios": + query_result = query_func() + elif args.query_type == "use_audio_in_video": + query_result = query_func() else: query_result = query_func() @@ -272,7 +333,7 @@ def main(args): ) thinker_sampling_params = SamplingParams( - temperature=0.4, + temperature=0.9, top_p=0.9, top_k=-1, max_tokens=1200, @@ -304,8 +365,8 @@ def main(args): sampling_params_list = [ thinker_sampling_params, - # talker_sampling_params, # code predictor is integrated into talker for Qwen3 Omni - # code2wav_sampling_params, + talker_sampling_params, # code predictor is integrated into talker for Qwen3 Omni + code2wav_sampling_params, ] if args.txt_prompts is None: @@ -333,6 +394,8 @@ def main(args): total_requests = len(prompts) processed_count = 0 + print(f"query type: {args.query_type}") + for stage_outputs in omni_generator: if stage_outputs.final_output_type == "text": for output in stage_outputs.request_output: @@ -387,7 +450,7 @@ def parse_args(): "--query-type", "-q", type=str, - default="mixed_modalities", + default="use_mixed_modalities", choices=query_map.keys(), help="Query type.", ) diff --git a/examples/online_serving/image_to_image/openai_chat_client.py b/examples/online_serving/image_to_image/openai_chat_client.py index 14bec8a3be4..0fe4b0edece 100644 --- a/examples/online_serving/image_to_image/openai_chat_client.py +++ b/examples/online_serving/image_to_image/openai_chat_client.py @@ -127,7 +127,7 @@ def main(): parser.add_argument("--width", type=int, default=1024, help="Output image width") parser.add_argument("--steps", type=int, default=50, help="Inference steps") parser.add_argument("--guidance", type=float, default=7.5, help="Guidance scale") - parser.add_argument("--seed", type=int, help="Random seed") + parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument("--negative", help="Negative prompt") args = parser.parse_args() diff --git a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py index 0d502063576..b9a76858161 100644 --- a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py +++ b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py @@ -304,6 +304,25 @@ def get_multi_audios_query(custom_prompt: str | None = None): ], } +def get_use_audio_in_video_query( + video_path: str | None = None, + audio_path: str | None = None, + custom_prompt: str | None = None, +): + question = custom_prompt or ( + "Describe the content of the video in details, then convert what the " + "baby say into text." + ) + video_url = get_video_url_from_path(video_path) + audio_url = get_audio_url_from_path(audio_path) + return { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": question}, + ], + } query_map = { "text": get_text_query, @@ -312,6 +331,7 @@ def get_multi_audios_query(custom_prompt: str | None = None): "use_video": get_video_query, "use_mixed_modalities": get_mixed_modalities_query, "use_multi_audios": get_multi_audios_query, + "use_audio_in_video": get_use_audio_in_video_query, } @@ -372,6 +392,12 @@ def run_multimodal_generation(args) -> None: prompt = query_func(audio_path=audio_path, custom_prompt=custom_prompt) elif args.query_type == "text": prompt = query_func(custom_prompt=custom_prompt) + elif args.query_type == "use_audio_in_video": + prompt = query_func( + video_path=video_path, + audio_path=audio_path, + custom_prompt=custom_prompt, + ) else: prompt = query_func() diff --git a/examples/online_serving/text_to_image/README.md b/examples/online_serving/text_to_image/README.md index a4f1ad63321..744b7b2921d 100644 --- a/examples/online_serving/text_to_image/README.md +++ b/examples/online_serving/text_to_image/README.md @@ -116,6 +116,7 @@ Use `extra_body` to pass generation parameters: | `seed` | int | None | Random seed (reproducible) | | `negative_prompt` | str | None | Negative prompt | | `num_outputs_per_prompt` | int | 1 | Number of images to generate | +| `--cfg-parallel-size`. | int | 1 | Number of GPUs for CFG parallelism | ## Response Format diff --git a/examples/online_serving/text_to_image/openai_chat_client.py b/examples/online_serving/text_to_image/openai_chat_client.py index c529bf203fd..39fa7dc22b7 100644 --- a/examples/online_serving/text_to_image/openai_chat_client.py +++ b/examples/online_serving/text_to_image/openai_chat_client.py @@ -100,7 +100,7 @@ def main(): parser.add_argument("--width", type=int, default=1024, help="Image width") parser.add_argument("--steps", type=int, default=50, help="Inference steps") parser.add_argument("--cfg-scale", type=float, default=4.0, help="True CFG scale") - parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument("--negative", help="Negative prompt") args = parser.parse_args() diff --git a/mkdocs.yml b/mkdocs.yml index 71cfe030569..1e8e38f5104 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -60,6 +60,10 @@ hooks: - docs/mkdocs/hooks/url_schemes.py - docs/mkdocs/hooks/generate_examples.py +# Exclude include files from navigation warnings +exclude_docs: | + **/*.inc.md + # Plugins plugins: - meta diff --git a/pyproject.toml b/pyproject.toml index 209a085bf87..2e2cddc5b7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ dev = [ "pytest-cov>=4.0.0", "mypy==1.11.1", "pre-commit==4.0.1", + "openai-whisper>=20250625", + "psutil>=7.2.0" ] docs = [ @@ -151,11 +153,34 @@ addopts = [ "--cov-report=xml", ] markers = [ - "unit: Unit tests", - "integration: Integration tests", + # ci/cd required + "core_model: Core model tests (run in each PR)", + # function module markers + "diffusion: Diffusion model tests", + "omni: Omni model tests", + "cache: Cache backend tests", + "parallel: Parallelism/distributed tests", + # platform markers + "cpu: Tests that run on CPU", + "gpu: Tests that run on GPU (auto-added)", + "cuda: Tests that run on CUDA (auto-added)", + "rocm: Tests that run on AMD/ROCm (auto-added)", + "npu: Tests that run on NPU/Ascend (auto-added)", + # specified computation resources marks (auto-added) + "H100: Tests that require H100 GPU", + "L4: Tests that require L4 GPU", + "MI325: Tests that require MI325 GPU (AMD/ROCm)", + "A2: Tests that require A2 NPU", + "A3: Tests that require A3 NPU", + "distributed_cuda: Tests that require multi cards on CUDA platform", + "distributed_rocm: Tests that require multi cards on ROCm platform", + "distributed_npu: Tests that require multi cards on NPU platform", + "skipif_cuda: Skip if the num of CUDA cards is less than the required", + "skipif_rocm: Skip if the num of ROCm cards is less than the required", + "skipif_npu: Skip if the num of NPU cards is less than the required", + # more detailed markers + "slow: Slow tests (may skip in quick CI)", "benchmark: Benchmark tests", - "slow: Slow tests", - "core_model: enable this model test in each PR instead of only nightly", ] [tool.typos.default] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 8fb4beb9755..00000000000 --- a/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -markers = - gpu_mem_high: needs high VRAM diff --git a/tests/conftest.py b/tests/conftest.py index 82c959f07ca..5b21f671bdb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,19 @@ +import base64 import os +import socket +import subprocess +import sys +import time +from pathlib import Path +from typing import Any +import psutil import pytest import torch +import whisper +import yaml from vllm.logger import init_logger +from vllm.utils import get_open_port logger = init_logger(__name__) @@ -34,3 +45,286 @@ def clean_gpu_memory_between_tests(): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + + +def dummy_messages_from_mix_data( + system_prompt: dict[str, Any] = None, + video_data_url: Any = None, + audio_data_url: Any = None, + image_data_url: Any = None, + content_text: str = None, +): + """Create messages with video、image、audio data URL for OpenAI API.""" + + if content_text is not None: + content = [{"type": "text", "text": content_text}] + else: + content = [] + + media_items = [] + if isinstance(video_data_url, list): + for video_url in video_data_url: + media_items.append((video_url, "video")) + else: + media_items.append((video_data_url, "video")) + + if isinstance(image_data_url, list): + for url in image_data_url: + media_items.append((url, "image")) + else: + media_items.append((image_data_url, "image")) + + if isinstance(audio_data_url, list): + for url in audio_data_url: + media_items.append((url, "audio")) + else: + media_items.append((audio_data_url, "audio")) + + content.extend( + {"type": f"{media_type}_url", f"{media_type}_url": {"url": url}} + for url, media_type in media_items + if url is not None + ) + messages = [{"role": "user", "content": content}] + if system_prompt is not None: + messages = [system_prompt] + messages + return messages + + +def cosine_similarity_text(s1, s2): + """ + Calculate cosine similarity between two text strings. + Notes: + ------ + - Higher score means more similar texts + - Score of 1.0 means identical word composition (bag-of-words) + - Score of 0.0 means completely different vocabulary + """ + from sklearn.feature_extraction.text import CountVectorizer + from sklearn.metrics.pairwise import cosine_similarity + + vectorizer = CountVectorizer().fit_transform([s1, s2]) + vectors = vectorizer.toarray() + return cosine_similarity([vectors[0]], [vectors[1]])[0][0] + + +def convert_audio_to_text(audio_data): + """ + Convert base64 encoded audio data to text using speech recognition. + """ + + audio_data = base64.b64decode(audio_data) + output_path = f"./test_{int(time.time())}" + with open(output_path, "wb") as audio_file: + audio_file.write(audio_data) + + print(f"audio data is saved: {output_path}") + model = whisper.load_model("base") + text = model.transcribe(output_path)["text"] + if text: + return text + else: + return "" + + +def modify_stage_config( + yaml_path: str, + stage_updates: dict[int, dict[str, Any]], +) -> str: + """ + Batch modify configurations for multiple stages in a YAML file. + + Args: + yaml_path: Path to the YAML configuration file. + stage_updates: Dictionary where keys are stage IDs and values are dictionaries of + modifications for that stage. Each modification dictionary uses + dot-separated paths as keys and new configuration values as values. + Example: { + 0: {'engine_args.max_model_len': 5800}, + 1: {'runtime.max_batch_size': 2} + } + + Returns: + str: Path to the newly created modified YAML file with timestamp suffix. + + Example: + >>> output_file = modify_stage_config( + ... 'config.yaml', + ... { + ... 0: {'engine_args.max_model_len': 5800}, + ... 1: {'runtime.max_batch_size': 2} + ... } + ... ) + >>> print(f"Modified configuration saved to: {output_file}") + Modified configuration saved to: config_1698765432.yaml + """ + path = Path(yaml_path) + if not path.exists(): + raise FileNotFoundError(f"yaml does not exist: {path}") + try: + with open(yaml_path, encoding="utf-8") as f: + config = yaml.safe_load(f) or {} + except Exception as e: + raise ValueError(f"Cannot parse YAML file: {e}") + + stage_args = config.get("stage_args", []) + if not stage_args: + raise ValueError("the stage_args does not exist") + + for stage_id, config_dict in stage_updates.items(): + target_stage = None + for stage in stage_args: + if stage.get("stage_id") == stage_id: + target_stage = stage + break + + if target_stage is None: + available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s] + raise KeyError(f"Stage ID {stage_id} is not exist, available IDs: {available_ids}") + + for key_path, value in config_dict.items(): + current = target_stage + keys = key_path.split(".") + for i in range(len(keys) - 1): + key = keys[i] + if key not in current: + raise KeyError(f"the {'.'.join(keys[: i + 1])} does not exist") + + elif not isinstance(current[key], dict) and i < len(keys) - 2: + raise ValueError(f"{'.'.join(keys[: i + 1])}' cannot continue deeper because it's not a dict") + current = current[key] + current[keys[-1]] = value + + output_path = f"{yaml_path.split('.')[0]}_{int(time.time())}.yaml" + with open(output_path, "w", encoding="utf-8") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False, allow_unicode=True, indent=2) + + return output_path + + +class OmniServer: + """Omniserver for vLLM-Omni tests.""" + + def __init__( + self, + model: str, + serve_args: list[str], + *, + env_dict: dict[str, str] | None = None, + ) -> None: + self.model = model + self.serve_args = serve_args + self.env_dict = env_dict + self.proc: subprocess.Popen | None = None + self.host = "127.0.0.1" + self.port = get_open_port() + + def _start_server(self) -> None: + """Start the vLLM-Omni server subprocess.""" + env = os.environ.copy() + env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if self.env_dict is not None: + env.update(self.env_dict) + + cmd = [ + sys.executable, + "-m", + "vllm_omni.entrypoints.cli.main", + "serve", + self.model, + "--omni", + "--host", + self.host, + "--port", + str(self.port), + ] + self.serve_args + + print(f"Launching OmniServer with: {' '.join(cmd)}") + self.proc = subprocess.Popen( + cmd, + env=env, + cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root + ) + + # Wait for server to be ready + max_wait = 600 # 10 minutes + start_time = time.time() + while time.time() - start_time < max_wait: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1) + result = sock.connect_ex((self.host, self.port)) + if result == 0: + print(f"Server ready on {self.host}:{self.port}") + return + except Exception: + pass + time.sleep(2) + + raise RuntimeError(f"Server failed to start within {max_wait} seconds") + + def _kill_process_tree(self, pid): + """kill process and its children""" + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + for child in children: + try: + child.terminate() + except psutil.NoSuchProcess: + pass + + gone, still_alive = psutil.wait_procs(children, timeout=10) + + for child in still_alive: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + try: + parent.terminate() + parent.wait(timeout=10) + except (psutil.NoSuchProcess, psutil.TimeoutExpired): + try: + parent.kill() + except psutil.NoSuchProcess: + pass + + except psutil.NoSuchProcess: + pass + + def __enter__(self): + self._start_server() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.proc: + try: + parent = psutil.Process(self.proc.pid) + children = parent.children(recursive=True) + for child in children: + try: + child.terminate() + except psutil.NoSuchProcess: + pass + + gone, still_alive = psutil.wait_procs(children, timeout=10) + + for child in still_alive: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + try: + parent.terminate() + parent.wait(timeout=10) + except (psutil.NoSuchProcess, psutil.TimeoutExpired): + try: + parent.kill() + except psutil.NoSuchProcess: + pass + + except psutil.NoSuchProcess: + pass diff --git a/tests/diffusion/attention/test_flash_attn.py b/tests/diffusion/attention/test_flash_attn.py new file mode 100644 index 00000000000..3f3862405ed --- /dev/null +++ b/tests/diffusion/attention/test_flash_attn.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Test script for FlashAttention backend with padding handling. + +This script tests two main scenarios: +1. Case 1: Comparing padded vs unpadded inputs for batch_size=1 +2. Case 2: Comparing FlashAttention and SDPA backends for batch_size=2 with padding +""" + +import pytest +import torch + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl +from vllm_omni.diffusion.attention.backends.sdpa import SDPAImpl + + +def create_attention_mask(batch_size: int, seq_len: int, valid_len: int, device: torch.device) -> torch.Tensor: + """ + Create attention mask where first valid_len tokens are valid (1) and rest are padding (0). + + Args: + batch_size: Batch size + seq_len: Total sequence length (including padding) + valid_len: Number of valid (non-padded) tokens + + Returns: + Attention mask of shape (batch_size, seq_len) + """ + mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device) + mask[:, :valid_len] = True + return mask + + +def pad_tensor(tensor: torch.Tensor, target_seq_len: int, pad_value: float = 0.0) -> torch.Tensor: + """ + Pad tensor along sequence dimension (dim=1). + + Args: + tensor: Input tensor of shape (batch_size, seq_len, num_heads, head_dim) + target_seq_len: Target sequence length after padding + pad_value: Value to use for padding + + Returns: + Padded tensor of shape (batch_size, target_seq_len, num_heads, head_dim) + """ + batch_size, seq_len, num_heads, head_dim = tensor.shape + if target_seq_len <= seq_len: + return tensor + + padding = torch.full( + (batch_size, target_seq_len - seq_len, num_heads, head_dim), pad_value, dtype=tensor.dtype, device=tensor.device + ) + return torch.cat([tensor, padding], dim=1) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA") +def test_padding_equivalence(): + """ + Case 1: Test that padded and unpadded inputs produce similar outputs. + + - Input A: batch_size=1, hidden_states (1, 48), encoder_hidden_states (1, 16) + Concatenated length: 64, NO attention_mask + - Input B: Same data but padded: hidden_states (1, 58), encoder_hidden_states (1, 26) + Concatenated length: 84, WITH attention_mask + + Expected: Output A and Output B should be very close. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + + # Configuration + batch_size = 1 + hidden_seq_len = 48 + encoder_seq_len = 16 + pad_length = 10 + num_heads = 8 + head_dim = 64 + + # Initialize FlashAttention + fa_impl = FlashAttentionImpl( + num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False + ) + + # Create base tensors with random values (same for both A and B) + torch.manual_seed(42) + hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype) + encoder_hidden_states_base = torch.randn( + batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + + # ========== Input A: Unpadded, no attention mask ========== + query_a = torch.cat([hidden_states_base, encoder_hidden_states_base], dim=1) + key_a = query_a.clone() + value_a = query_a.clone() + + attn_metadata_a = AttentionMetadata(attn_mask=None) + + output_a = fa_impl.forward(query=query_a, key=key_a, value=value_a, attn_metadata=attn_metadata_a) + + # ========== Input B: Padded with attention mask ========== + hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length) + encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length) + + query_b = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1) + key_b = query_b.clone() + value_b = query_b.clone() + + # Create attention mask + attn_mask_b = torch.cat( + [ + create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device), + create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device), + ], + dim=1, + ) + + attn_metadata_b = AttentionMetadata(attn_mask=attn_mask_b) + + output_b = fa_impl.forward(query=query_b, key=key_b, value=value_b, attn_metadata=attn_metadata_b) + + # Extract non-padded portion from output_b + output_b_unpadded = torch.cat( + [ + output_b[:, :hidden_seq_len, :, :], + output_b[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :], + ], + dim=1, + ) + + # Compare outputs + max_diff = torch.max(torch.abs(output_a - output_b_unpadded)).item() + mean_diff = torch.mean(torch.abs(output_a - output_b_unpadded)).item() + + print("\n=== Case 1: Padding Equivalence Test ===") + print(f"Output A shape: {output_a.shape}") + print(f"Output B shape: {output_b.shape}") + print(f"Output B unpadded shape: {output_b_unpadded.shape}") + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + + # Assert that outputs are close + # Using higher tolerance for bfloat16 + assert max_diff < 0.1, f"Max difference {max_diff} exceeds threshold 0.1" + assert mean_diff < 0.01, f"Mean difference {mean_diff} exceeds threshold 0.01" + + print("βœ“ Case 1 PASSED: Padded and unpadded outputs are very close!") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA") +def test_fa_vs_sdpa(): + """ + Case 2: Compare FlashAttention and SDPA backends with padding. + + - batch_size=2 + - hidden_states: (2, 48) padded to (2, 58) + - encoder_hidden_states: (2, 16) padded to (2, 26) + - Concatenated length: 84 + - Compare FA and SDPA outputs + + Expected: FA and SDPA outputs should be very close. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + + # Configuration + batch_size = 2 + hidden_seq_len = 48 + encoder_seq_len = 16 + pad_length = 10 + num_heads = 8 + head_dim = 64 + + # Initialize both backends + fa_impl = FlashAttentionImpl( + num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False + ) + + sdpa_impl = SDPAImpl(num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False) + + # Create base tensors + torch.manual_seed(123) + hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype) + encoder_hidden_states_base = torch.randn( + batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + + # Pad tensors + hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length) + encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length) + + # Concatenate + query = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1) + key = query.clone() + value = query.clone() + + # Create attention mask + attn_mask = torch.cat( + [ + create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device), + create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device), + ], + dim=1, + ) + + attn_metadata = AttentionMetadata(attn_mask=attn_mask) + + # Run FlashAttention + output_fa = fa_impl.forward(query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata) + + # Run SDPA + # SDPA expects 4D attention mask: (batch_size, 1, seq_len, seq_len) or (batch_size, seq_len) + # For causal=False, we need to convert 2D mask to 4D + if attn_mask is not None: + # Expand mask for SDPA: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) + attn_mask_4d = attn_mask.unsqueeze(1).unsqueeze(2) + # Convert bool to float: True -> 0.0, False -> -inf + attn_mask_float = torch.zeros_like(attn_mask_4d, dtype=dtype) + attn_mask_float.masked_fill_(~attn_mask_4d, float("-inf")) + attn_metadata_sdpa = AttentionMetadata(attn_mask=attn_mask_float) + else: + attn_metadata_sdpa = AttentionMetadata(attn_mask=None) + + output_sdpa = sdpa_impl.forward( + query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata_sdpa + ) + + # Compare outputs (only compare valid regions) + output_fa_valid = torch.cat( + [ + output_fa[:, :hidden_seq_len, :, :], + output_fa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :], + ], + dim=1, + ) + output_sdpa_valid = torch.cat( + [ + output_sdpa[:, :hidden_seq_len, :, :], + output_sdpa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :], + ], + dim=1, + ) + + max_diff = torch.max(torch.abs(output_fa_valid - output_sdpa_valid)).item() + mean_diff = torch.mean(torch.abs(output_fa_valid - output_sdpa_valid)).item() + + print("\n=== Case 2: FA vs SDPA Comparison ===") + print(f"Batch size: {batch_size}") + print(f"FA output shape: {output_fa.shape}") + print(f"SDPA output shape: {output_sdpa.shape}") + print(f"Max absolute difference (valid region): {max_diff:.6f}") + print(f"Mean absolute difference (valid region): {mean_diff:.6f}") + + # Assert that outputs are close + # Using higher tolerance for bfloat16 and different implementations + assert max_diff < 0.01, f"Max difference {max_diff} exceeds threshold 0.01" + assert mean_diff < 0.001, f"Mean difference {mean_diff} exceeds threshold 0.001" + + print("βœ“ Case 2 PASSED: FA and SDPA outputs are very close!") + + +if __name__ == "__main__": + print("Running FlashAttention Padding Tests...") + print("=" * 60) + + # Try to run CUDA tests + if torch.cuda.is_available(): + try: + print("\n[Running Case 1: Padding Equivalence for FA]") + test_padding_equivalence() + except Exception as e: + print(f"βœ— Case 1 failed: {e}") + import traceback + + traceback.print_exc() + + try: + print("\n[Running Case 2: FA vs SDPA]") + test_fa_vs_sdpa() + except Exception as e: + print(f"βœ— Case 2 failed: {e}") + import traceback + + traceback.print_exc() + else: + raise RuntimeError("CUDA is not available") + print("\n" + "=" * 60) + print("Test suite completed!") diff --git a/tests/diffusion/test_gpu_worker.py b/tests/diffusion/test_gpu_diffusion_worker.py similarity index 81% rename from tests/diffusion/test_gpu_worker.py rename to tests/diffusion/test_gpu_diffusion_worker.py index defeffe5b56..7a43710c878 100644 --- a/tests/diffusion/test_gpu_worker.py +++ b/tests/diffusion/test_gpu_diffusion_worker.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Unit tests for GPUWorker class. +Unit tests for GPUDiffusionWorker class. -This module tests the GPUWorker implementation: +This module tests the GPUDiffusionWorker implementation: - load_weights: Loading model weights - sleep: Putting worker into sleep mode (levels 1 and 2) - wake_up: Waking worker from sleep mode @@ -15,7 +15,7 @@ import pytest import torch -from vllm_omni.diffusion.worker.gpu_worker import GPUWorker +from vllm_omni.diffusion.worker.gpu_diffusion_worker import GPUDiffusionWorker @pytest.fixture @@ -33,20 +33,21 @@ def mock_od_config(): @pytest.fixture def mock_gpu_worker(mock_od_config): - """Create a GPUWorker with mocked initialization.""" - with patch.object(GPUWorker, "init_device_and_model"): - worker = GPUWorker(local_rank=0, rank=0, od_config=mock_od_config) - # Mock the pipeline - worker.pipeline = Mock() - worker.cache_backend = None + """Create a GPUDiffusionWorker with mocked initialization.""" + with patch.object(GPUDiffusionWorker, "init_device"): + worker = GPUDiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config) + # Mock the model_runner with pipeline + worker.model_runner = Mock() + worker.model_runner.pipeline = Mock() + worker._sleep_saved_buffers = {} return worker -class TestGPUWorkerLoadWeights: - """Test GPUWorker.load_weights method.""" +class TestGPUDiffusionWorkerLoadWeights: + """Test GPUDiffusionWorker.load_weights method.""" def test_load_weights_calls_pipeline(self, mock_gpu_worker): - """Test that load_weights delegates to pipeline.load_weights.""" + """Test that load_weights delegates to model_runner.load_weights.""" # Setup mock weights mock_weights = [ ("layer1.weight", torch.randn(10, 10)), @@ -54,30 +55,30 @@ def test_load_weights_calls_pipeline(self, mock_gpu_worker): ] expected_loaded = {"layer1.weight", "layer2.weight"} - # Configure pipeline mock - mock_gpu_worker.pipeline.load_weights = Mock(return_value=expected_loaded) + # Configure model_runner mock + mock_gpu_worker.model_runner.load_weights = Mock(return_value=expected_loaded) # Call load_weights result = mock_gpu_worker.load_weights(mock_weights) - # Verify pipeline.load_weights was called with the weights - mock_gpu_worker.pipeline.load_weights.assert_called_once_with(mock_weights) + # Verify model_runner.load_weights was called with the weights + mock_gpu_worker.model_runner.load_weights.assert_called_once_with(mock_weights) assert result == expected_loaded def test_load_weights_empty_iterable(self, mock_gpu_worker): """Test load_weights with empty weights iterable.""" - mock_gpu_worker.pipeline.load_weights = Mock(return_value=set()) + mock_gpu_worker.model_runner.load_weights = Mock(return_value=set()) result = mock_gpu_worker.load_weights([]) - mock_gpu_worker.pipeline.load_weights.assert_called_once_with([]) + mock_gpu_worker.model_runner.load_weights.assert_called_once_with([]) assert result == set() -class TestGPUWorkerSleep: - """Test GPUWorker.sleep method.""" +class TestGPUDiffusionWorkerSleep: + """Test GPUDiffusionWorker.sleep method.""" - @patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info") + @patch("vllm_omni.diffusion.worker.gpu_diffusion_worker.torch.cuda.mem_get_info") @patch("vllm.device_allocator.cumem.CuMemAllocator") def test_sleep_level_1(self, mock_allocator_class, mock_mem_info, mock_gpu_worker): """Test sleep mode level 1 (offload weights only).""" @@ -103,7 +104,7 @@ def test_sleep_level_1(self, mock_allocator_class, mock_mem_info, mock_gpu_worke # Verify buffers were NOT saved (level 1 doesn't save buffers) assert len(mock_gpu_worker._sleep_saved_buffers) == 0 - @patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info") + @patch("vllm_omni.diffusion.worker.gpu_diffusion_worker.torch.cuda.mem_get_info") @patch("vllm.device_allocator.cumem.CuMemAllocator") def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worker): """Test sleep mode level 2 (offload all, save buffers).""" @@ -121,7 +122,7 @@ def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worke # Mock pipeline buffers mock_buffer1 = torch.randn(10, 10) mock_buffer2 = torch.randn(20, 20) - mock_gpu_worker.pipeline.named_buffers = Mock( + mock_gpu_worker.model_runner.pipeline.named_buffers = Mock( return_value=[ ("buffer1", mock_buffer1), ("buffer2", mock_buffer2), @@ -140,7 +141,7 @@ def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worke assert "buffer1" in mock_gpu_worker._sleep_saved_buffers assert "buffer2" in mock_gpu_worker._sleep_saved_buffers - @patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info") + @patch("vllm_omni.diffusion.worker.gpu_diffusion_worker.torch.cuda.mem_get_info") @patch("vllm.device_allocator.cumem.CuMemAllocator") def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_mem_info, mock_gpu_worker): """Test that sleep validates memory was actually freed.""" @@ -159,8 +160,8 @@ def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_mem_info mock_gpu_worker.sleep(level=1) -class TestGPUWorkerWakeUp: - """Test GPUWorker.wake_up method.""" +class TestGPUDiffusionWorkerWakeUp: + """Test GPUDiffusionWorker.wake_up method.""" @patch("vllm.device_allocator.cumem.CuMemAllocator") def test_wake_up_without_buffers(self, mock_allocator_class, mock_gpu_worker): @@ -202,7 +203,7 @@ def test_wake_up_with_buffers(self, mock_allocator_class, mock_gpu_worker): mock_buffer2 = Mock() mock_buffer2.data = Mock() - mock_gpu_worker.pipeline.named_buffers = Mock( + mock_gpu_worker.model_runner.pipeline.named_buffers = Mock( return_value=[ ("buffer1", mock_buffer1), ("buffer2", mock_buffer2), @@ -243,7 +244,7 @@ def test_wake_up_partial_buffer_restore(self, mock_allocator_class, mock_gpu_wor mock_buffer2 = Mock() mock_buffer2.data = Mock() - mock_gpu_worker.pipeline.named_buffers = Mock( + mock_gpu_worker.model_runner.pipeline.named_buffers = Mock( return_value=[ ("buffer1", mock_buffer1), ("buffer2", mock_buffer2), diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py index 0066d49b161..cefda891571 100644 --- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py +++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py @@ -1,11 +1,10 @@ import sys -import threading -import time from pathlib import Path import pytest import torch +from tests.utils import GPUMemoryMonitor from vllm_omni.utils.platform_utils import is_npu, is_rocm # ruff: noqa: E402 @@ -15,39 +14,6 @@ from vllm_omni import Omni - -class GPUMemoryMonitor: - """Poll global device memory usage via CUDA APIs.""" - - def __init__(self, device_index: int, interval: float = 0.05): - self.device_index = device_index - self.interval = interval - self.peak_used_mb = 0.0 - self._stop_event = threading.Event() - self._thread: threading.Thread | None = None - - def start(self) -> None: - def monitor_loop() -> None: - while not self._stop_event.is_set(): - try: - with torch.cuda.device(self.device_index): - free_bytes, total_bytes = torch.cuda.mem_get_info() - used_mb = (total_bytes - free_bytes) / (1024**2) - self.peak_used_mb = max(self.peak_used_mb, used_mb) - except Exception: - pass - time.sleep(self.interval) - - self._thread = threading.Thread(target=monitor_loop, daemon=True) - self._thread.start() - - def stop(self) -> None: - if self._thread is None: - return - self._stop_event.set() - self._thread.join(timeout=2.0) - - models = ["riverclouds/qwen_image_random"] @@ -73,13 +39,7 @@ def inference(offload: bool = True): generator=torch.Generator("cuda").manual_seed(42), ) - monitor.stop() - torch.cuda.synchronize(device_index) - fallback_alloc = torch.cuda.max_memory_allocated(device=device_index) / (1024**2) - fallback_reserved = torch.cuda.max_memory_reserved(device=device_index) / (1024**2) - peak_memory_mb = max(monitor.peak_used_mb, fallback_alloc, fallback_reserved) - - return peak_memory_mb + return monitor.peak_used_mb offload_peak_memory = inference(offload=True) no_offload_peak_memory = inference(offload=False) diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py index d32bb2b8223..60686992278 100644 --- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py @@ -17,6 +17,7 @@ if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) +from tests.utils import GPUMemoryMonitor from vllm_omni import Omni from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.outputs import OmniRequestOutput @@ -66,7 +67,12 @@ def _extract_single_image(outputs) -> Image.Image: def _run_zimage_generate( *, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int -) -> tuple[Image.Image, float]: +) -> tuple[Image.Image, float, float]: + torch.cuda.empty_cache() + device_index = torch.cuda.current_device() + monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + monitor.start() + m = Omni( model=_get_zimage_model(), parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size), @@ -107,7 +113,10 @@ def _run_zimage_generate( pass median_time_s = float(np.median(per_request_times_s)) - return _extract_single_image([last_output]), median_time_s + + peak_memory_mb = monitor.peak_used_mb + + return _extract_single_image([last_output]), median_time_s, peak_memory_mb finally: m.close() cleanup_dist_env_and_memory() @@ -125,14 +134,14 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): num_inference_steps = 2 seed = 42 - tp1_img, tp1_time_s = _run_zimage_generate( + tp1_img, tp1_time_s, tp1_peak_mem = _run_zimage_generate( tp_size=1, height=height, width=width, num_inference_steps=num_inference_steps, seed=seed, ) - tp2_img, tp2_time_s = _run_zimage_generate( + tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate( tp_size=2, height=height, width=width, @@ -164,3 +173,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}") assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})" + + print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}") + assert tp2_peak_mem < tp1_peak_mem, ( + f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})" + ) diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py new file mode 100644 index 00000000000..6a47e96f866 --- /dev/null +++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E Online tests for Qwen3-Omni model. +""" + +import concurrent.futures +import os +import time +from pathlib import Path + +import openai +import pytest + +from tests.conftest import ( + OmniServer, + convert_audio_to_text, + cosine_similarity_text, + dummy_messages_from_mix_data, + modify_stage_config, +) + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] + +# CI stage config for 2*H100-80G GPUs +stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")] + +# Create parameter combinations for model and stage config +test_params = [(model, stage_config) for model in models for stage_config in stage_configs] + + +def client(omni_server): + """OpenAI client for the running vLLM-Omni server.""" + return openai.OpenAI( + base_url=f"http://{omni_server.host}:{omni_server.port}/v1", + api_key="EMPTY", + ) + + +def get_system_prompt(): + return { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "You are Qwen, a virtual human developed by the Qwen Team, " + "Alibaba Group, capable of perceiving auditory and visual inputs, " + "as well as generating text and speech." + ), + } + ], + } + + +def get_prompt(prompt_type="text_only"): + prompts = { + "text_only": "What is the capital of China?", + "mix": "What is recited in the audio? What is in this image? Describe the video briefly.", + } + return prompts.get(prompt_type, prompts["text_only"]) + + +def get_max_batch_size(size_type="few"): + batch_sizes = {"few": 5, "medium": 100, "large": 256} + return batch_sizes.get(size_type, 5) + + +@pytest.mark.parametrize("test_config", test_params) +def test_text_to_text_001(test_config: tuple[str, str]) -> None: + """Test processing text, generating text output via OpenAI API.""" + model, stage_config_path = test_config + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server: + messages = dummy_messages_from_mix_data(system_prompt=get_system_prompt(), content_text=get_prompt()) + + # Test single completion + api_client = client(server) + start_time = time.perf_counter() + chat_completion = api_client.chat.completions.create( + model=server.model, messages=messages, max_tokens=20, modalities=["text"] + ) + # Verify E2E + print(f"the request e2e is: {time.perf_counter() - start_time}") + # TODO: Verify the E2E latency after confirmation baseline. + + # Verify only output text + assert len(chat_completion.choices) == 1, "The generated content includes more than just text." + + # Verify text output success + text_choice = chat_completion.choices[0] + assert text_choice.message.content is not None, "No text output is generated" + assert chat_completion.usage.completion_tokens <= 20, "The output length more than the requested max_tokens." + assert "beijing" in text_choice.message.content.lower(), "The output do not contain keywords." + + +@pytest.mark.parametrize("test_config", test_params) +def test_text_to_text_audio_001(test_config: tuple[str, str]) -> None: + """Test processing text, generating text and audio output via OpenAI API.""" + + model, stage_config_path = test_config + num_concurrent_requests = get_max_batch_size() + stage_config_path = modify_stage_config( + stage_config_path, + { + 0: {"runtime.max_batch_size": num_concurrent_requests}, + 1: {"runtime.max_batch_size": num_concurrent_requests}, + }, + ) + with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server: + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), content_text="What is the capital of China?" + ) + + # Test single completion + api_client = client(server) + e2e_list = list() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor: + # Submit multiple completion requests concurrently + futures = [ + executor.submit(api_client.chat.completions.create, model=server.model, messages=messages) + for _ in range(num_concurrent_requests) + ] + start_time = time.perf_counter() + # Wait for all requests to complete and collect results + chat_completions = list() + for future in concurrent.futures.as_completed(futures): + chat_completions.append(future.result()) + # Verify E2E + current_e2e = time.perf_counter() - start_time + print(f"the request e2e is: {current_e2e}") + # TODO: Verify the E2E latency after confirmation baseline. + e2e_list.append(current_e2e) + + print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}") + # Verify all completions succeeded + assert len(chat_completions) == num_concurrent_requests, "Not all requests succeeded." + for chat_completion in chat_completions: + # Verify audio output success + audio_message = chat_completion.choices[1].message + audio_data = audio_message.audio.data + assert audio_data is not None, "No audio output is generated" + assert audio_message.audio.expires_at > time.time(), "The generated audio has expired." + + # Verify text output success + text_choice = chat_completion.choices[0] + text_content = text_choice.message.content + assert text_choice.message.content is not None, "No text output is generated" + assert "beijing" in text_choice.message.content.lower(), "The output do not contain keywords." + + # Verify text output same as audio output + audio_content = convert_audio_to_text(audio_data) + print(f"text content is: {text_content}") + print(f"audio content is: {audio_content}") + assert cosine_similarity_text(audio_content.lower(), text_content.lower()) > 0.9, ( + "The audio content is not same as the text" + ) diff --git a/tests/e2e/stage_configs/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/qwen3_omni_ci.yaml new file mode 100644 index 00000000000..5106b185419 --- /dev/null +++ b/tests/e2e/stage_configs/qwen3_omni_ci.yaml @@ -0,0 +1,95 @@ +# Stage config for running Qwen3-Omni-MoE with 3-stage architecture +# Stage 0: Thinker (multimodal understanding + text generation) +# Stage 1: Talker (text embeddings β†’ 16-layer RVQ codec codes) +# Stage 2: Code2Wav (8-layer RVQ codes β†’ audio waveform) + +# The following config has been verified on 2x H100-80G GPUs. +stage_args: +- stage_id: 0 + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.6 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output hidden states for talker + distributed_executor_backend: "mp" + max_num_batched_tokens: 32768 + enable_prefix_caching: false + hf_config_name: thinker_config + tensor_parallel_size: 2 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 100 + seed: 42 + ignore_eos: False + detokenize: True + repetition_penalty: 1.05 + +- stage_id: 1 + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: talker + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.3 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # Output codec codes for code2wav + enable_prefix_caching: false + max_num_batched_tokens: 32768 + distributed_executor_backend: "mp" + hf_config_name: talker_config + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 100 + seed: 42 + detokenize: False + repetition_penalty: 1.05 + stop_token_ids: [2150] + +- stage_id: 2 + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3OmniMoeForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio # Final output: audio waveform + gpu_memory_utilization: 0.1 + distributed_executor_backend: "mp" + max_num_batched_tokens: 1000000 + hf_config_name: thinker_config + engine_input_source: [1] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 200 + seed: 42 + detokenize: True + repetition_penalty: 1.1 diff --git a/tests/utils.py b/tests/utils.py index aba734501eb..8e5593d6501 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,11 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +# Some functions are copied from vllm/tests/utils.py +import functools import os +import signal +import subprocess +import sys +import tempfile +import threading import time -from contextlib import contextmanager +from collections.abc import Callable +from contextlib import ExitStack, contextmanager, suppress +from typing import Any, Literal +import cloudpickle +import pytest +import torch +from typing_extensions import ParamSpec from vllm.platforms import current_platform +from vllm.utils.torch_utils import cuda_device_count_stateless + +_P = ParamSpec("_P") if current_platform.is_rocm(): from amdsmi import ( @@ -90,10 +105,16 @@ def wait_for_gpu_memory_to_clear( print("") if threshold_bytes is not None: - is_free = lambda used, total: used <= threshold_bytes / 2**30 # noqa E731 + + def is_free(used, total): + return used <= threshold_bytes / 2**30 # noqa E731 + threshold = f"{threshold_bytes / 2**30} GiB" else: - is_free = lambda used, total: used / total <= threshold_ratio # noqa E731 + + def is_free(used, total): + return used / total <= threshold_ratio # noqa E731 + threshold = f"{threshold_ratio:.2f}" dur_s = time.time() - start_time @@ -105,3 +126,394 @@ def wait_for_gpu_memory_to_clear( raise ValueError(f"Memory of devices {devices=} not free after {dur_s=:.02f} ({threshold=})") time.sleep(5) + + +def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to fork a new process for each test function. + See https://github.com/vllm-project/vllm/issues/7053 for more details. + """ + + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # Make the process the leader of its own process group + # to avoid sending SIGTERM to the parent process + os.setpgrp() + from _pytest.outcomes import Skipped + + # Create a unique temporary file to store exception info from child + # process. Use test function name and process ID to avoid collisions. + with ( + tempfile.NamedTemporaryFile( + delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc" + ) as exc_file, + ExitStack() as delete_after, + ): + exc_file_path = exc_file.name + delete_after.callback(os.remove, exc_file_path) + + pid = os.fork() + print(f"Fork a new process to run a test {pid}") + if pid == 0: + # Parent process responsible for deleting, don't delete + # in child. + delete_after.pop_all() + try: + func(*args, **kwargs) + except Skipped as e: + # convert Skipped to exit code 0 + print(str(e)) + os._exit(0) + except Exception as e: + import traceback + + tb_string = traceback.format_exc() + + # Try to serialize the exception object first + exc_to_serialize: dict[str, Any] + try: + # First, try to pickle the actual exception with + # its traceback. + exc_to_serialize = {"pickled_exception": e} + # Test if it can be pickled + cloudpickle.dumps(exc_to_serialize) + except (Exception, KeyboardInterrupt): + # Fall back to string-based approach. + exc_to_serialize = { + "exception_type": type(e).__name__, + "exception_msg": str(e), + "traceback": tb_string, + } + try: + with open(exc_file_path, "wb") as f: + cloudpickle.dump(exc_to_serialize, f) + except Exception: + # Fallback: just print the traceback. + print(tb_string) + os._exit(1) + else: + os._exit(0) + else: + pgid = os.getpgid(pid) + _pid, _exitcode = os.waitpid(pid, 0) + # ignore SIGTERM signal itself + old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) + # kill all child processes + os.killpg(pgid, signal.SIGTERM) + # restore the signal handler + signal.signal(signal.SIGTERM, old_signal_handler) + if _exitcode != 0: + # Try to read the exception from the child process + exc_info = {} + if os.path.exists(exc_file_path): + with suppress(Exception), open(exc_file_path, "rb") as f: + exc_info = cloudpickle.load(f) + + if (original_exception := exc_info.get("pickled_exception")) is not None: + # Re-raise the actual exception object if it was + # successfully pickled. + assert isinstance(original_exception, Exception) + raise original_exception + + if (original_tb := exc_info.get("traceback")) is not None: + # Use string-based traceback for fallback case + raise AssertionError( + f"Test {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode}):\n{original_tb}" + ) from None + + # Fallback to the original generic error + raise AssertionError( + f"function {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode})" + ) from None + + return wrapper + + +def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to spawn a new process for each test function.""" + + @functools.wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + # Check if we're already in a subprocess + if os.environ.get("RUNNING_IN_SUBPROCESS") == "1": + # If we are, just run the function directly + return f(*args, **kwargs) + + import torch.multiprocessing as mp + + with suppress(RuntimeError): + mp.set_start_method("spawn") + + # Get the module + module_name = f.__module__ + + # Create a process with environment variable set + env = os.environ.copy() + env["RUNNING_IN_SUBPROCESS"] = "1" + + with tempfile.TemporaryDirectory() as tempdir: + output_filepath = os.path.join(tempdir, "new_process.tmp") + + # `cloudpickle` allows pickling complex functions directly + input_bytes = cloudpickle.dumps((f, output_filepath)) + + cmd = [sys.executable, "-m", f"{module_name}"] + + returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e + + return wrapper + + +def create_new_process_for_each_test( + method: Literal["spawn", "fork"] | None = None, +) -> Callable[[Callable[_P, None]], Callable[_P, None]]: + """Creates a decorator that runs each test function in a new process. + + Args: + method: The process creation method. Can be either "spawn" or "fork". + If not specified, it defaults to "spawn" on ROCm and XPU + platforms and "fork" otherwise. + + Returns: + A decorator to run test functions in separate processes. + """ + if method is None: + # TODO: Spawn is not working correctly on ROCm + # The test content will not run and tests passed immediately. + # For now, using `fork` for ROCm as it can run with `fork` + # and tests are running correctly. + use_spawn = current_platform.is_xpu() + method = "spawn" if use_spawn else "fork" + + assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" + + if method == "fork": + return fork_new_process_for_each_test + + return spawn_new_process_for_each_test + + +def cuda_marks(*, res: str, num_cards: int): + """ + Get a collection of pytest marks to apply for `@cuda_test`. + + Args: + res: Resource type, e.g., "L4" or "H100". + num_cards: Number of GPU cards required. + + Returns: + List of pytest marks to apply. + """ + test_platform_detail = pytest.mark.cuda + + if res == "L4": + test_resource = pytest.mark.L4 + elif res == "H100": + test_resource = pytest.mark.H100 + else: + raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100") + + marks = [test_resource, test_platform_detail] + + if num_cards == 1: + return marks + else: + test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards) + test_skipif = pytest.mark.skipif_cuda( + cuda_device_count_stateless() < num_cards, + reason=f"Need at least {num_cards} CUDA GPUs to run the test.", + ) + return marks + [test_distributed, test_skipif] + + +def rocm_marks(*, res: str, num_cards: int): + """ + Get a collection of pytest marks to apply for `@rocm_test`. + + Args: + res: Resource type, e.g., "MI325". + num_cards: Number of GPU cards required. + + Returns: + List of pytest marks to apply. + """ + test_platform_detail = pytest.mark.rocm + + if res == "MI325": + test_resource = pytest.mark.MI325 + else: + raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325") + + marks = [test_resource, test_platform_detail] + + if num_cards == 1: + return marks + else: + test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards) + # TODO: add ROCm support for `skipif_rocm` marker + return marks + [test_distributed] + + +def gpu_marks(*, res: str, num_cards: int): + """ + Get a collection of pytest marks to apply for `@gpu_test`. + Platform is automatically determined based on resource type. + + Args: + res: Resource type, e.g., "L4", "H100" for CUDA, or "MI325" for ROCm. + num_cards: Number of GPU cards required. + + Returns: + List of pytest marks to apply. + """ + test_platform = pytest.mark.gpu + if res in ("L4", "H100"): + return [test_platform] + cuda_marks(res=res, num_cards=num_cards) + if res == "MI325": + return [test_platform] + rocm_marks(res=res, num_cards=num_cards) + raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325") + + +def npu_marks(*, res: str, num_cards: int): + """Get a collection of pytest marks to apply for `@npu_test`.""" + test_platform = pytest.mark.npu + if res == "A2": + test_resource = pytest.mark.A2 + elif res == "A3": + test_resource = pytest.mark.A3 + else: + # TODO: Currently we don't have various NPU card types defined + # Use None to skip resource-specific marking for unknown types + test_resource = None + + if num_cards == 1: + return [mark for mark in [test_platform, test_resource] if mark is not None] + else: + # Multiple cards scenario needs distributed_npu mark + test_distributed = pytest.mark.distributed_npu(num_cards=num_cards) + # TODO: add NPU support for `skipif_npu` marker + return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None] + + +def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1): + """ + Decorate a test for multiple hardware platforms with a single call. + Automatically wraps the test with @create_new_process_for_each_test() for distributed tests. + + Args: + res: Mapping from platform to resource type. Supported platforms/resources: + - cuda: L4, H100 + - rocm: MI325 + - npu: A2, A3 + num_cards: Number of cards required. Can be: + - int: same card count for all platforms (default: 1) + - dict: per-platform card count, e.g., {"cuda": 2, "rocm": 2} + + Example: + @hardware_test( + res={"cuda": "L4", "rocm": "MI325", "npu": "A2"}, + num_cards={"cuda": 2, "rocm": 2, "npu": 2}, + ) + def test_multi_platform(): + ... + """ + # Validate platforms + # Don't validate platform details in this decorator + for platform, _ in res.items(): + if platform not in ("cuda", "rocm", "npu"): + raise ValueError(f"Unsupported platform: {platform}") + + # Normalize num_cards + if isinstance(num_cards, int): + num_cards_dict = {platform: num_cards for platform in res.keys()} + else: + num_cards_dict = num_cards + for platform in num_cards_dict.keys(): + if platform not in res: + raise ValueError( + f"Platform '{platform}' in num_cards but not in res. Available platforms: {list(res.keys())}" + ) + for platform in res.keys(): + if platform not in num_cards_dict: + num_cards_dict[platform] = 1 + + # Collect marks from all platforms + all_marks: list[Callable[[Callable[_P, None]], Callable[_P, None]]] = [] + for platform, resource in res.items(): + cards = num_cards_dict[platform] + if platform == "cuda" or platform == "rocm": + marks = gpu_marks(res=resource, num_cards=cards) + elif platform == "npu": + marks = npu_marks(res=resource, num_cards=cards) + else: + raise ValueError(f"Unsupported platform: {platform}") + all_marks.extend(marks) + + create_new_process_flag = False + for cards in num_cards_dict.values(): + if cards > 1: + create_new_process_flag = True + break + + def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: + if create_new_process_flag: + # only for distributed tests + func = create_new_process_for_each_test()(f) + else: + func = f + for mark in reversed(all_marks): + func = mark(func) + return func + + return wrapper + + +class GPUMemoryMonitor: + """Poll global device memory usage via CUDA APIs.""" + + def __init__(self, device_index: int, interval: float = 0.05): + self.device_index = device_index + self.interval = interval + self._peak_used_mb = 0.0 + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + def monitor_loop() -> None: + while not self._stop_event.is_set(): + try: + with torch.cuda.device(self.device_index): + free_bytes, total_bytes = torch.cuda.mem_get_info() + used_mb = (total_bytes - free_bytes) / (1024**2) + self._peak_used_mb = max(self._peak_used_mb, used_mb) + except Exception: + pass + time.sleep(self.interval) + + self._thread = threading.Thread(target=monitor_loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + if self._thread is None: + return + self._stop_event.set() + self._thread.join(timeout=2.0) + + @property + def peak_used_mb(self) -> float: + fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2) + fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2) + return max(self._peak_used_mb, fallback_alloc, fallback_reserved) + + def __del__(self): + self.stop() diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py new file mode 100644 index 00000000000..b0132306c81 --- /dev/null +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -0,0 +1,123 @@ +from contextlib import contextmanager + +import torch + +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + + +class DummyBuffer: + """A minimal buffer wrapper that exposes the `.gpu` attribute.""" + + def __init__(self, t: torch.Tensor): + self.gpu = t + + +class DummyInputBatch: + """A minimal input batch that only provides `req_ids`.""" + + def __init__(self, req_ids): + self.req_ids = req_ids + + +class DummyReqState: + """A minimal request state container.""" + + pass + + +class DummyTalkerMTP(torch.nn.Module): + """A fake talker_mtp module for deterministic CPU testing.""" + + def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step): + # Deterministic behavior: + # - output embeds = input embeds + 1 + # - output codes = [[0], [1], ...] + bsz = req_embeds.shape[0] + new_embeds = req_embeds + 1.0 + codes = torch.arange(bsz, dtype=torch.int64).view(bsz, 1) + return new_embeds, codes + + +@contextmanager +def _noop_forward_context(*args, **kwargs): + """A no-op context manager to replace vLLM forward context in CPU tests.""" + yield + + +def _make_runner(req_ids=("r1", "r2"), hidden_size=4): + # Create an instance without calling OmniGPUModelRunner.__init__ + runner = object.__new__(OmniGPUModelRunner) + + # Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward + runner.input_batch = DummyInputBatch(list(req_ids)) + runner.requests = {rid: DummyReqState() for rid in req_ids} + + # query_start_loc.cpu[req_index] is used to locate the token position + # in the flattened `inputs_embeds`. + runner.query_start_loc = type("QSL", (), {})() + # Map: r1 -> offset 0, r2 -> offset 3 + runner.query_start_loc.cpu = torch.tensor([0, 3], dtype=torch.int32) + + bsz = len(req_ids) + runner.talker_mtp_input_ids = DummyBuffer(torch.zeros((bsz,), dtype=torch.int64)) + runner.talker_mtp_inputs_embeds = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) + runner.last_talker_hidden = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) + runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) + + runner.talker_mtp = DummyTalkerMTP() + runner.vllm_config = object() + + # Provide a minimal implementation that returns the expected 4-tuple. + def _determine_batch_execution_and_padding(**kwargs): + return None, object(), None, None + + runner._determine_batch_execution_and_padding = _determine_batch_execution_and_padding + + # Use the real merge method from OmniGPUModelRunner. + return runner + + +def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch): + # Patch the module-level `set_forward_context` symbol used inside + # OmniGPUModelRunner._talker_mtp_forward. + import vllm_omni.worker.gpu_model_runner as mod # Must be the same module that defines OmniGPUModelRunner + + monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) + + runner = _make_runner(req_ids=("r1", "r2"), hidden_size=4) + + # Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds) + runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0]) + runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0]) + + # Flattened `inputs_embeds`: offsets 0 and 3 will be overwritten + inputs_embeds = torch.zeros((6, 4), dtype=torch.float32) + + # Call the original implementation from OmniGPUModelRunner (no re-implementation) + OmniGPUModelRunner._talker_mtp_forward(runner, ["r1", "r2"], inputs_embeds) + + # Validate embeds were written back (+1) + assert torch.allclose(inputs_embeds[0], torch.tensor([2.0, 3.0, 4.0, 5.0])) + assert torch.allclose(inputs_embeds[3], torch.tensor([11.0, 21.0, 31.0, 41.0])) + + # Validate per-request additional_information_cpu was updated + info_r1 = runner.requests["r1"].additional_information_cpu + info_r2 = runner.requests["r2"].additional_information_cpu + assert int(info_r1["code_predictor_codes"][0, 0]) == 0 + assert int(info_r2["code_predictor_codes"][0, 0]) == 1 + + +def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch): + import vllm_omni.worker.gpu_model_runner as mod + + monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) + + runner = _make_runner(req_ids=("r1",), hidden_size=4) + + inputs_embeds = torch.randn((2, 4)) + before = inputs_embeds.clone() + + OmniGPUModelRunner._talker_mtp_forward(runner, [], inputs_embeds) + + # Ensure no changes were made + assert torch.allclose(inputs_embeds, before) diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index 562999d7e58..db45f29900d 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -18,6 +18,7 @@ ALLOWED_FILES = { "vllm_omni/entrypoints/omni_llm.py", "tests/e2e/offline_inference/utils.py", + "tests/utils.py", "vllm_omni/diffusion/distributed/group_coordinator.py", "tests/diffusion/attention/test_sequence_parallel.py", } diff --git a/vllm_omni/assets/video.py b/vllm_omni/assets/video.py new file mode 100644 index 00000000000..361e2ac785f --- /dev/null +++ b/vllm_omni/assets/video.py @@ -0,0 +1,14 @@ +import librosa +import numpy as np +from vllm.assets.video import VideoAsset +def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndarray: + """ This function extracts the audio from a video file path and returns the audio as a numpy array. + Args: + path: The path to the video file. + Returns: + The audio as a numpy array. + """ + if not path: + path = VideoAsset(name="baby_reading").video_path + audio_signal, sr = librosa.load(path, sr=sampling_rate) + return audio_signal \ No newline at end of file diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index 2e53a7af2e1..592c501a631 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -1,12 +1,9 @@ import warnings -from importlib.util import find_spec from typing import Any import torch -import vllm.envs as envs from pydantic import ConfigDict from pydantic.dataclasses import dataclass -from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.config import ModelConfig, config from vllm.config.model import ( _RUNNER_CONVERTS, @@ -24,10 +21,12 @@ get_pooling_config, ) from vllm.transformers_utils.gguf_utils import ( + is_gguf, maybe_patch_hf_config_from_gguf, ) from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.transformers_utils.gguf_utils import is_gguf +from vllm.v1.attention.backends.registry import AttentionBackendEnum + import vllm_omni.model_executor.models as me_models logger = init_logger(__name__) @@ -108,9 +107,7 @@ def __post_init__( video_pruning_rate: float | None, ) -> None: # Keep set served_model_name before maybe_model_redirect(self.model) - self.served_model_name = get_served_model_name( - self.model, self.served_model_name - ) + self.served_model_name = get_served_model_name(self.model, self.served_model_name) self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default. if self.tokenizer is None: @@ -167,9 +164,7 @@ def __post_init__( if dict_overrides: self._apply_dict_overrides(hf_config, dict_overrides) self.hf_text_config = self.draw_hf_text_config() - self.attention_chunk_size = getattr( - self.hf_text_config, "attention_chunk_size", None - ) + self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None) self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision @@ -182,9 +177,7 @@ def __post_init__( is_pooling_model = registry.is_pooling_model(architectures, self) self.runner_type = self._get_runner_type(architectures, self.runner) - self.convert_type = self._get_convert_type( - architectures, self.runner_type, self.convert - ) + self.convert_type = self._get_convert_type(architectures, self.runner_type, self.convert) if self.runner_type == "generate" and not is_generative_model: generate_converts = _RUNNER_CONVERTS["generate"] @@ -244,10 +237,7 @@ def __post_init__( # Init multimodal config if needed if self._model_info.supports_multimodal: - if ( - mm_encoder_tp_mode == "data" - and not self._model_info.supports_multimodal_encoder_tp_data - ): + if mm_encoder_tp_mode == "data" and not self._model_info.supports_multimodal_encoder_tp_data: logger.warning_once( "This model does not support `--mm-encoder-tp-mode data`. " "Falling back to `--mm-encoder-tp-mode weights`." @@ -269,9 +259,7 @@ def __post_init__( video_pruning_rate=video_pruning_rate, ) - mm_config_kwargs = { - k: v for k, v in mm_config_kwargs.items() if v is not None - } + mm_config_kwargs = {k: v for k, v in mm_config_kwargs.items() if v is not None} self.multimodal_config = MultiModalConfig(**mm_config_kwargs) diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index dc3f56ac2db..3c262b7c927 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -3,6 +3,8 @@ from collections import defaultdict from time import time +import numpy as np +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.distributed.kv_events import KVEventBatch from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger @@ -10,10 +12,13 @@ from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.metrics.perf import PerfStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats +logger = init_logger(__name__) + class OmniARScheduler(VLLMScheduler): """ @@ -73,7 +78,7 @@ def update_from_output( pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output - cudagraph_stats = model_runner_output.cudagraph_stats + cudagraph_stats: CUDAGraphStat | None = model_runner_output.cudagraph_stats perf_stats: PerfStats | None = None if self.perf_metrics and self.perf_metrics.is_enabled(): @@ -152,7 +157,7 @@ def update_from_output( new_token_ids, stopped = self._update_request_with_output(request, new_token_ids) if pooler_output: - # Note: As we occupied the pooler output, for multimodal outputs, we do not intermediate stop checking for pooler output + # Note: For multimodal outputs, we skip intermediate stop checks. if request.output_token_ids: stopped = check_stop(request, self.max_model_len) routed_experts = None @@ -172,13 +177,10 @@ def update_from_output( # compute slot mapping: slot = block_id * block_size + offset slot_mapping = ( - block_offsets.reshape((1, block_size)) - + block_ids_array.reshape((num_blocks, 1)) * block_size + block_offsets.reshape((1, block_size)) + block_ids_array.reshape((num_blocks, 1)) * block_size ).flatten()[:num_tokens] - routed_experts = self.routed_experts_reader.get_routed_experts( - indices=slot_mapping - ) + routed_experts = self.routed_experts_reader.get_routed_experts(indices=slot_mapping) kv_transfer_params = self._free_request(request) if status_before_stop == RequestStatus.RUNNING: stopped_running_reqs.add(request) @@ -221,6 +223,7 @@ def update_from_output( kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, + routed_experts=routed_experts, num_nans_in_logits=request.num_nans_in_logits, ) ) @@ -234,7 +237,7 @@ def update_from_output( if stopped_preempted_reqs: # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) - + if failed_kv_load_req_ids and not self.recompute_kv_load_failures: requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids] self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR) @@ -286,7 +289,7 @@ def update_from_output( engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats)) is not None: + if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats)) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 7acef3bdcb5..2d04faeeea7 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -90,7 +90,7 @@ def schedule(self) -> SchedulerOutput: any_request = self.running[0] num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id) - # Assemble SchedulerOutput (align with v0.12.0) + # Assemble SchedulerOutput (align with v0.14.0) if self.use_v2_model_runner: # No resumed reqs in fast path; pass prefill_token_ids for new reqs. new_reqs_data = [ @@ -129,7 +129,7 @@ def schedule(self) -> SchedulerOutput: preempted_req_ids=set(), ) - # Record the request ids scheduled in this step (v0.12.0 behavior). + # Record the request ids scheduled in this step (v0.14.0 behavior). self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) @@ -176,7 +176,7 @@ def update_from_output( outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None kv_connector_stats = kv_connector_output.kv_connector_stats if kv_connector_output else None - # Merge connector-side stats (align with v0.12.0) + # Merge connector-side stats (align with v0.14.0) if kv_connector_stats and self.connector: kv_stats = self.connector.get_kv_connector_stats() if kv_stats: @@ -294,7 +294,7 @@ def update_from_output( if kv_connector_output: self._update_from_kv_xfer_finished(kv_connector_output) - # Collect and publish KV cache events (align with v0.12.0) + # Collect and publish KV cache events (align with v0.14.0) events = self.kv_cache_manager.take_events() if self.connector is not None: connector_events = self.connector.take_events() diff --git a/vllm_omni/diffusion/attention/backends/flash_attn.py b/vllm_omni/diffusion/attention/backends/flash_attn.py index 6921c79aee9..3623d49db8f 100644 --- a/vllm_omni/diffusion/attention/backends/flash_attn.py +++ b/vllm_omni/diffusion/attention/backends/flash_attn.py @@ -9,13 +9,14 @@ AttentionImpl, AttentionMetadata, ) +from vllm_omni.diffusion.attention.backends.utils.fa import _pad_input, _unpad_input, _upad_input logger = init_logger(__name__) try: # only tested with flash_attn v3 # from flash_attn_interface import flash_attn_func as flash_attn_3_func # not available in flash-attn 2.8.1 - from flash_attn import flash_attn_func # can be FA2 or FA3 + from flash_attn import flash_attn_func, flash_attn_varlen_func # can be FA2 or FA3 except ImportError: logger.warning( "FlashAttentionBackend is not available. You may install flash-attn " @@ -63,12 +64,51 @@ def forward( value: torch.Tensor, attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: - # TODO: flash_attn_func does not support attn_mask. - out: torch.Tensor = flash_attn_func( - query, - key, - value, - causal=self.causal, - softmax_scale=self.softmax_scale, - ) + """ + Flash attention implementation. + + Args: + query: (batch_size, seq_len, num_heads, head_dim) + key: (batch_size, seq_len, num_heads, head_dim) + value: (batch_size, seq_len, num_heads, head_dim) + attn_metadata: AttentionMetadata. Attention mask is supported as attn_metadata.attn_mask + + Returns: + (batch_size, seq_len, num_heads, head_dim) + """ + query_length = query.size(1) + attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None + # Contains at least one padding token in the sequence + if attention_mask is not None and torch.any(~attention_mask): + assert attention_mask.ndim == 2, "attention_mask must be 2D, (batch_size, seq_len)" + q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( + query, key, value, attention_mask, query_length, _unpad_input + ) + + out_unpad = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **{ + "causal": self.causal, + "softmax_scale": self.softmax_scale, + }, + ) + if isinstance(out_unpad, tuple): + out_unpad = out_unpad[0] + + out = _pad_input(out_unpad, indices_q, query.size(0), query_length) + + else: + out: torch.Tensor = flash_attn_func( + query, + key, + value, + causal=self.causal, + softmax_scale=self.softmax_scale, + ) return out diff --git a/vllm_omni/diffusion/attention/backends/ring_flash_attn.py b/vllm_omni/diffusion/attention/backends/ring_flash_attn.py index e163ef1ace1..dd27f88a8c1 100644 --- a/vllm_omni/diffusion/attention/backends/ring_flash_attn.py +++ b/vllm_omni/diffusion/attention/backends/ring_flash_attn.py @@ -28,6 +28,20 @@ def ring_flash_attn_forward( joint_tensor_value=None, joint_strategy="front", ): + # Validate causal + joint_strategy combination + # When causal=True and joint_strategy="rear", the causal mask would incorrectly + # prevent local query tokens from attending to joint key tokens (which are + # concatenated at the end). This breaks the semantics where joint tokens + # (e.g., text conditioning) should be visible to all local tokens. + if causal and joint_tensor_key is not None and joint_strategy == "rear": + raise ValueError( + "joint_strategy='rear' is not compatible with causal=True in Ring Attention. " + "When using causal attention with joint tokens, use joint_strategy='front' " + "to ensure joint tokens act as a visible prefix for all local tokens. " + "With 'rear' strategy, the causal mask would incorrectly block local tokens " + "from seeing the joint tokens." + ) + comm = RingComm(process_group) out = None diff --git a/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py b/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py index 43ee35f7098..9ed2c2076c4 100644 --- a/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py +++ b/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py @@ -62,6 +62,20 @@ def forward( joint_tensor_value=None, joint_strategy="front", ): + # Validate causal + joint_strategy combination + # When causal=True and joint_strategy="rear", the causal mask would incorrectly + # prevent local query tokens from attending to joint key tokens (which are + # concatenated at the end). This breaks the semantics where joint tokens + # (e.g., text conditioning) should be visible to all local tokens. + if is_causal and joint_tensor_key is not None and joint_strategy == "rear": + raise ValueError( + "joint_strategy='rear' is not compatible with causal=True in Ring Attention. " + "When using causal attention with joint tokens, use joint_strategy='front' " + "to ensure joint tokens act as a visible prefix for all local tokens. " + "With 'rear' strategy, the causal mask would incorrectly block local tokens " + "from seeing the joint tokens." + ) + comm = RingComm(group) # Ensure tensors are contiguous for P2P communication q = q.contiguous() diff --git a/vllm_omni/diffusion/attention/backends/utils/__init__.py b/vllm_omni/diffusion/attention/backends/utils/__init__.py new file mode 100644 index 00000000000..92c7c8027cb --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/utils/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Utils for attention backends. +""" + +from vllm_omni.diffusion.attention.backends.utils.fa import _pad_input, _unpad_input, _upad_input + +__all__ = [ + "_pad_input", + "_unpad_input", + "_upad_input", +] diff --git a/vllm_omni/diffusion/attention/backends/utils/fa.py b/vllm_omni/diffusion/attention/backends/utils/fa.py new file mode 100644 index 00000000000..d89082e717a --- /dev/null +++ b/vllm_omni/diffusion/attention/backends/utils/fa.py @@ -0,0 +1,209 @@ +# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py +import torch +import torch.nn.functional as F + + +def _index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. This is functionally equivalent to + FA2's `index_first_axis` and replaces the need to import it. + """ + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] + + +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + _index_first_axis(hidden_states, indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def _pad_input(hidden_states, indices, batch, seqlen): + """ + pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. + `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong + to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in + order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into + ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, + `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + if torch.compiler.is_compiling(): + # allow PyTorch compiler to include operations that return scalar values (like .item() + torch._dynamo.config.capture_scalar_outputs = True + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> + # we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = _index_first_axis(key_layer, indices_k) + value_layer = _index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = _index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def _is_packed_sequence(position_ids, batch_size): + """ + Check the position ids whether packed sequences are indicated or not + 1. Position ids exist + 2. Flattened sequences only are supported + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. + we have multiple increasing sequences + """ + if position_ids is None: + return False + + increasing_position_sequences = torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() + return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() diff --git a/vllm_omni/diffusion/attention/selector.py b/vllm_omni/diffusion/attention/selector.py index f2f60243f04..b9f62e0681d 100644 --- a/vllm_omni/diffusion/attention/selector.py +++ b/vllm_omni/diffusion/attention/selector.py @@ -31,6 +31,8 @@ "ASCEND": {"module": "vllm_omni.diffusion.attention.backends.ascend_attn", "class": "AscendAttentionBackend"}, } +_BACKENDS_SUPPORT_ATTENTION_MASK = ["SDPA", "ASCEND", "FLASH_ATTN"] + def load_backend(backend_name: str) -> type[AttentionBackend]: config = _BACKEND_CONFIG[backend_name] diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index 3485e2262c9..0c43659c9a6 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -7,11 +7,19 @@ pipelines in vllm-omni, supporting both single and dual-transformer architectures. """ +import functools from collections.abc import Callable +from contextlib import ExitStack from typing import Any, Optional import cache_dit +import torch from cache_dit import BlockAdapter, DBCacheConfig, ForwardPattern, ParamsModifier, TaylorSeerCalibratorConfig +from cache_dit.caching.block_adapters import FakeDiffusionPipeline +from cache_dit.caching.cache_adapters.cache_adapter import CachedAdapter +from cache_dit.caching.cache_blocks.pattern_0_1_2 import CachedBlocks_Pattern_0_1_2 +from cache_dit.caching.cache_contexts import BasicCacheConfig +from cache_dit.caching.cache_contexts.cache_manager import CachedContextManager from vllm.logger import init_logger from vllm_omni.diffusion.cache.base import CacheBackend @@ -401,6 +409,384 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +class BagelCachedContextManager(CachedContextManager): + """ + Custom CachedContextManager for Bagel that safely handles NaiveCache objects + (mapped to encoder_hidden_states) by skipping tensor operations on them. + """ + + @torch.compiler.disable + def apply_cache( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + prefix: str = "Bn", + encoder_prefix: str = "Bn_encoder", + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Allow Bn and Fn prefix to be used for residual cache. + if "Bn" in prefix: + hidden_states_prev = self.get_Bn_buffer(prefix) + else: + hidden_states_prev = self.get_Fn_buffer(prefix) + + assert hidden_states_prev is not None, f"{prefix}_buffer must be set before" + + if self.is_cache_residual(): + hidden_states = hidden_states_prev + hidden_states + else: + # If cache is not residual, we use the hidden states directly + hidden_states = hidden_states_prev + + hidden_states = hidden_states.contiguous() + + if encoder_hidden_states is not None: + if "Bn" in encoder_prefix: + encoder_hidden_states_prev = self.get_Bn_encoder_buffer(encoder_prefix) + else: + encoder_hidden_states_prev = self.get_Fn_encoder_buffer(encoder_prefix) + + if encoder_hidden_states_prev is not None: + if self.is_encoder_cache_residual(): + # FIX: Check if encoder_hidden_states is a tensor before adding + if isinstance(encoder_hidden_states, torch.Tensor) and isinstance( + encoder_hidden_states_prev, torch.Tensor + ): + encoder_hidden_states = encoder_hidden_states_prev + encoder_hidden_states + else: + # If encoder cache is not residual, we use the encoder hidden states directly + encoder_hidden_states = encoder_hidden_states_prev + + # FIX: Check if encoder_hidden_states is a tensor before calling contiguous + if isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states.contiguous() + + return hidden_states, encoder_hidden_states + + +class BagelCachedBlocks(CachedBlocks_Pattern_0_1_2): + """ + Custom CachedBlocks for Bagel that safely handles NaiveCache objects + by adding isinstance checks in call_Mn_blocks and compute_or_prune. + """ + + def call_Mn_blocks( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + *args, + **kwargs, + ): + original_hidden_states = hidden_states + original_encoder_hidden_states = encoder_hidden_states + for block in self._Mn_blocks(): + hidden_states = block( + hidden_states, + encoder_hidden_states, + *args, + **kwargs, + ) + hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states) + + # compute hidden_states residual + hidden_states = hidden_states.contiguous() + + hidden_states_residual = hidden_states - original_hidden_states + + if ( + encoder_hidden_states is not None + and original_encoder_hidden_states is not None + and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check + ): + encoder_hidden_states = encoder_hidden_states.contiguous() + encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states + else: + encoder_hidden_states_residual = None + + return ( + hidden_states, + encoder_hidden_states, + hidden_states_residual, + encoder_hidden_states_residual, + ) + + def compute_or_prune( + self, + block_id: int, # Block index in the transformer blocks + # Below are the inputs to the block + block, # The transformer block to be executed + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + *args, + **kwargs, + ): + # NOTE: Although Bagel likely won't use pruning, implementing safe version just in case. + # Copy-pasted from original but adding checks. + + original_hidden_states = hidden_states + original_encoder_hidden_states = encoder_hidden_states + + can_use_prune = self._maybe_prune( + block_id, + hidden_states, + prefix=f"{self.cache_prefix}_{block_id}_Fn_original", + ) + + torch._dynamo.graph_break() + if can_use_prune: + self.context_manager.add_pruned_step() + hidden_states, encoder_hidden_states = self.context_manager.apply_prune( + hidden_states, + encoder_hidden_states, + prefix=( + f"{self.cache_prefix}_{block_id}_Bn_residual" + if self.context_manager.is_cache_residual() + else f"{self.cache_prefix}_Bn_hidden_states" + ), + encoder_prefix=( + f"{self.cache_prefix}_{block_id}_Bn_encoder_residual" + if self.context_manager.is_encoder_cache_residual() + else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states" + ), + ) + torch._dynamo.graph_break() + else: + # Normal steps: Compute the block and cache the residuals. + hidden_states = block( + hidden_states, + encoder_hidden_states, + *args, + **kwargs, + ) + hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states) + if not self._skip_prune(block_id): + hidden_states = hidden_states.contiguous() + hidden_states_residual = hidden_states - original_hidden_states + + if ( + encoder_hidden_states is not None + and original_encoder_hidden_states is not None + and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check + ): + encoder_hidden_states = encoder_hidden_states.contiguous() + encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states + else: + encoder_hidden_states_residual = None + + self.context_manager.set_Fn_buffer( + original_hidden_states, + prefix=f"{self.cache_prefix}_{block_id}_Fn_original", + ) + if self.context_manager.is_cache_residual(): + self.context_manager.set_Bn_buffer( + hidden_states_residual, + prefix=f"{self.cache_prefix}_{block_id}_Bn_residual", + ) + else: + self.context_manager.set_Bn_buffer( + hidden_states, + prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states", + ) + if encoder_hidden_states_residual is not None: + if self.context_manager.is_encoder_cache_residual(): + self.context_manager.set_Bn_encoder_buffer( + encoder_hidden_states_residual, + prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual", + ) + else: + self.context_manager.set_Bn_encoder_buffer( + encoder_hidden_states_residual, + prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states", + ) + torch._dynamo.graph_break() + + return hidden_states, encoder_hidden_states + + +class BagelCachedAdapter(CachedAdapter): + """ + Custom CachedAdapter for Bagel that uses BagelCachedContextManager and BagelCachedBlocks. + """ + + @classmethod + def create_context( + cls, + block_adapter: BlockAdapter, + **context_kwargs, + ) -> tuple[list[str], list[dict[str, Any]]]: + # Override to use BagelCachedContextManager + + BlockAdapter.assert_normalized(block_adapter) + + if BlockAdapter.is_cached(block_adapter.pipe): + return block_adapter.pipe + + # Check context_kwargs + context_kwargs = cls.check_context_kwargs(block_adapter, **context_kwargs) + + # Each Pipeline should have it's own context manager instance. + cache_config: BasicCacheConfig = context_kwargs.get("cache_config", None) + assert cache_config is not None, "cache_config can not be None." + + # Apply cache on pipeline: wrap cache context + pipe_cls_name = block_adapter.pipe.__class__.__name__ + + # USE CUSTOM CONTEXT MANAGER + context_manager = BagelCachedContextManager( + name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}", + persistent_context=isinstance(block_adapter.pipe, FakeDiffusionPipeline), + ) + + flatten_contexts, contexts_kwargs = cls.modify_context_params(block_adapter, **context_kwargs) + + block_adapter.pipe._context_manager = context_manager # instance level + + if not context_manager.persistent_context: + original_call = block_adapter.pipe.__class__.__call__ + + @functools.wraps(original_call) + def new_call(self, *args, **kwargs): + with ExitStack() as stack: + # cache context will be reset for each pipe inference + for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs): + stack.enter_context( + context_manager.enter_context( + context_manager.reset_context( + context_name, + **context_kwargs, + ), + ) + ) + outputs = original_call(self, *args, **kwargs) + cls.apply_stats_hooks(block_adapter) + return outputs + + block_adapter.pipe.__class__.__call__ = new_call + block_adapter.pipe.__class__._original_call = original_call + + else: + # Init persistent cache context for transformer + for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs): + context_manager.reset_context( + context_name, + **context_kwargs, + ) + + block_adapter.pipe.__class__._is_cached = True + + cls.apply_params_hooks(block_adapter, contexts_kwargs) + + return flatten_contexts, contexts_kwargs + + @classmethod + def collect_unified_blocks( + cls, + block_adapter: BlockAdapter, + contexts_kwargs: list[dict], + ) -> list[dict[str, torch.nn.ModuleList]]: + # Override to use BagelCachedBlocks + + BlockAdapter.assert_normalized(block_adapter) + + total_cached_blocks: list[dict[str, torch.nn.ModuleList]] = [] + assert hasattr(block_adapter.pipe, "_context_manager") + # Skipping isinstance check for ContextManager._supported_managers to avoid import issues + + for i in range(len(block_adapter.transformer)): + unified_blocks_bind_context = {} + for j in range(len(block_adapter.blocks[i])): + cache_config: BasicCacheConfig = contexts_kwargs[i * len(block_adapter.blocks[i]) + j]["cache_config"] + + # Directly instantiate BagelCachedBlocks + unified_blocks_bind_context[block_adapter.unique_blocks_name[i][j]] = torch.nn.ModuleList( + [ + BagelCachedBlocks( + # 0. Transformer blocks configuration + block_adapter.blocks[i][j], + transformer=block_adapter.transformer[i], + forward_pattern=block_adapter.forward_pattern[i][j], + check_forward_pattern=block_adapter.check_forward_pattern, + check_num_outputs=block_adapter.check_num_outputs, + # 1. Cache/Prune context configuration + cache_prefix=block_adapter.blocks_name[i][j], + cache_context=block_adapter.unique_blocks_name[i][j], + context_manager=block_adapter.pipe._context_manager, + cache_type=cache_config.cache_type, + ) + ] + ) + + total_cached_blocks.append(unified_blocks_bind_context) + + return total_cached_blocks + + +def enable_cache_for_bagel(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Bagel model (via OmniDiffusion pipeline). + + Args: + pipeline: The OmniDiffusion pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + # Build DBCacheConfig + db_cache_config = _build_db_cache_config(cache_config) + + # Build calibrator config if TaylorSeer is enabled + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Access the transformer: BagelPipeline -> Qwen2MoTForCausalLM -> Qwen2MoTModel + # BagelPipeline has self.language_model which is Qwen2MoTForCausalLM + # Qwen2MoTForCausalLM has self.model which is Qwen2MoTModel + transformer = pipeline.language_model.model + + logger.info( + f"Enabling cache-dit on Bagel transformer: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + # Enable cache-dit on the transformer + # Pattern_0 corresponds to (hidden_states, encoder_hidden_states) input, output + # Custom adapter for Bagel to handle NaiveCache correctly + # from vllm_omni.diffusion.cache.bagel_cache_adapter import BagelCachedAdapter # No longer needed + BagelCachedAdapter.apply( + BlockAdapter( + transformer=transformer, + blocks=transformer.layers, + forward_pattern=ForwardPattern.Pattern_0, + ), + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + transformer = pipeline.language_model.model + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + # Register custom cache-dit enablers after function definitions CUSTOM_DIT_ENABLERS.update( { @@ -409,6 +795,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, "StableDiffusion3Pipeline": enable_cache_for_sd3, + "BagelPipeline": enable_cache_for_bagel, } ) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 64184caa749..1948b40a3b4 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -293,6 +293,7 @@ def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]): def _dummy_run(self): """A dummy run to warm up the model.""" prompt = "dummy run" + # note that num_inference_steps=1 will cause timestep and temb None in the pipeline num_inference_steps = 1 height = 1024 width = 1024 diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index 7950389ef1f..256f25e0839 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -314,10 +314,12 @@ def __init__( def forward( self, - packed_query_sequence: torch.Tensor, - query_lens: torch.Tensor, - packed_query_position_embeddings: torch.Tensor, - packed_query_indexes: torch.Tensor, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + packed_query_sequence: torch.Tensor | None = None, + query_lens: torch.Tensor = None, + packed_query_position_embeddings: torch.Tensor = None, + packed_query_indexes: torch.Tensor = None, past_key_values: NaiveCache | None = None, key_values_lens: torch.Tensor | None = None, packed_key_value_indexes: torch.Tensor | None = None, @@ -327,6 +329,8 @@ def forward( packed_vae_token_indexes=None, packed_text_indexes=None, ) -> BaseNavitOutputWithPast: + if packed_query_sequence is None: + packed_query_sequence = hidden_states residual = packed_query_sequence if mode == "und": packed_query_sequence = self.input_layernorm(packed_query_sequence) @@ -437,7 +441,8 @@ def forward( for layer_idx, decoder_layer in enumerate(self.layers): packed_query_sequence, past_key_values = decoder_layer( - packed_query_sequence=packed_query_sequence, + hidden_states=packed_query_sequence, + encoder_hidden_states=None, query_lens=query_lens, packed_query_position_embeddings=packed_query_position_embeddings, packed_query_indexes=packed_query_indexes, diff --git a/vllm_omni/diffusion/models/flux2_klein/__init__.py b/vllm_omni/diffusion/models/flux2_klein/__init__.py new file mode 100644 index 00000000000..0d477ab0a48 --- /dev/null +++ b/vllm_omni/diffusion/models/flux2_klein/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Flux2 klein diffusion model components.""" + +from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( + Flux2Transformer2DModel, +) +from vllm_omni.diffusion.models.flux2_klein.pipeline_flux2_klein import ( + Flux2KleinPipeline, + get_flux2_klein_post_process_func, +) + +__all__ = [ + "Flux2KleinPipeline", + "Flux2Transformer2DModel", + "get_flux2_klein_post_process_func", +] diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py new file mode 100644 index 00000000000..86658a01deb --- /dev/null +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -0,0 +1,723 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterable +from types import SimpleNamespace +from typing import Any + +import torch +import torch.nn as nn +from diffusers.models.embeddings import ( + TimestepEmbedding, + Timesteps, + get_1d_rotary_pos_embed, +) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + + +class Flux2SwiGLU(nn.Module): + """SwiGLU activation used by Flux2.""" + + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + +class Flux2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: float = 3.0, + inner_dim: int | None = None, + bias: bool = False, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias) + self.act_fn = Flux2SwiGLU() + self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_in(x) + x = self.act_fn(x) + return self.linear_out(x) + + +class Flux2Attention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dropout = dropout + self.added_kv_proj_dim = added_kv_proj_dim + + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + disable_tp=True, + bias=bias, + ) + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_out = nn.ModuleList( + [ReplicatedLinear(self.inner_dim, self.out_dim, bias=out_bias), nn.Dropout(dropout)] + ) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + disable_tp=True, + bias=added_proj_bias, + ) + self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias) + + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + qkv, _ = self.to_qkv(hidden_states) + query, key, value = qkv.chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and self.added_kv_proj_dim is not None: + encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) + encoder_query, encoder_key, encoder_value = encoder_qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if encoder_hidden_states is not None and self.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.heads, -1)) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + + if encoder_hidden_states is not None: + context_len = encoder_hidden_states.shape[1] + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [context_len, hidden_states.shape[1] - context_len], + dim=1, + ) + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + return hidden_states + + +class Flux2ParallelSelfAttention(nn.Module): + """ + Parallel attention block that fuses QKV projections with MLP input projections. + """ + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + mlp_ratio: float = 4.0, + mlp_mult_factor: int = 2, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dropout = dropout + + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) + self.mlp_mult_factor = mlp_mult_factor + + self.to_qkv_mlp_proj = nn.Linear( + self.query_dim, + self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, + bias=bias, + ) + self.mlp_act_fn = Flux2SwiGLU() + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + self.to_out = nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) + self.rope = RotaryEmbedding(is_neox_style=False) + self.attn = Attention( + num_heads=self.heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.to_qkv_mlp_proj(hidden_states) + qkv, mlp_hidden_states = torch.split( + hidden_states, + [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor], + dim=-1, + ) + + query, key, value = qkv.chunk(3, dim=-1) + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + query = self.rope(query, cos, sin) + key = self.rope(key, cos, sin) + + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + attn_output = self.attn(query, key, value, attn_metadata) + attn_output = attn_output.flatten(2, 3).to(query.dtype) + + mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1) + return self.to_out(hidden_states) + + +class Flux2SingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.attn = Flux2ParallelSelfAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + out_bias=bias, + eps=eps, + mlp_ratio=mlp_ratio, + mlp_mult_factor=2, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + split_hidden_states: bool = False, + text_seq_len: int | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is not None: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + mod_shift, mod_scale, mod_gate = temb_mod_params + + norm_hidden_states = self.norm(hidden_states) + norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = hidden_states + mod_gate * attn_output + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + if split_hidden_states: + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + return hidden_states + + +class Flux2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 3.0, + eps: float = 1e-6, + bias: bool = False, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = Flux2Attention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + joint_attention_kwargs = joint_attention_kwargs or {} + + (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img + (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt + + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa + + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa + + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate_mlp * ff_output + + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Flux2PosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: list[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(len(self.axes_dim)): + freqs_cis = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + use_real=False, + freqs_dtype=freqs_dtype, + ) + cos_out.append(freqs_cis.real) + sin_out.append(freqs_cis.imag) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Flux2TimestepGuidanceEmbeddings(nn.Module): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): + super().__init__() + self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + ) + + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + ) + else: + self.guidance_embedder = None + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor | None) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) + + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) + return timesteps_emb + guidance_emb + return timesteps_emb + + +class Flux2Modulation(nn.Module): + def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): + super().__init__() + self.mod_param_sets = mod_param_sets + self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) + self.act_fn = nn.SiLU() + + def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + mod = self.act_fn(temb) + mod = self.linear(mod) + if mod.ndim == 2: + mod = mod.unsqueeze(1) + mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) + return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) + + +class Flux2Transformer2DModel(nn.Module): + """ + The Transformer model introduced in Flux 2. + """ + + _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + + def __init__( + self, + patch_size: int = 1, + in_channels: int = 128, + out_channels: int | None = None, + num_layers: int = 8, + num_single_layers: int = 48, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + joint_attention_dim: int = 15360, + timestep_guidance_channels: int = 256, + mlp_ratio: float = 3.0, + axes_dims_rope: tuple[int, ...] = (32, 32, 32, 32), + rope_theta: int = 2000, + eps: float = 1e-6, + guidance_embeds: bool = True, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.config = SimpleNamespace( + patch_size=patch_size, + in_channels=in_channels, + out_channels=self.out_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + timestep_guidance_channels=timestep_guidance_channels, + mlp_ratio=mlp_ratio, + axes_dims_rope=axes_dims_rope, + rope_theta=rope_theta, + eps=eps, + guidance_embeds=guidance_embeds, + ) + + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope)) + self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, + ) + + self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) + self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) + + self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + Flux2TransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + Flux2SingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=mlp_ratio, + eps=eps, + bias=False, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, + self.inner_dim, + elementwise_affine=False, + eps=eps, + bias=False, + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + img_ids: torch.Tensor, + txt_ids: torch.Tensor, + guidance: torch.Tensor | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + joint_attention_kwargs = joint_attention_kwargs or {} + + num_txt_tokens = encoder_hidden_states.shape[1] + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = self.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = self.double_stream_modulation_img(temb) + double_stream_mod_txt = self.double_stream_modulation_txt(temb) + single_stream_mod = self.single_stream_modulation(temb)[0] + + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for block in self.single_transformer_blocks: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = hidden_states[:, num_txt_tokens:, ...] + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "to_qkvkv_mlp_proj" in name: + name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj") + if "to_qkv_mlp_proj" in name: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py new file mode 100644 index 00000000000..ba29e681c32 --- /dev/null +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -0,0 +1,963 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math +import os +from collections.abc import Callable, Iterable +from typing import Any + +import numpy as np +import PIL.Image +import torch +import torch.nn as nn +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl_flux2 import AutoencoderKLFlux2 +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils.torch_utils import randn_tensor +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( + Flux2Transformer2DModel, +) +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + +logger = init_logger(__name__) + + +class Flux2ImageProcessor(VaeImageProcessor): + """Image processor to preprocess the reference image for Flux2 klein.""" + + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + do_normalize: bool = True, + do_convert_rgb: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_normalize=do_normalize, + do_convert_rgb=do_convert_rgb, + ) + + @staticmethod + def check_image_input( + image: PIL.Image.Image, + max_aspect_ratio: int = 8, + min_side_length: int = 64, + max_area: int = 1024 * 1024, + ) -> PIL.Image.Image: + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}") + + width, height = image.size + if width < min_side_length or height < min_side_length: + raise ValueError(f"Image too small: {width}x{height}. Both dimensions must be at least {min_side_length}px") + + aspect_ratio = max(width / height, height / width) + if aspect_ratio > max_aspect_ratio: + raise ValueError( + f"Aspect ratio too extreme: {width}x{height} (ratio: {aspect_ratio:.1f}:1). " + f"Maximum allowed ratio is {max_aspect_ratio}:1" + ) + + if width * height > max_area: + logger.warning("Image area exceeds recommended maximum; resizing will be applied.") + + return image + + @staticmethod + def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + scale = math.sqrt(target_area / (image_width * image_height)) + width = int(image_width * scale) + height = int(image_height * scale) + return image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + @staticmethod + def _resize_if_exceeds_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + if image_width * image_height <= target_area: + return image + return Flux2ImageProcessor._resize_to_target_area(image, target_area) + + def _resize_and_crop(self, image: PIL.Image.Image, width: int, height: int) -> PIL.Image.Image: + image_width, image_height = image.size + left = (image_width - width) // 2 + top = (image_height - height) // 2 + right = left + width + bottom = top + height + return image.crop((left, top, right, bottom)) + + @staticmethod + def concatenate_images(images: list[PIL.Image.Image]) -> PIL.Image.Image: + if len(images) == 1: + return images[0].copy() + + images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] + total_width = sum(img.width for img in images) + max_height = max(img.height for img in images) + background_color = (255, 255, 255) + new_img = PIL.Image.new("RGB", (total_width, max_height), background_color) + + x_offset = 0 + for img in images: + y_offset = (max_height - img.height) // 2 + new_img.paste(img, (x_offset, y_offset)) + x_offset += img.width + + return new_img + + +def get_flux2_klein_post_process_func( + od_config: OmniDiffusionConfig, +): + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = Flux2ImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func(images: torch.Tensor): + return image_processor.postprocess(images) + + return post_process_func + + +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +class Flux2KleinPipeline(nn.Module, SupportImageInput): + """Flux2 klein pipeline for text-to-image generation.""" + + support_image_input = True + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + is_distilled: bool = False, + ): + super().__init__() + self.od_config = od_config + self.is_distilled = is_distilled + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self._execution_device = get_local_device() + model = od_config.model + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, + subfolder="scheduler", + local_files_only=local_files_only, + ) + self.text_encoder = Qwen3ForCausalLM.from_pretrained( + model, + subfolder="text_encoder", + local_files_only=local_files_only, + ) + self.tokenizer = Qwen2TokenizerFast.from_pretrained( + model, + subfolder="tokenizer", + local_files_only=local_files_only, + ) + self.vae = AutoencoderKLFlux2.from_pretrained( + model, + subfolder="vae", + local_files_only=local_files_only, + ).to(self._execution_device) + + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel) + self.transformer = Flux2Transformer2DModel(**transformer_kwargs) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + self._guidance_scale = None + self._attention_kwargs = None + self._num_timesteps = None + self._current_timestep = None + self._interrupt = False + + @staticmethod + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + hidden_states_layers: list[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: torch.Tensor | None = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_positions = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_positions) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + layer_ids = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, layer_ids) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids + def _prepare_image_ids( + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + "`height` and `width` have to be divisible by %s but are %s and %s. " + "Dimensions will be resized accordingly", + self.vae_scale_factor * 2, + height, + width, + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in ["latents", "prompt_embeds"] for k in callback_on_step_end_tensor_inputs + ): + raise ValueError("`callback_on_step_end_tensor_inputs` must be a subset of ['latents', 'prompt_embeds'].") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if guidance_scale > 1.0 and self.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale is not None and self._guidance_scale > 1 and not self.is_distilled + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def forward( + self, + req: OmniDiffusionRequest, + image: PIL.Image.Image | list[PIL.Image.Image] | None = None, + prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), + ) -> DiffusionOutput: + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list of these): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, + `guidance_scale` is ignored. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline. + If not provided, will be generated from "". + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + prompt = req.prompt if req.prompt is not None else prompt + image = req.pil_image if req.pil_image is not None else image + height = req.height or height + width = req.width or width + num_inference_steps = req.num_inference_steps or num_inference_steps + guidance_scale = req.guidance_scale if req.guidance_scale is not None else guidance_scale + generator = req.generator or generator + req_num_outputs = getattr(req, "num_outputs_per_prompt", None) + if req_num_outputs and req_num_outputs > 0: + num_images_per_prompt = req_num_outputs + + if isinstance(req.prompt_embeds, torch.Tensor): + prompt_embeds = req.prompt_embeds + if isinstance(req.negative_prompt_embeds, torch.Tensor): + negative_prompt_embeds = req.negative_prompt_embeds + + if req.max_sequence_length is not None: + max_sequence_length = req.max_sequence_length + if getattr(req, "text_encoder_out_layers", None) is not None: + text_encoder_out_layers = req.text_encoder_out_layers + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + guidance_scale=guidance_scale, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + if self.do_classifier_free_guidance: + negative_prompt = "" + if prompt is not None and isinstance(prompt, list): + negative_prompt = [negative_prompt] * len(prompt) + negative_prompt_embeds, negative_text_ids = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + if self.do_classifier_free_guidance: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1) :] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype and torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + self._current_timestep = None + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + if output_type == "latent": + image = latents + else: + if latents.dtype != self.vae.dtype: + latents = latents.to(self.vae.dtype) + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py index 09f7b17e133..615b9194af4 100644 --- a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py +++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn from diffusers.models.attention import FeedForward -from diffusers.models.transformers.transformer_glm_image import GlmImageCombinedTimestepSizeEmbeddings from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_glm_image import GlmImageCombinedTimestepSizeEmbeddings from vllm.logger import init_logger from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -414,9 +414,10 @@ def forward( query_img = query[:, text_seq_length:, :, :] key_img = key[:, text_seq_length:, :, :] from diffusers.models.embeddings import apply_rotary_emb - query_img = apply_rotary_emb(query_img,image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + + query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) # key_img = self.rope(key_img, cos, sin) - key_img = apply_rotary_emb(key_img,image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1) key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1) @@ -555,7 +556,7 @@ def __init__( od_config: OmniDiffusionConfig, ): super().__init__() - + patch_size = od_config.tf_model_config.patch_size in_channels = od_config.tf_model_config.in_channels out_channels = od_config.tf_model_config.out_channels @@ -565,8 +566,6 @@ def __init__( condition_dim = od_config.tf_model_config.condition_dim prior_vq_quantizer_codebook_size = od_config.tf_model_config.prior_vq_quantizer_codebook_size text_embed_dim = od_config.tf_model_config.text_embed_dim - - # Get num_layers from config if available model_config = od_config.tf_model_config diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 6401b526a5f..87c7c2f73c2 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -36,6 +36,7 @@ QwenImageTransformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -274,7 +275,8 @@ def __init__( self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self.device ) - self.transformer = QwenImageTransformer2DModel(od_config=od_config) + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) + self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 53609c87757..901625e34b6 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -38,6 +38,7 @@ QwenImageTransformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -231,7 +232,8 @@ def __init__( self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self.device ) - self.transformer = QwenImageTransformer2DModel(od_config=od_config) + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) + self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) self.processor = Qwen2VLProcessor.from_pretrained( model, subfolder="processor", local_files_only=local_files_only diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 80e48112c71..6df502d48f0 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -41,6 +41,7 @@ QwenImageTransformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -191,7 +192,9 @@ def __init__( self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self.device ) - self.transformer = QwenImageTransformer2DModel(od_config=od_config) + + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) + self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) self.processor = Qwen2VLProcessor.from_pretrained( model, subfolder="processor", local_files_only=local_files_only diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index eb5e2aa987a..4642b3eb418 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -37,6 +37,7 @@ QwenImageTransformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -211,18 +212,8 @@ def __init__( ) ] - use_additional_t_cond = od_config.tf_model_config.use_additional_t_cond - zero_cond_t = od_config.tf_model_config.zero_cond_t - use_layer3d_rope = od_config.tf_model_config.use_layer3d_rope - guidance_embeds = od_config.tf_model_config.guidance_embeds - - self.transformer = QwenImageTransformer2DModel( - od_config=od_config, - use_additional_t_cond=use_additional_t_cond, - zero_cond_t=zero_cond_t, - use_layer3d_rope=use_layer3d_rope, - guidance_embeds=guidance_embeds, - ) + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) + self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) # Pipeline configuration & processing parameters self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 2aa9c29104d..cbf0b7e10ac 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from diffusers.models.attention import FeedForward +import torch.nn.functional as F # TODO replace this with vLLM implementation from diffusers.models.embeddings import TimestepEmbedding, Timesteps @@ -16,14 +16,18 @@ from diffusers.models.normalization import AdaLayerNormContinuous from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm_omni.diffusion.attention.backends.abstract import ( AttentionMetadata, ) from vllm_omni.diffusion.attention.layer import Attention -from vllm_omni.diffusion.attention.selector import get_attn_backend +from vllm_omni.diffusion.attention.selector import _BACKENDS_SUPPORT_ATTENTION_MASK, get_attn_backend from vllm_omni.diffusion.cache.base import CachedTransformer from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.distributed.parallel_state import ( @@ -287,79 +291,133 @@ def _compute_video_freqs(self, frame, height, width, idx=0): return freqs.clone().contiguous() +class ColumnParallelApproxGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True): + super().__init__() + self.proj = ColumnParallelLinear( + dim_in, + dim_out, + bias=bias, + gather_output=False, + return_bias=False, + ) + self.approximate = approximate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return F.gelu(x, approximate=self.approximate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + activation_fn: str = "gelu-approximate", + inner_dim: int | None = None, + bias: bool = True, + ) -> None: + super().__init__() + + assert activation_fn == "gelu-approximate", "Only gelu-approximate is supported." + + inner_dim = inner_dim or int(dim * mult) + dim_out = dim_out or dim + + layers: list[nn.Module] = [ + ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias), + nn.Identity(), # placeholder for weight loading + RowParallelLinear( + inner_dim, + dim_out, + input_is_parallel=True, + return_bias=False, + ), + ] + + self.net = nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + class QwenImageCrossAttention(nn.Module): def __init__( self, dim: int, # query_dim num_heads: int, head_dim: int, - window_size=(-1, -1), - added_kv_proj_dim: int = None, + added_kv_proj_dim: int, + window_size: tuple[int, int] = (-1, -1), out_bias: bool = True, - qk_norm=True, # rmsnorm - eps=1e-6, - pre_only=False, + qk_norm: bool = True, + eps: float = 1e-6, + pre_only: bool = False, context_pre_only: bool = False, - parallel_attention=False, - out_dim: int = None, + out_dim: int | None = None, ) -> None: - assert dim % num_heads == 0 super().__init__() + assert dim % num_heads == 0 + self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads + self.head_dim = head_dim + self.total_num_heads = num_heads self.window_size = window_size self.qk_norm = qk_norm self.eps = eps - self.parallel_attention = parallel_attention - # layers - # self.to_q = ReplicatedLinear(dim, dim) - # self.to_k = ReplicatedLinear(dim, dim) - # self.to_v = ReplicatedLinear(dim, dim) self.to_qkv = QKVParallelLinear( hidden_size=dim, head_size=self.head_dim, total_num_heads=num_heads, - disable_tp=True, ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() - self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads - self.inner_kv_dim = self.inner_dim - if added_kv_proj_dim is not None: - assert context_pre_only is not None - # self.add_k_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_kv_dim, bias=True) - # self.add_v_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_kv_dim, bias=True) - # self.add_q_proj = ReplicatedLinear( - # added_kv_proj_dim, self.inner_dim, bias=True - # ) - self.add_kv_proj = QKVParallelLinear( - added_kv_proj_dim, - head_size=self.inner_kv_dim // self.num_heads, - total_num_heads=self.num_heads, - disable_tp=True, - ) - if context_pre_only is not None and not context_pre_only: - self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) - else: - self.to_add_out = None + self.inner_dim = out_dim if out_dim is not None else head_dim * self.total_num_heads - if not pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)) - else: - self.to_out = None + assert context_pre_only is not None + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=head_dim, + total_num_heads=num_heads, + ) + self.add_query_num_heads = self.add_kv_proj.num_heads + self.add_kv_num_heads = self.add_kv_proj.num_kv_heads + + assert not context_pre_only + self.to_add_out = RowParallelLinear( + self.inner_dim, + self.dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) + + assert not pre_only + self.to_out = RowParallelLinear( + self.inner_dim, + self.dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) self.norm_added_q = RMSNorm(head_dim, eps=eps) self.norm_added_k = RMSNorm(head_dim, eps=eps) self.attn = Attention( - num_heads=num_heads, + num_heads=self.query_num_heads, head_size=self.head_dim, softmax_scale=1.0 / (self.head_dim**0.5), causal=False, + num_kv_heads=self.kv_num_heads, ) self.rope = RotaryEmbedding(is_neox_style=False) @@ -377,61 +435,55 @@ def forward( txt_freqs: torch.Tensor, hidden_states_mask: torch.Tensor | None = None, encoder_hidden_states_mask: torch.Tensor | None = None, - ): - # if mask is all true, set it to None + ) -> tuple[torch.Tensor, torch.Tensor]: if hidden_states_mask is not None and hidden_states_mask.all(): hidden_states_mask = None if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.all(): encoder_hidden_states_mask = None - seq_len_txt = encoder_hidden_states.shape[1] - # Compute QKV for image stream (sample projections) - qkv, _ = self.to_qkv(hidden_states) - img_query, img_key, img_value = qkv.chunk(3, dim=-1) + img_qkv, _ = self.to_qkv(hidden_states) + q_size = self.query_num_heads * self.head_dim + kv_size = self.kv_num_heads * self.head_dim + img_query, img_key, img_value = img_qkv.split([q_size, kv_size, kv_size], dim=-1) - # Compute QKV for text stream (context projections) - qkv, _ = self.add_kv_proj(encoder_hidden_states) - txt_query, txt_key, txt_value = qkv.chunk(3, dim=-1) + txt_qkv, _ = self.add_kv_proj(encoder_hidden_states) + add_q_size = self.add_query_num_heads * self.head_dim + add_kv_size = self.add_kv_num_heads * self.head_dim + txt_query, txt_key, txt_value = txt_qkv.split([add_q_size, add_kv_size, add_kv_size], dim=-1) - # Reshape for multi-head attention - img_query = img_query.unflatten(-1, (self.num_heads, -1)) - img_key = img_key.unflatten(-1, (self.num_heads, -1)) - img_value = img_value.unflatten(-1, (self.num_heads, -1)) + img_query = img_query.unflatten(-1, (self.query_num_heads, self.head_dim)) + img_key = img_key.unflatten(-1, (self.kv_num_heads, self.head_dim)) + img_value = img_value.unflatten(-1, (self.kv_num_heads, self.head_dim)) - txt_query = txt_query.unflatten(-1, (self.num_heads, -1)) - txt_key = txt_key.unflatten(-1, (self.num_heads, -1)) - txt_value = txt_value.unflatten(-1, (self.num_heads, -1)) + txt_query = txt_query.unflatten(-1, (self.add_query_num_heads, self.head_dim)) + txt_key = txt_key.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) + txt_value = txt_value.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) - # Apply QK normalization img_query = self.norm_q(img_query) img_key = self.norm_k(img_key) txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - # Apply RoPE img_cos = vid_freqs.real.to(img_query.dtype) img_sin = vid_freqs.imag.to(img_query.dtype) txt_cos = txt_freqs.real.to(txt_query.dtype) txt_sin = txt_freqs.imag.to(txt_query.dtype) + img_query = self.rope(img_query, img_cos, img_sin) img_key = self.rope(img_key, img_cos, img_sin) txt_query = self.rope(txt_query, txt_cos, txt_sin) txt_key = self.rope(txt_key, txt_cos, txt_sin) - # Concatenate for joint attention - # Order: [text, image] + seq_len_txt = encoder_hidden_states.shape[1] joint_query = torch.cat([txt_query, img_query], dim=1) joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) - # Compute joint attention if ( self.parallel_config is not None and self.parallel_config.sequence_parallel_size > 1 and not get_forward_context().split_text_embed_in_sp ): - # if using sequence parallel, but not splitting text embed, - # we need to pass text embedding to attention layer as joint qkv attn_metadata = AttentionMetadata( joint_query=txt_query, joint_key=txt_key, @@ -443,22 +495,17 @@ def forward( if encoder_hidden_states_mask is not None: attn_metadata.joint_attn_mask = encoder_hidden_states_mask - joint_hidden_states = self.attn( - img_query, - img_key, - img_value, - attn_metadata, - ) + joint_hidden_states = self.attn(img_query, img_key, img_value, attn_metadata) else: attn_metadata = None if hidden_states_mask is not None or encoder_hidden_states_mask is not None: - mask_list = [] + mask_list: list[torch.Tensor] = [] if encoder_hidden_states_mask is not None: mask_list.append(encoder_hidden_states_mask) else: mask_list.append( torch.ones( - [encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]], + encoder_hidden_states.shape[:2], dtype=torch.bool, device=encoder_hidden_states.device, ) @@ -468,34 +515,22 @@ def forward( else: mask_list.append( torch.ones( - [hidden_states.shape[0], hidden_states.shape[1]], + hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device, ) ) - joint_mask = ( - None if len(mask_list) == 0 else torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0] - ) + joint_mask = torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0] attn_metadata = AttentionMetadata(attn_mask=joint_mask) - joint_hidden_states = self.attn( - joint_query, - joint_key, - joint_value, - attn_metadata, - ) - joint_hidden_states = joint_hidden_states.flatten(2, 3) - joint_hidden_states = joint_hidden_states.to(joint_query.dtype) - # Split attention outputs back - txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] # Text part - img_attn_output = joint_hidden_states[:, seq_len_txt:, :] # Image part + joint_hidden_states = self.attn(joint_query, joint_key, joint_value, attn_metadata) - # Apply output projections - img_attn_output, _ = self.to_out[0](img_attn_output) - if len(self.to_out) > 1: - (img_attn_output,) = self.to_out[1](img_attn_output) # dropout + joint_hidden_states = joint_hidden_states.flatten(2, 3).to(joint_query.dtype) + txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] + img_attn_output = joint_hidden_states[:, seq_len_txt:, :] - txt_attn_output, _ = self.to_add_out(txt_attn_output) + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) return img_attn_output, txt_attn_output @@ -530,7 +565,7 @@ def __init__( head_dim=attention_head_dim, ) self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.img_mlp = FeedForward(dim=dim, dim_out=dim) # Text processing modules self.txt_mod = nn.Sequential( @@ -540,7 +575,7 @@ def __init__( self.txt_norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) # Text doesn't need separate attention - it's handled by img_attn joint computation self.txt_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.txt_mlp = FeedForward(dim=dim, dim_out=dim) self.zero_cond_t = zero_cond_t @@ -604,7 +639,7 @@ def forward( txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] # Process image stream - norm1 + modulation - img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1) + img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1, modulate_index) # Process text stream - norm1 + modulation txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1) @@ -632,7 +667,8 @@ def forward( encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output # Process image stream - norm2 + MLP - img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2) + img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2, modulate_index) + img_mlp_output = self.img_mlp(img_modulated2) hidden_states = hidden_states + img_gate2 * img_mlp_output @@ -692,15 +728,13 @@ def __init__( attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 3584, - guidance_embeds: bool = False, # TODO: this should probably be removed + guidance_embeds: bool = False, axes_dims_rope: tuple[int, int, int] = (16, 56, 56), zero_cond_t: bool = False, use_additional_t_cond: bool = False, use_layer3d_rope: bool = False, ): super().__init__() - model_config = od_config.tf_model_config - num_layers = model_config.num_layers self.parallel_config = od_config.parallel_config self.in_channels = in_channels self.out_channels = out_channels or in_channels @@ -791,8 +825,8 @@ def forward( sp_size = get_sequence_parallel_world_size() if seq_len % sp_size != 0: - # flash_attn, ring_attn, sage_attn do not support attention_mask - if get_attn_backend(-1).get_name() != "SDPA" and get_attn_backend(-1).get_name() != "ASCEND": + # ring_attn, sage_attn do not support attention_mask + if get_attn_backend(-1).get_name() not in _BACKENDS_SUPPORT_ATTENTION_MASK: raise ValueError( f"When generating image shape that the sequence length is NOT divisible by sp_size={sp_size}," f"cannot use {get_attn_backend(-1).get_name()} which does not support attention_mask." @@ -893,6 +927,8 @@ def get_rotary_emb_chunk(freqs, padding=0): if original_seq_len is not None: output = output[:, :original_seq_len, :] + torch.cuda.empty_cache() + return Transformer2DModelOutput(sample=output) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -917,17 +953,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: + original_name = name + lookup_name = name for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in original_name: continue - name = name.replace(weight_name, param_name) - param = params_dict[name] + lookup_name = original_name.replace(weight_name, param_name) + param = params_dict[lookup_name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - param = params_dict[name] + if lookup_name not in params_dict and ".to_out.0." in lookup_name: + lookup_name = lookup_name.replace(".to_out.0.", ".to_out.") + param = params_dict[lookup_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) + loaded_params.add(original_name) + loaded_params.add(lookup_name) return loaded_params diff --git a/vllm_omni/diffusion/models/schedulers/__init__.py b/vllm_omni/diffusion/models/schedulers/__init__.py new file mode 100644 index 00000000000..6f8df78ebf0 --- /dev/null +++ b/vllm_omni/diffusion/models/schedulers/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler, +) + +__all__ = [ + "FlowUniPCMultistepScheduler", +] diff --git a/vllm_omni/diffusion/models/schedulers/base.py b/vllm_omni/diffusion/models/schedulers/base.py new file mode 100644 index 00000000000..bc9d87d7f55 --- /dev/null +++ b/vllm_omni/diffusion/models/schedulers/base.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/hao-ai-lab/FastVideo +# Originally from https://github.com/huggingface/diffusers +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +"""Base scheduler class for diffusion models.""" + +from abc import ABC, abstractmethod + +import torch + + +class BaseScheduler(ABC): + """ + Abstract base class for schedulers. + + Subclasses must define: + - timesteps: torch.Tensor + - order: int + - num_train_timesteps: int + """ + + timesteps: torch.Tensor + order: int + num_train_timesteps: int + + def __init__(self): + required_attrs = ["timesteps", "order", "num_train_timesteps"] + for attr in required_attrs: + if not hasattr(self, attr): + raise AttributeError( + f"Subclass {self.__class__.__name__} must define `{attr}` before calling super().__init__()" + ) + + @abstractmethod + def set_shift(self, shift: float) -> None: + """Set the shift parameter for the scheduler.""" + raise NotImplementedError + + @abstractmethod + def set_timesteps(self, *args, **kwargs) -> None: + """Set the timesteps for the scheduler.""" + raise NotImplementedError + + @abstractmethod + def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None) -> torch.Tensor: + """Scale the model input.""" + raise NotImplementedError diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py new file mode 100644 index 00000000000..3efe564bc61 --- /dev/null +++ b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py @@ -0,0 +1,741 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/hao-ai-lab/FastVideo +# Originally from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +""" +FlowUniPCMultistepScheduler - A training-free framework for fast sampling of flow-matching diffusion models. + +This scheduler implements the UniPC (Unified Predictor-Corrector) algorithm adapted for flow matching, +providing faster convergence than simple Euler methods while maintaining quality. +""" + +from __future__ import annotations + +import math +from typing import Any + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from diffusers.utils import deprecate + +from vllm_omni.diffusion.models.schedulers.base import BaseScheduler + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler): + """ + `FlowUniPCMultistepScheduler` is a training-free framework designed for the fast sampling of + flow-matching diffusion models. + + This scheduler implements the UniPC (Unified Predictor-Corrector) algorithm adapted for flow matching, + which can achieve the same quality as Euler methods in fewer steps (typically 20-30 steps vs 40-50). + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler. + shift (`float`, defaults to 1.0): + The shift parameter for the noise schedule. For Wan2.2: use 5.0 for 720p, 12.0 for 480p. + use_dynamic_shifting (`bool`, defaults to False): + Whether to use dynamic shifting based on image resolution. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. Use `bh1` for unconditional sampling when steps < 10, `bh2` otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Stabilizes sampling for steps < 15. + disable_corrector (`list`, default `[]`): + Steps to disable the corrector to mitigate misalignment with large guidance scales. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule. Either `"zero"` or `"sigma_min"`. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: float | None = 1.0, + use_dynamic_shifting: bool = False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: tuple = (), + solver_p: SchedulerMixin | None = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: str | None = "zero", + **kwargs, + ): + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + self.num_inference_steps: int | None = None + + # Initialize sigma schedule + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # Apply timestep shifting based on shift parameter + assert shift is not None + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + self.num_train_timesteps = num_train_timesteps + + # State for multistep solver + self.model_outputs: list[torch.Tensor | None] = [None] * solver_order + self.timestep_list: list[Any | None] = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = list(disable_corrector) + self.solver_p = solver_p + self.last_sample: torch.Tensor | None = None + self._step_index: int | None = None + self._begin_index: int | None = None + self.this_order: int = 1 + + # Move sigmas to CPU to reduce GPU/CPU communication + self.sigmas = self.sigmas.to("cpu") + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + BaseScheduler.__init__(self) + + @property + def step_index(self) -> int | None: + """The index counter for current timestep. Increases by 1 after each scheduler step.""" + return self._step_index + + @property + def begin_index(self) -> int | None: + """The index for the first timestep. Should be set from pipeline with `set_begin_index` method.""" + return self._begin_index + + def set_shift(self, shift: float) -> None: + """Set the shift parameter for the scheduler.""" + self.config.shift = shift + + def set_begin_index(self, begin_index: int = 0) -> None: + """ + Sets the begin index for the scheduler. Run from pipeline before inference. + + Args: + begin_index (`int`): The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + sigmas: list[float] | None = None, + mu: float | None = None, + shift: float | None = None, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain (run before inference). + + Args: + num_inference_steps (`int`): + Total number of timesteps. + device (`str` or `torch.device`, *optional*): + The device to move timesteps to. + sigmas (`list[float]`, *optional*): + Custom sigma schedule. + mu (`float`, *optional*): + Parameter for dynamic shifting. + shift (`float`, *optional*): + Override shift parameter. + """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("Must pass a value for `mu` when `use_dynamic_shifting` is True") + + if sigmas is None: + assert num_inference_steps is not None + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] + + if self.config.use_dynamic_shifting: + assert mu is not None + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + if shift is None: + shift = self.config.shift + assert isinstance(sigmas, np.ndarray) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = self.sigma_min + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError(f"`final_sigmas_type` must be 'zero' or 'sigma_min', got {self.config.final_sigmas_type}") + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + # Reset state + self.model_outputs = [None] * self.config.solver_order + self.timestep_list = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") + + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + Dynamic thresholding to prevent pixel saturation. + + From "Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding" + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() + + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + abs_sample = sample.abs() + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp(s, min=1, max=self.config.sample_max_value) + s = s.unsqueeze(1) + sample = torch.clamp(sample, -s, s) / s + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor: + """Convert sigma to timestep.""" + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert sigma to alpha and sigma_t for flow matching.""" + return 1 - sigma, sigma + + def time_shift(self, mu: float, sigma: float, t: np.ndarray) -> np.ndarray: + """Apply time shift transformation.""" + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the format needed by the UniPC algorithm. + + Args: + model_output (`torch.Tensor`): Direct output from the diffusion model. + sample (`torch.Tensor`): Current sample in the diffusion process. + + Returns: + `torch.Tensor`: Converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion " + "is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index].to(sample.device) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = sigma.to(sample.device) + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be `flow_prediction` " + "for the FlowUniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = sigma.to(sample.device) + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be `flow_prediction` " + "for the FlowUniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = sigma.to(sample.device) + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor | None = None, + order: int | None = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version) predictor. + + Args: + model_output (`torch.Tensor`): Direct output from the diffusion model. + sample (`torch.Tensor`): Current sample. + order (`int`): The order of UniP at this timestep. + + Returns: + `torch.Tensor`: The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError("missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect.", + ) + + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + device = sample.device + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1].to(device), + self.sigmas[self.step_index].to(device), + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + + rks = [] + D1s: list[Any] | None = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device)) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + assert mi is not None + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if D1s is not None and len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + assert isinstance(R, torch.Tensor) + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor | None = None, + this_sample: torch.Tensor | None = None, + order: int | None = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version) corrector. + + Args: + this_model_output (`torch.Tensor`): Model outputs at `x_t`. + last_sample (`torch.Tensor`): Sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): Sample after the last predictor `x_{t}`. + order (`int`): The order of UniC-p. Effective accuracy is `order + 1`. + + Returns: + `torch.Tensor`: The corrected sample tensor. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError("missing `last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError("missing `this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError("missing `order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect.", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + device = this_sample.device + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index].to(device), + self.sigmas[self.step_index - 1].to(device), + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + + rks = [] + D1s: list[Any] | None = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device)) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + assert mi is not None + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if D1s is not None and len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep: torch.Tensor, schedule_timesteps: torch.Tensor | None = None) -> int: + """Get the index for a given timestep.""" + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + pos = 1 if len(indices) > 1 else 0 + step_index: int = indices[pos].item() + + return step_index + + def _init_step_index(self, timestep: torch.Tensor) -> None: + """Initialize the step_index counter for the scheduler.""" + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> SchedulerOutput | tuple: + """ + Predict the sample from the previous timestep by reversing the SDE using multistep UniPC. + + Args: + model_output (`torch.Tensor`): Direct output from the diffusion model. + timestep (`int`): Current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): Current sample created by the diffusion process. + return_dict (`bool`): Whether to return a SchedulerOutput or tuple. + + Returns: + `SchedulerOutput` or `tuple`: The sample tensor at the previous timestep. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + # Update model output history + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + # Determine order for this step + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + assert self._step_index is not None + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input. + + Args: + sample (`torch.Tensor`): The input sample. + + Returns: + `torch.Tensor`: A scaled input sample (unchanged for this scheduler). + """ + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples (`torch.Tensor`): Original samples. + noise (`torch.Tensor`): Noise to add. + timesteps (`torch.IntTensor`): Timesteps for noise addition. + + Returns: + `torch.Tensor`: Noisy samples. + """ + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + step_indices = [self.step_index] * timesteps.shape[0] + else: + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self) -> int: + return self.config.num_train_timesteps diff --git a/vllm_omni/diffusion/models/sd3/sd3_transformer.py b/vllm_omni/diffusion/models/sd3/sd3_transformer.py index a2afa8c0946..22a11741a53 100644 --- a/vllm_omni/diffusion/models/sd3/sd3_transformer.py +++ b/vllm_omni/diffusion/models/sd3/sd3_transformer.py @@ -102,8 +102,8 @@ def __init__( else: self.to_out = None - self.norm_added_q = RMSNorm(head_dim, eps=eps) - self.norm_added_k = RMSNorm(head_dim, eps=eps) + self.norm_added_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_added_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() self.attn = Attention( num_heads=num_heads, @@ -341,8 +341,10 @@ def __init__( self.pooled_projection_dim = model_config.pooled_projection_dim self.joint_attention_dim = model_config.joint_attention_dim self.patch_size = model_config.patch_size - self.dual_attention_layers = model_config.dual_attention_layers - self.qk_norm = model_config.qk_norm + self.dual_attention_layers = ( + model_config.dual_attention_layers if hasattr(model_config, "dual_attention_layers") else () + ) + self.qk_norm = model_config.qk_norm if hasattr(model_config, "qk_norm") else "" self.pos_embed_max_size = model_config.pos_embed_max_size self.pos_embed = PatchEmbed( diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 61a858ca4d2..3b04e2a5a9b 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -9,7 +9,7 @@ import PIL.Image import torch -from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler +from diffusers import AutoencoderKLWan from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, UMT5EncoderModel @@ -18,6 +18,7 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -244,12 +245,13 @@ def __init__( else: self.transformer_2 = None - self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - model, subfolder="scheduler", local_files_only=local_files_only + # Initialize UniPC scheduler + flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=flow_shift, + prediction_type="flow_prediction", ) - # Apply flow_shift if specified (12.0 for 480p, 5.0 for 720p recommended for Wan2.2) - if od_config.flow_shift is not None: - self.scheduler.config.flow_shift = od_config.flow_shift self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index d72afc9ee84..ed0d6e4c6b6 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -9,7 +9,7 @@ import numpy as np import PIL.Image import torch -from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler +from diffusers import AutoencoderKLWan from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel @@ -18,6 +18,8 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( create_transformer_from_config, load_transformer_config, @@ -112,7 +114,7 @@ def pre_process_func(requests: list[OmniDiffusionRequest]) -> list[OmniDiffusion return pre_process_func -class Wan22I2VPipeline(nn.Module): +class Wan22I2VPipeline(nn.Module, SupportImageInput): """ Wan2.2 Image-to-Video Pipeline. @@ -198,12 +200,13 @@ def __init__( else: self.transformer_2 = None - # Scheduler - self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - model, subfolder="scheduler", local_files_only=local_files_only + # Initialize UniPC scheduler + flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=flow_shift, + prediction_type="flow_prediction", ) - if od_config.flow_shift is not None: - self.scheduler.config.flow_shift = od_config.flow_shift # VAE scale factors self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4 diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index bee70a7a96b..6a9a6a6a0e9 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -22,7 +22,7 @@ import numpy as np import PIL.Image import torch -from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler +from diffusers import AutoencoderKLWan from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, UMT5EncoderModel @@ -31,6 +31,8 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( create_transformer_from_config, load_transformer_config, @@ -102,7 +104,7 @@ def pre_process_func(requests: list[OmniDiffusionRequest]) -> list[OmniDiffusion return pre_process_func -class Wan22TI2VPipeline(nn.Module): +class Wan22TI2VPipeline(nn.Module, SupportImageInput): """ Wan2.2 Text-Image-to-Video (TI2V) Pipeline. @@ -156,12 +158,13 @@ def __init__( transformer_config = load_transformer_config(model, "transformer", local_files_only) self.transformer = create_transformer_from_config(transformer_config) - # Scheduler - self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - model, subfolder="scheduler", local_files_only=local_files_only + # Initialize UniPC scheduler + flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=flow_shift, + prediction_type="flow_prediction", ) - if od_config.flow_shift is not None: - self.scheduler.config.flow_shift = od_config.flow_shift # VAE scale factors self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4 diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 94cf889210f..99b0ea4c765 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -79,6 +79,11 @@ "pipeline_sd3", "StableDiffusion3Pipeline", ), + "Flux2KleinPipeline": ( + "flux2_klein", + "pipeline_flux2_klein", + "Flux2KleinPipeline", + ), } @@ -127,6 +132,7 @@ def initialize_model( "BagelPipeline": "get_bagel_post_process_func", "LongCatImageEditPipeline": "get_longcat_image_post_process_func", "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", + "Flux2KleinPipeline": "get_flux2_klein_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = { diff --git a/vllm_omni/diffusion/utils/tf_utils.py b/vllm_omni/diffusion/utils/tf_utils.py new file mode 100644 index 00000000000..44a78804452 --- /dev/null +++ b/vllm_omni/diffusion/utils/tf_utils.py @@ -0,0 +1,54 @@ +import inspect +from typing import Any + +from vllm_omni.diffusion.data import TransformerConfig + + +def get_transformer_config_kwargs( + tf_model_config: TransformerConfig, model_class: type[Any] | None = None +) -> dict[str, Any]: + """ + This function extracts parameters from a TransformerConfig instance and filters out internal + diffusers metadata keys (those starting with '_') that should not be passed to model initialization. + Also filters out parameters that are not accepted by the model's __init__ method (e.g., pooled_projection_dim + for QwenImageTransformer2DModel). + + This uses inspect.signature to dynamically detect accepted parameters, making it general for any model class. + Similar to how diffusers' @register_to_config decorator works. + + Args: + tf_model_config: TransformerConfig instance containing model parameters + model_class: Optional model class to inspect for accepted __init__ parameters. + If None, all non-internal parameters are returned (backward compatibility). + + Returns: + dict: Filtered dictionary of parameters suitable for transformer model initialization + """ + # Extract transformer config parameters, filtering out internal diffusers metadata + # TransformerConfig stores params in a 'params' dict, and we need to exclude + # internal keys like '_class_name' and '_diffusers_version' + tf_config_params = tf_model_config.to_dict() + + # Filter out internal diffusers metadata keys that start with '_' + filtered_params = {k: v for k, v in tf_config_params.items() if not k.startswith("_")} + + # If model_class is provided, use inspect.signature to get accepted parameters + if model_class is not None: + try: + # Get the signature of the model's __init__ method + sig = inspect.signature(model_class.__init__) + # Get all parameter names (excluding 'self' and special parameters) + accepted_params = { + name + for name, param in sig.parameters.items() + if name != "self" and param.kind != inspect.Parameter.VAR_KEYWORD # Exclude **kwargs + } + + # Filter to only include parameters that are in the model's signature + filtered_params = {k: v for k, v in filtered_params.items() if k in accepted_params} + except (TypeError, AttributeError): + # If inspection fails, fall back to returning all non-internal params + # This maintains backward compatibility + pass + + return filtered_params diff --git a/vllm_omni/diffusion/worker/__init__.py b/vllm_omni/diffusion/worker/__init__.py index dc3306dae3f..dfec4596bc2 100644 --- a/vllm_omni/diffusion/worker/__init__.py +++ b/vllm_omni/diffusion/worker/__init__.py @@ -2,6 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Worker classes for diffusion models.""" -from vllm_omni.diffusion.worker.gpu_worker import GPUWorker, WorkerProc +from vllm_omni.diffusion.worker.gpu_diffusion_model_runner import GPUDiffusionModelRunner +from vllm_omni.diffusion.worker.gpu_diffusion_worker import ( + GPUDiffusionWorker, + WorkerProc, +) -__all__ = ["GPUWorker", "WorkerProc"] +__all__ = [ + "GPUDiffusionModelRunner", + "GPUDiffusionWorker", + "WorkerProc", +] diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py new file mode 100644 index 00000000000..8ffa12ed2cc --- /dev/null +++ b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Diffusion Model Runner for vLLM-Omni. + +Handles model loading, compilation, caching, and execution of diffusion model +forward passes. This follows the AR pattern where the Runner handles all +model-related operations. +""" + +from __future__ import annotations + +import time +from collections.abc import Iterable +from contextlib import nullcontext + +import torch +from vllm.config import LoadConfig +from vllm.logger import init_logger +from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes + +from vllm_omni.diffusion.cache.selector import get_cache_backend +from vllm_omni.diffusion.compile import regionally_compile +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.forward_context import set_forward_context +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.offload import apply_offload_hooks +from vllm_omni.diffusion.request import OmniDiffusionRequest + +logger = init_logger(__name__) + + +class GPUDiffusionModelRunner: + """ + Model runner that handles model loading and execution for diffusion models. + + This class follows the AR pattern where the Runner handles all model-related + operations including loading, compilation, offloading, caching, and execution. + The Worker only handles infrastructure (device, distributed env). + """ + + def __init__( + self, + vllm_config, + od_config: OmniDiffusionConfig, + device: torch.device, + ): + """ + Initialize the diffusion model runner. + + Args: + vllm_config: vLLM configuration. + od_config: OmniDiffusion configuration. + device: The device to run on. + """ + self.vllm_config = vllm_config + self.od_config = od_config + self.device = device + self.pipeline = None + self.cache_backend = None + + def load_model( + self, + memory_pool_context_fn: callable | None = None, + ) -> None: + """ + Load the diffusion model, apply compilation and offloading. + + Args: + memory_pool_context_fn: Optional function that returns a context manager + for memory pool allocation (used for sleep mode). + """ + load_device = "cpu" if self.od_config.enable_cpu_offload else str(self.device) + + def get_memory_context(): + if memory_pool_context_fn is not None: + return memory_pool_context_fn(tag="weights") + return nullcontext() + + # Load model within forward context + with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + load_config = LoadConfig() + model_loader = DiffusersPipelineLoader(load_config) + time_before_load = time.perf_counter() + + with get_memory_context(): + with DeviceMemoryProfiler() as m: + self.pipeline = model_loader.load_model( + od_config=self.od_config, + load_device=load_device, + ) + time_after_load = time.perf_counter() + + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + m.consumed_memory / GiB_bytes, + time_after_load - time_before_load, + ) + logger.info("Model runner: Model loaded successfully.") + + # Apply CPU offloading (DiT <-> encoders mutual exclusion) + if self.od_config.enable_cpu_offload: + for name in ["vae"]: + module = getattr(self.pipeline, name, None) + if module is None: + continue + try: + module.to(self.device, non_blocking=True) + except Exception as exc: + logger.debug("Failed to move %s to GPU: %s", name, exc) + + apply_offload_hooks(self.pipeline, self.od_config, device=self.device) + + # Apply torch.compile if not in eager mode + if not self.od_config.enforce_eager: + try: + self.pipeline.transformer = regionally_compile( + self.pipeline.transformer, + dynamic=True, + ) + logger.info("Model runner: Model compiled with torch.compile.") + except Exception as e: + logger.warning(f"Model runner: torch.compile failed with error: {e}. Using eager mode.") + + # Setup cache backend + self.cache_backend = get_cache_backend(self.od_config.cache_backend, self.od_config.cache_config) + + if self.cache_backend is not None: + self.cache_backend.enable(self.pipeline) + + logger.info("Model runner: Initialization complete.") + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights into the pipeline.""" + return self.pipeline.load_weights(weights) + + @torch.inference_mode() + def execute_model(self, reqs: list[OmniDiffusionRequest]) -> DiffusionOutput: + """ + Execute a forward pass for the given requests. + + Args: + reqs: List of diffusion requests to process. + + Returns: + DiffusionOutput with generated results. + """ + assert self.pipeline is not None, "Model not loaded. Call load_model() first." + if not reqs or len(reqs) == 0: + raise ValueError("Cannot execute model with empty request list") + + # TODO: dealing with first req for now + req = reqs[0] + + if req.generator is None and req.seed is not None: + req.generator = torch.Generator(device=self.device).manual_seed(req.seed) + + # Refresh cache context if needed + if self.cache_backend is not None and self.cache_backend.is_enabled(): + self.cache_backend.refresh(self.pipeline, req.num_inference_steps) + + with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + output = self.pipeline.forward(req) + + return output diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py similarity index 65% rename from vllm_omni/diffusion/worker/gpu_worker.py rename to vllm_omni/diffusion/worker/gpu_diffusion_worker.py index 99a718e389f..7fa1e3da4d5 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py @@ -1,20 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Diffusion Worker for vLLM-Omni. + +Handles GPU infrastructure initialization and delegates model operations +to GPUDiffusionModelRunner. +""" + import multiprocessing as mp import os -import time -from collections.abc import Iterable from contextlib import AbstractContextManager, nullcontext import torch import zmq -from vllm.config import LoadConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.logger import init_logger -from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes +from vllm.utils.mem_utils import GiB_bytes -from vllm_omni.diffusion.cache.selector import get_cache_backend -from vllm_omni.diffusion.compile import regionally_compile from vllm_omni.diffusion.data import ( DiffusionOutput, OmniDiffusionConfig, @@ -25,16 +28,23 @@ initialize_model_parallel, ) from vllm_omni.diffusion.forward_context import set_forward_context -from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.offload import apply_offload_hooks from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.worker.gpu_diffusion_model_runner import GPUDiffusionModelRunner logger = init_logger(__name__) -class GPUWorker: +class GPUDiffusionWorker: """ - A worker that executes the model on a single GPU. + A worker that manages GPU infrastructure and delegates to the model runner. + + This class handles infrastructure initialization only: + - Device setup (CUDA device selection) + - Distributed environment (NCCL, model parallel) + - Memory management (sleep/wake) + + All model-related operations (loading, compilation, execution) are + delegated to GPUDiffusionModelRunner. """ def __init__( @@ -46,15 +56,17 @@ def __init__( self.local_rank = local_rank self.rank = rank self.od_config = od_config - self.pipeline = None - self.device = None + self.device: torch.device | None = None + self.vllm_config: VllmConfig | None = None + self.model_runner: GPUDiffusionModelRunner | None = None self._sleep_saved_buffers: dict[str, torch.Tensor] = {} - self.init_device_and_model() + self.init_device() - def init_device_and_model(self) -> None: - """Initialize the device and load the model.""" + def init_device(self) -> None: + """Initialize the device and distributed environment.""" world_size = self.od_config.num_gpus rank = self.rank + # Set environment variables for distributed initialization os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(self.od_config.master_port) @@ -62,19 +74,21 @@ def init_device_and_model(self) -> None: os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) + # Setup device self.device = torch.device(f"cuda:{rank}") torch.cuda.set_device(self.device) - # hack + # Create vllm_config for parallel configuration vllm_config = VllmConfig() vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size self.vllm_config = vllm_config - load_device = "cpu" if self.od_config.enable_cpu_offload else str(self.device) + # Initialize distributed environment with set_forward_context(vllm_config=vllm_config, omni_diffusion_config=self.od_config): init_distributed_environment(world_size=world_size, rank=rank) logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") + parallel_config = self.od_config.parallel_config initialize_model_parallel( data_parallel_size=parallel_config.data_parallel_size, @@ -86,107 +100,45 @@ def init_device_and_model(self) -> None: pipeline_parallel_size=parallel_config.pipeline_parallel_size, ) - load_config = LoadConfig() - model_loader = DiffusersPipelineLoader(load_config) - time_before_load = time.perf_counter() - with self._maybe_get_memory_pool_context(tag="weights"): - with DeviceMemoryProfiler() as m: - self.pipeline = model_loader.load_model( - od_config=self.od_config, - load_device=load_device, - ) - time_after_load = time.perf_counter() - - logger.info( - "Model loading took %.4f GiB and %.6f seconds", - m.consumed_memory / GiB_bytes, - time_after_load - time_before_load, + # Create model runner and load model + self.model_runner = GPUDiffusionModelRunner( + vllm_config=self.vllm_config, + od_config=self.od_config, + device=self.device, ) - logger.info(f"Worker {self.rank}: Model loaded successfully.") - - # Apply CPU offloading (DiT <-> encoders mutual exclusion) - if self.od_config.enable_cpu_offload: - for name in ["vae"]: - module = getattr(self.pipeline, name, None) - if module is None: - continue - try: - module.to(self.device, non_blocking=True) - except Exception as exc: - logger.debug("Failed to move %s to GPU: %s", name, exc) - - apply_offload_hooks(self.pipeline, self.od_config, device=self.device) - - if not self.od_config.enforce_eager: - try: - self.pipeline.transformer = regionally_compile( - self.pipeline.transformer, - dynamic=True, - ) - logger.info(f"Worker {self.rank}: Model compiled with torch.compile.") - except Exception as e: - logger.warning(f"Worker {self.rank}: torch.compile failed with error: {e}. Using eager mode.") - - # Setup cache backend based on type (both backends use enable()/reset() interface) - self.cache_backend = get_cache_backend(self.od_config.cache_backend, self.od_config.cache_config) - - if self.cache_backend is not None: - self.cache_backend.enable(self.pipeline) + self.model_runner.load_model( + memory_pool_context_fn=self._maybe_get_memory_pool_context, + ) + logger.info(f"Worker {self.rank}: Initialization complete.") def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: - """ - Generate output for the given requests. - - Args: - requests: List of diffusion requests - - Returns: - DiffusionOutput with generated results - """ + """Generate output for the given requests.""" return self.execute_model(requests, self.od_config) - @torch.inference_mode() def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: - """ - Execute a forward pass. - """ - assert self.pipeline is not None - if not reqs or len(reqs) == 0: - raise ValueError("Cannot execute model with empty request list") - # TODO: dealing with first req for now - req = reqs[0] + """Execute a forward pass by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + return self.model_runner.execute_model(reqs) - if req.generator is None and req.seed is not None: - req.generator = torch.Generator(device=self.device).manual_seed(req.seed) - - # Refresh cache context if needed - if self.cache_backend is not None and self.cache_backend.is_enabled(): - self.cache_backend.refresh(self.pipeline, req.num_inference_steps) - with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): - output = self.pipeline.forward(req) - return output - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - return self.pipeline.load_weights(weights) + def load_weights(self, weights) -> set[str]: + """Load weights by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + return self.model_runner.load_weights(weights) def sleep(self, level: int = 1) -> bool: """ - Put the worker to sleep. The worker should not process any requests. - The caller should guarantee that no requests are being processed - during the sleep period, before `wake_up` is called. + Put the worker to sleep, offloading model weights. Args: - level: The sleep level. Level 1 sleep will offload the model - weights and discard the kv cache. - Currently only support level 1. + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. """ from vllm.device_allocator.cumem import CuMemAllocator free_bytes_before_sleep = torch.cuda.mem_get_info()[0] # Save the buffers before level 2 sleep - if level == 2: - model = self.pipeline + if level == 2 and self.model_runner is not None: + model = self.model_runner.pipeline self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()} allocator = CuMemAllocator.get_instance() @@ -220,8 +172,8 @@ def wake_up(self, tags: list[str] | None = None) -> bool: allocator.wake_up(tags) # Restore the buffers after level 2 sleep - if len(self._sleep_saved_buffers): - model = self.pipeline + if len(self._sleep_saved_buffers) and self.model_runner is not None: + model = self.model_runner.pipeline for name, buffer in model.named_buffers(): if name in self._sleep_saved_buffers: buffer.data.copy_(self._sleep_saved_buffers[name].data) @@ -229,6 +181,7 @@ def wake_up(self, tags: list[str] | None = None) -> bool: return True def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: + """Get memory pool context for sleep mode support.""" if self.od_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator @@ -240,6 +193,7 @@ def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: return nullcontext() def shutdown(self) -> None: + """Shutdown the worker and cleanup distributed environment.""" destroy_distributed_env() @@ -257,17 +211,14 @@ def __init__( # Inter-process Communication self.context = zmq.Context(io_threads=2) - # Initialize MessageQueue reader from handle (unified for generation & RPC) + # Initialize MessageQueue reader from handle self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id) self.result_mq = None self.result_mq_handle = None - # Setup result sender (only for rank 0 for now, or whoever needs to reply) - # Assuming only rank 0 replies to scheduler as per original logic + # Setup result sender (only for rank 0) if gpu_id == 0: - # Create MessageQueue for results (1 writer -> 1 reader) - # We assume the reader (SyncScheduler) will act as rank 0 self.result_mq = MessageQueue(n_reader=1, n_local_reader=1, local_reader_ranks=[0]) self.result_mq_handle = self.result_mq.export_handle() logger.info(f"Worker {gpu_id} created result MessageQueue") @@ -277,31 +228,25 @@ def __init__( self.gpu_id = gpu_id self._running = True - def _create_worker(self, gpu_id: int, od_config: OmniDiffusionConfig) -> GPUWorker: + def _create_worker(self, gpu_id: int, od_config: OmniDiffusionConfig) -> GPUDiffusionWorker: """Create a worker instance. Override in subclasses for different worker types.""" - return GPUWorker( + return GPUDiffusionWorker( local_rank=gpu_id, rank=gpu_id, od_config=od_config, ) def return_result(self, output: DiffusionOutput): - """ - replies to client, only on rank 0 - """ + """Reply to client, only on rank 0.""" if self.result_mq is not None: self.result_mq.enqueue(output) def recv_message(self): - """ - Receive unified messages (RPC requests, shutdown) from broadcast queue. - Uses indefinite=True to block until a message arrives. - """ + """Receive messages from broadcast queue.""" return self.mq.dequeue(indefinite=True) def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]: """Execute an RPC request and indicate whether to reply.""" - method = rpc_request["method"] args = rpc_request.get("args", ()) kwargs = rpc_request.get("kwargs", {}) @@ -325,14 +270,11 @@ def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]: logger.error(f"Error executing RPC: {e}", exc_info=True) return {"status": "error", "error": str(e)}, should_reply - # TODO: queueing, cancellation def worker_busy_loop(self) -> None: - """Main busy loop for Multiprocessing Workers""" - + """Main busy loop for Multiprocessing Workers.""" logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory") while self._running: - # Receive unified message (generation request, RPC request, or shutdown) msg = None try: msg = self.recv_message() @@ -349,7 +291,6 @@ def worker_busy_loop(self) -> None: # Route message based on type if isinstance(msg, dict) and msg.get("type") == "rpc": - # Handle RPC request try: result, should_reply = self.execute_rpc(msg) if should_reply: @@ -360,13 +301,12 @@ def worker_busy_loop(self) -> None: self.return_result({"status": "error", "error": str(e)}) elif isinstance(msg, dict) and msg.get("type") == "shutdown": - # Handle shutdown message logger.info("Worker %s: Received shutdown message", self.gpu_id) self._running = False continue else: - # Handle generation request (OmniDiffusionRequest list) + # Handle generation request try: output = self.worker.execute_model(msg, self.od_config) except Exception as e: @@ -379,17 +319,14 @@ def worker_busy_loop(self) -> None: try: self.return_result(output) except zmq.ZMQError as e: - # Reply failed; log and keep loop alive to accept future requests logger.error(f"ZMQ error sending reply: {e}") continue logger.info("event loop terminated.") try: self.worker.shutdown() - except Exception as exc: # pragma: no cover - best effort cleanup + except Exception as exc: logger.warning("Worker %s: Shutdown encountered an error: %s", self.gpu_id, exc) - # if self.result_sender is not None: - # self.result_sender.close() self.context.term() @staticmethod @@ -400,7 +337,6 @@ def worker_main( broadcast_handle, ) -> None: """Worker initialization and execution loops.""" - worker_proc = WorkerProc( od_config, gpu_id=rank, diff --git a/vllm_omni/diffusion/worker/npu/npu_worker.py b/vllm_omni/diffusion/worker/npu/npu_worker.py index bfeb0d914c9..446c29cae4d 100644 --- a/vllm_omni/diffusion/worker/npu/npu_worker.py +++ b/vllm_omni/diffusion/worker/npu/npu_worker.py @@ -3,6 +3,8 @@ import multiprocessing as mp import os import time +from collections.abc import Iterable +from contextlib import AbstractContextManager, nullcontext import torch from vllm.config import LoadConfig, VllmConfig @@ -11,25 +13,42 @@ from vllm_omni.diffusion.cache.selector import get_cache_backend from vllm_omni.diffusion.data import ( + DiffusionOutput, OmniDiffusionConfig, ) from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, init_distributed_environment, initialize_model_parallel, ) from vllm_omni.diffusion.forward_context import set_forward_context from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.worker.gpu_worker import GPUWorker, WorkerProc +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.worker.gpu_diffusion_worker import WorkerProc logger = init_logger(__name__) -class NPUWorker(GPUWorker): +class NPUWorker: """ A worker that executes the model on a single NPU. Inherits from GPUWorker and overrides device-specific initialization. """ + def __init__( + self, + local_rank: int, + rank: int, + od_config: OmniDiffusionConfig, + ): + self.local_rank = local_rank + self.rank = rank + self.od_config = od_config + self.pipeline = None + self.device = None + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + self.init_device_and_model() + def init_device_and_model(self) -> None: """Initialize the NPU device and load the model.""" world_size = self.od_config.num_gpus @@ -86,6 +105,115 @@ def init_device_and_model(self) -> None: if self.cache_backend is not None: self.cache_backend.enable(self.pipeline) + def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: + """ + Generate output for the given requests. + + Args: + requests: List of diffusion requests + + Returns: + DiffusionOutput with generated results + """ + return self.execute_model(requests, self.od_config) + + @torch.inference_mode() + def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: + """ + Execute a forward pass. + """ + assert self.pipeline is not None + if not reqs or len(reqs) == 0: + raise ValueError("Cannot execute model with empty request list") + # TODO: dealing with first req for now + req = reqs[0] + + if req.generator is None and req.seed is not None: + req.generator = torch.Generator(device=self.device).manual_seed(req.seed) + + # Refresh cache context if needed + if self.cache_backend is not None and self.cache_backend.is_enabled(): + self.cache_backend.refresh(self.pipeline, req.num_inference_steps) + with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + output = self.pipeline.forward(req) + return output + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + return self.pipeline.load_weights(weights) + + def sleep(self, level: int = 1) -> bool: + """ + Put the worker to sleep. The worker should not process any requests. + The caller should guarantee that no requests are being processed + during the sleep period, before `wake_up` is called. + + Args: + level: The sleep level. Level 1 sleep will offload the model + weights and discard the kv cache. + Currently only support level 1. + """ + from vllm.device_allocator.cumem import CuMemAllocator + + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + + # Save the buffers before level 2 sleep + if level == 2: + model = self.pipeline + self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()} + + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.", + freed_bytes / GiB_bytes, + used_bytes / GiB_bytes, + ) + return True + + def wake_up(self, tags: list[str] | None = None) -> bool: + """ + Wake up the worker from sleep mode. See the sleep function + method for more details. + + Args: + tags: An optional list of tags to reallocate the worker memory + for specific memory allocations. Values must be in + `("weights")`. If None, all memory is reallocated. + wake_up should be called with all tags (or None) before the + worker is used again. + """ + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + allocator.wake_up(tags) + + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.pipeline + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + return True + + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: + if self.od_config.enable_sleep_mode: + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + if tag == "weights": + assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process." + return allocator.use_memory_pool(tag=tag) + else: + return nullcontext() + + def shutdown(self) -> None: + destroy_distributed_env() + class NPUWorkerProc(WorkerProc): """Wrapper that runs one NPUWorker in a separate process.""" diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index 79dc25cf494..47d094e4163 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -9,11 +9,9 @@ import msgspec import torch from vllm.v1.engine import ( - EngineCoreEvent, EngineCoreRequest, - FinishReason, - LogprobsLists, - LogprobsTensors, + EngineCoreOutput, + EngineCoreOutputs, SchedulerStats, UtilityOutput, ) @@ -79,64 +77,10 @@ class OmniEngineCoreRequest(EngineCoreRequest): additional_information: AdditionalInformationPayload | None = None -class OmniEngineCoreOutput( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False, -): # type: ignore[call-arg] - request_id: str - new_token_ids: list[int] +class OmniEngineCoreOutput(EngineCoreOutput): + pooling_output: dict[str, torch.Tensor] | None = None - new_logprobs: LogprobsLists | None = None - new_prompt_logprobs_tensors: LogprobsTensors | None = None - pooling_output: dict[str, torch.Tensor] | None = None - finish_reason: FinishReason | None = None - stop_reason: int | str | None = None - events: list[EngineCoreEvent] | None = None - kv_transfer_params: dict[str, Any] | None = None - - trace_headers: Mapping[str, str] | None = None - # The number of tokens with prefix cache hits. - num_cached_tokens: int = 0 - - # The number of NaNs in logits. - # A value greater than 0 indicates that the output is corrupted. - num_nans_in_logits: int = 0 - - @property - def finished(self) -> bool: - return self.finish_reason is not None - - -class OmniEngineCoreOutputs( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False, -): # type: ignore[call-arg] - # NOTE(Nick): We could consider ways to make this more compact, - # e.g. columnwise layout - - engine_index: int = 0 - - # [num_reqs] - outputs: list[OmniEngineCoreOutput] = [] - scheduler_stats: SchedulerStats | None = None - timestamp: float = 0.0 - - utility_output: UtilityOutput | None = None - finished_requests: set[str] | None = None - - # In DP case, used to signal that the current wave of requests - # has finished and the engines are paused. - wave_complete: int | None = None - # In DP case, used to signal that a request was received for an - # "old" wave, so the next wave needs to be started in other engines. - start_wave: int | None = None - - def __post_init__(self): - if self.timestamp == 0.0: - self.timestamp = time.monotonic() +class OmniEngineCoreOutputs(EngineCoreOutputs): + outputs: list[OmniEngineCoreOutput] = [] \ No newline at end of file diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index 714b8dfcc53..6aef7fb162c 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -8,8 +8,6 @@ from vllm.sampling_params import RequestOutputKind from vllm.tokenizers import TokenizerLike from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason -from vllm.v1.engine.detokenizer import IncrementalDetokenizer -from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.output_processor import OutputProcessor as VLLMOutputProcessor from vllm.v1.engine.output_processor import OutputProcessorOutput, RequestOutputCollector, RequestState from vllm.v1.engine.parallel_sampling import ParentRequest diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 1cd7f506319..3c275147fa0 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -131,19 +131,21 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st ulysses_degree = kwargs.get("ulysses_degree") or 1 ring_degree = kwargs.get("ring_degree") or 1 sequence_parallel_size = kwargs.get("sequence_parallel_size") + tensor_parallel_size = kwargs.get("tensor_parallel_size") or 1 + cfg_parallel_size = kwargs.get("cfg_parallel_size") or 1 if sequence_parallel_size is None: sequence_parallel_size = ulysses_degree * ring_degree - num_devices = sequence_parallel_size + num_devices = sequence_parallel_size * tensor_parallel_size * cfg_parallel_size for i in range(1, num_devices): devices += f",{i}" parallel_config = DiffusionParallelConfig( pipeline_parallel_size=1, data_parallel_size=1, - tensor_parallel_size=1, + tensor_parallel_size=tensor_parallel_size, sequence_parallel_size=sequence_parallel_size, ulysses_degree=ulysses_degree, ring_degree=ring_degree, - cfg_parallel_size=1, + cfg_parallel_size=cfg_parallel_size, ) default_stage_cfg = [ { @@ -161,6 +163,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st "cache_backend": cache_backend, "cache_config": cache_config, "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), + "enforce_eager": kwargs.get("enforce_eager", False), }, "final_output": True, "final_output_type": "image", @@ -347,12 +350,6 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator result = await req_state.queue.get() assert stage_id == req_state.stage_id - req_id = result.get("request_id") - if "error" in result: - logger.error( - f"[{self._name}] Stage {stage_id} error on request {req_id}: {result['error']}", - ) - raise RuntimeError(result) # Request Finished due to error req_id = result.get("request_id") if "error" in result: logger.error( diff --git a/vllm_omni/entrypoints/chat_utils.py b/vllm_omni/entrypoints/chat_utils.py index 6517ae666fe..0fdef5edbb7 100644 --- a/vllm_omni/entrypoints/chat_utils.py +++ b/vllm_omni/entrypoints/chat_utils.py @@ -22,7 +22,6 @@ _postprocess_messages, _ToolParser, ) -from vllm.transformers_utils.tokenizer import AnyTokenizer class OmniAsyncMultiModalItemTracker(AsyncMultiModalItemTracker): @@ -129,7 +128,6 @@ def _cleanup_file_sync(file_path: str) -> None: def parse_chat_messages_futures( messages: list[ChatCompletionMessageParam], model_config: ModelConfig, - tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, mm_processor_kwargs: dict[str, Any] | None = None, ) -> tuple[ @@ -138,7 +136,7 @@ def parse_chat_messages_futures( MultiModalUUIDDict | None, ]: conversation: list[ConversationMessage] = [] - mm_tracker = OmniAsyncMultiModalItemTracker(model_config, tokenizer) + mm_tracker = OmniAsyncMultiModalItemTracker(model_config) for msg in messages: sub_messages = _parse_chat_message_content( diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 3b222c8179c..c3a37e3c82e 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -208,6 +208,9 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="Scheduler flow_shift for video models (e.g., 5.0 for 720p, 12.0 for 480p).", ) + omni_config_group.add_argument( + "--cfg-parallel-size", type=int, default=1, help="Number of GPUs for CFG parallel computation" + ) return serve_parser diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 74fe6a80376..8db4b11d94d 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -1,14 +1,15 @@ +from collections.abc import Callable from typing import Any import cloudpickle from pydantic import ValidationError from tqdm import tqdm -from vllm.outputs import RequestOutput, PoolingRequestOutput -from typing import Callable + # External library imports (vLLM) from vllm.config import CompilationConfig, StructuredOutputsConfig, is_init_field from vllm.entrypoints.llm import LLM from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.usage.usage_lib import UsageContext from vllm.utils.counter import Counter @@ -193,9 +194,7 @@ def __del__(self) -> None: # best-effort except Exception as e: logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True) - def _run_engine( - self, *, use_tqdm: bool | Callable[..., tqdm] = True - ) -> list[RequestOutput | PoolingRequestOutput]: + def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[RequestOutput | PoolingRequestOutput]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -223,14 +222,9 @@ def _run_engine( assert output.prompt_token_ids is not None total_in_toks += len(output.prompt_token_ids) * n in_spd = total_in_toks / pbar.format_dict["elapsed"] - total_out_toks += sum( - len(stp.token_ids) for stp in output.outputs - ) + total_out_toks += sum(len(stp.token_ids) for stp in output.outputs) out_spd = total_out_toks / pbar.format_dict["elapsed"] - pbar.postfix = ( - f"est. speed input: {in_spd:.2f} toks/s, " - f"output: {out_spd:.2f} toks/s" - ) + pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s" pbar.update(n) else: pbar.update(1) @@ -242,4 +236,4 @@ def _run_engine( # Sort the outputs by the int part of request ID which is in format of 'int-uuid'. # This is necessary because some requests may be finished earlier than # its previous requests. - return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0])) \ No newline at end of file + return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0])) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index af6f60f0420..a2790cd06e5 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -474,12 +474,14 @@ def _stage_worker( data_parallel_size = parallel_config.get("data_parallel_size", 1) prefill_context_parallel_size = 1 # not used for diffusion sequence_parallel_size = parallel_config.get("sequence_parallel_size", 1) + cfg_parallel_size = parallel_config.get("cfg_parallel_size", 1) else: tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) data_parallel_size = engine_args.get("data_parallel_size", 1) prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1) sequence_parallel_size = 1 # not use in omni model + cfg_parallel_size = 1 # not used in omni model # Calculate total number of devices needed for this stage # For a single stage worker: @@ -488,7 +490,8 @@ def _stage_worker( # - DP: replicates model, but each replica uses TP devices # - PCP: context parallelism, typically uses TP devices # - SP: sequence parallelism, typically uses TP devices - # The number of devices per stage is determined by TP * PP * DP * PCP * SP size + # - CFG: Classifier-Free Guidance parallelism for diffusion models + # The number of devices per stage is determined by TP * PP * DP * PCP * SP * CFG size # (PP/DP/PCP are higher-level parallelism that don't add devices per stage) num_devices_per_stage = ( tensor_parallel_size @@ -496,6 +499,7 @@ def _stage_worker( * data_parallel_size * prefill_context_parallel_size * sequence_parallel_size + * cfg_parallel_size ) # Get physical device IDs from CUDA_VISIBLE_DEVICES @@ -951,8 +955,6 @@ async def _stage_worker_async( except Exception as e: logger.warning("Device setup failed: %s", e) - max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) - engine_args["max_num_seqs"] = max_batch_size # Initialize OmniConnectors if configured to match sync worker behavior connectors: dict[Any, Any] = {} if connectors_config: diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 49958eb66f6..888d39085ac 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing import multiprocessing.forkserver as forkserver import os @@ -14,31 +16,48 @@ from fastapi import Depends, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse from starlette.datastructures import State -from vllm.config import VllmConfig +from starlette.routing import Route from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import load_chat_template, resolve_hf_chat_template, resolve_mistral_chat_template +from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.api_server import ( base, build_app, load_log_config, - maybe_register_tokenizer_info_endpoint, router, setup_server, - validate_json_request, ) -from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse -from vllm.entrypoints.openai.serving_models import BaseModelPath, LoRAModulePath, OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParserManager - -# yapf conflicts with isort for this block -# yapf: disable -# yapf: enable +from vllm.entrypoints.openai.orca_metrics import metrics_header +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ErrorResponse, +) +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses +from vllm.entrypoints.openai.serving_transcription import ( + OpenAIServingTranscription, + OpenAIServingTranslation, +) +from vllm.entrypoints.openai.utils import validate_json_request +from vllm.entrypoints.pooling.classify.serving import ServingClassification +from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding +from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling +from vllm.entrypoints.pooling.score.serving import ServingScores +from vllm.entrypoints.serve.disagg.serving import ServingTokens +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer -from vllm.entrypoints.utils import load_aware_call, with_cancellation +from vllm.entrypoints.utils import ( + load_aware_call, + process_chat_template, + process_lora_modules, + with_cancellation, +) from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer +from vllm.tasks import POOLING_TASKS +from vllm.tool_parsers import ToolParserManager from vllm.utils.system_utils import decorate_logs from vllm_omni.entrypoints.async_omni import AsyncOmni @@ -57,6 +76,29 @@ logger = init_logger(__name__) +ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format" + + +def _remove_route_from_router(router_obj, path: str, methods: set[str] | None = None): + """Remove a route from the router by path and optionally by methods. + + This is needed because vllm's api_server registers routes when imported, + and we need to override some routes (like /v1/chat/completions) with + omni-specific implementations. + """ + routes_to_remove = [] + for route in router_obj.routes: + if isinstance(route, Route) and route.path == path: + if methods is None or (hasattr(route, "methods") and route.methods & methods): + routes_to_remove.append(route) + + for route in routes_to_remove: + router_obj.routes.remove(route) + + +# Remove vllm's /v1/chat/completions route so we can register our own omni version +_remove_route_from_router(router, "/v1/chat/completions", {"POST"}) + # Server entry points @@ -88,6 +130,10 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) + if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: + from vllm.reasoning import ReasoningParserManager + + ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) # Load logging config for uvicorn if specified log_config = load_log_config(args.log_config_file) @@ -98,11 +144,11 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, args, client_config=client_config, ) as engine_client: - maybe_register_tokenizer_info_endpoint(args) app = build_app(args) + await omni_init_app_state(engine_client, app.state, args) + vllm_config = await engine_client.get_vllm_config() - await omni_init_app_state(engine_client, vllm_config, app.state, args) # Check if pure diffusion mode (vllm_config will be None) is_pure_diffusion = vllm_config is None @@ -233,7 +279,6 @@ async def build_async_omni_from_stage_config( async def omni_init_app_state( engine_client: EngineClient, - vllm_config: VllmConfig | None, state: State, args: Namespace, ) -> None: @@ -246,10 +291,12 @@ async def omni_init_app_state( Args: engine_client: Engine client instance (AsyncOmni) - vllm_config: vLLM configuration object (may be None for pure diffusion) state: FastAPI application state object to initialize args: Parsed command-line arguments """ + # Get vllm_config from engine_client (following 0.14.0 pattern) + vllm_config = await engine_client.get_vllm_config() + # Detect if it's pure Diffusion mode (single stage and is Diffusion) is_pure_diffusion = False if hasattr(engine_client, "stage_configs") and engine_client.stage_configs: @@ -273,6 +320,7 @@ async def omni_init_app_state( base_model_paths = [BaseModelPath(name=name, model_path=args.model) for name in served_model_names] state.engine_client = engine_client state.log_stats = not args.disable_log_stats + state.args = args # For omni models state.stage_configs = engine_client.stage_configs if hasattr(engine_client, "stage_configs") else None @@ -302,32 +350,18 @@ async def omni_init_app_state( logger.warning("vllm_config is None, some features may not work correctly") state.vllm_config = vllm_config - if vllm_config is not None: - _model_config = vllm_config.model_config - - resolved_chat_template = load_chat_template(args.chat_template) - if resolved_chat_template is not None and vllm_config is not None: - # Get the tokenizer to check official template - tokenizer = await engine_client.get_tokenizer() - - if tokenizer is not None: - if isinstance(tokenizer, MistralTokenizer): - # The warning is logged in resolve_mistral_chat_template. - resolved_chat_template = resolve_mistral_chat_template(chat_template=resolved_chat_template) - else: - hf_chat_template = resolve_hf_chat_template( - tokenizer=tokenizer, - chat_template=None, - tools=None, - model_config=vllm_config.model_config, - ) - if hf_chat_template != resolved_chat_template: - logger.warning( - "Using supplied chat template: %s\nIt is different from official chat template '%s'. This discrepancy may lead to performance degradation.", # noqa: E501 - resolved_chat_template, - args.model, - ) + # Get supported tasks + supported_tasks: set[str] = {"generate"} + if hasattr(engine_client, "get_supported_tasks"): + supported_tasks = set(await engine_client.get_supported_tasks()) + logger.info("Supported tasks: %s", supported_tasks) + + resolved_chat_template = await process_chat_template( + args.chat_template, + engine_client, + vllm_config.model_config if vllm_config is not None else None, + ) if args.tool_server == "demo": tool_server: ToolServer | None = DemoToolServer() @@ -340,23 +374,12 @@ async def omni_init_app_state( tool_server = None # Merge default_mm_loras into the static lora_modules - default_mm_loras = {} - if vllm_config is not None and vllm_config.lora_config is not None: - default_mm_loras = vllm_config.lora_config.default_mm_loras - - lora_modules = args.lora_modules - if default_mm_loras: - default_mm_lora_paths = [ - LoRAModulePath( - name=modality, - path=lora_path, - ) - for modality, lora_path in default_mm_loras.items() - ] - if args.lora_modules is None: - lora_modules = default_mm_lora_paths - else: - lora_modules += default_mm_lora_paths + default_mm_loras = ( + vllm_config.lora_config.default_mm_loras + if vllm_config is not None and vllm_config.lora_config is not None + else {} + ) + lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) # Ensure input_processor, io_processor, and model_config exist for OpenAIServingModels compatibility if ( @@ -415,24 +438,182 @@ async def omni_init_app_state( lora_modules=lora_modules, ) await state.openai_serving_models.init_static_loras() - state.openai_serving_chat = OmniOpenAIServingChat( + + state.openai_serving_responses = ( + OpenAIServingResponses( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + tool_server=tool_server, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_chat = ( + OmniOpenAIServingChat( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + default_chat_template_kwargs=args.default_chat_template_kwargs, + trust_request_chat_template=args.trust_request_chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + enable_log_deltas=args.enable_log_deltas, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + # Warm up chat template processing to avoid first-request latency + if state.openai_serving_chat is not None: + await state.openai_serving_chat.warmup() + + state.openai_serving_completion = ( + OpenAIServingCompletion( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_pooling = ( + OpenAIServingPooling( + engine_client, + state.openai_serving_models, + supported_tasks=supported_tasks, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if any(task in POOLING_TASKS for task in supported_tasks) + else None + ) + state.openai_serving_embedding = ( + OpenAIServingEmbedding( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if "embed" in supported_tasks + else None + ) + state.openai_serving_classification = ( + ServingClassification( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if "classify" in supported_tasks + else None + ) + state.openai_serving_scores = ( + ServingScores( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + score_template=resolved_chat_template, + log_error_stack=args.log_error_stack, + ) + if ("embed" in supported_tasks or "score" in supported_tasks) + else None + ) + state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, state.openai_serving_models, - args.response_role, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, - tool_parser=args.tool_call_parser, - reasoning_parser=args.structured_outputs_config.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - enable_log_outputs=args.enable_log_outputs, log_error_stack=args.log_error_stack, ) + state.openai_serving_transcription = ( + OpenAIServingTranscription( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "transcription" in supported_tasks + else None + ) + state.openai_serving_translation = ( + OpenAIServingTranslation( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "transcription" in supported_tasks + else None + ) + state.anthropic_serving_messages = ( + AnthropicServingMessages( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "generate" in supported_tasks + else None + ) + state.serving_tokens = ( + ServingTokens( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + log_error_stack=args.log_error_stack, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_log_outputs=args.enable_log_outputs, + force_no_detokenize=args.tokens_only, + ) + if "generate" in supported_tasks + else None + ) state.openai_serving_speech = OmniOpenAIServingSpeech( engine_client, state.openai_serving_models, request_logger=request_logger @@ -463,6 +644,7 @@ def Omnispeech(request: Request) -> OmniOpenAIServingSpeech | None: @with_cancellation @load_aware_call async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + metrics_header_format = raw_request.headers.get(ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, "") handler = Omnichat(raw_request) if handler is None: return base(raw_request).create_error_response(message="The model does not support Chat Completions API") @@ -474,7 +656,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re if isinstance(generator, ErrorResponse): return JSONResponse( - content=generator.model_dump(), status_code=generator.code if hasattr(generator, "code") else 400 + content=generator.model_dump(), + status_code=generator.error.code if generator.error else 400, ) elif isinstance(generator, ChatCompletionResponse): @@ -490,18 +673,27 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re try: # Use serialize_as_any=True to bypass type checking response_dict = generator.model_dump(mode="json", serialize_as_any=True, warnings="none") - return JSONResponse(content=response_dict) + return JSONResponse( + content=response_dict, + headers=metrics_header(metrics_header_format), + ) except Exception: # Fallback: convert to JSON string and parse back to avoid any serialization issues try: response_json = generator.model_dump_json(warnings="none", serialize_as_any=True) response_dict = json_lib.loads(response_json) - return JSONResponse(content=response_dict) + return JSONResponse( + content=response_dict, + headers=metrics_header(metrics_header_format), + ) except Exception: # Last resort: regular dump with warnings suppressed with warnings_module.catch_warnings(): warnings_module.filterwarnings("ignore", category=UserWarning) - return JSONResponse(content=generator.model_dump(mode="json", warnings="none")) + return JSONResponse( + content=generator.model_dump(mode="json", warnings="none"), + headers=metrics_header(metrics_header_format), + ) return StreamingResponse(content=generator, media_type="text/event-stream") diff --git a/vllm_omni/entrypoints/openai/protocol/chat_completion.py b/vllm_omni/entrypoints/openai/protocol/chat_completion.py index 9f607624938..d0c83f56f8b 100644 --- a/vllm_omni/entrypoints/openai/protocol/chat_completion.py +++ b/vllm_omni/entrypoints/openai/protocol/chat_completion.py @@ -1,4 +1,6 @@ -from vllm.entrypoints.openai.protocol import ChatCompletionStreamResponse +from vllm.entrypoints.openai.protocol import ( + ChatCompletionStreamResponse, +) class OmniChatCompletionStreamResponse(ChatCompletionStreamResponse): diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 6fb9750ccc1..6320cb73a16 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -29,7 +29,10 @@ make_tool_call_id, resolve_chat_template_content_format, ) -from vllm.entrypoints.harmony_utils import get_streamable_parser_for_assistant, parse_chat_output +from vllm.entrypoints.openai.parser.harmony_utils import ( + get_streamable_parser_for_assistant, + parse_chat_output, +) from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, @@ -41,29 +44,24 @@ DeltaFunctionCall, DeltaMessage, DeltaToolCall, + ErrorInfo, ErrorResponse, FunctionCall, FunctionDefinition, PromptTokenUsageInfo, RequestResponseMetadata, + ResponsesRequest, ToolCall, UsageInfo, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_engine import ( ChatLikeRequest, - EngineTokensPrompt, - RequestPrompt, - ResponsesRequest, - TextTokensPrompt, clamp_prompt_logprobs, - is_list_of, ) -from vllm.entrypoints.openai.tool_parsers import ToolParser -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import should_include_usage -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -75,8 +73,10 @@ truncate_tool_call_ids, validate_request_params, ) +from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils.collection_utils import as_list +from vllm.utils.collection_utils import as_list, is_list_of from vllm_omni.entrypoints.chat_utils import parse_chat_messages_futures from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin @@ -177,52 +177,68 @@ async def create_chat_completion( truncate_tool_call_ids(request) validate_request_params(request) - if ( - request.tool_choice == "auto" - and not (self.enable_auto_tools and tool_parser is not None) - and not isinstance(tokenizer, MistralTokenizer) - and not self.use_harmony + # Check if tool parsing is unavailable (common condition) + tool_parsing_unavailable = ( + tool_parser is None and not isinstance(tokenizer, MistralTokenizer) and not self.use_harmony + ) + + # Validate tool_choice when tool parsing is required but unavailable + if tool_parsing_unavailable and request.tool_choice not in ( + None, + "none", ): - # for hf tokenizers, "auto" tools requires - # --enable-auto-tool-choice and --tool-call-parser - return self.create_error_response( - '"auto" tool choice requires --enable-auto-tool-choice and --tool-call-parser to be set' - ) + if request.tool_choice == "auto" and not self.enable_auto_tools: + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + '"auto" tool choice requires --enable-auto-tool-choice and --tool-call-parser to be set' + ) + elif request.tool_choice != "auto": + # "required" or named tool requires tool parser + return self.create_error_response( + f'tool_choice="{request.tool_choice}" requires --tool-call-parser to be set' + ) if request.tools is None or (request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none): tool_dicts = None else: tool_dicts = [tool.model_dump() for tool in request.tools] - # Common case. - request_chat_template = request.chat_template - chat_template_kwargs = request.chat_template_kwargs - if not self.trust_request_chat_template and ( - request_chat_template is not None - or (chat_template_kwargs and chat_template_kwargs.get("chat_template") is not None) - ): - return self.create_error_response( - "Chat template is passed with request, but --trust-request-chat-template is not set. " - "Refused request with untrusted chat template." + if not self.use_harmony: + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, ) - ( - conversation, - request_prompts, - engine_prompts, - ) = await self._preprocess_chat( - request, - tokenizer, - request.messages, - chat_template=request_chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - tool_dicts=tool_dicts, - documents=request.documents, - chat_template_kwargs=request.chat_template_kwargs, - tool_parser=tool_parser, - add_special_tokens=request.add_special_tokens, - ) + if error_check_ret is not None: + return error_check_ret + + chat_template_kwargs = request.chat_template_kwargs or {} + chat_template_kwargs.update(reasoning_effort=request.reasoning_effort) + + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self.chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=getattr(request, "documents", None), + chat_template_kwargs=chat_template_kwargs, + default_chat_template_kwargs=self.default_chat_template_kwargs, + tool_parser=tool_parser, + add_special_tokens=request.add_special_tokens, + ) + else: + should_include_tools = tool_dicts is not None + conversation, engine_prompts = self._make_request_with_harmony(request, should_include_tools) + request_prompts = [engine_prompt.get("prompt_token_ids", []) for engine_prompt in engine_prompts] except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -314,12 +330,13 @@ async def _preprocess_chat( tool_dicts: list[dict[str, Any]] | None = None, documents: list[dict[str, str]] | None = None, chat_template_kwargs: dict[str, Any] | None = None, + default_chat_template_kwargs: dict[str, Any] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, add_special_tokens: bool = False, ) -> tuple[ list[ConversationMessage], - Sequence[RequestPrompt], - list[EngineTokensPrompt], + Sequence[PromptType], + list[TokensPrompt], ]: model_config = self.model_config @@ -333,11 +350,17 @@ async def _preprocess_chat( conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( messages, model_config, - tokenizer, content_format=resolved_content_format, mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), ) + # Merge default_chat_template_kwargs with request-provided kwargs + # Request kwargs take precedence over defaults + merged_kwargs = self._prepare_extra_chat_template_kwargs( + chat_template_kwargs, + default_chat_template_kwargs, + ) + _chat_template_kwargs: dict[str, Any] = dict( chat_template=chat_template, add_generation_prompt=add_generation_prompt, @@ -345,7 +368,7 @@ async def _preprocess_chat( tools=tool_dicts, documents=documents, ) - _chat_template_kwargs.update(chat_template_kwargs or {}) + _chat_template_kwargs.update(merged_kwargs) request_prompt: str | list[int] @@ -388,7 +411,7 @@ async def _preprocess_chat( "Prompt has to be a string", "when the tokenizer is not initialised", ) - prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) + prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) elif isinstance(request_prompt, str): prompt_inputs = await self._tokenize_prompt_input_async( request, @@ -399,20 +422,21 @@ async def _preprocess_chat( else: # For MistralTokenizer assert is_list_of(request_prompt, int), "Prompt has to be either a string or a list of token ids" - prompt_inputs = TextTokensPrompt( + prompt_inputs = TokensPrompt( prompt=tokenizer.decode(request_prompt), prompt_token_ids=request_prompt, ) - engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) + engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data if mm_uuids is not None: engine_prompt["multi_modal_uuids"] = mm_uuids - if request.mm_processor_kwargs is not None: - engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + mm_processor_kwargs = getattr(request, "mm_processor_kwargs", None) + if mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs if hasattr(request, "cache_salt") and request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt @@ -513,7 +537,7 @@ def _build_sampling_params_list_from_request( def _log_inputs( self, request_id: str, - inputs: RequestPrompt | PromptType, + inputs: PromptType, params_list: list[SamplingParams] | None, lora_request: LoRARequest | None, ) -> None: @@ -599,9 +623,13 @@ async def chat_completion_stream_generator( try: if self.reasoning_parser: + chat_template_kwargs = self._prepare_extra_chat_template_kwargs( + request.chat_template_kwargs, + self.default_chat_template_kwargs, + ) reasoning_parser = self.reasoning_parser( tokenizer, - chat_template_kwargs=request.chat_template_kwargs, # type: ignore + chat_template_kwargs=chat_template_kwargs, # type: ignore ) except RuntimeError as e: logger.exception("Error in reasoning parser creation.") @@ -1052,10 +1080,9 @@ async def chat_completion_stream_generator( # wasn't ready to send a token, then # get the next token without streaming a chunk if delta_message is None: - if output.finish_reason is None: + if output.finish_reason is None and not request.return_token_ids: continue - else: - delta_message = DeltaMessage() + delta_message = DeltaMessage() # Log streaming delta if output logging is enabled if self.enable_log_outputs and self.request_logger: @@ -1438,13 +1465,20 @@ def _create_text_choice( if self.reasoning_parser: try: - reasoning_parser = self.reasoning_parser(tokenizer) + chat_template_kwargs = self._prepare_extra_chat_template_kwargs( + request.chat_template_kwargs, + self.default_chat_template_kwargs, + ) + reasoning_parser = self.reasoning_parser( + tokenizer, + chat_template_kwargs=chat_template_kwargs, # type: ignore + ) except RuntimeError as e: logger.exception("Error in reasoning parser creation.") return self.create_error_response(str(e)) # If the reasoning parser is enabled, # tool calls are extracted exclusively from the content. - reasoning_content, content = reasoning_parser.extract_reasoning_content(output.text, request=request) + reasoning_content, content = reasoning_parser.extract_reasoning(output.text, request=request) if not request.include_reasoning: reasoning_content = None else: @@ -2049,7 +2083,9 @@ def _create_error_response( ) -> ErrorResponse: """Create an error response following OpenAI error format.""" return ErrorResponse( - message=message, - type=err_type, - code=status_code, + error=ErrorInfo( + message=message, + type=err_type, + code=status_code, + ) ) diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 135b0e89ff2..eae3ea7afc4 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -195,6 +195,10 @@ def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None # Update base_engine_args with stage-specific engine_args if they exist if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None: base_engine_args_tmp = OmegaConf.merge(base_engine_args_tmp, stage_arg.engine_args) + if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None: + runtime_cfg = stage_arg.runtime + max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) + base_engine_args_tmp["max_num_seqs"] = max_batch_size stage_arg.engine_args = base_engine_args_tmp return stage_args diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py index 37733e2909c..45a1447b3a7 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py @@ -107,6 +107,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) if t2w_token_end_id: self.model.set_suppress_start_id(t2w_token_end_id + 1) + self.requires_raw_input_tokens = True elif self.model_stage == "code2wav": self.thinker = None @@ -125,6 +126,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._token2wav_conds: dict[str, torch.Tensor] = {} self._token2wav_ref_mels: dict[str, torch.Tensor] = {} self.model = self.token2wav + self.requires_raw_input_tokens = True else: raise ValueError("Invalid model stage") @@ -248,14 +250,27 @@ def forward( if is_npu(): # TODO: remove this hack when NPU supports batched inputs properly thinker_input_ids = input_ids[0] if input_ids is not None and added_batch_dim else input_ids - thinker_positions = positions[0] if positions.ndim > 1 else positions + # For MRoPE, positions shape is [3, num_tokens] (T/H/W), don't slice it + if positions.ndim == 2 and positions.shape[0] == 3: + thinker_positions = positions # MRoPE positions, keep as is + else: + thinker_positions = positions[0] if positions.ndim > 1 else positions thinker_inputs_embeds = ( inputs_embeds[0] if inputs_embeds is not None and added_batch_dim else inputs_embeds ) else: - thinker_input_ids = input_ids - thinker_positions = positions[0] - thinker_inputs_embeds = inputs_embeds + # Squeeze back if we added batch dim earlier + thinker_input_ids = input_ids[0] if input_ids is not None and added_batch_dim else input_ids + # For MRoPE, positions shape is [3, num_tokens] (T/H/W), don't slice it + if positions.ndim == 2 and positions.shape[0] == 3: + thinker_positions = positions # MRoPE positions, keep as is + elif added_batch_dim: + thinker_positions = positions[0] + else: + thinker_positions = positions + thinker_inputs_embeds = ( + inputs_embeds[0] if inputs_embeds is not None and added_batch_dim else inputs_embeds + ) # Run thinker thinker_output = self.thinker( @@ -288,10 +303,16 @@ def forward( if not hasattr(self, "voice_type"): self.voice_type = voice_type + # For MRoPE, positions shape is [3, num_tokens] (T/H/W), don't slice it + if positions.ndim == 2 and positions.shape[0] == 3: + talker_positions = positions # MRoPE positions, keep as is + else: + talker_positions = positions[0] + with torch.inference_mode(): talker_hidden = self.talker( input_ids=input_ids, - positions=positions[0], + positions=talker_positions, inputs_embeds=inputs_embeds, ) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py index 46175568e07..927bc552573 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py @@ -9,7 +9,6 @@ # from vllm.attention import AttentionMetadata # unused import from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from vllm.model_executor.models.qwen2_5_omni_thinker import ( Qwen2_5OmniThinkerDummyInputsBuilder, @@ -68,13 +67,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.config = config - self.thinker_to_talker_proj = ColumnParallelLinear( + self.thinker_to_talker_proj = nn.Linear( self.config.embedding_size, self.config.hidden_size, - bias=True, - gather_output=True, - skip_bias_add=False, - quant_config=quant_config, ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -145,7 +140,7 @@ def forward( input_ids = None # projection - inputs_embeds, _ = self.thinker_to_talker_proj(inputs_embeds) + inputs_embeds = self.thinker_to_talker_proj(inputs_embeds) hidden_states = self.language_model.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds @@ -155,7 +150,18 @@ def forward( def bad_word_processor(self, logits: torch.Tensor) -> torch.Tensor: # suppress token IDs unsupported by token2wav if self.suppress_start_id and self.suppress_start_id < logits.size(-1): - logits[..., self.suppress_start_id : logits.size(-1)] = -1e9 + # skip the end token id. + if hasattr(self.config, "tts_codec_end_token_id"): + end_id = int(getattr(self.config, "tts_codec_end_token_id")) + if self.suppress_start_id == end_id: + logits[..., end_id + 1 : logits.size(-1)] = -1e9 + elif self.suppress_start_id < end_id: + logits[..., self.suppress_start_id : end_id] = -1e9 + logits[..., end_id + 1 : logits.size(-1)] = -1e9 + else: + logits[..., self.suppress_start_id : logits.size(-1)] = -1e9 + else: + raise ValueError("config must have tts_codec_end_token_id attribute") if hasattr(self.config, "tts_codec_start_token_id"): bos_id = int(getattr(self.config, "tts_codec_start_token_id")) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py index 71c5e8377ac..a2250b82e73 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py @@ -1,4 +1,4 @@ -"""Thin Omni wrapper: reuse upstream Qwen2.5-Omni thinker (v0.12) with minimal overrides.""" +"""Thin Omni wrapper: reuse upstream Qwen2.5-Omni thinker (v0.14) with minimal overrides.""" from collections.abc import Iterable from typing import Any @@ -12,6 +12,7 @@ Qwen2_5OmniAudioEncoder, ) from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, @@ -26,8 +27,6 @@ Qwen2_5OmniThinkerDummyInputsBuilder, Qwen2_5OmniThinkerMultiModalProcessor, Qwen2_5OmniThinkerProcessingInfo, - get_llm_pos_ids_for_vision, - split_list_into_ranges, ) from vllm.model_executor.models.qwen2_5_omni_thinker import ( Qwen2_5OmniConditionalGenerationMixin as Qwen2_5OmniConditionalGenerationMixinBase, @@ -46,7 +45,9 @@ WeightsMapper, init_vllm_registered_model, maybe_prefix, + split_list_into_ranges, ) +from vllm.model_executor.models.vision import get_llm_pos_ids_for_vision from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalFeatureSpec, @@ -166,6 +167,43 @@ def _parse_and_validate_video_input( video_grid_thw=video_grid_thw, ) + def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"].type(self.visual.dtype) + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + with set_forward_context(None, self.vllm_config): + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs, + video_hashes: list[str] = None, + cached_video_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if video_input["type"] == "video_embeds": + return video_input["video_embeds"].type(self.visual.dtype) + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) + with set_forward_context(None, self.vllm_config): + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + @MULTIMODAL_REGISTRY.register_processor( Qwen2_5OmniThinkerMultiModalProcessor, @@ -180,8 +218,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration( SupportsMRoPE, Qwen2_5OmniConditionalGenerationMixin, ): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", @@ -250,6 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + multimodal_config=multimodal_config, ) else: self.visual = None diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py index ce9819b3e68..e04010196a8 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py @@ -3,9 +3,6 @@ import torch from torch import nn from transformers import Qwen2Config -from vllm.attention.backends.abstract import ( - AttentionType, -) from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -15,7 +12,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding @@ -24,15 +20,14 @@ from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, - WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) from vllm.sequence import IntermediateTensors -from vllm.v1.outputs import PoolerOutput, SamplerOutput -from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.attention.backend import AttentionType +from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler @@ -130,7 +125,6 @@ def __init__( self.rotary_pos_emb = get_rope( head_size=self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, is_neox_style=True, rope_parameters={ @@ -460,66 +454,3 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - - # TODO: Replace this model class with as_embedding_model( - # Qwen2ForCausalLM) after changing the default pooling method - if pooler_config.pooling_type is None: - logger.warning( - "This embedding model will default to last-token pooling in " - "an upcoming version. To avoid breaking changes, you should " - 'pass `--override-pooler-config \'{"pooling_type": "MEAN"}\'`' - " explicitly." - ) - - self._pooler = Pooler.from_config_with_defaults( - pooler_config, pooling_type=PoolingType.MEAN, normalize=True, softmax=False - ) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - ) -> torch.Tensor: - return self.model(input_ids, positions, intermediate_tensors) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput | None: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights if not name.startswith("lm_head.")) - self.model.load_weights(weights) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index c459289af34..17f16d1d929 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -356,18 +356,24 @@ def forward( elif self.model_stage == "code2wav": # Extract codec codes from input codes = [] - if input_ids is not None: + if input_ids.shape[0] % 16 == 0: codes.append(input_ids.reshape(1, 16, -1)) - else: - # for profile, we use max length from inputs_embeds - codes.append( - torch.zeros( - (1, 16, inputs_embeds.shape[1]), - dtype=torch.long, - device=inputs_embeds.device, - ) + logger.warning( + ( + "Input_ids length: %s is not divisible by 16, padding " + "with zeros. This should only happen in warm up." + ), + input_ids.shape[0], + ) + input_ids_flatten = input_ids.reshape(-1) + input_ids_flatten = torch.cat( + [ + input_ids_flatten, + torch.zeros(16 - input_ids.shape[0] % 16, dtype=torch.long, device=input_ids.device), + ] ) + codes.append(input_ids_flatten.reshape(1, 16, -1)) # Generate audio from codec codes audio_tensors = [] @@ -575,30 +581,22 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, if input_embeds is None and input_ids is not None: input_embeds = self.talker.embed_input_ids(input_ids) - text_step = torch.zeros( - 1, - self.talker_config.text_config.hidden_size, - device=self._module_device(self.talker), - dtype=torch.bfloat16, - ) - last_talker_hidden = torch.zeros( - 1, - 1, - self.talker_config.text_config.hidden_size, - device=self._module_device(self.talker), - dtype=torch.bfloat16, - ) - span_len = input_ids.shape[0] if span_len > 1: # prefill input_ids, input_embeds, update_dict = self.talker_preprocess_prefill(input_ids, input_embeds, **info_dict) + code_predictor_codes = torch.zeros( + (input_embeds.shape[0], self.talker.num_code_groups), + device=self._module_device(self.talker), + dtype=torch.long, + ) + update_dict["code_predictor_codes"] = code_predictor_codes else: last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode( input_ids, input_embeds, **info_dict ) - update_dict["mtp_inputs"] = last_talker_hidden, text_step + update_dict["mtp_inputs"] = last_talker_hidden, text_step return input_ids, input_embeds, update_dict @@ -610,24 +608,19 @@ def talker_mtp( text_step: torch.Tensor, ): # TODO(Peiqi): not support intermediate_tensors now - input_ids = safe_tensor_reshape(input_ids, (1, -1)) + input_ids = safe_tensor_reshape(input_ids, (input_ids.shape[0], -1)) inputs_embeds = safe_tensor_reshape(input_embeds, (-1, self.talker_config.text_config.hidden_size)) - text_step = safe_tensor_reshape(text_step, (1, -1)) - last_talker_hidden = safe_tensor_reshape(last_talker_hidden, (1, 1, self.talker_config.text_config.hidden_size)) + text_step = safe_tensor_reshape(text_step, (-1, self.talker_config.text_config.hidden_size)) + last_talker_hidden = safe_tensor_reshape( + last_talker_hidden, (-1, 1, self.talker_config.text_config.hidden_size) + ) # for profiling if inputs_embeds.shape[-1] == 2048: inputs_embeds = self.text_projection(inputs_embeds) - if inputs_embeds.shape[0] == 1: - code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward( - input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden - ) - inputs_embeds = summed_embeddings.clone() - else: - code_predictor_codes = torch.zeros( - (inputs_embeds.shape[0], self.talker.num_code_groups), - device=self._module_device(self.talker), - dtype=torch.long, - ) + code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward( + input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden + ) + inputs_embeds = summed_embeddings.clone() inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size) return inputs_embeds, code_predictor_codes.squeeze(-1) @@ -850,7 +843,7 @@ def talker_preprocess_decode(self, input_ids: torch.Tensor, input_embeds: torch. use_vec = q_tail[0:1, :] new_q_tail = ( q_tail[1:, :].detach().to("cpu").contiguous() - if q_tail.shape[1] > 1 + if q_tail.shape[0] > 1 else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype) ) text_step = use_vec.to(input_embeds.device, dtype=input_embeds.dtype) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py index 4e8730eab52..2e872340da4 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py @@ -111,7 +111,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() talker_config: Qwen3OmniMoeTalkerConfig = vllm_config.model_config.hf_config talker_config.text_config.rope_parameters = talker_config.text_config.rope_scaling - talker_config.text_config.rope_parameters["rope_theta"] = talker_config.text_config.rope_theta + talker_config.text_config.rope_parameters["rope_theta"] = talker_config.text_config.rope_parameters[ + "rope_theta" + ] self.quant_config = vllm_config.quant_config self.prefix = prefix self.vllm_config = vllm_config @@ -234,13 +236,7 @@ def code_predictor_forward( # Use the corresponding lm_head for this layer logits = self.code_predictor.lm_head[layer_idx](hidden_state[:, -1:, :]) # [batch, 1, vocab_size] - if len(pos_codes) > 1: - input_ids_for_logits_processors = torch.cat(pos_codes[1:], dim=1).to( - device=logits.device, dtype=torch.long - ) - else: - input_ids_for_logits_processors = self.empty_code - logits = logits_processors(input_ids_for_logits_processors, logits.squeeze(0)).unsqueeze(0) + logits = logits_processors(None, logits[:, -1]) # Sample from the filtered distribution probs = F.softmax(logits, dim=-1) @@ -288,7 +284,7 @@ def code_predictor_forward( all_summed_embeddings.append(pos_summed) # Concatenate across positions: [batch, seq_len, hidden_size] - summed_embeddings = torch.cat(all_summed_embeddings, dim=1) + summed_embeddings = torch.cat(all_summed_embeddings, dim=1).squeeze(1) return result_codes, summed_embeddings diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py index 7f3320a82eb..2d479062eb2 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py @@ -37,9 +37,6 @@ Qwen3OmniMoeConfig, Qwen3OmniMoeThinkerConfig, ) -from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( - Qwen3OmniMoeAudioEncoder, -) from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( Qwen3OmniMoeProcessor, ) @@ -56,6 +53,7 @@ SupportsMultiModal, SupportsPP, ) +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_omni_thinker import ( Qwen2_5OmniAudioFeatureInputs, Qwen2_5OmniThinkerDummyInputsBuilder, @@ -69,6 +67,7 @@ from vllm.model_executor.models.qwen3_moe import Qwen3MoeModel as _Qwen3MoeLLMModel from vllm.model_executor.models.qwen3_omni_moe_thinker import ( Qwen3Omni_VisionTransformer, + Qwen3OmniMoeAudioEncoder, _get_feat_extract_output_lengths, ) from vllm.model_executor.models.utils import ( @@ -248,11 +247,36 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: # https://github.com/huggingface/transformers/pull/41473 mm_kwargs = dict(mm_kwargs) tok_kwargs = dict(tok_kwargs) + mm_kwargs["audio_kwargs"] = dict(mm_kwargs.get("audio_kwargs") or {}) + mm_kwargs["text_kwargs"] = dict(mm_kwargs.get("text_kwargs") or {}) if Version(TRANSFORMERS_VERSION) < Version("4.58.0"): + # Extract audio_sample_rate before restructuring + audio_sample_rate = mm_kwargs.pop("audio_sample_rate", None) + # move truncation to audio_kwargs level to avoid conflict # with tok_kwargs - mm_kwargs["audio_kwargs"] = {"truncation": mm_kwargs.pop("truncation", False)} - mm_kwargs["text_kwargs"] = {"truncation": tok_kwargs.pop("truncation", False)} + mm_kwargs["audio_kwargs"].setdefault("truncation", mm_kwargs.pop("truncation", False)) + mm_kwargs["text_kwargs"].setdefault("truncation", tok_kwargs.pop("truncation", False)) + + # Validate and conditionally pass audio_sample_rate + # WhisperFeatureExtractor has a fixed sampling rate, and vLLM's + # audio loader already resamples audio to the target rate. + # Only pass the value if it matches to avoid unexpected behavior. + if audio_sample_rate is not None: + expected_sr = feature_extractor.sampling_rate + if audio_sample_rate != expected_sr: + logger.warning( + "[%s] audio_sample_rate mismatch: user provided %dHz " + "but model expects %dHz. Ignoring user value. " + "vLLM's audio loader already resampled to %dHz.", + self.__class__.__name__, + audio_sample_rate, + expected_sr, + expected_sr, + ) + else: + # Sample rate matches, safe to pass + mm_kwargs["audio_kwargs"]["audio_sample_rate"] = audio_sample_rate hf_inputs = super()._call_hf_processor( prompt=prompt, @@ -398,11 +422,11 @@ def _get_prompt_updates( if audio_feature_lengths is None and feature_attention_mask is None: audio_output_lengths = [] elif audio_feature_lengths is not None: - _, audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths) + audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths) audio_output_lengths = audio_output_lens.tolist() elif feature_attention_mask is not None: assert isinstance(feature_attention_mask, torch.Tensor) - _, audio_output_lens = _get_feat_extract_output_lengths(feature_attention_mask.sum(-1)) + audio_output_lens = _get_feat_extract_output_lengths(feature_attention_mask.sum(-1)) audio_output_lengths = audio_output_lens.tolist() # number of audios read from video. @@ -569,59 +593,21 @@ def _get_raw_input_ids( class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMixin): - def _parse_and_validate_audio_input(self, **kwargs: object) -> Qwen2_5OmniAudioFeatureInputs | None: - input_audio_features = kwargs.pop("input_audio_features", None) - audio_feature_lengths = kwargs.pop("audio_feature_lengths", None) - feature_attention_mask = kwargs.pop("feature_attention_mask", None) - if input_audio_features is None: - return None - if ( - input_audio_features is not None - and isinstance(input_audio_features, torch.Tensor) - and input_audio_features.ndim == 3 - ): - # (batch_size, feature_dim, chunk_size) -> (feature_dim, batch_size * chunk_size) - input_audio_features = input_audio_features.permute(1, 0, 2).flatten(1) - elif input_audio_features is not None and isinstance(input_audio_features, list): - input_audio_features = torch.cat(input_audio_features, dim=-1) - if ( - audio_feature_lengths is not None - and isinstance(audio_feature_lengths, torch.Tensor) - and audio_feature_lengths.ndim == 2 - ): - audio_feature_lengths = audio_feature_lengths.reshape(-1) - elif audio_feature_lengths is not None and isinstance(audio_feature_lengths, list): - audio_feature_lengths = torch.cat(audio_feature_lengths, dim=-1) - if ( - feature_attention_mask is not None - and isinstance(feature_attention_mask, torch.Tensor) - and feature_attention_mask.ndim == 3 - ): - feature_attention_mask = feature_attention_mask.reshape(-1, feature_attention_mask.shape[-1]) - elif feature_attention_mask is not None and isinstance(feature_attention_mask, list): - for i in range(len(feature_attention_mask)): - feature_attention_mask[i] = feature_attention_mask[i].reshape(-1) - return Qwen2_5OmniAudioFeatureInputs( - type="audio_features", - input_features=input_audio_features, - audio_feature_lengths=audio_feature_lengths, - feature_attention_mask=feature_attention_mask, - ) - def _process_audio_input( self, audio_input: Qwen2_5OmniAudioFeatureInputs, audio_hashes: list[str] | None = None, cached_audio_features: torch.Tensor | None = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, ...]: input_features = audio_input["input_features"] audio_feature_lengths = audio_input["audio_feature_lengths"] - audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths) + audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths) audio_outputs = self.audio_tower( input_features.to(self.audio_tower.dtype), feature_lens=audio_feature_lengths, + aftercnn_lens=audio_output_lengths, ) audio_features = audio_outputs.last_hidden_state return audio_features.split(audio_output_lengths.tolist()) @@ -639,8 +625,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( SupportsMRoPE, Qwen3OmniMoeConditionalGenerationMixin, ): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", @@ -649,6 +633,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( } ) + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): @@ -669,20 +665,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = thinker_config self.multimodal_config = multimodal_config - # force "use_flash_attention_2=True" to audio tower to align - # the results. - if flash_attn is not None: - audio_config = thinker_config.audio_config - audio_config._attn_implementation_autoset = True - audio_config._attn_implementation = "flash_attention_2" - else: - logger.warning( - "flash_attn is not available, the model may not yield the " - "exactly same result as the transformers implementation " - "in the audio tower part." - ) - - self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) + self.audio_tower = Qwen3OmniMoeAudioEncoder( + thinker_config.audio_config, + ) self.visual = Qwen3Omni_VisionTransformer( vision_config=thinker_config.vision_config, @@ -810,8 +795,6 @@ def embed_input_ids( return inputs_embeds deepstack_input_embeds = None - # TODO (ywang96): support overlapping modalitiy embeddings so that - # `use_audio_in_video` will work on V1. # split the feat dim to obtain multi-scale visual feature has_vision_embeddings = [ embeddings.shape[-1] != self.config.text_config.hidden_size for embeddings in multimodal_embeddings @@ -1005,7 +988,7 @@ def get_mrope_input_positions( bos_len = 1 llm_pos_ids_list.append(torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx) st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - _, audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx]) + audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx]) llm_pos_ids = torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx llm_pos_ids_list.append(llm_pos_ids) st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 @@ -1077,7 +1060,7 @@ def get_mrope_input_positions( llm_pos_ids_list.append(bos_block) llm_pos_ids_list.append(bos_block) st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - _, audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx]) + audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx]) audio_llm_pos_ids = torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx grid_t = video_grid_thw[video_idx][0] grid_hs = video_grid_thw[:, 1] @@ -1121,3 +1104,13 @@ def get_mrope_input_positions( mrope_position_delta = llm_positions.max() + 1 - seq_len return llm_positions, mrope_position_delta + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.merger", + tower_model=["visual.", "audio_tower."], + ) diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml index fc84f485ea4..e3d740ad580 100644 --- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml @@ -76,6 +76,7 @@ stage_args: trust_remote_code: true enable_prefix_caching: false max_num_batched_tokens: 32768 + async_scheduling: false engine_output_type: audio engine_input_source: [1] final_output: true diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml index c63dc563815..d4de078231a 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml @@ -16,7 +16,7 @@ stage_args: worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler gpu_memory_utilization: 0.6 - enforce_eager: false + enforce_eager: true trust_remote_code: true engine_output_type: latent # Output hidden states for talker distributed_executor_backend: "mp" @@ -46,8 +46,8 @@ stage_args: model_arch: Qwen3OmniMoeForConditionalGeneration worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.3 - enforce_eager: false + gpu_memory_utilization: 0.35 + enforce_eager: true trust_remote_code: true engine_output_type: latent # Output codec codes for code2wav # tensor_parallel_size: 2 @@ -80,6 +80,7 @@ stage_args: scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler enforce_eager: true trust_remote_code: true + async_scheduling: false enable_prefix_caching: false engine_output_type: audio # Final output: audio waveform gpu_memory_utilization: 0.1 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 246ea2996e8..a1457a9750b 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -160,7 +160,7 @@ def talker2code2wav( # Process each talker output for i, talker_output in enumerate(talker_outputs): output = talker_output.outputs[0] - seq_len = len(output.token_ids) + seq_len = len(output.token_ids) - 1 # Extract codec codes from talker output # Expected shape: [8, seq_len] (8-layer RVQ codes) codec_codes = ( diff --git a/vllm_omni/utils/platform_utils.py b/vllm_omni/utils/platform_utils.py index 5f8259ab83d..fb47018e789 100644 --- a/vllm_omni/utils/platform_utils.py +++ b/vllm_omni/utils/platform_utils.py @@ -53,6 +53,6 @@ def get_diffusion_worker_class() -> type: return NPUWorkerProc else: # Default to GPU worker for cuda and other devices - from vllm_omni.diffusion.worker.gpu_worker import WorkerProc + from vllm_omni.diffusion.worker.gpu_diffusion_worker import WorkerProc return WorkerProc diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index d4e7e195fe8..18c8bc761b1 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -12,10 +12,15 @@ import numpy as np import torch from vllm.config import CUDAGraphMode +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput -from vllm.v1.outputs import AsyncModelRunnerOutput +from vllm.v1.outputs import AsyncModelRunnerOutput, make_empty_encoder_model_runner_output from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext @@ -26,18 +31,12 @@ get_pp_group, get_tp_group, has_kv_transfer_group, - ) +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm.v1.worker.utils import is_residual_scattered_for_sp -from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( - RoutedExpertsCapturer, -) -from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer -from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group + from vllm_omni.outputs import OmniModelRunnerOutput from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner -from vllm.v1.outputs import make_empty_encoder_model_runner_output -from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices logger = init_logger(__name__) @@ -91,10 +90,7 @@ def execute_model( intermediate_tensors: IntermediateTensors | None = None, ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None: if self.execute_model_state is not None: - raise RuntimeError( - "State error: sample_tokens() must be called " - "after execute_model() returns None." - ) + raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") if self.vllm_config.model_config.enable_return_routed_experts: capturer = RoutedExpertsCapturer.get_instance() @@ -104,9 +100,7 @@ def execute_model( logger.error("RoutedExpertsCapturer not initialized.") if scheduler_output.preempted_req_ids and has_kv_transfer_group(): - get_kv_transfer_group().handle_preemptions( - scheduler_output.preempted_req_ids - ) + get_kv_transfer_group().handle_preemptions(scheduler_output.preempted_req_ids) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with ( @@ -126,8 +120,7 @@ def execute_model( if not num_scheduled_tokens: if ( - self.parallel_config.distributed_executor_backend - == "external_launcher" + self.parallel_config.distributed_executor_backend == "external_launcher" and self.parallel_config.data_parallel_size > 1 ): # this is a corner case when both external launcher @@ -196,9 +189,7 @@ def execute_model( ) num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = ( - batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs - ) + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( should_ubatch, num_scheduled_tokens_np, @@ -218,19 +209,17 @@ def execute_model( use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices - attn_metadata, spec_decode_common_attn_metadata = ( - self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_tokens_padded=num_tokens_padded if pad_attn else None, - num_reqs=num_reqs, - num_reqs_padded=num_reqs_padded if pad_attn else None, - max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices_attn, - logits_indices=logits_indices, - use_spec_decode=use_spec_decode, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - cascade_attn_prefix_lens=cascade_attn_prefix_lens, - ) + attn_metadata, spec_decode_common_attn_metadata = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, ) ( @@ -240,9 +229,7 @@ def execute_model( intermediate_tensors, model_kwargs, ec_connector_output, - ) = self._preprocess( - scheduler_output, num_tokens_padded, intermediate_tensors - ) + ) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors) # Set cudagraph mode to none if calc_kv_scales is true. # KV scales calculation involves dynamic operations that are incompatible @@ -329,9 +316,7 @@ def execute_model( sample_hidden_states = hidden_states[logits_indices] if not get_pp_group().is_last_rank: all_gather_tensors = { - "residual": not is_residual_scattered_for_sp( - self.vllm_config, num_tokens_padded - ) + "residual": not is_residual_scattered_for_sp(self.vllm_config, num_tokens_padded) } get_pp_group().send_tensor_dict( hidden_states.tensors, @@ -408,9 +393,7 @@ def sample_tokens( # Apply structured output bitmasks if present. if grammar_output is not None: - apply_grammar_bitmask( - scheduler_output, grammar_output, self.input_batch, logits - ) + apply_grammar_bitmask(scheduler_output, grammar_output, self.input_batch, logits) with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) @@ -450,23 +433,19 @@ def propose_draft_token_ids(sampled_token_ids): propose_draft_token_ids(sampled_token_ids) elif self.valid_sampled_token_count_event is not None: assert spec_decode_common_attn_metadata is not None - next_token_ids, valid_sampled_tokens_count = ( - self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, - sampled_token_ids, - self.requests, - self.input_batch, - self.discard_request_mask.gpu, - ) - ) - self._copy_valid_sampled_token_count( - next_token_ids, valid_sampled_tokens_count + next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, ) + self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) # Since we couldn't run the drafter, # just use zeros for the draft tokens. - self._draft_token_ids = torch.zeros( - 1, device=self.device, dtype=torch.int32 - ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._draft_token_ids = torch.zeros(1, device=self.device, dtype=torch.int32).expand( + len(self.input_batch.req_ids), self.num_spec_tokens + ) self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) else: propose_drafts_after_bookkeeping = input_fits_in_drafter diff --git a/vllm_omni/worker/gpu_ar_worker.py b/vllm_omni/worker/gpu_ar_worker.py index 599dea31f2f..d2dafce6877 100644 --- a/vllm_omni/worker/gpu_ar_worker.py +++ b/vllm_omni/worker/gpu_ar_worker.py @@ -3,19 +3,18 @@ import torch from vllm.logger import init_logger -from vllm.utils.torch_utils import set_random_seed from vllm.platforms import current_platform from vllm.utils.mem_utils import MemorySnapshot, format_gib +from vllm.utils.torch_utils import set_random_seed from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_worker import Worker as GPUWorker from vllm.v1.worker.gpu_worker import init_worker_distributed_environment +from vllm.v1.worker.utils import request_memory +from vllm.v1.worker.workspace import init_workspace_manager from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner -from vllm.v1.worker.workspace import init_workspace_manager -from vllm.v1.worker.utils import request_memory -from vllm.logger import init_logger -logger = init_logger(__name__) +logger = init_logger(__name__) class GPUARWorker(GPUWorker): @@ -31,8 +30,7 @@ def init_device(self): os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) parallel_config = self.parallel_config if ( - parallel_config.distributed_executor_backend - not in ("ray", "external_launcher") + parallel_config.distributed_executor_backend not in ("ray", "external_launcher") and parallel_config.data_parallel_backend != "ray" and parallel_config.nnodes_within_dp == 1 ): @@ -42,8 +40,7 @@ def init_device(self): dp_local_rank = self.parallel_config.data_parallel_index tp_pp_world_size = ( - self.parallel_config.pipeline_parallel_size - * self.parallel_config.tensor_parallel_size + self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size ) # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK @@ -51,9 +48,7 @@ def init_device(self): assert self.local_rank < torch.cuda.device_count(), ( f"DP adjusted local rank {self.local_rank} is out of bounds. " ) - visible_device_count = ( - torch.cuda.device_count() if torch.cuda.is_available() else 0 - ) + visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 assert self.parallel_config.local_world_size <= visible_device_count, ( f"local_world_size ({self.parallel_config.local_world_size}) must " f"be less than or equal to the number of visible devices " @@ -87,9 +82,7 @@ def init_device(self): self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) self.requested_memory = request_memory(init_snapshot, self.cache_config) logger.debug("worker init memory snapshot: %r", self.init_snapshot) - logger.debug( - "worker requested memory: %sGiB", format_gib(self.requested_memory) - ) + logger.debug("worker requested memory: %sGiB", format_gib(self.requested_memory)) else: raise RuntimeError(f"Not support device type: {self.device_config.device}") @@ -102,4 +95,4 @@ def init_device(self): if self.rank == 0: # If usage stat is enabled, collect relevant info. - report_usage_stats(self.vllm_config) \ No newline at end of file + report_usage_stats(self.vllm_config) diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 40b26e2ba50..011c74d2e83 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -5,16 +5,23 @@ """ from __future__ import annotations -from copy import copy import gc import logging -from typing import Any +from copy import copy + import numpy as np import torch from vllm.config import CUDAGraphMode +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) +from vllm.model_executor.models.interfaces import supports_mm_encoder_only from vllm.utils.math_utils import cdiv -from vllm.v1.core.sched.output import SchedulerOutput, GrammarOutput +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.outputs import AsyncModelRunnerOutput, make_empty_encoder_model_runner_output from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker.gpu_model_runner import ( @@ -25,20 +32,13 @@ get_pp_group, set_forward_context, ) -from vllm.model_executor.models.interfaces import supports_mm_encoder_only -from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices +from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs + from vllm_omni.outputs import OmniModelRunnerOutput -from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner -from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( - RoutedExpertsCapturer, -) -from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer -from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group -from vllm.v1.outputs import make_empty_encoder_model_runner_output -from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices -from vllm.v1.outputs import AsyncModelRunnerOutput from vllm_omni.worker.gpu_ar_model_runner import ExecuteModelState +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + logger = logging.getLogger(__name__) @@ -57,10 +57,7 @@ def execute_model( intermediate_tensors: IntermediateTensors | None = None, ) -> OmniModelRunnerOutput | IntermediateTensors: if self.execute_model_state is not None: - raise RuntimeError( - "State error: sample_tokens() must be called " - "after execute_model() returns None." - ) + raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") if self.vllm_config.model_config.enable_return_routed_experts: capturer = RoutedExpertsCapturer.get_instance() @@ -70,10 +67,8 @@ def execute_model( logger.error("RoutedExpertsCapturer not initialized.") if scheduler_output.preempted_req_ids and has_kv_transfer_group(): - get_kv_transfer_group().handle_preemptions( - scheduler_output.preempted_req_ids - ) - + get_kv_transfer_group().handle_preemptions(scheduler_output.preempted_req_ids) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with ( record_function_or_nullcontext("gpu_model_runner: preprocess"), @@ -82,7 +77,7 @@ def execute_model( self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: return EMPTY_MODEL_RUNNER_OUTPUT - + if has_ec_transfer() and get_ec_transfer().is_producer: with self.maybe_get_ec_connector_output( scheduler_output, @@ -93,8 +88,7 @@ def execute_model( if not num_scheduled_tokens: if ( - self.parallel_config.distributed_executor_backend - == "external_launcher" + self.parallel_config.distributed_executor_backend == "external_launcher" and self.parallel_config.data_parallel_size > 1 ): # this is a corner case when both external launcher @@ -107,7 +101,7 @@ def execute_model( if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) if self.cache_config.kv_sharing_fast_prefill: @@ -127,7 +121,7 @@ def execute_model( scheduler_output, num_scheduled_tokens_np, ) - + cascade_attn_prefix_lens = None # Disable cascade attention when using microbatching (DBO) if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: @@ -137,7 +131,7 @@ def execute_model( self.input_batch.num_computed_tokens_cpu[:num_reqs], scheduler_output.num_common_prefix_blocks, ) - + ( cudagraph_mode, batch_desc, @@ -163,9 +157,7 @@ def execute_model( ) num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = ( - batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs - ) + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( should_ubatch, num_scheduled_tokens_np, @@ -185,19 +177,17 @@ def execute_model( use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices - attn_metadata, spec_decode_common_attn_metadata = ( - self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_tokens_padded=num_tokens_padded if pad_attn else None, - num_reqs=num_reqs, - num_reqs_padded=num_reqs_padded if pad_attn else None, - max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices_attn, - logits_indices=logits_indices, - use_spec_decode=use_spec_decode, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - cascade_attn_prefix_lens=cascade_attn_prefix_lens, - ) + attn_metadata, spec_decode_common_attn_metadata = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, ) ( @@ -260,12 +250,15 @@ def execute_model( ) self.kv_connector_output = kv_connector_output return None - + @torch.inference_mode() def sample_tokens( self, grammar_output: GrammarOutput | None = None, ) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + # NOTE: Even though the model is non-autoregressive, we still need + # this function to match the interface of the engine core. + # In this case, this function kv_connector_output = self.kv_connector_output self.kv_connector_output = None @@ -329,9 +322,7 @@ def sample_tokens( kv_connector_output=kv_connector_output, num_nans_in_logits={}, cudagraph_stats=cudagraph_stats, - ec_connector_output=ec_connector_output - if self.supports_mm_inputs - else None, + ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, ) if not self.use_async_scheduling: @@ -433,10 +424,7 @@ def _dummy_run( # mm encoder dummy run may need to add in the future. return torch.tensor([]), torch.tensor([]) - assert ( - cudagraph_runtime_mode is None - or cudagraph_runtime_mode.valid_runtime_modes() - ) + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -497,8 +485,7 @@ def _dummy_run( max_num_scheduled_tokens=max_query_len, use_cascade_attn=False, allow_microbatching=allow_microbatching, - force_eager=is_profile - or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE), # `force_uniform_decode` is used for cudagraph capture; because for # capturing mixed prefill-decode batches, we sometimes use # num_tokens == num_reqs which looks like a uniform decode batch to the @@ -520,9 +507,7 @@ def _dummy_run( ) num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = ( - batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs - ) + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( should_ubatch, num_scheduled_tokens, @@ -601,17 +586,13 @@ def _dummy_run( intermediate_tensors = None else: if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device, - ) + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, ) - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens_padded, None, False - ) + intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False) if ubatch_slices_padded is not None: # Adjust values to reflect a single ubatch. @@ -652,14 +633,8 @@ def _dummy_run( # Therefore only use cudagraphs if the main model uses PIECEWISE # NOTE(lucas): this is a hack, need to clean up. use_cudagraphs = ( - ( - is_graph_capturing - and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE - ) - or ( - not is_graph_capturing - and cudagraph_runtime_mode != CUDAGraphMode.NONE - ) + (is_graph_capturing and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE) + or (not is_graph_capturing and cudagraph_runtime_mode != CUDAGraphMode.NONE) ) and not self.speculative_config.enforce_eager # Note(gnovack) - We need to disable cudagraphs for one of the two diff --git a/vllm_omni/worker/gpu_generation_worker.py b/vllm_omni/worker/gpu_generation_worker.py index 6a1a3039211..19f8ab84b99 100644 --- a/vllm_omni/worker/gpu_generation_worker.py +++ b/vllm_omni/worker/gpu_generation_worker.py @@ -2,17 +2,21 @@ import os import torch -from vllm.utils.torch_utils import set_random_seed +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.mem_utils import MemorySnapshot, format_gib +from vllm.utils.torch_utils import set_random_seed from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_worker import Worker as GPUWorker from vllm.v1.worker.gpu_worker import init_worker_distributed_environment -from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.utils import request_memory +from vllm.v1.worker.workspace import init_workspace_manager + from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner -from vllm.logger import init_logger + logger = init_logger(__name__) + + class GPUGenerationWorker(GPUWorker): """GPU Worker for Generation model (non-autoregressive waveform generation). @@ -26,8 +30,7 @@ def init_device(self): os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) parallel_config = self.parallel_config if ( - parallel_config.distributed_executor_backend - not in ("ray", "external_launcher") + parallel_config.distributed_executor_backend not in ("ray", "external_launcher") and parallel_config.data_parallel_backend != "ray" and parallel_config.nnodes_within_dp == 1 ): @@ -37,8 +40,7 @@ def init_device(self): dp_local_rank = self.parallel_config.data_parallel_index tp_pp_world_size = ( - self.parallel_config.pipeline_parallel_size - * self.parallel_config.tensor_parallel_size + self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size ) # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK @@ -46,9 +48,7 @@ def init_device(self): assert self.local_rank < torch.cuda.device_count(), ( f"DP adjusted local rank {self.local_rank} is out of bounds. " ) - visible_device_count = ( - torch.cuda.device_count() if torch.cuda.is_available() else 0 - ) + visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 assert self.parallel_config.local_world_size <= visible_device_count, ( f"local_world_size ({self.parallel_config.local_world_size}) must " f"be less than or equal to the number of visible devices " @@ -82,9 +82,7 @@ def init_device(self): self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) self.requested_memory = request_memory(init_snapshot, self.cache_config) logger.debug("worker init memory snapshot: %r", self.init_snapshot) - logger.debug( - "worker requested memory: %sGiB", format_gib(self.requested_memory) - ) + logger.debug("worker requested memory: %sGiB", format_gib(self.requested_memory)) else: raise RuntimeError(f"Not support device type: {self.device_config.device}") diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 24d4ffd028e..06abd9f3a78 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -8,7 +8,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.models.interfaces import supports_mrope, supports_mm_encoder_only +from vllm.model_executor.models.interfaces import supports_mm_encoder_only, supports_mrope from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.sampling_params import SamplingType from vllm.utils.import_utils import LazyLoader @@ -17,6 +17,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices + from vllm_omni.model_executor.models.output_templates import OmniOutput if TYPE_CHECKING: @@ -94,7 +95,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): if use_audio_in_video_value is not None: use_audio_in_video = bool(use_audio_in_video_value.item()) - if supports_mrope(self.model): + if supports_mrope(self.get_model()): req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( req_state.prompt_token_ids, mm_features=req_state.mm_features, @@ -248,7 +249,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: except Exception as e: logger.error(f"Error decoding additional information: {e}") pass - + if sampling_params and sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = ( self.input_batch.vocab_size @@ -258,11 +259,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(req_state) - + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) if self.uses_xdrope_dim > 0: self._init_xdrope_positions(req_state) - + reqs_to_add.append(self.requests[req_id]) # Update the states of the running/resumed requests. @@ -281,14 +282,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: resumed_from_preemption = req_id in req_data.resumed_req_ids num_output_tokens = req_data.num_output_tokens[i] req_index = self.input_batch.req_id_to_index.get(req_id) - + if req_state.prev_num_draft_len and self.use_async_scheduling: # prev_num_draft_len is used in async scheduling mode with # spec decode. it indicates if need to update num_computed_tokens # of the request. for example: # fist step: num_computed_tokens = 0, spec_tokens = [], # prev_num_draft_len = 0. - # second step: num_computed_tokens = 100(prompt lenth), + # second step: num_computed_tokens = 100(prompt length), # spec_tokens = [a,b], prev_num_draft_len = 0. # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], # prev_num_draft_len = 2. @@ -305,7 +306,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_rejected = req_state.prev_num_draft_len - num_accepted num_computed_tokens -= num_rejected req_state.output_token_ids.extend([-1] * num_accepted) - + # Update the cached states. req_state.num_computed_tokens = num_computed_tokens @@ -327,12 +328,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # failure. Align the cached state. del req_state.output_token_ids[num_output_tokens:] if req_index is not None: - end_idx = ( - self.input_batch.num_prompt_tokens[req_index] - + num_output_tokens - ) + end_idx = self.input_batch.num_prompt_tokens[req_index] + num_output_tokens self.input_batch.num_tokens_no_spec[req_index] = end_idx - + # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: @@ -372,15 +370,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(new_token_ids) - self.input_batch.token_ids_cpu[ - req_index, start_token_index:end_token_index - ] = new_token_ids + self.input_batch.token_ids_cpu[req_index, start_token_index:end_token_index] = new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) - # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: @@ -457,10 +452,7 @@ def _dummy_run( # mm encoder dummy run may need to add in the future. return torch.tensor([]), torch.tensor([]) - assert ( - cudagraph_runtime_mode is None - or cudagraph_runtime_mode.valid_runtime_modes() - ) + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -521,8 +513,7 @@ def _dummy_run( max_num_scheduled_tokens=max_query_len, use_cascade_attn=False, allow_microbatching=allow_microbatching, - force_eager=is_profile - or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE), # `force_uniform_decode` is used for cudagraph capture; because for # capturing mixed prefill-decode batches, we sometimes use # num_tokens == num_reqs which looks like a uniform decode batch to the @@ -544,9 +535,7 @@ def _dummy_run( ) num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = ( - batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs - ) + num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( should_ubatch, num_scheduled_tokens, @@ -625,17 +614,13 @@ def _dummy_run( intermediate_tensors = None else: if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device, - ) + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, ) - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens_padded, None, False - ) + intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False) if ubatch_slices_padded is not None: # Adjust values to reflect a single ubatch. @@ -657,6 +642,13 @@ def _dummy_run( ubatch_slices=ubatch_slices_padded, ), ): + if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): + outputs = self.talker_mtp( + self.talker_mtp_input_ids.gpu[:num_tokens_padded], + self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded], + self.last_talker_hidden.gpu[:num_tokens_padded], + self.text_step.gpu[:num_tokens_padded], + ) outputs = self.model( input_ids=input_ids, positions=positions, @@ -676,14 +668,8 @@ def _dummy_run( # Therefore only use cudagraphs if the main model uses PIECEWISE # NOTE(lucas): this is a hack, need to clean up. use_cudagraphs = ( - ( - is_graph_capturing - and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE - ) - or ( - not is_graph_capturing - and cudagraph_runtime_mode != CUDAGraphMode.NONE - ) + (is_graph_capturing and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE) + or (not is_graph_capturing and cudagraph_runtime_mode != CUDAGraphMode.NONE) ) and not self.speculative_config.enforce_eager # Note(gnovack) - We need to disable cudagraphs for one of the two @@ -721,9 +707,7 @@ def _dummy_run( self.eplb_step(is_dummy=True, is_profile=is_profile) logit_indices = np.cumsum(num_scheduled_tokens) - 1 - logit_indices_device = torch.from_numpy(logit_indices).to( - self.device, non_blocking=True - ) + logit_indices_device = torch.from_numpy(logit_indices).to(self.device, non_blocking=True) return hidden_states, hidden_states[logit_indices_device] def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput") -> None: @@ -988,6 +972,7 @@ def _preprocess( if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: # Overlay custom prompt_embeds per request for the prompt portion; # collect additional_information (tensor/list) for prefill portion only + decode_req_ids = [] for req_index, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests.get(req_id) req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None @@ -998,43 +983,17 @@ def _preprocess( span_len = int(e) - int(s) # call the custom process function - try: - req_input_ids, req_embeds, update_dict = self.model.preprocess( - input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos - ) - except Exception as e: - logger.error(f"Error in preprocess for request {req_id}: {e}") - import traceback - traceback.print_exc() - raise e - #TODO: This is Model Specific Code, need to be generalized in the future ZTC - # run talker mtp decode - if hasattr(self.model, "talker_mtp"): - _cudagraph_mode, batch_desc, _, _, _ = self._determine_batch_execution_and_padding( - num_tokens=span_len, - num_reqs=1, - num_scheduled_tokens_np=num_scheduled_tokens_np[req_index], - max_num_scheduled_tokens=1, - force_eager=span_len > 1, - use_cascade_attn=False, - ) + req_input_ids, req_embeds, update_dict = self.model.preprocess( + input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos + ) + if hasattr(self.model, "talker_mtp") and span_len == 1: last_talker_hidden, text_step = update_dict.pop("mtp_inputs") - if _cudagraph_mode != CUDAGraphMode.NONE: - self.talker_mtp_input_ids.gpu[:span_len].copy_(req_input_ids) - self.talker_mtp_inputs_embeds.gpu[:span_len].copy_(req_embeds) - self.last_talker_hidden.gpu[:span_len].copy_(last_talker_hidden) - self.text_step.gpu[:span_len].copy_(text_step) - req_input_ids = self.talker_mtp_input_ids.gpu[:span_len] - req_embeds = self.talker_mtp_inputs_embeds.gpu[:span_len] - last_talker_hidden = self.last_talker_hidden.gpu[:span_len] - text_step = self.text_step.gpu[:span_len] - with set_forward_context( - None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc - ): - req_embeds, code_predictor_codes = self.talker_mtp( - req_input_ids, req_embeds, last_talker_hidden, text_step - ) - update_dict["code_predictor_codes"] = code_predictor_codes + decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) + self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) + self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) + self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) + self.text_step.gpu[decode_slice].copy_(text_step) + decode_req_ids.append(req_id) # TODO(Peiqi): the merge stage could move out from the critical path self._merge_additional_information_update(req_id, update_dict) @@ -1045,6 +1004,10 @@ def _preprocess( if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: input_ids[s : s + seg_len] = req_input_ids + # run talker mtp decode + if hasattr(self.model, "talker_mtp"): + self._talker_mtp_forward(decode_req_ids, inputs_embeds) + return ( input_ids, inputs_embeds, @@ -1054,6 +1017,34 @@ def _preprocess( ec_connector_output, ) + def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None: + decode_batch_size = len(decode_req_ids) + if decode_batch_size == 0: + return + _cudagraph_mode, batch_desc, _, _, _ = self._determine_batch_execution_and_padding( + num_tokens=decode_batch_size, + num_reqs=decode_batch_size, + num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32), + max_num_scheduled_tokens=1, + use_cascade_attn=False, + ) + req_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size] + req_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size] + last_talker_hidden = self.last_talker_hidden.gpu[:decode_batch_size] + text_step = self.text_step.gpu[:decode_batch_size] + with set_forward_context( + None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc + ): + req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + # update the inputs_embeds and code_predictor_codes + code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + for idx, req_id in enumerate(decode_req_ids): + req_index = self.input_batch.req_ids.index(req_id) + start_offset = int(self.query_start_loc.cpu[req_index]) + inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] + update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} + self._merge_additional_information_update(req_id, update_dict) + def _model_forward( self, input_ids: torch.Tensor | None = None, diff --git a/vllm_omni/worker/npu/npu_model_runner.py b/vllm_omni/worker/npu/npu_model_runner.py index 1083bbe40d4..264d5e3413b 100644 --- a/vllm_omni/worker/npu/npu_model_runner.py +++ b/vllm_omni/worker/npu/npu_model_runner.py @@ -64,7 +64,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): if use_audio_in_video_value is not None: use_audio_in_video = bool(use_audio_in_video_value.item()) - if supports_mrope(self.model): + if supports_mrope(self.get_model()): req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( req_state.prompt_token_ids, mm_features=req_state.mm_features,