diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cae61aef16..643cf7a465 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: - id: ruff-format - repo: https://github.com/crate-ci/typos - rev: v1.35.5 + rev: v1.38.1 hooks: - id: typos # only for staged files diff --git a/docs/user_guide/cache_dit_acceleration.md b/docs/user_guide/cache_dit_acceleration.md index 9565893ebb..35aeae3c96 100644 --- a/docs/user_guide/cache_dit_acceleration.md +++ b/docs/user_guide/cache_dit_acceleration.md @@ -201,7 +201,7 @@ omni = Omni( You can customize the configuration by modifying the `cache_config` dictionary to use only specific methods (e.g., DBCache only, DBCache + SCM, etc.) based on your quality and speed requirements. -To test another model, you can modify `--model` with the target model identifier like `Tongyi-MAI/Z-Image-Turbo` and update `cache_confg` according the model architecture (e.g., number of transformer blocks). +To test another model, you can modify `--model` with the target model identifier like `Tongyi-MAI/Z-Image-Turbo` and update `cache_config` according the model architecture (e.g., number of transformer blocks). ## Additional Resources diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py index 1de671c2e3..9312a95541 100644 --- a/examples/offline_inference/qwen2_5_omni/end2end.py +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -6,7 +6,7 @@ """ import os -from typing import NamedTuple, Optional +from typing import NamedTuple import librosa import numpy as np @@ -58,9 +58,9 @@ def get_text_query(question: str = None) -> QueryResult: def get_mixed_modalities_query( - video_path: Optional[str] = None, - image_path: Optional[str] = None, - audio_path: Optional[str] = None, + video_path: str | None = None, + image_path: str | None = None, + audio_path: str | None = None, num_frames: int = 16, sampling_rate: int = 16000, ) -> QueryResult: @@ -114,7 +114,7 @@ def get_mixed_modalities_query( def get_use_audio_in_video_query( - video_path: Optional[str] = None, num_frames: int = 16, sampling_rate: int = 16000 + video_path: str | None = None, num_frames: int = 16, sampling_rate: int = 16000 ) -> QueryResult: question = "Describe the content of the video, then convert what the baby say into text." prompt = ( @@ -151,7 +151,7 @@ def get_use_audio_in_video_query( ) -def get_multi_audios_query(audio_path: Optional[str] = None, sampling_rate: int = 16000) -> QueryResult: +def get_multi_audios_query(audio_path: str | None = None, sampling_rate: int = 16000) -> QueryResult: question = "Are these two audio clips the same?" prompt = ( f"<|im_start|>system\n{default_system}<|im_end|>\n" @@ -190,7 +190,7 @@ def get_multi_audios_query(audio_path: Optional[str] = None, sampling_rate: int ) -def get_image_query(question: str = None, image_path: Optional[str] = None) -> QueryResult: +def get_image_query(question: str = None, image_path: str | None = None) -> QueryResult: if question is None: question = "What is the content of this image?" prompt = ( @@ -219,7 +219,7 @@ def get_image_query(question: str = None, image_path: Optional[str] = None) -> Q ) -def get_video_query(question: str = None, video_path: Optional[str] = None, num_frames: int = 16) -> QueryResult: +def get_video_query(question: str = None, video_path: str | None = None, num_frames: int = 16) -> QueryResult: if question is None: question = "Why is this video funny?" prompt = ( @@ -247,7 +247,7 @@ def get_video_query(question: str = None, video_path: Optional[str] = None, num_ ) -def get_audio_query(question: str = None, audio_path: Optional[str] = None, sampling_rate: int = 16000) -> QueryResult: +def get_audio_query(question: str = None, audio_path: str | None = None, sampling_rate: int = 16000) -> QueryResult: if question is None: question = "What is the content of this audio?" prompt = ( diff --git a/examples/offline_inference/qwen2_5_omni/extract_prompts.py b/examples/offline_inference/qwen2_5_omni/extract_prompts.py index 574dd05446..dce0788dbf 100644 --- a/examples/offline_inference/qwen2_5_omni/extract_prompts.py +++ b/examples/offline_inference/qwen2_5_omni/extract_prompts.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 import argparse -from typing import Optional -def extract_prompt(line: str) -> Optional[str]: +def extract_prompt(line: str) -> str | None: # Extract the content between the first '|' and the second '|' i = line.find("|") if i == -1: diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 5db68784b1..98f033dca5 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -6,7 +6,7 @@ """ import os -from typing import NamedTuple, Optional +from typing import NamedTuple import librosa import numpy as np @@ -57,7 +57,7 @@ def get_text_query(question: str = None) -> QueryResult: ) -def get_video_query(question: str = None, video_path: Optional[str] = None, num_frames: int = 16) -> QueryResult: +def get_video_query(question: str = None, video_path: str | None = None, num_frames: int = 16) -> QueryResult: if question is None: question = "Why is this video funny?" prompt = ( @@ -85,7 +85,7 @@ def get_video_query(question: str = None, video_path: Optional[str] = None, num_ ) -def get_image_query(question: str = None, image_path: Optional[str] = None) -> QueryResult: +def get_image_query(question: str = None, image_path: str | None = None) -> QueryResult: if question is None: question = "What is the content of this image?" prompt = ( @@ -114,7 +114,7 @@ def get_image_query(question: str = None, image_path: Optional[str] = None) -> Q ) -def get_audio_query(question: str = None, audio_path: Optional[str] = None, sampling_rate: int = 16000) -> QueryResult: +def get_audio_query(question: str = None, audio_path: str | None = None, sampling_rate: int = 16000) -> QueryResult: if question is None: question = "What is the content of this audio?" prompt = ( diff --git a/examples/online_serving/qwen2_5_omni/gradio_demo.py b/examples/online_serving/qwen2_5_omni/gradio_demo.py index deacbd9cc7..fcf14411b7 100644 --- a/examples/online_serving/qwen2_5_omni/gradio_demo.py +++ b/examples/online_serving/qwen2_5_omni/gradio_demo.py @@ -5,7 +5,7 @@ import sys from pathlib import Path from types import SimpleNamespace -from typing import Any, Optional +from typing import Any import gradio as gr import numpy as np @@ -175,16 +175,16 @@ def create_prompt_args(base_args: argparse.Namespace) -> SimpleNamespace: def process_audio_file( - audio_file: Optional[Any], -) -> Optional[tuple[np.ndarray, int]]: + audio_file: Any | None, +) -> tuple[np.ndarray, int] | None: """Normalize Gradio audio input to (np.ndarray, sample_rate).""" if audio_file is None: return None - sample_rate: Optional[int] = None - audio_np: Optional[np.ndarray] = None + sample_rate: int | None = None + audio_np: np.ndarray | None = None - def _load_from_path(path_str: str) -> Optional[tuple[np.ndarray, int]]: + def _load_from_path(path_str: str) -> tuple[np.ndarray, int] | None: if not path_str: return None path = Path(path_str) @@ -237,7 +237,7 @@ def _load_from_path(path_str: str) -> Optional[tuple[np.ndarray, int]]: return audio_np.astype(np.float32), sample_rate -def process_image_file(image_file: Optional[Image.Image]) -> Optional[Image.Image]: +def process_image_file(image_file: Image.Image | None) -> Image.Image | None: """Process image file from Gradio input. Returns: @@ -252,10 +252,10 @@ def process_image_file(image_file: Optional[Image.Image]) -> Optional[Image.Imag def process_video_file( - video_file: Optional[str], + video_file: str | None, enable_audio_in_video: bool = False, max_frames: int = 32, -) -> Optional[tuple[np.ndarray, dict[str, Any], Optional[tuple[np.ndarray, int]]]]: +) -> tuple[np.ndarray, dict[str, Any], tuple[np.ndarray, int] | None] | None: """Process video file and optionally extract audio track.""" if video_file is None: return None @@ -272,7 +272,7 @@ def process_video_file( print(f"Failed to decode video {video_path}: {exc}") return None - audio_tuple: Optional[tuple[np.ndarray, int]] = None + audio_tuple: tuple[np.ndarray, int] | None = None if enable_audio_in_video: try: import librosa # type: ignore import @@ -290,9 +290,9 @@ async def run_inference_async_omni( sampling_params: list[SamplingParams], prompt_args_template: SimpleNamespace, user_prompt: str, - audio_file: Optional[tuple[str, tuple[int, np.ndarray]]] = None, - image_file: Optional[Image.Image] = None, - video_file: Optional[str] = None, + audio_file: tuple[str, tuple[int, np.ndarray]] | None = None, + image_file: Image.Image | None = None, + video_file: str | None = None, use_audio_in_video: bool = False, ): """Run inference using AsyncOmni directly with multimodal support.""" @@ -420,9 +420,9 @@ def build_interface( async def run_inference( user_prompt: str, - audio_file: Optional[tuple[str, tuple[int, np.ndarray]]], - image_file: Optional[Image.Image], - video_file: Optional[str], + audio_file: tuple[str, tuple[int, np.ndarray]] | None, + image_file: Image.Image | None, + video_file: str | None, use_audio_in_video: bool, ): return await run_inference_async_omni( diff --git a/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py index 8ab718b30c..65ca8b5b19 100644 --- a/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py +++ b/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py @@ -1,6 +1,5 @@ import base64 import os -from typing import Optional import requests from openai import OpenAI @@ -38,7 +37,7 @@ def encode_base64_content_from_file(file_path: str) -> str: return result -def get_video_url_from_path(video_path: Optional[str]) -> str: +def get_video_url_from_path(video_path: str | None) -> str: """Convert a video path (local file or URL) to a video URL format for the API. If video_path is None or empty, returns the default URL. @@ -77,7 +76,7 @@ def get_video_url_from_path(video_path: Optional[str]) -> str: return f"data:{mime_type};base64,{video_base64}" -def get_image_url_from_path(image_path: Optional[str]) -> str: +def get_image_url_from_path(image_path: str | None) -> str: """Convert an image path (local file or URL) to an image URL format for the API. If image_path is None or empty, returns the default URL. @@ -114,7 +113,7 @@ def get_image_url_from_path(image_path: Optional[str]) -> str: return f"data:{mime_type};base64,{image_base64}" -def get_audio_url_from_path(audio_path: Optional[str]) -> str: +def get_audio_url_from_path(audio_path: str | None) -> str: """Convert an audio path (local file or URL) to an audio URL format for the API. If audio_path is None or empty, returns the default URL. @@ -169,7 +168,7 @@ def get_system_prompt(): } -def get_text_query(custom_prompt: Optional[str] = None): +def get_text_query(custom_prompt: str | None = None): question = ( custom_prompt or "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." ) @@ -186,10 +185,10 @@ def get_text_query(custom_prompt: Optional[str] = None): def get_mixed_modalities_query( - video_path: Optional[str] = None, - image_path: Optional[str] = None, - audio_path: Optional[str] = None, - custom_prompt: Optional[str] = None, + video_path: str | None = None, + image_path: str | None = None, + audio_path: str | None = None, + custom_prompt: str | None = None, ): question = ( custom_prompt or "What is recited in the audio? What is the content of this image? Why is this video funny?" @@ -222,7 +221,7 @@ def get_mixed_modalities_query( return prompt -def get_use_audio_in_video_query(video_path: Optional[str] = None, custom_prompt: Optional[str] = None): +def get_use_audio_in_video_query(video_path: str | None = None, custom_prompt: str | None = None): question = custom_prompt or "Describe the content of the video, then convert what the baby say into text." video_url = get_video_url_from_path(video_path) @@ -246,7 +245,7 @@ def get_use_audio_in_video_query(video_path: Optional[str] = None, custom_prompt return prompt -def get_multi_audios_query(audio_path: Optional[str] = None, custom_prompt: Optional[str] = None): +def get_multi_audios_query(audio_path: str | None = None, custom_prompt: str | None = None): question = custom_prompt or "Are these two audio clips the same?" audio_url = get_audio_url_from_path(audio_path) prompt = { diff --git a/examples/online_serving/qwen3_omni/gradio_demo.py b/examples/online_serving/qwen3_omni/gradio_demo.py index 3286361176..20182a88fc 100644 --- a/examples/online_serving/qwen3_omni/gradio_demo.py +++ b/examples/online_serving/qwen3_omni/gradio_demo.py @@ -5,7 +5,7 @@ import sys from pathlib import Path from types import SimpleNamespace -from typing import Any, Optional +from typing import Any import gradio as gr import numpy as np @@ -178,16 +178,16 @@ def create_prompt_args(base_args: argparse.Namespace) -> SimpleNamespace: def process_audio_file( - audio_file: Optional[Any], -) -> Optional[tuple[np.ndarray, int]]: + audio_file: Any | None, +) -> tuple[np.ndarray, int] | None: """Normalize Gradio audio input to (np.ndarray, sample_rate).""" if audio_file is None: return None - sample_rate: Optional[int] = None - audio_np: Optional[np.ndarray] = None + sample_rate: int | None = None + audio_np: np.ndarray | None = None - def _load_from_path(path_str: str) -> Optional[tuple[np.ndarray, int]]: + def _load_from_path(path_str: str) -> tuple[np.ndarray, int] | None: if not path_str: return None path = Path(path_str) @@ -240,7 +240,7 @@ def _load_from_path(path_str: str) -> Optional[tuple[np.ndarray, int]]: return audio_np.astype(np.float32), sample_rate -def process_image_file(image_file: Optional[Image.Image]) -> Optional[Image.Image]: +def process_image_file(image_file: Image.Image | None) -> Image.Image | None: """Process image file from Gradio input. Returns: @@ -255,10 +255,10 @@ def process_image_file(image_file: Optional[Image.Image]) -> Optional[Image.Imag def process_video_file( - video_file: Optional[str], + video_file: str | None, enable_audio_in_video: bool = False, max_frames: int = 32, -) -> Optional[tuple[np.ndarray, dict[str, Any], Optional[tuple[np.ndarray, int]]]]: +) -> tuple[np.ndarray, dict[str, Any], tuple[np.ndarray, int] | None] | None: """Process video file and optionally extract audio track.""" if video_file is None: return None @@ -275,7 +275,7 @@ def process_video_file( print(f"Failed to decode video {video_path}: {exc}") return None - audio_tuple: Optional[tuple[np.ndarray, int]] = None + audio_tuple: tuple[np.ndarray, int] | None = None if enable_audio_in_video: try: import librosa # type: ignore import @@ -293,9 +293,9 @@ async def run_inference_async_omni( sampling_params: list[SamplingParams], prompt_args_template: SimpleNamespace, user_prompt: str, - audio_file: Optional[tuple[str, tuple[int, np.ndarray]]] = None, - image_file: Optional[Image.Image] = None, - video_file: Optional[str] = None, + audio_file: tuple[str, tuple[int, np.ndarray]] | None = None, + image_file: Image.Image | None = None, + video_file: str | None = None, use_audio_in_video: bool = False, ): """Run inference using AsyncOmni directly with multimodal support.""" @@ -426,9 +426,9 @@ def build_interface( async def run_inference( user_prompt: str, - audio_file: Optional[tuple[str, tuple[int, np.ndarray]]], - image_file: Optional[Image.Image], - video_file: Optional[str], + audio_file: tuple[str, tuple[int, np.ndarray]] | None, + image_file: Image.Image | None, + video_file: str | None, use_audio_in_video: bool, ): return await run_inference_async_omni( diff --git a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py index ea2af62bd6..18b89b2541 100644 --- a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py +++ b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py @@ -1,6 +1,6 @@ import base64 import os -from typing import NamedTuple, Optional +from typing import NamedTuple import requests from openai import OpenAI @@ -43,7 +43,7 @@ def encode_base64_content_from_file(file_path: str) -> str: return result -def get_video_url_from_path(video_path: Optional[str]) -> str: +def get_video_url_from_path(video_path: str | None) -> str: """Convert a video path (local file or URL) to a video URL format for the API. If video_path is None or empty, returns the default URL. @@ -82,7 +82,7 @@ def get_video_url_from_path(video_path: Optional[str]) -> str: return f"data:{mime_type};base64,{video_base64}" -def get_image_url_from_path(image_path: Optional[str]) -> str: +def get_image_url_from_path(image_path: str | None) -> str: """Convert an image path (local file or URL) to an image URL format for the API. If image_path is None or empty, returns the default URL. @@ -119,7 +119,7 @@ def get_image_url_from_path(image_path: Optional[str]) -> str: return f"data:{mime_type};base64,{image_base64}" -def get_audio_url_from_path(audio_path: Optional[str]) -> str: +def get_audio_url_from_path(audio_path: str | None) -> str: """Convert an audio path (local file or URL) to an audio URL format for the API. If audio_path is None or empty, returns the default URL. @@ -174,7 +174,7 @@ def get_system_prompt(): } -def get_text_query(custom_prompt: Optional[str] = None): +def get_text_query(custom_prompt: str | None = None): question = ( custom_prompt or "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." ) @@ -197,7 +197,7 @@ def get_text_query(custom_prompt: Optional[str] = None): ) -def get_video_query(video_path: Optional[str] = None, custom_prompt: Optional[str] = None): +def get_video_query(video_path: str | None = None, custom_prompt: str | None = None): question = custom_prompt or "Why is this video funny?" video_url = get_video_url_from_path(video_path) prompt = { @@ -216,7 +216,7 @@ def get_video_query(video_path: Optional[str] = None, custom_prompt: Optional[st return prompt -def get_image_query(image_path: Optional[str] = None, custom_prompt: Optional[str] = None): +def get_image_query(image_path: str | None = None, custom_prompt: str | None = None): question = custom_prompt or "What is the content of this image?" image_url = get_image_url_from_path(image_path) prompt = { @@ -235,7 +235,7 @@ def get_image_query(image_path: Optional[str] = None, custom_prompt: Optional[st return prompt -def get_audio_query(audio_path: Optional[str] = None, custom_prompt: Optional[str] = None): +def get_audio_query(audio_path: str | None = None, custom_prompt: str | None = None): question = custom_prompt or "What is the content of this audio?" audio_url = get_audio_url_from_path(audio_path) prompt = { diff --git a/pyproject.toml b/pyproject.toml index 50ba7b6838..9d1e6fd482 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "vllm-omni" version = "0.11.0rc1" description = "A framework for efficient model inference with omni-modality models" readme = "README.md" -requires-python = ">=3.9,<3.14" # Align with vLLM v0.11 supported Python versions +requires-python = ">=3.10,<3.14" license = {text = "Apache-2.0"} authors = [ {name = "vLLM-Omni Team"} @@ -18,11 +18,10 @@ classifiers = [ "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries :: Python Modules", ] @@ -48,10 +47,8 @@ dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "pytest-cov>=4.0.0", - "black>=23.0.0", - "isort>=5.12.0", - "mypy>=1.0.0", - "pre-commit>=3.0.0", + "mypy==1.11.1", + "pre-commit==4.0.1", ] docs = [ @@ -65,8 +62,8 @@ docs = [ "mkdocs-git-revision-date-localized-plugin", "mkdocs-minify-plugin", "regex", + "ruff", "pydantic", - "black>=23.0.0", ] diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index ef9710e900..8fc929c38d 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -1,7 +1,7 @@ import json import warnings from importlib.util import find_spec -from typing import Any, Literal, Optional +from typing import Any, Literal import torch import vllm.envs as envs @@ -65,8 +65,8 @@ class OmniModelConfig(ModelConfig): stage_id: int = 0 model_stage: str = "thinker" model_arch: str = "Qwen2_5OmniForConditionalGeneration" - engine_output_type: Optional[str] = None - hf_config_name: Optional[str] = None + engine_output_type: str | None = None + hf_config_name: str | None = None @property def registry(self): @@ -87,16 +87,16 @@ def draw_hf_text_config(self): def __post_init__( self, # Multimodal config init vars - limit_mm_per_prompt: Optional[dict[str, int]], - media_io_kwargs: Optional[dict[str, dict[str, Any]]], - mm_processor_kwargs: Optional[dict[str, Any]], - mm_processor_cache_gb: Optional[float], - mm_processor_cache_type: Optional[MMCacheType], - mm_shm_cache_max_object_size_mb: Optional[int], - mm_encoder_tp_mode: Optional[MMEncoderTPMode], - interleave_mm_strings: Optional[bool], - skip_mm_profiling: Optional[bool], - video_pruning_rate: Optional[float], + limit_mm_per_prompt: dict[str, int] | None, + media_io_kwargs: dict[str, dict[str, Any]] | None, + mm_processor_kwargs: dict[str, Any] | None, + mm_processor_cache_gb: float | None, + mm_processor_cache_type: MMCacheType | None, + mm_shm_cache_max_object_size_mb: int | None, + mm_encoder_tp_mode: MMEncoderTPMode | None, + interleave_mm_strings: bool | None, + skip_mm_profiling: bool | None, + video_pruning_rate: float | None, ) -> None: # Set the default seed to 0 in V1. if envs.VLLM_USE_V1 and self.seed is None: diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 0916487ab6..430154a556 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -1,6 +1,5 @@ import time from collections import defaultdict -from typing import Optional from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.request_queue import create_request_queue @@ -170,7 +169,7 @@ def update_from_output( kv_connector_output = model_runner_output.kv_connector_output outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) - spec_decoding_stats: Optional[SpecDecodingStats] = None + spec_decoding_stats: SpecDecodingStats | None = None kv_connector_stats = kv_connector_output.kv_connector_stats if kv_connector_output else None # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py index 0d564fcdab..bbb38d099a 100644 --- a/vllm_omni/core/sched/output.py +++ b/vllm_omni/core/sched/output.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional from vllm.v1.core.sched.output import NewRequestData from vllm.v1.request import Request @@ -21,9 +20,9 @@ class OmniNewRequestData(NewRequestData): """ # Optional serialized prompt embeddings - prompt_embeds: Optional[PromptEmbedsPayload] = None + prompt_embeds: PromptEmbedsPayload | None = None # Optional serialized additional information - additional_information: Optional[AdditionalInformationPayload] = None + additional_information: AdditionalInformationPayload | None = None @classmethod def from_request( diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index 63362159e2..5365466adb 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -7,7 +7,8 @@ pipelines in vllm-omni, supporting both single and dual-transformer architectures. """ -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from vllm.logger import init_logger @@ -311,8 +312,8 @@ def __init__(self, cache_config: Any = None): super().__init__(config) # Cache-dit specific attributes - self._refresh_func: Optional[Callable[[Any, int, bool], None]] = None - self._last_num_inference_steps: Optional[int] = None + self._refresh_func: Callable[[Any, int, bool], None] | None = None + self._last_num_inference_steps: int | None = None def enable(self, pipeline: Any) -> None: """Enable cache-dit on the pipeline if configured. diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py index 700c5e99be..5a0cac6261 100644 --- a/vllm_omni/diffusion/cache/teacache/config.py +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional # Model-specific polynomial coefficients for rescaling L1 distances # These coefficients account for model-specific characteristics in how embeddings change @@ -51,7 +50,7 @@ class TeaCacheConfig: """ rel_l1_thresh: float = 0.2 - coefficients: Optional[list[float]] = None + coefficients: list[float] | None = None model_type: str = "QwenImagePipeline" def __post_init__(self) -> None: diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index 7bd574383a..49bcb76ca1 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -13,8 +13,9 @@ transformer execution, and postprocessing logic. """ +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Union +from typing import Any import torch import torch.nn as nn @@ -146,7 +147,7 @@ def extract_qwen_context( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor, - timestep: Union[torch.Tensor, float, int], + timestep: torch.Tensor | float | int, img_shapes: torch.Tensor, txt_seq_lens: torch.Tensor, guidance: torch.Tensor | None = None, diff --git a/vllm_omni/diffusion/cache/teacache/state.py b/vllm_omni/diffusion/cache/teacache/state.py index 42a55bede2..a6429e5401 100644 --- a/vllm_omni/diffusion/cache/teacache/state.py +++ b/vllm_omni/diffusion/cache/teacache/state.py @@ -7,8 +7,6 @@ This module manages the state for TeaCache hooks across diffusion timesteps. """ -from typing import Optional - import torch @@ -27,9 +25,9 @@ def __init__(self): # Caching state self.accumulated_rel_l1_distance = 0.0 - self.previous_modulated_input: Optional[torch.Tensor] = None - self.previous_residual: Optional[torch.Tensor] = None - self.previous_residual_encoder: Optional[torch.Tensor] = None + self.previous_modulated_input: torch.Tensor | None = None + self.previous_residual: torch.Tensor | None = None + self.previous_residual_encoder: torch.Tensor | None = None def reset(self) -> None: """Reset all state variables for a new inference run.""" diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index c7853f8a26..8210bedab5 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -4,8 +4,9 @@ import enum import os import random +from collections.abc import Callable from dataclasses import dataclass, field, fields -from typing import Any, Callable +from typing import Any import torch from vllm.logger import init_logger diff --git a/vllm_omni/diffusion/hooks.py b/vllm_omni/diffusion/hooks.py index b82676fad1..2296aa9a03 100644 --- a/vllm_omni/diffusion/hooks.py +++ b/vllm_omni/diffusion/hooks.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import torch.nn as nn diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 84967c6cf7..e5b99840dc 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -7,7 +7,7 @@ import math import os from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any import numpy as np import torch @@ -75,10 +75,10 @@ def calculate_shift( def retrieve_timesteps( scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[list[int]] = None, - sigmas: Optional[list[float]] = None, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ) -> tuple[torch.Tensor, int]: r""" @@ -186,7 +186,7 @@ def get_timestep_embedding( def apply_rotary_emb_qwen( x: torch.Tensor, - freqs_cis: Union[torch.Tensor, tuple[torch.Tensor]], + freqs_cis: torch.Tensor | tuple[torch.Tensor], use_real: bool = True, use_real_unbind_dim: int = -1, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -359,8 +359,8 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor def _get_qwen_prompt_embeds( self, - prompt: Union[str, list[str]] = None, - dtype: Optional[torch.dtype] = None, + prompt: str | list[str] = None, + dtype: torch.dtype | None = None, ): dtype = dtype or self.text_encoder.dtype @@ -400,10 +400,10 @@ def _get_qwen_prompt_embeds( def encode_prompt( self, - prompt: Union[str, list[str]], + prompt: str | list[str], num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, ): r""" @@ -601,23 +601,23 @@ def diffuse( def forward( self, req: OmniDiffusionRequest, - prompt: Union[str, list[str]] = "", - negative_prompt: Union[str, list[str]] = "", + prompt: str | list[str] = "", + negative_prompt: str | list[str] = "", true_cfg_scale: float = 4.0, height: int | None = None, width: int | None = None, num_inference_steps: int = 50, - sigmas: Optional[list[float]] = None, + sigmas: list[float] | None = None, guidance_scale: float = 1.0, num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_mask: Optional[torch.Tensor] = None, - output_type: Optional[str] = "pil", - attention_kwargs: Optional[dict[str, Any]] = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 512, ) -> DiffusionOutput: diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 07cbb9fdc2..c60623d81e 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -7,7 +7,7 @@ import math import os from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any import numpy as np import PIL.Image @@ -134,10 +134,10 @@ def calculate_dimensions(target_area: float, ratio: float): def retrieve_timesteps( scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[list[int]] = None, - sigmas: Optional[list[float]] = None, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ) -> tuple[torch.Tensor, int]: r""" @@ -173,7 +173,7 @@ def retrieve_timesteps( def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "argmax" + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "argmax" ): """Retrieve latents from VAE encoder output.""" if hasattr(encoder_output, "latent_dist"): @@ -305,9 +305,9 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor def _get_qwen_prompt_embeds( self, - prompt: Union[str, list[str]] = None, - image: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, + prompt: str | list[str] = None, + image: torch.Tensor | None = None, + dtype: torch.dtype | None = None, ): dtype = dtype or self.text_encoder.dtype @@ -350,9 +350,9 @@ def _get_qwen_prompt_embeds( def _get_qwen_prompt_embeds( self, - prompt: Union[str, list[str]] = None, - image: Optional[Union[PIL.Image.Image, torch.Tensor]] = None, - dtype: Optional[torch.dtype] = None, + prompt: str | list[str] = None, + image: PIL.Image.Image | torch.Tensor | None = None, + dtype: torch.dtype | None = None, ): """Get prompt embeddings with image support for editing.""" dtype = dtype or self.text_encoder.dtype @@ -397,11 +397,11 @@ def _get_qwen_prompt_embeds( def encode_prompt( self, - prompt: Union[str, list[str]], - image: Optional[torch.Tensor] = None, + prompt: str | list[str], + image: torch.Tensor | None = None, num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, ): r""" @@ -638,24 +638,24 @@ def diffuse( def forward( self, req: OmniDiffusionRequest, - prompt: Union[str, list[str]] = "", - negative_prompt: Union[str, list[str]] = "", - image: Optional[Union[PIL.Image.Image, torch.Tensor]] = None, + prompt: str | list[str] = "", + negative_prompt: str | list[str] = "", + image: PIL.Image.Image | torch.Tensor | None = None, true_cfg_scale: float = 4.0, height: int | None = None, width: int | None = None, num_inference_steps: int = 50, - sigmas: Optional[list[float]] = None, + sigmas: list[float] | None = None, guidance_scale: float = 1.0, num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_mask: Optional[torch.Tensor] = None, - output_type: Optional[str] = "pil", - attention_kwargs: Optional[dict[str, Any]] = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 512, ) -> DiffusionOutput: diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index e62400668d..5bcbed0c10 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -3,7 +3,7 @@ import functools from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any import torch import torch.nn as nn @@ -26,7 +26,7 @@ def apply_rotary_emb_qwen( x: torch.Tensor, - freqs_cis: Union[torch.Tensor, tuple[torch.Tensor]], + freqs_cis: torch.Tensor | tuple[torch.Tensor], use_real: bool = True, use_real_unbind_dim: int = -1, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -393,8 +393,8 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, - image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - joint_attention_kwargs: Optional[dict[str, Any]] = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # Get modulation parameters for both streams img_mod_params = self.img_mod(temb) # [B, 6*dim] @@ -490,7 +490,7 @@ def __init__( od_config: OmniDiffusionConfig, patch_size: int = 2, in_channels: int = 64, - out_channels: Optional[int] = 16, + out_channels: int | None = 16, num_layers: int = 60, attention_head_dim: int = 128, num_attention_heads: int = 24, @@ -537,12 +537,12 @@ def forward( encoder_hidden_states: torch.Tensor = None, encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, - img_shapes: Optional[list[tuple[int, int, int]]] = None, - txt_seq_lens: Optional[list[int]] = None, + img_shapes: list[tuple[int, int, int]] | None = None, + txt_seq_lens: list[int] | None = None, guidance: torch.Tensor = None, # TODO: this should probably be removed - attention_kwargs: Optional[dict[str, Any]] = None, + attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, - ) -> Union[torch.Tensor, Transformer2DModelOutput]: + ) -> torch.Tensor | Transformer2DModelOutput: """ The [`QwenTransformer2DModel`] forward method. diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index d65c6839ed..4210ad72ec 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any import torch import torch.nn as nn @@ -127,7 +127,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens class WanImageEmbedding(nn.Module): """Image embedding module for I2V tasks.""" - def __init__(self, in_features: int, out_features: int, pos_embed_seq_len: Optional[int] = None): + def __init__(self, in_features: int, out_features: int, pos_embed_seq_len: int | None = None): super().__init__() self.norm1 = FP32LayerNorm(in_features) @@ -159,8 +159,8 @@ def __init__( time_freq_dim: int, time_proj_dim: int, text_embed_dim: int, - image_embed_dim: Optional[int] = None, - pos_embed_seq_len: Optional[int] = None, + image_embed_dim: int | None = None, + pos_embed_seq_len: int | None = None, ): super().__init__() @@ -178,9 +178,9 @@ def forward( self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, - encoder_hidden_states_image: Optional[torch.Tensor] = None, - timestep_seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + encoder_hidden_states_image: torch.Tensor | None = None, + timestep_seq_len: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: timestep = self.timesteps_proj(timestep) if timestep_seq_len is not None: timestep = timestep.unflatten(0, (-1, timestep_seq_len)) @@ -250,7 +250,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: # Fused QKV projection qkv, _ = self.to_qkv(hidden_states) @@ -296,7 +296,7 @@ def __init__( head_dim: int, eps: float = 1e-5, dropout: float = 0.0, - added_kv_proj_dim: Optional[int] = None, + added_kv_proj_dim: int | None = None, ): super().__init__() @@ -413,7 +413,7 @@ def __init__( ffn_dim: int, num_heads: int, eps: float = 1e-6, - added_kv_proj_dim: Optional[int] = None, + added_kv_proj_dim: int | None = None, cross_attn_norm: bool = False, ): super().__init__() @@ -528,10 +528,10 @@ def __init__( num_layers: int = 40, cross_attn_norm: bool = True, eps: float = 1e-6, - image_dim: Optional[int] = None, - added_kv_proj_dim: Optional[int] = None, + image_dim: int | None = None, + added_kv_proj_dim: int | None = None, rope_max_seq_len: int = 1024, - pos_embed_seq_len: Optional[int] = None, + pos_embed_seq_len: int | None = None, ): super().__init__() @@ -598,10 +598,10 @@ def forward( hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, - encoder_hidden_states_image: Optional[torch.Tensor] = None, + encoder_hidden_states_image: torch.Tensor | None = None, return_dict: bool = True, - attention_kwargs: Optional[dict[str, Any]] = None, - ) -> Union[torch.Tensor, Transformer2DModelOutput]: + attention_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor | Transformer2DModelOutput: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index e9da3690d4..81539808f3 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -18,8 +18,8 @@ import inspect import json import os -from collections.abc import Iterable -from typing import Any, Callable +from collections.abc import Callable, Iterable +from typing import Any import torch import torch.nn as nn diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 912e3e7a16..288fa98189 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -17,7 +17,6 @@ import math from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -232,7 +231,7 @@ def forward( x: torch.Tensor, attn_mask: torch.Tensor, freqs_cis: torch.Tensor, - adaln_input: Optional[torch.Tensor] = None, + adaln_input: torch.Tensor | None = None, ): if self.modulation: assert adaln_input is not None diff --git a/vllm_omni/distributed/omni_connectors/adapter.py b/vllm_omni/distributed/omni_connectors/adapter.py index fc75996dd9..97d9765b1f 100644 --- a/vllm_omni/distributed/omni_connectors/adapter.py +++ b/vllm_omni/distributed/omni_connectors/adapter.py @@ -4,7 +4,8 @@ # and vllm_omni.entrypoints.omni_llm.py import time -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from .utils.logging import get_connector_logger @@ -90,7 +91,7 @@ def try_recv_via_connector( task: dict[str, Any], connectors: dict[Any, Any], stage_id: int, -) -> tuple[Any, Optional[dict[str, Any]]]: +) -> tuple[Any, dict[str, Any] | None]: """ Attempts to resolve input data from either connector or IPC. Returns (engine_inputs, rx_metrics) or (None, None) if failed/skipped. diff --git a/vllm_omni/distributed/omni_connectors/connectors/base.py b/vllm_omni/distributed/omni_connectors/connectors/base.py index 0c21878c02..6163e7b86b 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/base.py +++ b/vllm_omni/distributed/omni_connectors/connectors/base.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from ..utils.logging import get_connector_logger @@ -15,7 +15,7 @@ class OmniConnectorBase(ABC): @abstractmethod def put( self, from_stage: str, to_stage: str, request_id: str, data: Any - ) -> tuple[bool, int, Optional[dict[str, Any]]]: + ) -> tuple[bool, int, dict[str, Any] | None]: """Store Python object, internal serialization handled by connector. Args: @@ -32,8 +32,8 @@ def put( @abstractmethod def get( - self, from_stage: str, to_stage: str, request_id: str, metadata: Optional[dict[str, Any]] = None - ) -> Optional[tuple[Any, int]]: + self, from_stage: str, to_stage: str, request_id: str, metadata: dict[str, Any] | None = None + ) -> tuple[Any, int] | None: """Retrieve Python object and payload size (bytes). Args: diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_connector.py index 00358912a1..2e12d3a403 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_connector.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from typing import Any, Optional +from typing import Any from ..utils.logging import get_connector_logger from .base import OmniConnectorBase @@ -36,8 +36,8 @@ def __init__(self, config: dict[str, Any]): self.proto = config.get("proto", "tcp") self.rdma = config.get("rdma", "") - self.store: Optional[MooncakeDistributedStore] = None - self.pin: Optional[ReplicateConfig] = None + self.store: MooncakeDistributedStore | None = None + self.pin: ReplicateConfig | None = None self._metrics = { "puts": 0, @@ -74,7 +74,7 @@ def _init_store(self): def put( self, from_stage: str, to_stage: str, request_id: str, data: Any - ) -> tuple[bool, int, Optional[dict[str, Any]]]: + ) -> tuple[bool, int, dict[str, Any] | None]: if not self.store: logger.error("Store not initialized") return False, 0, None @@ -102,8 +102,8 @@ def put( return False, 0, None def get( - self, from_stage: str, to_stage: str, request_id: str, metadata: Optional[dict[str, Any]] = None - ) -> Optional[tuple[Any, int]]: + self, from_stage: str, to_stage: str, request_id: str, metadata: dict[str, Any] | None = None + ) -> tuple[Any, int] | None: if not self.store: logger.error("Store not initialized") return None diff --git a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py index 6921f9a56c..9f1ca994f4 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any from vllm_omni.entrypoints.stage_utils import shm_read_bytes, shm_write_bytes @@ -31,7 +31,7 @@ def __init__(self, config: dict[str, Any]): def put( self, from_stage: str, to_stage: str, request_id: str, data: Any - ) -> tuple[bool, int, Optional[dict[str, Any]]]: + ) -> tuple[bool, int, dict[str, Any] | None]: try: # Always serialize first to check size (and for SHM writing) # Note: For extremely large objects in "inline" mode (e.g. Ray), @@ -63,8 +63,8 @@ def put( return False, 0, None def get( - self, from_stage: str, to_stage: str, request_id: str, metadata: Optional[dict[str, Any]] = None - ) -> Optional[tuple[Any, int]]: + self, from_stage: str, to_stage: str, request_id: str, metadata: dict[str, Any] | None = None + ) -> tuple[Any, int] | None: if not metadata: logger.error(f"SharedMemoryConnector get called without metadata for req {request_id}") return None diff --git a/vllm_omni/distributed/omni_connectors/utils/config.py b/vllm_omni/distributed/omni_connectors/utils/config.py index e2df14e04c..9316c7e116 100644 --- a/vllm_omni/distributed/omni_connectors/utils/config.py +++ b/vllm_omni/distributed/omni_connectors/utils/config.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Any from .logging import get_connector_logger @@ -28,9 +28,9 @@ class OmniTransferConfig: # Direct mapping: (from_stage, to_stage) -> connector connectors: dict[tuple[str, str], ConnectorSpec] = field(default_factory=dict) - default_connector: Optional[ConnectorSpec] = None + default_connector: ConnectorSpec | None = None - def get_connector_for_edge(self, from_stage: str, to_stage: str) -> Optional[ConnectorSpec]: + def get_connector_for_edge(self, from_stage: str, to_stage: str) -> ConnectorSpec | None: """Get connector spec for a specific edge.""" edge_key = (from_stage, to_stage) return self.connectors.get(edge_key, self.default_connector) diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py index 3263b273ba..cef91915e9 100644 --- a/vllm_omni/distributed/omni_connectors/utils/initialization.py +++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py @@ -6,7 +6,7 @@ import json import sys from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from ..factory import OmniConnectorFactory from .config import ConnectorSpec, OmniTransferConfig @@ -21,8 +21,8 @@ def initialize_connectors_from_config( - config_path: Optional[Union[str, Path]] = None, default_shm_threshold: int = 65536 -) -> tuple[Optional[OmniTransferConfig], dict[tuple[str, str], OmniConnectorBase]]: + config_path: str | Path | None = None, default_shm_threshold: int = 65536 +) -> tuple[OmniTransferConfig | None, dict[tuple[str, str], OmniConnectorBase]]: """ Initialize connectors from configuration file. @@ -64,9 +64,7 @@ def create_connectors_from_config( return connectors -def get_connectors_config_for_stage( - transfer_config: Optional[OmniTransferConfig], stage_id: Union[str, int] -) -> dict[str, Any]: +def get_connectors_config_for_stage(transfer_config: OmniTransferConfig | None, stage_id: str | int) -> dict[str, Any]: """ Extract connector configurations relevant for a specific stage worker. @@ -98,10 +96,10 @@ def get_connectors_config_for_stage( def load_omni_transfer_config( - config_path: Optional[Union[str, Path]] = None, - config_dict: Optional[dict[str, Any]] = None, + config_path: str | Path | None = None, + config_dict: dict[str, Any] | None = None, default_shm_threshold: int = 65536, -) -> Optional[OmniTransferConfig]: +) -> OmniTransferConfig | None: """Load OmniTransferConfig from file or dict.""" if config_path is None and config_dict is None: # Even if no config provided, we might want to return a default config with SHM connectors @@ -239,8 +237,8 @@ def load_omni_transfer_config( def initialize_orchestrator_connectors( - config_path: Optional[str], worker_backend: Optional[str] = "multi_process", shm_threshold_bytes: int = 65536 -) -> tuple[Optional[OmniTransferConfig], dict[tuple[str, str], OmniConnectorBase]]: + config_path: str | None, worker_backend: str | None = "multi_process", shm_threshold_bytes: int = 65536 +) -> tuple[OmniTransferConfig | None, dict[tuple[str, str], OmniConnectorBase]]: """Initialize connectors shared at orchestrator level. Args: config_path: The path to the configuration file. @@ -259,7 +257,7 @@ def initialize_orchestrator_connectors( def get_stage_connector_config( - transfer_config: Optional[OmniTransferConfig], + transfer_config: OmniTransferConfig | None, stage_id: int, ) -> dict[str, Any]: """Return the serialized connector config payload for a specific stage.""" @@ -280,7 +278,7 @@ def get_stage_connector_config( def build_stage_connectors( stage_id: int, connectors_config: dict[str, Any], -) -> Optional[dict[tuple[str, str], Any]]: +) -> dict[tuple[str, str], Any] | None: """Instantiate OmniConnectors for a stage based on config.""" if not connectors_config: return {} diff --git a/vllm_omni/distributed/ray_utils/utils.py b/vllm_omni/distributed/ray_utils/utils.py index 8c68269cf0..07513b6601 100644 --- a/vllm_omni/distributed/ray_utils/utils.py +++ b/vllm_omni/distributed/ray_utils/utils.py @@ -4,7 +4,7 @@ import logging import os from contextlib import contextmanager -from typing import Any, Optional +from typing import Any import torch @@ -96,7 +96,7 @@ def get_ray_queue_class(): return lambda: RayQueue(maxsize=0) -def initialize_ray_cluster(address: Optional[str] = None): +def initialize_ray_cluster(address: str | None = None): if not RAY_AVAILABLE: logger.warning("Ray is not available, skipping initialization.") return @@ -107,9 +107,7 @@ def initialize_ray_cluster(address: Optional[str] = None): ray.init(address=address, ignore_reinit_error=True, runtime_env=runtime_env) -def create_placement_group( - number_of_stages: int, address: Optional[str] = None, strategy: str = "PACK" -) -> PlacementGroup: +def create_placement_group(number_of_stages: int, address: str | None = None, strategy: str = "PACK") -> PlacementGroup: """Create a placement group for the given number of stages. Args: number_of_stages: The number of stages to create the placement group for. diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index 8ce5f043a7..1b3cec6ae3 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -4,7 +4,7 @@ import time from collections.abc import Mapping -from typing import Any, Optional, Union +from typing import Any import msgspec import torch @@ -42,12 +42,12 @@ class AdditionalInformationEntry(msgspec.Struct): """ # Tensor form - tensor_data: Optional[bytes] = None - tensor_shape: Optional[list[int]] = None - tensor_dtype: Optional[str] = None + tensor_data: bytes | None = None + tensor_shape: list[int] | None = None + tensor_dtype: str | None = None # List form - list_data: Optional[list[Any]] = None + list_data: list[Any] | None = None class AdditionalInformationPayload(msgspec.Struct): @@ -74,9 +74,9 @@ class OmniEngineCoreRequest(EngineCoreRequest): """ # Optional prompt embeddings (direct-transfer version) - prompt_embeds: Optional[PromptEmbedsPayload] = None + prompt_embeds: PromptEmbedsPayload | None = None # Optional additional information dictionary (serialized) - additional_information: Optional[AdditionalInformationPayload] = None + additional_information: AdditionalInformationPayload | None = None class OmniEngineCoreOutput( @@ -88,17 +88,17 @@ class OmniEngineCoreOutput( request_id: str new_token_ids: list[int] - new_logprobs: Optional[LogprobsLists] = None - new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None + new_logprobs: LogprobsLists | None = None + new_prompt_logprobs_tensors: LogprobsTensors | None = None - pooling_output: Optional[dict[str, torch.Tensor]] = None + pooling_output: dict[str, torch.Tensor] | None = None - finish_reason: Optional[FinishReason] = None - stop_reason: Union[int, str, None] = None - events: Optional[list[EngineCoreEvent]] = None - kv_transfer_params: Optional[dict[str, Any]] = None + finish_reason: FinishReason | None = None + stop_reason: int | str | None = None + events: list[EngineCoreEvent] | None = None + kv_transfer_params: dict[str, Any] | None = None - trace_headers: Optional[Mapping[str, str]] = None + trace_headers: Mapping[str, str] | None = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 @@ -120,18 +120,18 @@ class OmniEngineCoreOutputs( # [num_reqs] outputs: list[OmniEngineCoreOutput] = [] - scheduler_stats: Optional[SchedulerStats] = None + scheduler_stats: SchedulerStats | None = None timestamp: float = 0.0 - utility_output: Optional[UtilityOutput] = None - finished_requests: Optional[set[str]] = None + utility_output: UtilityOutput | None = None + finished_requests: set[str] | None = None # In DP case, used to signal that the current wave of requests # has finished and the engines are paused. - wave_complete: Optional[int] = None + wave_complete: int | None = None # In DP case, used to signal that a request was received for an # "old" wave, so the next wave needs to be started in other engines. - start_wave: Optional[int] = None + start_wave: int | None = None def __post_init__(self): if self.timestamp == 0.0: diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index 0a38f65de7..d3e2268eaa 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeTextConfig from vllm.engine.arg_utils import EngineArgs @@ -30,8 +29,8 @@ class OmniEngineArgs(EngineArgs): stage_id: int = 0 model_stage: str = "thinker" model_arch: str = "Qwen2_5OmniForConditionalGeneration" - engine_output_type: Optional[str] = None - hf_config_name: Optional[str] = None + engine_output_type: str | None = None + hf_config_name: str | None = None def draw_hf_text_config(self, config_dict: dict) -> Qwen3OmniMoeTextConfig: # transformers' get_text_config method is used to get the text config from thinker_config. @@ -91,8 +90,8 @@ class AsyncOmniEngineArgs(AsyncEngineArgs): stage_id: int = 0 model_stage: str = "thinker" model_arch: str = "Qwen2_5OmniForConditionalGeneration" - engine_output_type: Optional[str] = None - hf_config_name: Optional[str] = None + engine_output_type: str | None = None + hf_config_name: str | None = None def draw_hf_text_config(self, config_dict: dict) -> Qwen3OmniMoeTextConfig: # transformers' get_text_config method is used to get the text config from thinker_config. diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index 63e5ecdab0..5b317e4fcf 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -1,5 +1,6 @@ from ast import Dict -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any import torch from vllm.logger import init_logger @@ -33,18 +34,18 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) - self.mm_type: Optional[str] = None - self.mm_accumulated: Optional[Dict[str, Any]] = None + self.mm_type: str | None = None + self.mm_accumulated: Dict[str, Any] | None = None @classmethod def from_new_request( cls, tokenizer: AnyTokenizer, request: EngineCoreRequest, - prompt: Optional[str], - parent_req: Optional[ParentRequest], + prompt: str | None, + parent_req: ParentRequest | None, request_index: int, - queue: Optional[Any], + queue: Any | None, log_stats: bool, ) -> "OmniRequestState": if sampling_params := request.sampling_params: @@ -93,7 +94,7 @@ def from_new_request( log_stats=log_stats, ) - def add_multimodal_tensor(self, payload: Optional[Any], mm_type: Optional[str]) -> None: + def add_multimodal_tensor(self, payload: Any | None, mm_type: str | None) -> None: if payload is None: return try: @@ -167,11 +168,11 @@ def _to_cpu(x): def make_request_output( self, new_token_ids: list[int], - pooling_output: Optional[torch.Tensor], - finish_reason: Optional[FinishReason], - stop_reason: Optional[Union[int, str]], - kv_transfer_params: Optional[dict[str, Any]] = None, - ) -> Optional[Union[OmniRequestOutput, PoolingRequestOutput]]: + pooling_output: torch.Tensor | None, + finish_reason: FinishReason | None, + stop_reason: int | str | None, + kv_transfer_params: dict[str, Any] | None = None, + ) -> OmniRequestOutput | PoolingRequestOutput | None: """Create a request output from generation results. Creates a RequestOutput or PoolingRequestOutput from the generated @@ -210,8 +211,8 @@ def make_request_output( def _new_completion_output( self, token_ids: list[int], - finish_reason: Optional[FinishReason], - stop_reason: Optional[Union[int, str]], + finish_reason: FinishReason | None, + stop_reason: int | str | None, ) -> Any: # Reuse base text/logprobs logic, then annotate with pooling_result. base_output = super()._new_completion_output(token_ids, finish_reason, stop_reason) @@ -248,7 +249,7 @@ def __init__( self, tokenizer: AnyTokenizer, log_stats: bool, - engine_core_output_type: Optional[str] = None, + engine_core_output_type: str | None = None, ): """Initialize the multimodal output processor. @@ -281,10 +282,10 @@ def register_handler(self, modality: str, handler: Callable[[EngineCoreOutput], def add_request( self, request: EngineCoreRequest, - prompt: Optional[str], - parent_req: Optional[ParentRequest] = None, + prompt: str | None, + parent_req: ParentRequest | None = None, request_index: int = 0, - queue: Optional[RequestOutputCollector] = None, + queue: RequestOutputCollector | None = None, ) -> None: """Add a new request to be processed. @@ -322,8 +323,8 @@ def add_request( def process_outputs( self, engine_core_outputs: list[EngineCoreOutput], - engine_core_timestamp: Optional[float] = None, - iteration_stats: Optional[IterationStats] = None, + engine_core_timestamp: float | None = None, + iteration_stats: IterationStats | None = None, ) -> OutputProcessorOutput: """Process engine core outputs into request outputs. @@ -524,7 +525,7 @@ def _process_pooling_output(self, eco: EngineCoreOutput) -> None: except Exception: pass - def _extract_from_multimodal_outputs(self, eco: EngineCoreOutput, keys: tuple[str, ...]) -> Optional[torch.Tensor]: + def _extract_from_multimodal_outputs(self, eco: EngineCoreOutput, keys: tuple[str, ...]) -> torch.Tensor | None: mm = getattr(eco, "multimodal_outputs", None) if not isinstance(mm, dict): return None diff --git a/vllm_omni/engine/processor.py b/vllm_omni/engine/processor.py index a93c98b7cd..0474f9f4ab 100644 --- a/vllm_omni/engine/processor.py +++ b/vllm_omni/engine/processor.py @@ -1,6 +1,6 @@ import time from collections.abc import Mapping -from typing import Any, Optional, Union +from typing import Any import torch from vllm.config import VllmConfig @@ -90,14 +90,14 @@ def process_inputs( self, request_id: str, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - trace_headers: Optional[Mapping[str, str]] = None, + params: SamplingParams | PoolingParams, + arrival_time: float | None = None, + lora_request: LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + trace_headers: Mapping[str, str] | None = None, priority: int = 0, - data_parallel_rank: Optional[int] = None, - ) -> tuple[Optional[str], OmniEngineCoreRequest]: + data_parallel_rank: int | None = None, + ) -> tuple[str | None, OmniEngineCoreRequest]: """Process input prompt into an engine core request. Converts a prompt (text, tokens, or multimodal) into an @@ -185,7 +185,7 @@ def process_inputs( # discriminated unions of TypedDicts, because of how it handles # inheritance of TypedDict. If we explicitly extract the items we want # we can avoid type errors from using `dict.get` later in the method. - prompt_str: Optional[str] = None if decoder_inputs["type"] == "embeds" else decoder_inputs.get("prompt") + prompt_str: str | None = None if decoder_inputs["type"] == "embeds" else decoder_inputs.get("prompt") prompt_token_ids = decoder_inputs["prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs["type"] == "embeds" else None @@ -205,7 +205,7 @@ def process_inputs( pooling_params = params.clone() # Multimodal related. - mm_features: Optional[list[MultiModalFeatureSpec]] = None + mm_features: list[MultiModalFeatureSpec] | None = None if decoder_inputs["type"] == "multimodal": decoder_mm_inputs = decoder_inputs["mm_kwargs"] @@ -230,8 +230,8 @@ def process_inputs( # Serialize prompt_embeds and additional_information if provided # (direct-transfer path) - prompt_embeds_payload: Optional[PromptEmbedsPayload] = None - additional_information_payload: Optional[AdditionalInformationPayload] = None + prompt_embeds_payload: PromptEmbedsPayload | None = None + additional_information_payload: AdditionalInformationPayload | None = None if "prompt_embeds" in decoder_inputs: # type: ignore[operator] pe: torch.Tensor = decoder_inputs["prompt_embeds"] # type: ignore[index] if pe.ndim != 2: diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 489bbfef64..7de5ddd322 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -6,7 +6,7 @@ from argparse import Namespace from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Optional, Union +from typing import Any import torch @@ -261,11 +261,11 @@ async def generate( self, prompt: PromptType, request_id: str, - sampling_params_list: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, + sampling_params_list: SamplingParams | Sequence[SamplingParams] | None = None, + lora_request: LoRARequest | None = None, + trace_headers: Mapping[str, str] | None = None, priority: int = 0, - data_parallel_rank: Optional[int] = None, + data_parallel_rank: int | None = None, ) -> AsyncGenerator[OmniRequestOutput, None]: """Generate outputs for the given prompt asynchronously. @@ -577,7 +577,7 @@ def errored(self) -> bool: def dead_error(self) -> BaseException: return EngineDeadError() - async def abort(self, request_id: Union[str, Iterable[str]]) -> None: + async def abort(self, request_id: str | Iterable[str]) -> None: pass async def get_vllm_config(self) -> VllmConfig: @@ -620,13 +620,13 @@ async def check_health(self) -> None: async def reset_mm_cache(self) -> None: pass - async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: + async def reset_prefix_cache(self, device: Device | None = None) -> None: pass async def sleep(self, level: int = 1) -> None: pass - async def wake_up(self, tags: Optional[list[str]] = None) -> None: + async def wake_up(self, tags: list[str] | None = None) -> None: pass async def is_sleeping(self) -> bool: @@ -689,8 +689,8 @@ def __init__( use_cached_outputs: bool = False, log_requests: bool = True, start_engine_loop: bool = True, - stat_loggers: Optional[list[StatLoggerFactory]] = None, - client_addresses: Optional[dict[str, str]] = None, + stat_loggers: list[StatLoggerFactory] | None = None, + client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, ) -> None: @@ -771,7 +771,7 @@ def __init__( ) # Loggers. - self.logger_manager: Optional[StatLoggerManager] = None + self.logger_manager: StatLoggerManager | None = None if self.log_stats: self.logger_manager = StatLoggerManager( vllm_config=vllm_config, @@ -782,7 +782,7 @@ def __init__( ) self.logger_manager.log_engine_initialized() - self.output_handler: Optional[asyncio.Task] = None + self.output_handler: asyncio.Task | None = None try: # Start output handler eagerly if we are in the asyncio eventloop. asyncio.get_running_loop() @@ -819,10 +819,10 @@ def from_vllm_config( engine_args: AsyncOmniEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, + stat_loggers: list[StatLoggerFactory] | None = None, enable_log_requests: bool = False, disable_log_stats: bool = False, - client_addresses: Optional[dict[str, str]] = None, + client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, disable_log_requests: bool = True, # Deprecated, will be removed diff --git a/vllm_omni/entrypoints/chat_utils.py b/vllm_omni/entrypoints/chat_utils.py index 3f84eb0db2..6517ae666f 100644 --- a/vllm_omni/entrypoints/chat_utils.py +++ b/vllm_omni/entrypoints/chat_utils.py @@ -1,5 +1,5 @@ from collections.abc import Awaitable, Iterable -from typing import Any, Optional, Union, cast +from typing import Any, cast import numpy as np from openai.types.chat import ChatCompletionContentPartTextParam @@ -33,13 +33,13 @@ def create_parser(self) -> "BaseMultiModalContentParser": class OmniAsyncMultiModalContentParser(AsyncMultiModalContentParser): def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: super().__init__(tracker=tracker) - self._mm_processor_kwargs: Optional[dict[str, Any]] = None + self._mm_processor_kwargs: dict[str, Any] | None = None - def set_mm_processor_kwargs(self, mm_processor_kwargs: Optional[dict[str, Any]]) -> None: + def set_mm_processor_kwargs(self, mm_processor_kwargs: dict[str, Any] | None) -> None: """Set mm_processor_kwargs for use in parsing.""" self._mm_processor_kwargs = mm_processor_kwargs - def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: + def parse_video(self, video_url: str | None, uuid: str | None = None) -> None: video = self._connector.fetch_video_async(video_url=video_url) if video_url else None placeholder = self._tracker.add("video", video, uuid) @@ -51,7 +51,7 @@ def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> N audio_placeholder = self._tracker.add("audio", audio_coro, uuid) self._add_placeholder("audio", audio_placeholder) - async def _extract_audio_from_video_async(self, video_url: str) -> tuple[np.ndarray, Union[int, float]]: + async def _extract_audio_from_video_async(self, video_url: str) -> tuple[np.ndarray, int | float]: """ Extract audio from video URL using librosa. Returns tuple of (audio_array, sample_rate) compatible with audio format. @@ -79,7 +79,7 @@ def _write_temp_file_sync(data: bytes, suffix: str) -> str: temp_file.write(data) return temp_file.name - def _load_audio_sync(file_path: str) -> tuple[np.ndarray, Union[int, float]]: + def _load_audio_sync(file_path: str) -> tuple[np.ndarray, int | float]: """Synchronous audio loading with librosa - runs in thread pool.""" import librosa @@ -131,11 +131,11 @@ def parse_chat_messages_futures( model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, - mm_processor_kwargs: Optional[dict[str, Any]] = None, + mm_processor_kwargs: dict[str, Any] | None = None, ) -> tuple[ list[ConversationMessage], - Awaitable[Optional[MultiModalDataDict]], - Optional[MultiModalUUIDDict], + Awaitable[MultiModalDataDict | None], + MultiModalUUIDDict | None, ]: conversation: list[ConversationMessage] = [] mm_tracker = OmniAsyncMultiModalItemTracker(model_config, tokenizer) @@ -165,7 +165,7 @@ def _parse_chat_message_content( mm_tracker: BaseMultiModalItemTracker, content_format: _ChatTemplateContentFormat, interleave_strings: bool, - mm_processor_kwargs: Optional[dict[str, Any]] = None, + mm_processor_kwargs: dict[str, Any] | None = None, ) -> list[ConversationMessage]: role = message["role"] content = message.get("content") @@ -210,7 +210,7 @@ def _parse_chat_message_content_parts( *, wrap_dicts: bool, interleave_strings: bool, - mm_processor_kwargs: Optional[dict[str, Any]] = None, + mm_processor_kwargs: dict[str, Any] | None = None, ) -> list[ConversationMessage]: content = list[_ContentPart]() diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 0e360e09c4..6cf64a97f8 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -4,7 +4,7 @@ import uuid from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Optional, Union +from typing import Any import cloudpickle from pydantic import ValidationError @@ -89,9 +89,9 @@ class OmniLLM: def __init__( self, model: str, - stage_configs_path: Optional[str] = None, + stage_configs_path: str | None = None, log_stats: bool = False, - log_file: Optional[str] = None, + log_file: str | None = None, init_sleep_seconds: int = 20, shm_threshold_bytes: int = 65536, batch_timeout: int = 10, @@ -228,8 +228,8 @@ def __del__(self) -> None: # best-effort def generate( self, - prompts: Union[PromptType, Sequence[PromptType]], - sampling_params_list: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, + prompts: PromptType | Sequence[PromptType], + sampling_params_list: SamplingParams | Sequence[SamplingParams] | None = None, ) -> list[OmniRequestOutput]: """Generate outputs for the given prompts. @@ -262,8 +262,8 @@ def generate( def _run_generation( self, - prompts: Union[PromptType, Sequence[PromptType]], - sampling_params_list: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, + prompts: PromptType | Sequence[PromptType], + sampling_params_list: SamplingParams | Sequence[SamplingParams] | None = None, ) -> list[OmniRequestOutput]: logger.debug("[Orchestrator] generate() called") if sampling_params_list is None: @@ -559,9 +559,9 @@ class OmniStageLLM(LLM): def __init__( self, model: str, - compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, - hf_overrides: Optional[dict[str, Any]] = None, - structured_outputs_config: Optional[Union[dict[str, Any], StructuredOutputsConfig]] = None, + compilation_config: int | dict[str, Any] | CompilationConfig | None = None, + hf_overrides: dict[str, Any] | None = None, + structured_outputs_config: dict[str, Any] | StructuredOutputsConfig | None = None, **kwargs: Any, ): """LLM constructor.""" @@ -633,7 +633,7 @@ def __init__( self.engine_class = type(self.llm_engine) self.request_counter = Counter() - self.default_sampling_params: Union[dict[str, Any], None] = None + self.default_sampling_params: dict[str, Any] | None = None supported_tasks = self.llm_engine.get_supported_tasks() # type: ignore diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 5fc05dfba8..7812af68b7 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -16,7 +16,7 @@ import multiprocessing as mp import os import sys -from typing import Any, Optional, Union +from typing import Any from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor @@ -81,10 +81,10 @@ def __init__(self, stage_config: Any): default_sampling_params = getattr(stage_config, "default_sampling_params", {}) self.default_sampling_params = SamplingParams(**_to_dict(default_sampling_params)) # Runtime orchestration state (added) - self._in_q: Optional[mp.Queue] = None - self._out_q: Optional[mp.Queue] = None - self._proc: Optional[mp.Process] = None - self._log_file: Optional[str] = None + self._in_q: mp.Queue | None = None + self._out_q: mp.Queue | None = None + self._proc: mp.Process | None = None + self._log_file: str | None = None self._shm_threshold_bytes: int = 65536 self._logger = logging.getLogger(__name__) @@ -160,11 +160,11 @@ def init_stage_worker( model: str, *, is_async: bool = False, - log_file: Optional[str] = None, + log_file: str | None = None, shm_threshold_bytes: int = 65536, - ctx: Optional[mp.context.BaseContext] = None, + ctx: mp.context.BaseContext | None = None, batch_timeout: int = 10, - connectors_config: Optional[dict] = None, + connectors_config: dict | None = None, worker_backend: str = "multi_process", **kwargs: Any, ) -> None: @@ -303,7 +303,7 @@ def submit(self, payload: dict[str, Any]) -> None: assert self._in_q is not None self._in_q.put(payload) - def try_collect(self) -> Optional[dict[str, Any]]: + def try_collect(self) -> dict[str, Any] | None: """Try to collect a result from the stage worker without blocking. Returns: @@ -317,8 +317,8 @@ def try_collect(self) -> Optional[dict[str, Any]]: return None def process_engine_inputs( - self, stage_list: list[Any], prompt: Union[OmniTokensPrompt, TextPrompt] = None - ) -> list[Union[OmniTokensPrompt, TextPrompt]]: + self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None + ) -> list[OmniTokensPrompt | TextPrompt]: """Process engine inputs for this stage from upstream stage outputs. Derives inputs for this stage from outputs of upstream stages. @@ -372,7 +372,7 @@ def _stage_worker( stage_payload: dict[str, Any], in_q: mp.Queue, out_q: mp.Queue, - log_file: Optional[str] = None, + log_file: str | None = None, batch_timeout: int = 10, ) -> None: """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index c9b8dd2a15..b40a04e287 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Any, Optional +from typing import Any import vllm.envs as envs from fastapi import Depends, HTTPException, Request @@ -111,8 +111,8 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, async def build_async_omni( args: Namespace, *, - disable_frontend_multiprocessing: Optional[bool] = None, - client_config: Optional[dict[str, Any]] = None, + disable_frontend_multiprocessing: bool | None = None, + client_config: dict[str, Any] | None = None, ) -> AsyncIterator[EngineClient]: """Build an AsyncOmni instance from command-line arguments. @@ -152,7 +152,7 @@ async def build_async_omni_from_stage_config( args: Namespace, *, disable_frontend_multiprocessing: bool = False, - client_config: Optional[dict[str, Any]] = None, + client_config: dict[str, Any] | None = None, ) -> AsyncIterator[EngineClient]: """Create AsyncOmni from stage configuration. @@ -182,7 +182,7 @@ async def build_async_omni_from_stage_config( "To disable frontend multiprocessing, set VLLM_USE_V1=0." ) - async_omni: Optional[EngineClient] = None + async_omni: EngineClient | None = None try: async_omni = AsyncOmni(model=args.model, cli_args=args) @@ -257,7 +257,7 @@ async def omni_init_app_state( ) if args.tool_server == "demo": - tool_server: Optional[ToolServer] = DemoToolServer() + tool_server: ToolServer | None = DemoToolServer() assert isinstance(tool_server, DemoToolServer) await tool_server.init_and_validate() elif args.tool_server: @@ -314,7 +314,7 @@ async def omni_init_app_state( state.server_load_metrics = 0 -def Omnichat(request: Request) -> Optional[OmniOpenAIServingChat]: +def Omnichat(request: Request) -> OmniOpenAIServingChat | None: return request.app.state.openai_serving_chat diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 0cd522e14b..90ba059a3f 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -3,10 +3,10 @@ import json import time import uuid -from collections.abc import AsyncGenerator, AsyncIterator, Sequence +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence from datetime import datetime, timedelta, timezone from io import BytesIO -from typing import Any, Callable, Optional, Union +from typing import Any import jinja2 from fastapi import Request @@ -78,8 +78,8 @@ class OmniOpenAIServingChat(OpenAIServingChat): async def create_chat_completion( self, request: ChatCompletionRequest, - raw_request: Optional[Request] = None, - ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ErrorResponse]: + raw_request: Request | None = None, + ) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse: """ Chat Completion API similar to OpenAI's API. @@ -236,17 +236,17 @@ async def create_chat_completion( async def _preprocess_chat( self, - request: Union[ChatLikeRequest, ResponsesRequest], + request: ChatLikeRequest | ResponsesRequest, tokenizer: AnyTokenizer, messages: list[ChatCompletionMessageParam], - chat_template: Optional[str], + chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, add_generation_prompt: bool = True, continue_final_message: bool = False, - tool_dicts: Optional[list[dict[str, Any]]] = None, - documents: Optional[list[dict[str, str]]] = None, - chat_template_kwargs: Optional[dict[str, Any]] = None, - tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, + tool_dicts: list[dict[str, Any]] | None = None, + documents: list[dict[str, str]] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, + tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None, add_special_tokens: bool = False, ) -> tuple[ list[ConversationMessage], @@ -279,7 +279,7 @@ async def _preprocess_chat( ) _chat_template_kwargs.update(chat_template_kwargs or {}) - request_prompt: Union[str, list[int]] + request_prompt: str | list[int] if tokenizer is None: request_prompt = "placeholder" @@ -365,9 +365,9 @@ def _to_sampling_params_list(self, sampling_params_list: list[dict]) -> list[Sam def _log_inputs( self, request_id: str, - inputs: Union[RequestPrompt, PromptType], - params_list: Optional[list[SamplingParams]], - lora_request: Optional[LoRARequest], + inputs: RequestPrompt | PromptType, + params_list: list[SamplingParams] | None, + lora_request: LoRARequest | None, ) -> None: if self.request_logger is None: return @@ -399,9 +399,9 @@ async def chat_completion_full_generator( conversation: list[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - ) -> Union[ErrorResponse, ChatCompletionResponse]: + ) -> ErrorResponse | ChatCompletionResponse: created_time = int(time.time()) - final_res: Optional[RequestOutput] = None + final_res: RequestOutput | None = None final_outputs: list[OmniRequestOutput] = [] try: @@ -698,7 +698,7 @@ def _create_text_choice( choices.append(choice_data) if request.echo: - last_msg_content: Union[str, list[dict[str, str]]] = "" + last_msg_content: str | list[dict[str, str]] = "" if conversation and "content" in conversation[-1] and conversation[-1].get("role") == role: last_msg_content = conversation[-1]["content"] or "" if isinstance(last_msg_content, list): diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index f7e7856a7b..e092325f8f 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -2,7 +2,7 @@ from collections import Counter from dataclasses import asdict, is_dataclass from pathlib import Path -from typing import Any, Optional +from typing import Any from omegaconf import OmegaConf from vllm.transformers_utils.config import get_config @@ -97,7 +97,7 @@ def resolve_model_config_path(model: str) -> str: return str(stage_config_path) -def load_stage_configs_from_model(model: str, base_engine_args: Optional[dict] = None) -> list: +def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list: """Load stage configurations from model's default config file. Loads stage configurations based on the model type and device type. @@ -120,7 +120,7 @@ def load_stage_configs_from_model(model: str, base_engine_args: Optional[dict] = return stage_configs -def load_stage_configs_from_yaml(config_path: str, base_engine_args: Optional[dict] = None) -> list: +def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None = None) -> list: """Load stage configurations from a YAML file. Args: diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index 01a70d6a82..ec291e43e2 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any try: from typing import NotRequired @@ -73,10 +73,10 @@ class OmniEmbedsPrompt(EmbedsPrompt): def token_inputs_omni( prompt_token_ids: list[int], - prompt: Optional[str] = None, - cache_salt: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - additional_information: Optional[dict[str, Any]] = None, + prompt: str | None = None, + cache_salt: str | None = None, + prompt_embeds: torch.Tensor | None = None, + additional_information: dict[str, Any] | None = None, ) -> OmniTokenInputs: """Construct token inputs with optional embeddings and metadata. diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py index 3078d3457d..262b18bf7f 100644 --- a/vllm_omni/inputs/preprocess.py +++ b/vllm_omni/inputs/preprocess.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any from typing_extensions import assert_never from vllm.inputs.data import SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt @@ -23,15 +23,15 @@ class OmniInputPreprocessor(InputPreprocessor): def _process_tokens( self, parsed_content: TokensPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, - ) -> Union[OmniTokenInputs, MultiModalInputs]: + mm_uuids: MultiModalUUIDDict | None = None, + ) -> OmniTokenInputs | MultiModalInputs: prompt_token_ids = self._truncate_inputs(parsed_content["prompt_token_ids"], tokenization_kwargs) prompt_embeds = parsed_content.get("prompt_embeds") additional_information = parsed_content.get("additional_information") - inputs: Union[OmniTokenInputs, MultiModalInputs] + inputs: OmniTokenInputs | MultiModalInputs if multi_modal_data := parsed_content.get("multi_modal_data"): inputs = self._process_multimodal( prompt_token_ids, @@ -55,9 +55,9 @@ def _process_tokens( def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> SingletonInputs: """ Extract the singleton inputs from a prompt. diff --git a/vllm_omni/model_executor/layers/mrope.py b/vllm_omni/model_executor/layers/mrope.py index 4636534702..9ca6a36e23 100644 --- a/vllm_omni/model_executor/layers/mrope.py +++ b/vllm_omni/model_executor/layers/mrope.py @@ -1,5 +1,4 @@ import itertools -from typing import Optional, Union import numpy as np import torch @@ -50,11 +49,11 @@ def __init__( base: float, is_neox_style: bool, dtype: torch.dtype, - mrope_section: Optional[list[int]] = None, + mrope_section: list[int] | None = None, mrope_interleaved: bool = False, # YaRN parameters. *, - scaling_factor: Optional[float] = None, + scaling_factor: float | None = None, extrapolation_factor: float = 1, attn_factor: float = 1, beta_fast: int = 32, @@ -85,8 +84,8 @@ def forward( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """PyTorch-native implementation equivalent to forward(). Args: @@ -138,12 +137,12 @@ def get_input_positions( cls, input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - second_per_grid_ts: Optional[list[float]], + image_grid_thw: list[list[int]] | torch.Tensor | None, + video_grid_thw: list[list[int]] | torch.Tensor | None, + second_per_grid_ts: list[float] | None, context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[list[list[int]], int]: """Get mrope input positions and delta value.""" @@ -171,12 +170,12 @@ def get_input_positions_tensor( cls, input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float], context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: from vllm.transformers_utils.config import thinker_uses_mrope @@ -218,10 +217,10 @@ def _glm4v_get_input_positions_tensor( cls, input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, context_len: int = 0, - seq_len: Optional[int] = None, + seq_len: int | None = None, ) -> tuple[torch.Tensor, int]: """Get mrope input positions and delta value for GLM4V.""" @@ -319,11 +318,11 @@ def _vl_get_input_positions_tensor( cls, input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float], context_len: int = 0, - seq_len: Optional[int] = None, + seq_len: int | None = None, ) -> tuple[torch.Tensor, int]: """Get mrope input positions and delta value.""" @@ -417,12 +416,12 @@ def _omni_get_input_positions_tensor( cls, input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: Optional[list[float]] = None, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float] | None = None, context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: """Get mrope input positions and delta value (Qwen2.5-Omni version). @@ -642,7 +641,7 @@ def omni_get_updates_use_audio_in_video( cls, thinker_config: PretrainedConfig, audio_len: int, - video_grid_thw: Union[list[int], torch.Tensor], + video_grid_thw: list[int] | torch.Tensor, video_second_per_grid_t: float, ) -> list[int]: """Get video prompt updates when `use_audio_in_video` is True. diff --git a/vllm_omni/model_executor/model_loader/weight_utils.py b/vllm_omni/model_executor/model_loader/weight_utils.py index bd27781a99..7432ad9a2a 100644 --- a/vllm_omni/model_executor/model_loader/weight_utils.py +++ b/vllm_omni/model_executor/model_loader/weight_utils.py @@ -1,6 +1,5 @@ import time from pathlib import Path -from typing import Optional, Union import huggingface_hub import vllm.envs as envs @@ -17,10 +16,10 @@ def download_weights_from_hf_specific( model_name_or_path: str, - cache_dir: Optional[str], + cache_dir: str | None, allow_patterns: list[str], - revision: Optional[str] = None, - ignore_patterns: Optional[Union[str, list[str]]] = None, + revision: str | None = None, + ignore_patterns: str | list[str] | None = None, ) -> str: """Download model weights from Hugging Face Hub. Users can specify the allow_patterns to download only the necessary weights. diff --git a/vllm_omni/model_executor/models/output_templates.py b/vllm_omni/model_executor/models/output_templates.py index 6fcd53eb30..2ed2098065 100644 --- a/vllm_omni/model_executor/models/output_templates.py +++ b/vllm_omni/model_executor/models/output_templates.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Optional +from typing import NamedTuple import torch from vllm.sequence import IntermediateTensors @@ -8,6 +8,6 @@ class OmniOutput(NamedTuple): """Output from the merged Omni model containing both text and audio.""" text_hidden_states: torch.Tensor - multimodal_outputs: Optional[dict] = None - intermediate_tensors: Optional[IntermediateTensors] = None - next_token_id: Optional[torch.Tensor] = None + multimodal_outputs: dict | None = None + intermediate_tensors: IntermediateTensors | None = None + next_token_id: torch.Tensor | None = None diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py index f6f5608cef..331594f959 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py @@ -2,7 +2,6 @@ import os from collections.abc import Iterable from functools import cached_property -from typing import Optional, Union import numpy as np import torch @@ -136,9 +135,9 @@ def _module_device(module: nn.Module) -> torch.device: def move_submodules_to_devices( self, *, - thinker_device: Optional[Union[str, torch.device]] = None, - talker_device: Optional[Union[str, torch.device]] = None, - token2wav_device: Optional[Union[str, torch.device]] = None, + thinker_device: str | torch.device | None = None, + talker_device: str | torch.device | None = None, + token2wav_device: str | torch.device | None = None, ) -> None: """Optionally move thinker/talker/token2wav to different devices. @@ -182,17 +181,17 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, generate_audio: bool = True, voice_type: str = "Chelsie", - codec: Optional[torch.Tensor] = None, - sampling_metadata: Optional[SamplingMetadata] = None, - logits_index: Optional[int] = None, + codec: torch.Tensor | None = None, + sampling_metadata: SamplingMetadata | None = None, + logits_index: int | None = None, sampler=None, - additional_information: Optional[dict[str, object]] = None, + additional_information: dict[str, object] | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors, OmniOutput]: + ) -> torch.Tensor | IntermediateTensors | OmniOutput: """ Workflow: 1) Thinker: multimodal understanding → text hidden states. @@ -292,9 +291,9 @@ def forward( inputs_embeds = self.talker.get_input_embeddings(input_ids) # ------- Request-scoped additional information (no cross-request concat) ------- - request_ids: Optional[list[str]] = kwargs.get("request_ids") # ordered - request_token_spans: Optional[list[tuple[int, int]]] = kwargs.get("request_token_spans") - addi_by_req: Optional[dict] = kwargs.get("additional_information_by_req_id") + request_ids: list[str] | None = kwargs.get("request_ids") # ordered + request_token_spans: list[tuple[int, int]] | None = kwargs.get("request_token_spans") + addi_by_req: dict | None = kwargs.get("additional_information_by_req_id") runtime_addi = kwargs.get("runtime_additional_information") # Normalize runtime_addi into a mapping by request_id for convenience @@ -796,7 +795,7 @@ def _thinker_to_talker_decode_one_step( ) # for decode return output_token_ids, processed_output_token_embeds - def compute_logits(self, hidden_states: Union[torch.Tensor, OmniOutput]) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor | OmniOutput) -> torch.Tensor | None: # Handle OmniOutput type if isinstance(hidden_states, OmniOutput): hidden_states = hidden_states.text_hidden_states @@ -808,7 +807,7 @@ def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: + ) -> SamplerOutput | None: # Use thinker model for sampling return self.model.sample(logits, sampling_metadata) @@ -891,7 +890,7 @@ def _init_token2wav_model(self, hf_model_folder): key = os.path.basename(f).split("_")[0].lower() self._token2wav_ref_mels[key] = torch.as_tensor(np.load(f), device=device) - def _codec_to_audio(self, codec_tokens: torch.Tensor, voice_type: str = "default") -> Optional[torch.Tensor]: + def _codec_to_audio(self, codec_tokens: torch.Tensor, voice_type: str = "default") -> torch.Tensor | None: if self.token2wav is None: self._init_token2wav_model() if self.token2wav is None: diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py index 4d906f79a6..1a0edbed97 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py @@ -1,6 +1,5 @@ from collections.abc import Iterable from functools import cached_property -from typing import Optional, Union import torch import torch.nn as nn @@ -107,7 +106,7 @@ def sampler(self): def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None and len(multimodal_embeddings) != 0: @@ -129,10 +128,10 @@ def forward( self, input_ids: torch.Tensor = None, positions: torch.Tensor = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: assert input_ids is not None or inputs_embeds is not None, "input_ids or inputs_embeds must be provided" # forward_context: ForwardContext = get_forward_context() # unused variable @@ -158,7 +157,7 @@ def bad_word_processor(self, logits: torch.Tensor) -> torch.Tensor: logits[..., bos_id] = -1e9 return logits - def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: logits = self.language_model.compute_logits(hidden_states) logits = self.bad_word_processor(logits) return logits @@ -167,7 +166,7 @@ def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: + ) -> SamplerOutput | None: return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py index 28107cb455..405c18e748 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py @@ -1,7 +1,7 @@ from collections import defaultdict -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Callable, Literal, Optional, Union +from typing import Annotated, Any, Literal import torch import torch.nn as nn @@ -86,7 +86,7 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema): type: Literal["audio_features"] input_features: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("nmb", "tsl"), ] @@ -155,7 +155,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): def _parse_audio_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + data: dict[str, torch.Tensor] | ModalityData[ImageItem], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( @@ -185,7 +185,7 @@ def get_feature_extractor(self, **kwargs: object): assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None, "image": None, "video": None} @@ -608,7 +608,7 @@ def _derive_audio_from_video_placeholders( def _apply_hf_processor_main( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], @@ -676,7 +676,7 @@ def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str, dim: int else: return torch.concat(mm_input, dim=dim) - def _parse_and_validate_audio_input(self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]: + def _parse_and_validate_audio_input(self, **kwargs: object) -> Qwen2_5OmniAudioFeatureInputs | None: input_audio_features = kwargs.pop("input_audio_features", None) audio_feature_lengths = kwargs.pop("audio_feature_lengths", None) feature_attention_mask = kwargs.pop("feature_attention_mask", None) @@ -699,7 +699,7 @@ def _parse_and_validate_audio_input(self, **kwargs: object) -> Optional[Qwen2_5O def _parse_and_validate_image_input( self, **kwargs: dict[str, Any], - ) -> Optional[Qwen2_5_VLImageInputs]: + ) -> Qwen2_5_VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -731,7 +731,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: dict[str, Any], - ) -> Optional[Qwen2_5_VLVideoInputs]: + ) -> Qwen2_5_VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -841,7 +841,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|IMAGE|><|vision_end|>" if modality.startswith("video"): @@ -940,7 +940,7 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None and len(multimodal_embeddings) != 0: @@ -958,7 +958,7 @@ def get_input_embeddings( ) return inputs_embeds - def get_multimodal_embeddings_v0(self, **kwargs: object) -> Optional[NestedTensors]: + def get_multimodal_embeddings_v0(self, **kwargs: object) -> NestedTensors | None: audio_input = self._parse_and_validate_audio_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) @@ -982,7 +982,7 @@ def get_multimodal_embeddings_v0(self, **kwargs: object) -> Optional[NestedTenso def get_input_embeddings_v0( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, + multimodal_embeddings: NestedTensors | None = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is None or len(multimodal_embeddings) == 0: @@ -1002,10 +1002,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -1031,7 +1031,7 @@ def forward( ) return text_inputs_embeds, hidden_states.unsqueeze(0) # (1, S, D) - def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py index a3241928ec..6ea444916d 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_token2wav.py @@ -4,7 +4,6 @@ import math from collections.abc import Iterable -from typing import Optional, Union import numpy as np import torch @@ -395,9 +394,9 @@ def forward( speaker_embedding: torch.Tensor, condition_vector: torch.Tensor, code_embed: torch.Tensor, - drop_audio_cond: Optional[bool] = False, - code_embed_uncond: Optional[bool] = None, - apply_cfg: Optional[bool] = True, + drop_audio_cond: bool | None = False, + code_embed_uncond: bool | None = None, + apply_cfg: bool | None = True, ): if apply_cfg: hidden_states = torch.cat([hidden_states, hidden_states], dim=0) @@ -1324,7 +1323,7 @@ def fast_block_sample( y0: torch.Tensor, num_steps: int = 10, guidance_scale: float = 0.5, - sway_coefficient: Optional[float] = -1.0, + sway_coefficient: float | None = -1.0, ) -> torch.Tensor: """ Block-wise ODE sampling starting from provided initial state y0. @@ -1521,7 +1520,7 @@ def process_little_chunk( steps: int, prev_generated: torch.Tensor, finished: bool = False, - ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + ) -> tuple[torch.Tensor | None, torch.Tensor]: """Streaming per small chunk: returns (mel_or_None, audio_slice).""" start_index = max(i * self.chunk_size - self.past_cache_size, 0) end_index = min( @@ -1560,9 +1559,9 @@ def process_chunk( y_all: torch.Tensor, i: int, steps: int, - prev_generated: Union[torch.Tensor, list[torch.Tensor]], + prev_generated: torch.Tensor | list[torch.Tensor], finished: bool = False, - ) -> tuple[Union[torch.Tensor, list[torch.Tensor]], torch.Tensor]: + ) -> tuple[torch.Tensor | list[torch.Tensor], torch.Tensor]: """High-level chunk API aligning to qwen2_code2wav_dit signature.""" if not isinstance(prev_generated, torch.Tensor): prev_generated = prev_generated[0] if len(prev_generated) > 0 else None @@ -1585,7 +1584,7 @@ def _process_chunk_for_50hz( start_index: int, end_index: int, finished: bool, - prev_generated: Optional[torch.Tensor], + prev_generated: torch.Tensor | None, generated: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -1671,7 +1670,7 @@ def forward( num_steps: int = 10, guidance_scale: float = 0.5, sway_coefficient: float = -1.0, - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: IntermediateTensors | None = None, **kwargs, ) -> torch.Tensor: # Delegate to HF token2wav model @@ -1685,7 +1684,7 @@ def forward( **kwargs, ) - def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: # Token2Wav outputs waveform; logits are not applicable return hidden_states @@ -1693,7 +1692,7 @@ def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: + ) -> SamplerOutput | None: return None def load_weights_without_buffers(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -1782,7 +1781,7 @@ def process_chunk_dit_batch( ) @torch.inference_mode() - def process_chunk_bigvgan_batch(self, mel_batch: torch.Tensor) -> Optional[torch.Tensor]: + def process_chunk_bigvgan_batch(self, mel_batch: torch.Tensor) -> torch.Tensor | None: # BigVGAN is not part of this wrapper; return None for parity. return None @@ -1797,7 +1796,7 @@ def process_little_chunk( steps: int, prev_generated: torch.Tensor, finished: bool = False, - ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + ) -> tuple[torch.Tensor | None, torch.Tensor]: mel = self.token2wav( code=codec_all, conditioning=conditioning, @@ -1815,9 +1814,9 @@ def process_chunk( y_all: torch.Tensor, i: int, steps: int, - prev_generated: Union[torch.Tensor, list[torch.Tensor]], + prev_generated: torch.Tensor | list[torch.Tensor], finished: bool = False, - ) -> tuple[Union[torch.Tensor, list[torch.Tensor]], torch.Tensor]: + ) -> tuple[torch.Tensor | list[torch.Tensor], torch.Tensor]: _mel, out = self.process_little_chunk( conditioning=conditioning, reference_mel=reference_mel, diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py index c0145860b6..f6f05357a5 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py @@ -1,5 +1,4 @@ from collections.abc import Iterable -from typing import Optional, Union import torch from torch import nn @@ -43,7 +42,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -80,10 +79,10 @@ def __init__( num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - head_dim: Optional[int] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, + head_dim: int | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + rope_scaling: tuple | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: @@ -161,8 +160,8 @@ class Qwen2DecoderLayer(nn.Module): def __init__( self, config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -207,7 +206,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -302,9 +301,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -428,16 +427,16 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -445,7 +444,7 @@ def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: + ) -> SamplerOutput | None: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens @@ -503,7 +502,7 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: IntermediateTensors | None = None, ) -> torch.Tensor: return self.model(input_ids, positions, intermediate_tensors) @@ -511,7 +510,7 @@ def pooler( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: + ) -> PoolerOutput | None: return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 4650d381e4..12faccd2fe 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -6,7 +6,6 @@ from collections import defaultdict from collections.abc import Iterable from functools import cached_property -from typing import Optional, Union import torch import torch.nn as nn @@ -198,10 +197,10 @@ def _module_device(module: nn.Module) -> torch.device: def move_submodules_to_devices( self, *, - thinker_device: Optional[Union[str, torch.device]] = None, - talker_device: Optional[Union[str, torch.device]] = None, - code_predictor_device: Optional[Union[str, torch.device]] = None, - code2wav_device: Optional[Union[str, torch.device]] = None, + thinker_device: str | torch.device | None = None, + talker_device: str | torch.device | None = None, + code_predictor_device: str | torch.device | None = None, + code2wav_device: str | torch.device | None = None, ) -> None: """ Optionally move thinker/talker/code2wav to different devices. @@ -257,16 +256,16 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, generate_audio: bool = True, voice_type: str = "ethan", - codec: Optional[torch.Tensor] = None, - sampling_metadata: Optional[SamplingMetadata] = None, - logits_index: Optional[int] = None, - additional_information: Optional[dict[str, object]] = None, + codec: torch.Tensor | None = None, + sampling_metadata: SamplingMetadata | None = None, + logits_index: int | None = None, + additional_information: dict[str, object] | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors, OmniOutput]: + ) -> torch.Tensor | IntermediateTensors | OmniOutput: """ Unified forward pass for all model stages. @@ -383,9 +382,9 @@ def forward( inputs_embeds = self.talker.get_input_embeddings(input_ids) # ------- Request-scoped additional information (no cross-request concat) ------- - request_ids: Optional[list[str]] = kwargs.get("request_ids") # ordered - request_token_spans: Optional[list[tuple[int, int]]] = kwargs.get("request_token_spans") - addi_by_req: Optional[dict] = kwargs.get("additional_information_by_req_id") + request_ids: list[str] | None = kwargs.get("request_ids") # ordered + request_token_spans: list[tuple[int, int]] | None = kwargs.get("request_token_spans") + addi_by_req: dict | None = kwargs.get("additional_information_by_req_id") runtime_addi = kwargs.get("runtime_additional_information") # Normalize runtime_addi into a mapping by request_id for convenience @@ -808,14 +807,14 @@ def _thinker_to_talker_prefill( self, thinker_embed: torch.Tensor, thinker_hidden: torch.Tensor, - multimodal_mask: Optional[torch.Tensor], + multimodal_mask: torch.Tensor | None, input_ids: torch.Tensor, thinker_result_ids: torch.Tensor, speaker_id, - tts_bos_thinker: Optional[torch.Tensor] = None, - tts_eos_thinker: Optional[torch.Tensor] = None, - tts_pad_thinker: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + tts_bos_thinker: torch.Tensor | None = None, + tts_eos_thinker: torch.Tensor | None = None, + tts_pad_thinker: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """ Project thinker outputs to talker inputs during prefill stage. @@ -845,7 +844,7 @@ def _ensure_1x1(x: torch.Tensor) -> torch.Tensor: return x[-1] return x.view(1, 1, -1) - def _proj_from_thinker(x_opt: Optional[torch.Tensor]) -> torch.Tensor: + def _proj_from_thinker(x_opt: torch.Tensor | None) -> torch.Tensor: if isinstance(x_opt, torch.Tensor) and x_opt.numel() > 0: xin = _ensure_1x1(x_opt).to(talker_dev) else: @@ -863,7 +862,7 @@ def _proj_from_thinker(x_opt: Optional[torch.Tensor]) -> torch.Tensor: talker_input_embeds = [] # [1 t d] talker_input_ids = [] - trailing_text_hidden_all: Optional[torch.Tensor] = None + trailing_text_hidden_all: torch.Tensor | None = None # For every chatml parts for i in range(len(im_start_indexes) - 1): im_start_index = im_start_indexes[i].item() @@ -986,8 +985,8 @@ def _get_talker_assistant_parts( def _talker_to_code_predictor( self, - talker_hidden_states: Optional[torch.Tensor], - layer0_token_ids: Optional[torch.Tensor], + talker_hidden_states: torch.Tensor | None, + layer0_token_ids: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Project talker outputs to code predictor inputs. @@ -1025,9 +1024,9 @@ def _talker_to_code_predictor( def compute_logits( self, - hidden_states: Union[torch.Tensor, OmniOutput], + hidden_states: torch.Tensor | OmniOutput, sampling_metadata: SamplingMetadata = None, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """Compute logits from hidden states.""" # Handle OmniOutput type from vllm.v1.sample.logits_processor import LogitsProcessors @@ -1083,7 +1082,7 @@ def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: + ) -> SamplerOutput | None: """Sample from logits.""" return self.model.sample(logits, sampling_metadata) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py index 5c5c5b15d1..bd990977a0 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py @@ -7,7 +7,7 @@ """ from collections import namedtuple -from typing import Any, Optional +from typing import Any import torch import torch.nn as nn @@ -136,10 +136,10 @@ def forward( causal_mask: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, use_cache: bool = False, - position_ids: Optional[torch.LongTensor] = None, + position_ids: torch.LongTensor | None = None, ) -> torch.Tensor: bsz, seq_len, _ = hidden_states.shape @@ -253,8 +253,8 @@ def __init__( prefix: str, model_config: ModelConfig, layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.layer_idx = layer_idx @@ -284,10 +284,10 @@ def mtp_block( causal_mask: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, use_cache: bool = False, - position_ids: Optional[torch.LongTensor] = None, + position_ids: torch.LongTensor | None = None, ) -> torch.Tensor: # Self-attention with residual residual = hidden_states @@ -310,7 +310,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_index: int = 0, ) -> torch.Tensor: assert inputs_embeds is not None, "inputs_embeds required for MTP" @@ -396,11 +396,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, inputs_embeds: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Any] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Any | None = None, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, **kwargs: Any, ) -> Any: """ diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py index ce2d281fb7..e994589c4d 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -1,5 +1,3 @@ -from typing import Union - import torch from vllm.inputs import TextPrompt @@ -13,7 +11,7 @@ def thinker2talker( stage_list, engine_input_source, - prompt: Union[OmniTokensPrompt, TextPrompt] = None, + prompt: OmniTokensPrompt | TextPrompt = None, requires_multimodal_data: bool = False, ): if not engine_input_source: diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index b3b5500ef1..cfa2d37ad7 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -3,7 +3,7 @@ # Copyright 2025 The Qwen team. """Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition.""" -from typing import Any, Union +from typing import Any import torch from vllm.inputs import TextPrompt @@ -49,7 +49,7 @@ def _compute_talker_prompt_ids_length(info): def thinker2talker( stage_list: list[Any], engine_input_source: list[int], - prompt: Union[OmniTokensPrompt, TextPrompt, None] = None, + prompt: OmniTokensPrompt | TextPrompt | None = None, requires_multimodal_data: bool = False, ) -> list[OmniTokensPrompt]: """ @@ -114,7 +114,7 @@ def thinker2talker( def talker2code2wav( stage_list: list[Any], engine_input_source: list[int], - prompt: Union[OmniTokensPrompt, TextPrompt, None] = None, + prompt: OmniTokensPrompt | TextPrompt | None = None, requires_multimodal_data: bool = False, ) -> list[OmniTokensPrompt]: """ diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 3e1e2e0366..47b58f7f3b 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import torch from vllm.outputs import RequestOutput @@ -17,7 +16,7 @@ class OmniModelRunnerOutput(ModelRunnerOutput): output tensors (e.g., {"image": tensor, "audio": tensor}) """ - multimodal_outputs: Optional[dict[str, torch.Tensor]] = None + multimodal_outputs: dict[str, torch.Tensor] | None = None @dataclass diff --git a/vllm_omni/request.py b/vllm_omni/request.py index 31f90f94b9..6190fad001 100644 --- a/vllm_omni/request.py +++ b/vllm_omni/request.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Callable, Optional +from collections.abc import Callable +from typing import TYPE_CHECKING from vllm.v1.request import Request from vllm.v1.structured_output.request import StructuredOutputRequest @@ -25,22 +26,22 @@ class OmniRequest(Request): def __init__( self, - prompt_embeds: Optional[PromptEmbedsPayload] = None, - additional_information: Optional[AdditionalInformationPayload] = None, + prompt_embeds: PromptEmbedsPayload | None = None, + additional_information: AdditionalInformationPayload | None = None, *args, **kwargs, ): super().__init__(*args, **kwargs) # Serialized prompt embeddings payload (optional) - self.prompt_embeds: Optional[PromptEmbedsPayload] = prompt_embeds + self.prompt_embeds: PromptEmbedsPayload | None = prompt_embeds # Serialized additional information payload (optional) - self.additional_information: Optional[AdditionalInformationPayload] = additional_information + self.additional_information: AdditionalInformationPayload | None = additional_information @classmethod def from_engine_core_request( cls, request: OmniEngineCoreRequest, - block_hasher: Optional[Callable[["Request"], list["BlockHash"]]], + block_hasher: Callable[["Request"], list["BlockHash"]] | None, ) -> "Request": """Create an OmniRequest from an OmniEngineCoreRequest. diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 12a4f80ec0..dfe6947cc3 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np import torch @@ -307,7 +307,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() @torch.inference_mode() - def extract_multimodal_outputs(self, hidden_states: Union[torch.Tensor, list[torch.Tensor]]) -> dict: + def extract_multimodal_outputs(self, hidden_states: torch.Tensor | list[torch.Tensor]) -> dict: if hasattr(self.model, "have_multimodal_outputs") and self.model.have_multimodal_outputs: text_hidden_states = hidden_states.text_hidden_states multimodal_outputs = hidden_states.multimodal_outputs @@ -326,7 +326,7 @@ def extract_multimodal_outputs(self, hidden_states: Union[torch.Tensor, list[tor def _dummy_run( self, num_tokens: int, - cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, + cudagraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, allow_microbatching: bool = True, @@ -444,7 +444,7 @@ def _dummy_run( num_tokens_across_dp = num_tokens_after_padding num_tokens_after_padding = int(num_tokens_after_padding[0].item()) - attn_metadata: Optional[PerLayerAttnMetadata] = None + attn_metadata: PerLayerAttnMetadata | None = None # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. @@ -612,19 +612,19 @@ def _preprocess( self, scheduler_output: "SchedulerOutput", num_scheduled_tokens_np: np.ndarray, - intermediate_tensors: Optional[IntermediateTensors] = None, - ubatch_slices: Optional[UBatchSlices] = None, - num_tokens_after_padding: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + ubatch_slices: UBatchSlices | None = None, + num_tokens_after_padding: torch.Tensor | None = None, ) -> tuple[ int, int, - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, torch.Tensor, - Optional[IntermediateTensors], + IntermediateTensors | None, dict[str, Any], - Optional[dict[str, dict]], + dict[str, dict] | None, ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if ubatch_slices: @@ -638,7 +638,7 @@ def _preprocess( # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - per_req_additional_information: Optional[dict[str, dict]] = None + per_req_additional_information: dict[str, dict] | None = None if self.supports_mm_inputs and get_pp_group().is_first_rank and not self.model_config.is_encoder_decoder: # Build multimodal inputs and overlay prompt embeds; collect per-request info per_req_additional_information = self._build_mm_inputs_and_overlays( @@ -784,7 +784,7 @@ def _compute_request_token_spans(self, num_scheduled_tokens_np) -> list[tuple[in def _build_model_kwargs_extra( self, - per_req_additional_information: Optional[dict[str, dict]], + per_req_additional_information: dict[str, dict] | None, num_scheduled_tokens_np, ) -> dict: """Build extra keyword arguments passed to the model for this step, including: diff --git a/vllm_omni/worker/npu/npu_model_runner.py b/vllm_omni/worker/npu/npu_model_runner.py index ee6ef83f45..c57c6df488 100644 --- a/vllm_omni/worker/npu/npu_model_runner.py +++ b/vllm_omni/worker/npu/npu_model_runner.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np import torch @@ -293,8 +293,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: @torch.inference_mode() def extract_multimodal_outputs( - self, hidden_states: Union[torch.Tensor, list[torch.Tensor]] - ) -> tuple[torch.Tensor, Union[torch.Tensor, list[torch.Tensor], dict]]: + self, hidden_states: torch.Tensor | list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor | list[torch.Tensor] | dict]: """Extract multimodal outputs from hidden states.""" if hasattr(self.model, "have_multimodal_outputs") and self.model.have_multimodal_outputs: text_hidden_states = hidden_states.text_hidden_states @@ -373,7 +373,7 @@ def _dummy_run( num_tokens: int, with_prefill: bool = False, is_torchair_compile: bool = False, - aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + aclgraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, ) -> torch.Tensor: