diff --git a/.gitignore b/.gitignore index 5941d027c76..36d883072b3 100644 --- a/.gitignore +++ b/.gitignore @@ -176,7 +176,6 @@ tests/ # vLLM-omni specific # Model files and checkpoints -models/ checkpoints/ *.bin *.safetensors diff --git a/README.md b/README.md index 66c8725ee51..a1c828be047 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Traditional vLLM systems are limited to text-based, autoregressive generation. v - **Multi-modal Models**: Text, image, video, audio, and sensor data processing - **Non-autoregressive Architectures**: Diffusion Transformers (DiT) and other parallel generation models -- **Heterogeneous Outputs**: Beyond traditional text generation to structured, binary, and streaming outputs +- **Heterogeneous Outputs**: Beyond traditional text generation to multimodal outputs ## 🏗️ Architecture @@ -28,119 +28,48 @@ vLLM-omni is built on a modular architecture that extends vLLM's core functional - **Text**: Advanced tokenization and embedding generation - **Image**: Vision encoder integration (CLIP, etc.) - **Audio**: Speech processing and audio embedding -- **Video**: Frame-by-frame and temporal processing -- **Sensor**: IoT and sensor data interpretation - -### Output Formats - -- **Structured Data**: JSON, XML, and custom formats -- **Binary Outputs**: Images, audio, and video generation -- **Streaming**: Real-time progressive generation -- **Multipart**: Combined multi-modal responses ## 📋 Supported Models ### AR + Diffusion Transformer (DiT) Models -- Qwen-Image (Image generation and editing) - Qwen-omni (Thinker-Talker-Codec structure) -- Custom DiT and hiybrid architectures +- HunyunaImage 3.0 (Ongoing) +- Qwen-Image (Ongoing) ## 🛠️ Installation -### Quick Start - -#### Option 1: Docker (Recommended for macOS) - -```bash -# Clone the repository -git clone https://github.com/hsliuustc0106/vllm-omni.git -cd vllm-omni - -# Run the automated Docker setup -./scripts/docker-setup-macos.sh -``` - -#### Option 2: Local Installation - -```bash -# Clone the repository -git clone https://github.com/hsliuustc0106/vllm-omni.git -cd vllm-omni - -# Run the installation script -./install.sh -``` - -### Prerequisites - -- Python 3.11+ (recommended) -- Conda or Miniconda -- Git -- CUDA 11.8+ (for GPU acceleration) or CPU-only installation - -### Installation Methods - -#### Method 1: Automated Installation (Recommended) +Set up basic environments ```bash -# Using shell script -./install.sh - -# Or using Python script -python install.py +uv venv --python 3.12 --seed +source .venv/bin/activate ``` +Install certain version of vllm with commitid: 808a7b69df479b6b3a16181711cac7ca28a9b941 -#### Method 2: Manual Installation ```bash -# Create conda environment -conda create -n vllm_omni python=3.11 -y -conda activate vllm_omni - -# Install PyTorch (CPU or GPU) -pip install torch>=2.7 --index-url https://download.pytorch.org/whl/cpu # CPU -# pip install torch>=2.7 --index-url https://download.pytorch.org/whl/cu121 # GPU - -# Install dependencies -pip install -r requirements.txt -pip install "vllm>=0.10.2" - -# Install vLLM-omni -pip install -e . +git clone https://github.com/vllm-project/vllm.git +cd vllm +git checkout 808a7b69df479b6b3a16181711cac7ca28a9b941 +VLLM_USE_PRECOMPILED=1 uv pip install --editable . ``` -### Verify Installation +## Run examples (Qwen2.5-omni) +Get into the example folder ```bash -# Test the installation -python test_installation.py - -# Test basic functionality -python -c "import vllm_omni; print('Ready!')" - -# Test CLI -vllm --help +cd vllm_omni +cd examples/offline_inference/qwen2_5_omni ``` - -For detailed installation instructions, see [INSTALL.md](INSTALL.md). - -## 📥 Model Download - -Models are automatically downloaded when first used, or you can pre-download them: - +Modify PYTHONPATH in run.sh as your path of vllm_omni. Then run. ```bash -# Check downloaded models -python scripts/download_models.py --check-cache - -# Download all default models -python scripts/download_models.py --all - -# Download specific models -python scripts/download_models.py --ar-models Qwen/Qwen3-0.6B -python scripts/download_models.py --dit-models stabilityai/stable-diffusion-2-1 +bash run.sh ``` +The output audio is saved in ./output_audio -**Model Storage Location:** -- Default: `~/.cache/huggingface/hub/` -- AR models: 100MB - 1GB each -- DiT models: 2GB - 7GB each +## To-do list +- [x] Offline inference example for Qwen2.5-omni with single request +- [ ] Adaptation from current vllm branch to stable vllm v0.11.0 +- [ ] Offline inference example for Qwen2.5-omni with streaming multiple requests +- [ ] Online inference support +- [ ] Support for other models -For detailed model management, see [MODEL_DOWNLOAD_GUIDE.md](docs/MODEL_DOWNLOAD_GUIDE.md). +For detailed model management, see [vllm_omni_design.md](docs/architecture/vllm_omni_design.md) and [high_level_arch_design.md](docs/architecture/high_level_arch_design.md). diff --git a/docs/architecture/high level arch design.md b/docs/architecture/high_level_arch_design.md similarity index 100% rename from docs/architecture/high level arch design.md rename to docs/architecture/high_level_arch_design.md diff --git a/docs/architecture/implementation_architecture.md b/docs/architecture/implementation_architecture.md deleted file mode 100644 index a83d96d1e91..00000000000 --- a/docs/architecture/implementation_architecture.md +++ /dev/null @@ -1,353 +0,0 @@ -# vLLM-omni Implementation Architecture - -## 1. Package Structure - -``` -vllm_omni/ -├── __init__.py # Main package exports -├── config/ # Configuration management -│ ├── __init__.py -│ ├── stage_config.py # OmniStageConfig implementation -│ ├── dit_config.py # DiT-specific configurations -│ └── cache_config.py # Caching configurations -├── core/ # Core processing components -│ ├── __init__.py -│ ├── stage_manager.py # Multi-stage orchestration -│ ├── dit_cache_manager.py # DiT caching system -│ └── sched/ # Schedulers -│ ├── __init__.py -│ ├── scheduler.py # Base scheduler interface - -├── engine/ # Engine components -│ ├── __init__.py -│ ├── processor.py # Output processing -│ └── output_processor.py # Multimodal output handling -├── executor/ # Executor implementations -│ ├── __init__.py -│ ├── base_executor.py # Base executor interface -│ └── diffusers_executor.py # Diffusers pipeline executor -├── model_executor/ # Model execution components -│ ├── __init__.py -├── worker/ # Worker implementations -│ ├── __init__.py -│ ├── gpu_diffusion_model_runner.py # Loads and runs diffusers pipelines -│ └── gpu_diffusion_worker.py # Worker facade around the diffusion runner -├── entrypoints/ # Entry points and CLI -│ ├── __init__.py -│ ├── omni.py # OmniServeCommand -│ ├── omni_llm.py # OmniLLM and AsyncOmniLLM -│ └── cli/ # CLI integration -│ ├── __init__.py -│ └── main.py # CLI main entry point -├── request.py # Request handling -├── dit_cache_interface.py # DiT cache interface -└── utils/ # Utility functions - ├── __init__.py - ├── multimodal.py # Multimodal utilities - └── vae.py # VAE utilities for image processing -``` - -## 2. Core Module Dependencies - -### 2.1 vLLM Integration Points - -```python -# Key vLLM imports and extensions -from vllm.v1.engine.llm_engine import LLMEngine -from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.core.sched.scheduler import SchedulerInterface -from vllm.v1.executor.multiproc_executor import MultiprocExecutor -from vllm.v1.worker.gpu_worker import Worker as GPUWorker -from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from vllm.v1.outputs import ModelRunnerOutput -``` - -### 2.2 Internal Dependencies - -```python -# Internal module dependencies -vllm_omni.config → vllm_omni.core -vllm_omni.core → vllm_omni.engine, vllm_omni.executor, vllm_omni.worker -vllm_omni.engine → vllm_omni.model_executor -vllm_omni.entrypoints → vllm_omni.core -``` - -## 3. Configuration System - -### 3.1 Stage Configuration - -```python -# vllm_omni/config/stage_config.py -@dataclass -class OmniStageConfig: - stage_id: int - engine_type: Literal["AR", "DiT"] - model_path: str - input_modalities: List[str] - output_modalities: List[str] - vllm_config: Optional[VllmConfig] = None - dit_config: Optional[DiTConfig] = None - executor_class: Optional[type[Executor]] = None - stage_output: Optional[Any] = None -``` - -### 3.2 DiT Configuration - -```python -# vllm_omni/config/stage_config.py -@dataclass -class DiTConfig: - model_config: Optional[ModelConfig] = None - scheduler_config: Optional[Any] = None - device_config: Optional[DeviceConfig] = None - load_config: Optional[LoadConfig] = None - compilation_config: Optional[CompilationConfig] = None - dit_cache_config: Optional[DiTCacheConfig] = None - num_inference_steps: int - guidance_scale: float = 7.5 - use_diffusers: bool = False - diffusers_pipeline: Optional[str] = None - height: int = 512 - width: int = 512 -``` - -### 3.3 Cache Configuration - -```python -# vllm_omni/config/cache_config.py -@dataclass -class DiTCacheConfig: - cache_tensors: List[DiTCacheTensor] - max_cache_size: int - cache_strategy: str = "fifo" - enable_optimization: bool = True -``` - -## 4. Core Implementation Details - -### 4.1 OmniLLM Implementation - -```python -# vllm_omni/core/omni_llm.py -class OmniLLM(LLM): - def __init__(self, stage_configs: List[OmniStageConfig]): - super().__init__() - self.stage_configs = stage_configs - self.engine_list: List[LLMEngine] = [] - self.output_processor = MultimodalOutputProcessor() - self._initialize_stage_engines() - - def _initialize_stage_engines(self) -> None: - """Initialize LLMEngine instances for each stage""" - for stage_config in self.stage_configs: - if stage_config.engine_type == "AR": - engine = self._create_ar_engine(stage_config) - elif stage_config.engine_type == "DiT": - engine = self._create_dit_engine(stage_config) - self.engine_list.append(engine) - - def generate(self, stage_args_list: List[Dict], **kwargs) -> List[RequestOutput]: - """Main generation interface - orchestrates multi-stage processing""" - current_output = None - - for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): - stage_engine = self.engine_list[i] - - # Prepare input for this stage - processed_input = self._process_stage_inputs(stage_config, stage_args, current_output) - - # Execute stage - stage_output = self._execute_stage(stage_engine, processed_input) - - # Update for next stage - current_output = stage_output - stage_config.stage_output = stage_output - - # Process final output - return self.output_processor.process_output(current_output) -``` - -### 4.2 AsyncOmniLLM Implementation - -```python -# vllm_omni/core/omni_llm.py -class AsyncOmniLLM(AsyncLLM): - def __init__(self, stage_configs: List[OmniStageConfig]): - super().__init__() - self.stage_configs = stage_configs - self.async_engine_list: List[AsyncLLM] = [] - self.output_processor = MultimodalOutputProcessor() - self._initialize_async_stage_engines() - - async def generate_async(self, stage_args_list: List[Dict], **kwargs) -> List[RequestOutput]: - """Async generation interface""" - # Similar to OmniLLM but with async/await patterns - pass -`` - -### 4.3 Model Runners - -```python -# vllm_omni/model_executor/dit_model_runner.py -class OmniDiffusionModelRunner(GPUModelRunner): - def execute_model( - self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: - # DiT model execution logic - return ModelRunnerOutput( - req_ids=[...], - req_id_to_index={...}, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[tensor, ...], # DiT output tensors - kv_connector_output=None, - num_nans_in_logits=None, - ) - -# vllm_omni/model_executor/ar_model_runner.py -class OmniARModelRunner(GPUModelRunner): - def execute_model( - self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: - # AR model execution with hidden state output - return ModelRunnerOutput( - req_ids=[...], - req_id_to_index={...}, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[hidden_states, ...], # Hidden states - kv_connector_output=None, - num_nans_in_logits=None, - ) -``` - -### 4.5 Output Processing - -```python -# vllm_omni/engine/output_processor.py -class MultimodalOutputProcessor(OutputProcessor): - def __init__(self): - self.output_handlers: Dict[str, Callable] = { - "image": self._process_image_output, - "text+image": self._process_text_image_output, - "latents": self._process_latents_output, - "text": self._process_text_output, - } - - def process_outputs(self, engine_core_outputs: List[EngineCoreOutput], ...): - """Process multimodal outputs based on type""" - for engine_core_output in engine_core_outputs: - output_type = self._detect_output_type(engine_core_output) - handler = self.output_handlers.get(output_type, self._process_pooling_output) - handler(engine_core_output) -``` - -## 5. CLI Integration - -### 5.1 Entry Point Override - -```python -# vllm_omni/entrypoints/cli/main.py -def main(): - """Main CLI entry point that intercepts vLLM commands""" - if "--omni" in sys.argv: - omni_args = [arg for arg in sys.argv[1:] if arg != "--omni"] - omni_serve = OmniServeCommand() - omni_serve.run(omni_args) - else: - from vllm.entrypoints.cli.main import main as vllm_main - vllm_main() -``` - -### 5.2 Package Configuration - -```toml -# pyproject.toml updates -[project.scripts] -vllm = "vllm_omni.entrypoints.cli.main:main" -vllm-omni = "vllm_omni.entrypoints.cli.main:main" - -[project.entry-points."vllm.plugins"] -omni = "vllm_omni.plugin:OmniPlugin" -``` - -## 6. Testing Strategy - -### 6.1 Unit Tests -- Individual component testing -- Mock vLLM dependencies -- Configuration validation - -### 6.2 Integration Tests -- End-to-end pipeline testing -- vLLM compatibility testing -- Multi-stage processing validation - -### 6.3 Performance Tests -- Benchmarking against native vLLM -- Memory usage profiling -- Latency measurements - -## 7. Installation and Setup - -### 7.1 Package Installation -```bash -pip install vllm>=0.10.2 -pip install vllm-omni -``` - -### 7.2 Development Setup -```bash -git clone https://github.com/hsliuustc0106/vllm-omni -cd vllm-omni -pip install -e ".[dev]" -``` - -### 7.3 Usage -```bash -vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8000 -``` - -## 8. Memory Management - -### 8.1 DiT Cache Management -- Tensor caching for intermediate results -- Memory pooling for efficient allocation -- Cache eviction strategies - -### 8.2 Multi-Stage Memory -- Inter-stage data passing optimization -- Memory sharing between stages -- Garbage collection optimization - -## 9. Error Handling - -### 9.1 Stage Failure Handling -- Graceful degradation on stage failures -- Error propagation and reporting -- Recovery mechanisms - -### 9.2 vLLM Compatibility -- Version compatibility checks -- API change detection -- Fallback mechanisms - -## 10. Monitoring and Logging - -### 10.1 Performance Metrics -- Stage execution times -- Memory usage per stage -- Throughput measurements - -### 10.2 Debug Information -- Request tracing across stages -- Cache hit/miss ratios -- Error logging and reporting diff --git a/docs/architecture/vllm_omni_design.md b/docs/architecture/vllm_omni_design.md new file mode 100644 index 00000000000..33701ed2eea --- /dev/null +++ b/docs/architecture/vllm_omni_design.md @@ -0,0 +1,512 @@ +# vLLM-omni Software Design Document + +## Overview + +vLLM-omni is a multi-modality extension for vLLM that supports non-autoregressive structures and non-textual outputs. This document outlines the key software abstractions, APIs, and dependencies for the system, designed to maximize reuse of vLLM's proven architecture. + +## Architecture Principles + +1. **vLLM V1 Compatibility**: Built on vLLM's Engine V1 architecture with AsyncLLM and EngineCore patterns +2. **Stage-based Processing**: Models are divided into multiple stages, each processed by different Engine Cores +3. **Multiple Engine Core Support**: Each stage can use either AR Engine Core (reusing vLLM) or Diffusion Engine Core (new DiT support), or other new Cores +4. **Worker Process Pattern**: Follows vLLM's multiprocess worker architecture for scalability +5. **Extensibility**: Easy integration of new modalities, model architectures, and output formats + +## Key Data Flow + +API Server --> OmniLM/AsyncOmniLM (New, including multi engines) --> LLMEngine/AsyncLLM --> Engine Core + --> Scheduler (New one for DiT) --> Executor (New one for diffusers) --> Worker (New one for DiT) + --> ModelRunner (New one for AR hiddenstate, New one for DiT) --> RequestState --> OutputProcessoer (New one for final multimodal output) + +## Core Components (Ordered by Data Flow) + +### 1. Installation +```bash +pip install vllm +pip install vllm-omni +``` + +### 2. Online inference launch (Entry Point) + +Keep consistency with the vllm main branch +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8000 +``` + +#### Design Architecture +```mermaid +graph TD + A[vllm serve model --omni] --> B[vLLM-omni CLI Wrapper] + B --> C{Detect --omni flag} + C -->|Yes| D[Parse OmniConfig] + C -->|No| E[Forward to vLLM CLI] + D --> F[Initialize AsyncOmniLM] + F --> G[Start omni Server] + G --> H[Multi-stage Processing] + E --> I[Standard vLLM Pipeline] + + subgraph "vLLM-omni Components" + F + G + H + end + + subgraph "vLLM Components" + I + end +``` + +#### omni Serve Command Implementation +File: vllm_omni/entrypoints/cli/main.py +```python +import sys +import argparse +from typing import List, Optional +from vllm_omni.entrypoints.omni import OmniServeCommand + +def main(): + """Main CLI entry point that intercepts vLLM commands""" + # Check if --omni flag is present + if "--omni" in sys.argv: + # Remove --omni flag and process with vLLM-omni + omni_args = [arg for arg in sys.argv[1:] if arg != "--omni"] + omni_serve = OmniServeCommand() + omni_serve.run(omni_args) + else: + # Forward to original vLLM CLI + from vllm.entrypoints.cli.main import main as vllm_main + vllm_main() + +if __name__ == "__main__": + main() +``` + +#### vLLM Plugin Integration +File: vllm_omni/pyproject.toml +```python +[project.scripts] +# Override vLLM's CLI entry point +vllm = "vllm_omni.cli:main" + +# Add entry point for vLLM integration +[project.entry-points."vllm.plugins"] +omni = "vllm_omni.plugin:OmniPlugin" +``` + +### 3. Offline inference launch (Entry Point) +Design an upper level class to incorporate multi Engines, each engine has a engine call. +```python +from vllm.entrypoints.llm import LLM +from vllm.v1.engine.llm_engine import LLMEngine + +class OmniLM(LLM): + """Extended LLM supporting multiple engines and stage-based processing""" + + def __init__(self, stage_configs: List[StageConfig]): + super().__init__() + self.stage_configs = stage_configs + self.engine_list = [] # List of AsyncLLM instances for each stage + self.output_processor = MultimodalOutputProcessor() + self._initialize_stage_engines() + + def _initialize_stage_engines(self) -> None: + """Initialize LLMEngine instances for each stage""" + for stage_config in self.stage_configs: + stage_llm = LLMEngine.from_vllm_config( + vllm_config=stage_config.vllm_config, + executor_class=self.executor_class, + log_stats=self.log_stats + ) + self.engine_list.append(stage_llm) + + def generate( + self, + stage_args_list: List[stage_args], + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + priority: Optional[list[int]] = None, + ) -> list[RequestOutput]: + """Main generation interface - orchestrates multi-stage processing""" + + # Process through each stage sequentially + for i, stage_config in enumerate(self.stage_configs): + stage_engine = self.engine_list[i] + + # Prepare input for this stage + stage_args = stage_args_list[i] + prompt_str, engine_request, tokenization_kwargs = self._process_stage_inputs(stage_config, **stage_args) + + # Add inputs to Engine + stage_engine.add_request(requesy_id, prompt_str, tokenization_kwargs) + # Run Engine + stage_output = stage_engine.step() + + # Update stage input and output of each stage for later usage + self._update_stage_io(stage_output, stage_config) + + + # Process final output + final_output = self.output_processor.process_output( + self.stage_configs[-1].stage_output + ) + return final_output + + def _process_stage_inputs(stage_config, **stage_args) -> Any: + """Prepare input for specific stage""" + if stage_config.engine_type == "AR": + return self._process_ar_inputs(**stage_args) + elif stage_config.engine_type == "DiT": + return self._process_dit_inputs(**stage_args) + else + raise NotImplementedError + + def _process_dit_inputs(**stage_args)-> Any: + image_latent = self.vae.encode(**stage_args) + + return image_latent +``` + +### 4. Stage config for setting of different model stages +```python +@dataclass +class StageConfig: + """Configuration for a processing stage""" + stage_id: int + engine_type: str # "AR" or "DiT" + model_path: str + input_modalities: List[str] + output_modalities: List[str] + vllm_config: Optional[VllmConfig] = None # For engine config of corresponding stage + executor_class: type[Executor] # For execute class config of corresponding stage + dit_config: Optional[DiTConfig] = None # For diffusion stages + cache_config: Optional[DiTCacheConfig] = None +``` +For AR stage, the setting is: +```python +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.executor.multiproc_executor import MultiprocExecutor +ar_stage_config = StageConfig() +ar_stage_config.vllm_config.scheduler_config.scheduler_cls = Scheduler # original vllm scheduler for AR +ar_stage_config.executor_class = MultiprocExecutor # original vllm executor for AR +``` +For DiT stage, the setting is: +```python +from vllm.v1.executor.multiproc_executor import MultiprocExecutor +dit_stage_config = StageConfig() +dit_stage_config.vllm_config.scheduler_config.scheduler_cls = DiffusionScheduler # New scheduler for DiT + +# For diffusion models without using diffusers +dit_stage_config.executor_class = MultiprocExecutor + +# For diffusion models using diffusers +dit_stage_config.executor_class = DiffusersPipelineExecutor +``` + +### 5. Online inference main class + +Similar to OmniLM in offline inference, add some asynchronous processing, referring to AsyncLLM +```python +from vllm.v1.engine.async_llm improt AsyncLLM + + +class AsyncOmniLM(AsyncLLM): + """Extended AsyncLLM supporting multiple engines and stage-based processing""" + + def __init__(self, stage_configs: List[StageConfig]): + super().__init__() +``` + +### 6. Engine Core +No need to change. The specific executor, scheduler and output type can be transferred into it with new configs. +```python +# class: EngineCore +# EngineCore.step (simplified) +class EngineCore: + def step(self): + scheduler_output = scheduler.schedule() + model_output = executor.execute_model(scheduler_output) + engine_outputs = scheduler.update_from_output( + scheduler_output, model_output + ) + return engine_outputs +``` + +### 7. Scheduler +Create a new Diffusion Scheduler for DiT, which is inherited from original vllm AR scheduler. At first, no complicated strategy. +Just process the request first in first out. Then create a class of DiT Cache Manager for optimization in the future. +```python +from vllm.v1.core.sched.scheduler import Scheduler +class DiffusionScheduler(Scheduler): + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + dit_cache_config: DiTCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + + self.dit_cache_manager = DiTCacheManager(dit_cache_config) +``` +For DiT Cache Manager (Can refer to xDiT): +```python +class DiTCacheManager: + """Manages DiT-specific caching""" + + def __init__(self, config: DiTCacheConfig): + self.cache_tensors: Dict[str, torch.Tensor] = {} + self.cache_groups: List[DiTCacheTensor] = config.dit_cache_tensors + + def allocate_cache(self, request_id: str, size: int) -> torch.Tensor + def get_cache(self, request_id: str) -> Optional[torch.Tensor] + def release_cache(self, request_id: str) -> None + def clear_expired_cache(self) -> None +``` + +### 8. Executor for DiT without diffusers + +No need to change. Just import original AR executor. + +### 9. Diffusers Pipeline Executor (no worker) + +A single-process executor that directly runs the Diffusers pipeline without spawning workers or using RPC. Interfaces remain identical to `Executor` so the EngineCore loop is unchanged. + +#### Function map(Pipeline Executor) + +##### 1) Inherited and overridden +```python +from concurrent.futures import Future +from typing import Optional, Union, Callable, Any +import torch +import torch.nn as nn +from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +from vllm.v1.outputs import ModelRunnerOutput +from vllm.tasks import SupportedTask +from diffusers import DiffusionPipeline + +class DiffusersPipelineExecutor(Executor): + supports_pp: bool = False # Single-process, no TP/PP/DP + + def _init_executor(self) -> None: + # Called by ExecutorBase.__init__ + self._failure_callback: Optional[Callable[[], None]] = None + self._device = self._resolve_device() + self._dtype = self._resolve_dtype() + self._pipeline = self._build_pipeline(device=self._device, dtype=self._dtype) + self._profiler = None + self._is_failed = False + self.is_sleeping = False + self.sleeping_tags: set[str] = set() + + # major functions to build/run diffusers pipeline + def _build_pipeline(self, device:torch.device, dtype:torch.dtype)->DiffusionPipeline: + model_name = "Qwen/Qwen-Image" + + self.pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=dtype) + self.pipe = pipe.to(device) + + def _run_pipeline(self, scheduler_output) -> ModelRunnerOutput: ... + positive_magic = { + "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt + "zh": ", 超清,4K,电影级构图." # for chinese prompt + } + + # Generate image + prompt_embeds = self._get_and_process_prompt_embeds(scheduler_output, positive_magic) + negtive_prompt_embeds = self.pipe.embed_prompt(" ") + + width, height = aspect_ratios["16:9"] + + image = pipe( + prompt_embeds=prompt_embeds, + negtive_prompt_embeds=negtive_prompt_embeds, + width=width, + height=height, + num_inference_steps=50, + true_cfg_scale=4.0, + generator=torch.Generator(device="cuda").manual_seed(42) + ).images[0] + + output = self.wrap_image_as_ModelRunnerOutput(image) + return output + + # ---- Internal helpers (implementation-specific, not public API) ---- + def _resolve_device(self): ... + def _resolve_dtype(self): ... + def _get_model(self) -> nn.Module: ... + def _get_and_process_prompt_embeds(self, scheduler_output, positive_magic): + ... + #append the positive_magic to prompt and embed them to prompt embed tensors + + # Functions related to workers should either raise NotImplementedError + # or return empty defaults in this no-worker executor. + + def collective_rpc(self, method, timeout=None, args=(), kwargs=None) -> list[Any]: + # No workers in pipeline executor + raise NotImplementedError("No workers in DiffusersPipelineExecutor") + + def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: + return # no-op (pipeline already built in _init_executor) + + def register_failure_callback(self, callback): + self._failure_callback = callback + + def determine_available_memory(self) -> list[int]: # bytes + # Single device; return [available_bytes]. If CPU-only, return [0]. + return [self._determine_available_bytes(self._device)] + +``` + +### 10. Worker + +#### Inheritance strategy +Prefer reusing the mature GPU Worker end-to-end. Worker class is selected by configuration (vllm_config). Do not add a new executor-specific worker binding. If customization is needed, override only `init_device` to construct the `DiffusionModelRunner`; all other behaviors (device init details, profiling, sleep/wake, PP/TP comms, execute path) remain from `vllm/v1/worker/gpu_worker.py::Worker`. + + +Inherited and overridden +```python +# Optional: only if you need to plug a custom DiffusionModelRunner. +from vllm.v1.worker.gpu_worker import Worker as GPUWorker +from vllm.v1.worker.diffusion_model_runner import DiffusionModelRunner + +class Worker(GPUWorker): + def init_device(self) -> None: + #those related to device check and init + ... + + self.model_runner = DiffusionModelRunner(self.vllm_config, ...) +``` + +### 11. Model Runner + +### Function map (Model Runner) +#### 1) Inherited and overridden +Those parts relied to the KV Cache will be omitted if we do not register the model to the vllm config. The engine core will view it as do not require KV Cache, and handle it properly + +Reuse `vllm/v1/outputs.py::ModelRunnerOutput`: +- DiffusionModelRunner: Use the `pooler_output=[Tensor,...]` to return multi modal tensors +- ARModelRunner: Use the `pooler_output=[Tensor,...]` to return hidden states. +```python +from typing import Optional, Union +import torch +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.outputs import ModelRunnerOutput + + +class DiffusionModelRunner(GPUModelRunner): + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + ... + return ModelRunnerOutput( + req_ids=[...], + req_id_to_index={...}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[Tensor,...], # Return Hidden states + kv_connector_output=None, + num_nans_in_logits=None, + )# return multi modal tensors via pooler_output=[Tensor,...] + + +class ARModelRunner(GPUModelRunner): + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + ... + return ModelRunnerOutput( + req_ids=[...], + req_id_to_index={...}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[Tensor,...], # Return Hidden states + kv_connector_output=None, + num_nans_in_logits=None, + )# return hidden states via pooler_output=[Tensor,...] +``` + +### 12. RequestState Role and Usage +```bash +# RequestState serves as the per-request state tracker in OutputProcessor: +# - Maintains request-specific state (tokens, logprobs, detokenizer, etc.) +# - Converts EngineCoreOutput → RequestOutput/PoolingRequestOutput +# - Manages request lifecycle from registration to completion +AsyncLLM.add_request() +↓ +OutputProcessor.add_request() +↓ +RequestState.from_new_request() → 创建请求状态 + +AsyncLLM.__init__() / AsyncLLM.generate() / AsyncLLM.encode() → 创建一个Background loop 持续从EngineCore获取输出 +↓ +OutputProcessor.process_outputs() → 更新状态并处理输出 +↓ +RequestState.make_request_output() → 转换为最终输出,格式为RequestOutput或者PoolingRequestOutput +↓ +RequestOutputCollector.put() → 推送到队列(AsyncLLM) +``` + +Need to add implementation for an existing method +```python +class RequestState: + def __init__(self, request_id: str, parent_req: Optional[ParentRequest], + request_index: int, lora_name: Optional[str], + output_kind: RequestOutputKind, prompt: Optional[str], + prompt_token_ids: list[int], logprobs_processor: Optional[LogprobsProcessor], + detokenizer: Optional[IncrementalDetokenizer], + max_tokens_param: Optional[int], arrival_time: float, + queue: Optional[RequestOutputCollector], log_stats: bool): + + def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput: + """Create PoolingOutput for multimodal/pooling requests""" + return PoolingOutput(data=pooling_output) +``` + + +### 13. OutputProcessor + +For hidden state output, original OutputProcessor already support. + +We only need to add new one for final multimodal output. +```python +class MultimodalOutputProcessor(OutputProcessor): + """Handles multimodal output processing""" + + def __init__(self): + self.output_handlers: Dict[str, OutputProcessor] = {} + + def process_outputs(self, engine_core_outputs: list[EngineCoreOutput], ...): + for engine_core_output in engine_core_outputs: + # Option 1: Use output_type field (if EngineCoreOutput is extended) + if engine_core_output.output_type == "image": + self._process_image_output(engine_core_output) + elif engine_core_output.output_type == "text+image": + self._process_text_image_output(engine_core_output) + elif engine_core_output.output_type == "latents": + self._process_latents_output(engine_core_output) + elif engine_core_output.output_type == "text": + self._process_text_output(engine_core_output) + else: + # Fallback: use existing pooling_output logic + if engine_core_output.pooling_output is not None: + self._process_pooling_output(engine_core_output) + else: + self._process_text_output(engine_core_output) +``` + + + diff --git a/examples/basic/api_client.py b/examples/basic/api_client.py deleted file mode 100644 index 61e45a81be6..00000000000 --- a/examples/basic/api_client.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple API client example for vLLM-omni server. - -This example shows how to interact with the vLLM-omni API server -using HTTP requests. -""" - -import requests -import json -import time - - -def test_health(host="localhost", port=8000): - """Test if the server is healthy.""" - url = f"http://{host}:{port}/health" - try: - response = requests.get(url, timeout=5) - if response.status_code == 200: - print("✅ Server is healthy") - return True - else: - print(f"❌ Server health check failed: {response.status_code}") - return False - except requests.exceptions.RequestException as e: - print(f"❌ Cannot connect to server: {e}") - return False - - -def generate_text(prompt, host="localhost", port=8000, **kwargs): - """Generate text using the API server.""" - url = f"http://{host}:{port}/generate" - - # Default parameters - params = { - "prompts": [prompt], - "max_tokens": kwargs.get("max_tokens", 100), - "temperature": kwargs.get("temperature", 0.7), - "top_p": kwargs.get("top_p", 1.0), - "frequency_penalty": kwargs.get("frequency_penalty", 0.0), - "presence_penalty": kwargs.get("presence_penalty", 0.0), - } - - # Add optional parameters - if "stop" in kwargs: - params["stop"] = kwargs["stop"] - - try: - response = requests.post(url, json=params, timeout=30) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - print(f"❌ API request failed: {e}") - return None - - -def main(): - """Main example function.""" - print("vLLM-omni API Client Example") - print("=" * 40) - - # Test server health - if not test_health(): - print("Please start the server first:") - print(" vllm serve Qwen/Qwen3-0.6B --omni --port 8000") - return - - # Example 1: Simple generation - print("\n1. Simple text generation:") - result = generate_text("Hello, how are you?", max_tokens=50) - if result: - text = result["outputs"][0]["text"] - print(f" Input: Hello, how are you?") - print(f" Output: {text}") - - # Example 2: Creative writing - print("\n2. Creative writing:") - result = generate_text( - "Write a short story about a robot learning to paint", - max_tokens=150, - temperature=0.9 - ) - if result: - text = result["outputs"][0]["text"] - print(f" Input: Write a short story about a robot learning to paint") - print(f" Output: {text}") - - # Example 3: Question answering - print("\n3. Question answering:") - result = generate_text( - "What is the capital of France?", - max_tokens=30, - temperature=0.3 - ) - if result: - text = result["outputs"][0]["text"] - print(f" Input: What is the capital of France?") - print(f" Output: {text}") - - # Example 4: Multiple prompts - print("\n4. Multiple prompts:") - prompts = [ - "Tell me a joke", - "What's 2+2?", - "Describe the color blue" - ] - - for prompt in prompts: - result = generate_text(prompt, max_tokens=30) - if result: - text = result["outputs"][0]["text"] - print(f" Q: {prompt}") - print(f" A: {text}") - - # Example 5: Show full response structure - print("\n5. Full response structure:") - result = generate_text("Hello world", max_tokens=20) - if result: - print(" Full API response:") - print(json.dumps(result, indent=2)) - - print("\n" + "=" * 40) - print("API client examples completed!") - - -if __name__ == "__main__": - main() - diff --git a/examples/basic/text_generation.py b/examples/basic/text_generation.py deleted file mode 100644 index a6a9b5179b0..00000000000 --- a/examples/basic/text_generation.py +++ /dev/null @@ -1,137 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple vLLM-omni usage example. - -This example demonstrates how to use vLLM-omni for basic text generation -with a single AR (Autoregressive) stage. -""" - -import asyncio -from vllm_omni.entrypoints.omni_llm import OmniLLM, AsyncOmniLLM -from vllm_omni.config import create_ar_stage_config - - -def sync_example(): - """Synchronous usage example.""" - print("=== Synchronous vLLM-omni Example ===") - - # Create a simple AR stage configuration - stage_config = create_ar_stage_config( - stage_id=0, - model_path="Qwen/Qwen3-0.6B", - input_modalities=["text"], - output_modalities=["text"] - ) - - # Initialize OmniLLM - omni_llm = OmniLLM([stage_config]) - - # Prepare stage arguments - stage_args = [{ - "prompt": "Hello, how are you today?", - "max_tokens": 50, - "temperature": 0.7 - }] - - # Generate text - print(f"Input: {stage_args[0]['prompt']}") - outputs = omni_llm.generate(stage_args) - - # Display results - for i, output in enumerate(outputs): - print(f"Output {i}:") - if hasattr(output, 'outputs') and output.outputs: - for completion in output.outputs: - print(f" Text: {completion.text}") - print(f" Finished: {completion.finish_reason != 'length'}") - print(f" Tokens: {len(completion.token_ids)} tokens") - - -async def async_example(): - """Asynchronous usage example.""" - print("\n=== Asynchronous vLLM-omni Example ===") - - # Create a simple AR stage configuration - stage_config = create_ar_stage_config( - stage_id=0, - model_path="Qwen/Qwen3-0.6B", - input_modalities=["text"], - output_modalities=["text"] - ) - - # Initialize AsyncOmniLLM - omni_llm = AsyncOmniLLM([stage_config]) - - # Prepare stage arguments - stage_args = [{ - "prompt": "What is artificial intelligence?", - "max_tokens": 100, - "temperature": 0.8 - }] - - # Generate text asynchronously - print(f"Input: {stage_args[0]['prompt']}") - outputs = await omni_llm.generate_async(stage_args) - - # Display results - for i, output in enumerate(outputs): - print(f"Output {i}:") - if hasattr(output, 'outputs') and output.outputs: - for completion in output.outputs: - print(f" Text: {completion.text}") - print(f" Finished: {completion.finish_reason != 'length'}") - print(f" Tokens: {len(completion.token_ids)} tokens") - - -def multi_prompt_example(): - """Example with multiple prompts.""" - print("\n=== Multi-Prompt Example ===") - - # Create stage configuration - stage_config = create_ar_stage_config( - stage_id=0, - model_path="Qwen/Qwen3-0.6B", - input_modalities=["text"], - output_modalities=["text"] - ) - - # Initialize OmniLLM - omni_llm = OmniLLM([stage_config]) - - # Multiple prompts - prompts = [ - "Tell me a joke", - "What's the weather like?", - "Explain quantum computing" - ] - - for prompt in prompts: - print(f"\nInput: {prompt}") - stage_args = [{ - "prompt": prompt, - "max_tokens": 30, - "temperature": 0.9 - }] - - outputs = omni_llm.generate(stage_args) - - if outputs and hasattr(outputs[0], 'outputs') and outputs[0].outputs: - completion = outputs[0].outputs[0] - print(f"Output: {completion.text}") - - -if __name__ == "__main__": - print("vLLM-omni Simple Usage Examples") - print("=" * 40) - - # Run synchronous example - sync_example() - - # Run asynchronous example - asyncio.run(async_example()) - - # Run multi-prompt example - multi_prompt_example() - - print("\n" + "=" * 40) - print("Examples completed!") diff --git a/examples/offline_inference/qwen_2_5_omni/README.md b/examples/offline_inference/qwen_2_5_omni/README.md new file mode 100644 index 00000000000..d1dfde20598 --- /dev/null +++ b/examples/offline_inference/qwen_2_5_omni/README.md @@ -0,0 +1,37 @@ +# Offline Example of vLLM-omni for Qwen2.5-omni + +## Installation + +Set up basic environments +```bash +uv venv --python 3.12 --seed +source .venv/bin/activate +``` +Install certain version of vllm with commitid: 808a7b69df479b6b3a16181711cac7ca28a9b941 + +```bash +git clone https://github.com/vllm-project/vllm.git +cd vllm +git checkout 808a7b69df479b6b3a16181711cac7ca28a9b941 +VLLM_USE_PRECOMPILED=1 uv pip install --editable . +``` + +## Run examples + +Get into the example folder +```bash +cd vllm_omni +cd examples/offline_inference/qwen2_5_omni +``` +Modify PYTHONPATH in run.sh as your path of vllm_omni. Then run. +```bash +bash run.sh +``` +The output audio is saved in ./output_audio + +## To-do list +- [x] Offline inference example for Qwen2.5-omni with single request +- [ ] Adaptation from current vllm branch to stable vllm v0.11.0 +- [ ] Offline inference example for Qwen2.5-omni with streaming multiple requests +- [ ] Online inference support +- [ ] Support for other models \ No newline at end of file diff --git a/examples/offline_inference/qwen_2_5_omni/config_test.py b/examples/offline_inference/qwen_2_5_omni/config_test.py new file mode 100644 index 00000000000..0bb4abffd04 --- /dev/null +++ b/examples/offline_inference/qwen_2_5_omni/config_test.py @@ -0,0 +1,9 @@ +from omegaconf import OmegaConf + +config_file = "/home/dyvm6xra/dyvm6xrauser08/gh/vllm_project/vllm/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml" +config_data = OmegaConf.load(config_file) + +stage_configs = config_data.stage_args + +for stage_config in stage_configs: + print(stage_config.stage_id) \ No newline at end of file diff --git a/examples/offline_inference/qwen_2_5_omni/end2end.py b/examples/offline_inference/qwen_2_5_omni/end2end.py new file mode 100644 index 00000000000..cf0d4dff680 --- /dev/null +++ b/examples/offline_inference/qwen_2_5_omni/end2end.py @@ -0,0 +1,133 @@ +import argparse +import os +import soundfile as sf +import random +import numpy as np +import torch + +from vllm.sampling_params import SamplingParams + +import os as _os_env_toggle +_os_env_toggle.environ["VLLM_USE_V1"] = "1" + +from vllm_omni.entrypoints.omni_lm import OmniLM +from utils import make_omni_prompt + + +SEED = 42 +# Set all random seeds +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.cuda.manual_seed_all(SEED) + +# Make PyTorch deterministic +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +# Set environment variables for deterministic behavior +os.environ["PYTHONHASHSEED"] = str(SEED) +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" # Need to discuss with the team + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model', required=True, help='Path to merged model directory (will be created if downloading).') + parser.add_argument('--thinker-model', type=str, default=None) + parser.add_argument('--talker-model', type=str, default=None) + parser.add_argument('--code2wav-model', type=str, default=None) + parser.add_argument('--hf-hub-id', default='Qwen/Qwen2.5-Omni-7B', help='Hugging Face repo id to download if needed.') + parser.add_argument('--hf-revision', default=None, help='Optional HF revision (branch/tag/commit).') + parser.add_argument('--prompts', required=True, nargs='+', help='Input text prompts.') + parser.add_argument('--voice-type', default='default', help='Voice type, e.g., m02, f030, default.') + parser.add_argument('--code2wav-dir', default=None, help='Path to code2wav folder (contains spk_dict.pt).') + parser.add_argument('--dit-ckpt', default=None, help='Path to DiT checkpoint file (e.g., dit.pt).') + parser.add_argument('--bigvgan-ckpt', default=None, help='Path to BigVGAN checkpoint file.') + parser.add_argument('--dtype', default='bfloat16', choices=['float16', 'bfloat16', 'float32']) + parser.add_argument('--max-model-len', type=int, default=32768) + + parser.add_argument("--thinker-only", action="store_true") + parser.add_argument("--text-only", action="store_true") + parser.add_argument("--do-wave", action="store_true") + parser.add_argument('--prompt_type', + choices=[ + 'text', 'audio', 'audio-long', 'audio-long-chunks', + 'audio-long-expand-chunks', 'image', 'video', + 'video-frames', 'audio-in-video', 'audio-in-video-v2', + "audio-multi-round", "badcase-vl", "badcase-text", + "badcase-image-early-stop", "badcase-two-audios", + "badcase-two-videos", "badcase-multi-round", + "badcase-voice-type", "badcase-voice-type-v2", + "badcase-audio-tower-1", "badcase-audio-only" + ], + default='text') + parser.add_argument('--use-torchvision', action='store_true') + parser.add_argument('--tokenize', action='store_true') + parser.add_argument('--output-wav', default="output.wav", help='Output wav file path.') + parser.add_argument('--thinker-hidden-states-dir', default="thinker_hidden_states", help='Path to thinker hidden states directory.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + model_name = args.model + omni_lm = OmniLM(model=model_name) + print("omni_lm.stage_configs: ", omni_lm.stage_configs) + print("omni_lm.stage_list: ", omni_lm.stage_list) + thinker_sampling_params = SamplingParams( + temperature=0.0, # Deterministic - no randomness + top_p=1.0, # Disable nucleus sampling + top_k=-1, # Disable top-k sampling + max_tokens=2048, + seed=SEED, # Fixed seed for sampling + detokenize=True, + repetition_penalty=1.1, + ) + talker_sampling_params = SamplingParams( + temperature=0.0, # Deterministic - no randomness + top_p=1.0, # Disable nucleus sampling + top_k=-1, # Disable top-k sampling + max_tokens=2048, + seed=SEED, # Fixed seed for sampling + detokenize=True, + repetition_penalty=1.1, + stop_token_ids=[8294] + ) + code2wav_sampling_params = SamplingParams( + temperature=0.0, # Deterministic - no randomness + top_p=1.0, # Disable nucleus sampling + top_k=-1, # Disable top-k sampling + max_tokens=2048, + seed=SEED, # Fixed seed for sampling + detokenize=True, + repetition_penalty=1.1, + ) + + sampling_params_list = [thinker_sampling_params, + talker_sampling_params, + code2wav_sampling_params] + + prompt = [make_omni_prompt(args, prompt) for prompt in args.prompts] + omni_outputs = omni_lm.generate(prompt, sampling_params_list) + + os.makedirs(args.output_wav, exist_ok=True) + for stage_outputs in omni_outputs: + if stage_outputs.final_output_type == "text": + for output in stage_outputs.request_output: + request_id = output.request_id + text_output = output.outputs[0].text + print(f"Request ID: {request_id}, Text Output: {text_output}") + elif stage_outputs.final_output_type == "audio": + for output in stage_outputs.request_output: + request_id = output.request_id + audio_tensor = output.multimodal_output["audio"] + output_wav = os.path.join(args.output_wav, f"output_{output.request_id}.wav") + sf.write(output_wav, audio_tensor.detach().cpu().numpy(), samplerate=24000) + print(f"Request ID: {request_id}, Saved audio to {output_wav}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/offline_inference/qwen_2_5_omni/processing_omni.py b/examples/offline_inference/qwen_2_5_omni/processing_omni.py new file mode 100644 index 00000000000..1bbe3de9934 --- /dev/null +++ b/examples/offline_inference/qwen_2_5_omni/processing_omni.py @@ -0,0 +1,374 @@ +from __future__ import annotations + +import base64 +import logging +import math +import os +import sys +import time +import warnings +from functools import lru_cache +from io import BytesIO +import requests +import torch +import torchvision +from packaging import version +from PIL import Image +from torchvision import io, transforms +from torchvision.transforms import InterpolationMode + + +logger = logging.getLogger(__name__) + +IMAGE_FACTOR = 28 +MIN_PIXELS = 4 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +VIDEO_MIN_PIXELS = 128 * 28 * 28 +VIDEO_MAX_PIXELS = 768 * 28 * 28 +VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 +FRAME_FACTOR = 2 +FPS = 2.0 +FPS_MIN_FRAMES = 4 +FPS_MAX_FRAMES = 768 + +temporal_patch_size = 2 +spatial_patch_size = 14 +spatial_merge_size = 2 + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize(height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def fetch_image(ele: dict[str, str | Image.Image], + size_factor: int = IMAGE_FACTOR) -> Image.Image: + if "image" in ele: + image = ele["image"] + else: + image = ele["image_url"] + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + image_obj = Image.open(requests.get(image, stream=True).raw) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + image_obj = Image.open(BytesIO(data)) + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError( + f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}" + ) + image = image_obj.convert("RGB") + ## resize + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=size_factor, + ) + else: + width, height = image.size + min_pixels = ele.get("min_pixels", MIN_PIXELS) + max_pixels = ele.get("max_pixels", MAX_PIXELS) + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def smart_nframes( + ele: dict, + total_frames: int, + video_fps: int | float, +) -> int: + """calculate the number of frames for video used for model inputs. + + Args: + ele (dict): a dict contains the configuration of video. + support either `fps` or `nframes`: + - nframes: the number of frames to extract for model inputs. + - fps: the fps to extract frames for model inputs. + - min_frames: the minimum number of frames of the video, only used when fps is provided. + - max_frames: the maximum number of frames of the video, only used when fps is provided. + total_frames (int): the original total number of frames of the video. + video_fps (int | float): the original fps of the video. + + Raises: + ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. + + Returns: + int: the number of frames for video used for model inputs. + """ + assert not ("fps" in ele + and "nframes" in ele), "Only accept either `fps` or `nframes`" + if "nframes" in ele: + nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) + else: + fps = ele.get("fps", FPS) + min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), + FRAME_FACTOR) + max_frames = floor_by_factor( + ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), + FRAME_FACTOR) + nframes = total_frames / video_fps * fps + nframes = min(max(nframes, min_frames), max_frames) + nframes = round_by_factor(nframes, FRAME_FACTOR) + if not (FRAME_FACTOR <= nframes and nframes <= total_frames): + raise ValueError( + f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}." + ) + return nframes + + +def _read_video_torchvision(ele: dict, ) -> torch.Tensor: + """read video using torchvision.io.read_video + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + video_path = ele["video"] + if version.parse(torchvision.__version__) < version.parse("0.19.0"): + if "http://" in video_path or "https://" in video_path: + warnings.warn( + "torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0." + ) + if "file://" in video_path: + video_path = video_path[7:] + st = time.time() + video, audio, info = io.read_video( + video_path, + start_pts=ele.get("video_start", 0.0), + end_pts=ele.get("video_end", None), + pts_unit="sec", + output_format="TCHW", + ) + total_frames, video_fps = video.size(0), info["video_fps"] + total_duration = round(total_frames / video_fps, 3) + logger.info( + f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, duration={total_duration}s, time={time.time() - st:.3f}s" + ) + nframes = smart_nframes(ele, + total_frames=total_frames, + video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long() + video = video[idx] + return video, total_duration, nframes + + +def is_decord_available() -> bool: + import importlib.util + + return importlib.util.find_spec("decord") is not None + + +def _read_video_decord(ele: dict, ) -> torch.Tensor: + """read video using decord.VideoReader + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + import decord + video_path = ele["video"] + st = time.time() + vr = decord.VideoReader(video_path) + # TODO: support start_pts and end_pts + if 'video_start' in ele or 'video_end' in ele: + raise NotImplementedError( + "not support start_pts and end_pts in decord for now.") + total_frames, video_fps = len(vr), vr.get_avg_fps() + total_duration = round(total_frames / video_fps, 3) + logger.info( + f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" + ) + nframes = smart_nframes(ele, + total_frames=total_frames, + video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() + video = vr.get_batch(idx).asnumpy() + video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format + return video, total_duration, nframes + + +VIDEO_READER_BACKENDS = { + "decord": _read_video_decord, + "torchvision": _read_video_torchvision, +} + +FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) + + +@lru_cache(maxsize=1) +def get_video_reader_backend() -> str: + if FORCE_QWENVL_VIDEO_READER is not None: + video_reader_backend = FORCE_QWENVL_VIDEO_READER + elif is_decord_available(): + video_reader_backend = "decord" + else: + video_reader_backend = "torchvision" + # print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr) + return video_reader_backend + + +def fetch_video( + ele: dict, + image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: + if isinstance(ele["video"], str): + video_reader_backend = get_video_reader_backend() + video, total_dur, nframes = VIDEO_READER_BACKENDS[ + video_reader_backend](ele) + frame_timestamps = total_dur * torch.arange(1, nframes + 1) / nframes + grid_timestamps = frame_timestamps[::FRAME_FACTOR] + second_per_grid = grid_timestamps[1] - grid_timestamps[0] + nframes, _, height, width = video.shape + factor = spatial_patch_size * spatial_merge_size + min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) + total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) + max_pixels = max( + min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), + int(min_pixels * 1.05)) + max_pixels = ele.get("max_pixels", max_pixels) + # min_pixels = (factor ** 2) * 52 + # max_pixels = (factor ** 2) * min(768, (16384 / nframes * temporal_patch_size)) + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=image_factor, + ) + else: + resized_height, resized_width = smart_resize( + height, + width, + factor=image_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + video = transforms.functional.resize( + video, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() + return video, total_dur, nframes, second_per_grid + else: + assert isinstance(ele["video"], (list, tuple)) + process_info = ele.copy() + process_info.pop("type", None) + process_info.pop("video", None) + images = [ + fetch_image({ + "image": video_element, + **process_info + }, + size_factor=image_factor) + for video_element in ele["video"] + ] + nframes = ceil_by_factor(len(images), FRAME_FACTOR) + if len(images) < nframes: + images.extend([images[-1]] * (nframes - len(images))) + return images, None, None, None + + +def extract_vision_info( + conversations: list[dict] | list[list[dict]]) -> list[dict]: + vision_infos = [] + if isinstance(conversations[0], dict): + conversations = [conversations] + for conversation in conversations: + for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ("image" in ele or "image_url" in ele or "video" in ele + or ele["type"] in ("image", "image_url", "video")): + vision_infos.append(ele) + return vision_infos + + +def process_vision_info( + conversations: list[dict] | list[list[dict]], +) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] + | None]: + vision_infos = extract_vision_info(conversations) + ## Read images or videos + image_inputs = [] + video_inputs = [] + for vision_info in vision_infos: + if "image" in vision_info or "image_url" in vision_info: + image_inputs.append(fetch_image(vision_info)) + elif "video" in vision_info: + video_inputs.append(fetch_video(vision_info)) + else: + raise ValueError("image, image_url or video should in content.") + if len(image_inputs) == 0: + image_inputs = None + if len(video_inputs) == 0: + video_inputs = None + return image_inputs, video_inputs \ No newline at end of file diff --git a/examples/offline_inference/qwen_2_5_omni/run.sh b/examples/offline_inference/qwen_2_5_omni/run.sh new file mode 100644 index 00000000000..b858099ae64 --- /dev/null +++ b/examples/offline_inference/qwen_2_5_omni/run.sh @@ -0,0 +1,9 @@ +export PYTHONPATH=/home/dyvm6xra/dyvm6xrauser08/gh/vllm_project/vllm:$PYTHONPATH +export HF_ENDPOINT=https://hf-mirror.com +python end2end.py --model Qwen/Qwen2.5-Omni-7B \ + --prompts "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." \ + --voice-type "m02" \ + --dit-ckpt none \ + --bigvgan-ckpt none \ + --output-wav output_audio \ + --prompt_type text \ No newline at end of file diff --git a/examples/offline_inference/qwen_2_5_omni/utils.py b/examples/offline_inference/qwen_2_5_omni/utils.py new file mode 100644 index 00000000000..049cd710d0c --- /dev/null +++ b/examples/offline_inference/qwen_2_5_omni/utils.py @@ -0,0 +1,307 @@ +import tempfile +from urllib.request import urlopen +import librosa +import soundfile as sf +import resampy +from typing import Dict, Optional +import torch +import requests +import torchvision.io + +from typing import Union, List +from vllm.inputs import TextPrompt +from vllm.sampling_params import SamplingParams +from vllm_omni.inputs.data import OmniTokensPrompt +from processing_omni import fetch_image, fetch_video + + +def get_system_prompt(): + + return { + 'role': + 'system', + 'content': [{ + 'type': + 'text', + 'text': + 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.' + }] + } + + +def resample_wav_to_16khz(input_filepath): + data, original_sample_rate = sf.read(input_filepath) + # Only use the first channel + if len(data.shape) > 1: + data = data[:, 0] + # resample to 16kHz + data_resampled = resampy.resample(data, + sr_orig=original_sample_rate, + sr_new=16000) + return data_resampled + + +def fetch_and_read_video(args, video_url: str, fps=2): + + def read_video_with_torchvision(video_file_name: str): + video, audio, info = torchvision.io.read_video( + video_file_name, + start_pts=0.0, + end_pts=None, + pts_unit="sec", + output_format="TCHW", + ) + + total_frames, video_fps = video.size(0), info["video_fps"] + total_duration = round(total_frames / video_fps, 3) + nframes = int(total_frames / video_fps * fps) + + frame_timestamps = total_duration * torch.arange(1, + nframes + 1) / nframes + grid_timestamps = frame_timestamps[::2] + second_per_grid = grid_timestamps[1] - grid_timestamps[0] + + idx = torch.linspace(0, video.size(0) - 1, nframes).round().long() + video_height, video_width = video.shape[2:] + video = video[idx] + + if args.legacy_omni_video: + return [video, total_duration, nframes, second_per_grid.item()] + else: + return video + + def read_video_with_transformers(video_file_name: Union[str, List[str]]): + video, total_duration, nframes, second_per_grid = fetch_video( + {'video': video_file_name}) + if total_duration is None and nframes is None: + nframes = len(video) + total_duration = 0.5 * nframes + second_per_grid = 1.0 + if args.legacy_omni_video: + return [video, total_duration, nframes, second_per_grid] + else: + return video + + def read_video(video_file_name: str): + if args.use_torchvision: + return read_video_with_torchvision(video_file_name) + else: + return read_video_with_transformers(video_file_name) + + if isinstance(video_url, str) and video_url.startswith("http"): + with tempfile.NamedTemporaryFile(delete=True) as temp_video_file: + resp = requests.get(video_url) + assert resp.status_code == requests.codes.ok, f"Failed to fetch video from {video_url}, status_code:{resp.status_code}, resp:{resp}" + + temp_video_file.write(urlopen(video_url).read()) + temp_video_file_path = temp_video_file.name + video_file_name = temp_video_file_path + return read_video(video_file_name) + else: + video_file_name = video_url + return read_video(video_file_name) + + +def make_inputs_qwen2_omni( + args, + messages: List[Dict[str, Union[str, List[Dict[str, str]]]]], + use_audio_in_video: Optional[bool] = False, + tokenize: bool = False, +) -> Union[OmniTokensPrompt, TextPrompt]: + + from transformers import AutoConfig, AutoProcessor, AutoTokenizer + processor = AutoProcessor.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model) + + try: + config = AutoConfig.from_pretrained(args.model) + if 'Qwen2_5OmniModel' in config.architectures: + args.legacy_omni_video = False + else: + args.legacy_omni_video = True + except: + args.legacy_omni_video = True + + audios, images, videos = [], [], [] + for message in messages: + if not isinstance(message['content'], list): + message['content'] = [{ + 'type': 'text', + 'text': message['content'], + }] + index, num_contents = 0, len(message['content']) + while index < num_contents: + ele = message['content'][index] + if 'type' not in ele: + if 'text' in ele: + ele['type'] = 'text' + elif 'audio' in ele: + ele['type'] = 'audio' + elif 'audio_url' in ele: + ele['type'] = 'audio_url' + elif 'image' in ele: + ele['type'] = 'image' + elif 'image_url' in ele: + ele['type'] = 'image_url' + elif 'video' in ele: + ele['type'] = 'video' + elif 'video_url' in ele: + ele['type'] = 'video_url' + else: + raise ValueError(f'Unknown ele: {ele}') + + if ele['type'] == 'audio' or ele['type'] == 'audio_url': + if 'audio_url' in ele: + audio_key = 'audio_url' + with tempfile.NamedTemporaryFile( + delete=True) as temp_audio_file: + temp_audio_file.write(urlopen(ele[audio_key]).read()) + temp_audio_file_path = temp_audio_file.name + audios.append( + resample_wav_to_16khz(temp_audio_file_path)) + ele['audio'] = temp_audio_file_path + elif 'audio' in ele: + audio_key = 'audio' + audios.append(resample_wav_to_16khz(ele[audio_key])) + else: + raise ValueError(f'Unknown ele {ele}') + elif use_audio_in_video and (ele['type'] == 'video' + or ele['type'] == 'video_url'): + # use video as audio as well + if 'video_url' in ele: + audio_key = 'video_url' + with tempfile.NamedTemporaryFile( + delete=True) as temp_video_file: + temp_video_file.write(urlopen(ele[audio_key]).read()) + temp_video_file_path = temp_video_file.name + ele[audio_key] = temp_video_file_path + audios.append( + librosa.load(temp_video_file_path, sr=16000)[0]) + videos.append( + fetch_and_read_video(args, temp_video_file_path)) + ele['video'] = temp_video_file_path + elif 'video' in ele: + audio_key = 'video' + audios.append(librosa.load(ele[audio_key], sr=16000)[0]) + videos.append(fetch_and_read_video(args, audio_key)) + else: + raise ValueError("Unknown ele {}".format(ele)) + # insert a audio after the video + message['content'].insert(index + 1, { + "type": "audio", + "audio": ele[audio_key], + }) + # no need to load the added audio again + index += 1 + elif ele['type'] == 'video' or ele['type'] == 'video_url': + if 'video_url' in ele: + video_key = 'video_url' + with tempfile.NamedTemporaryFile( + delete=True) as temp_video_file: + temp_video_file.write(urlopen(ele['video_url']).read()) + temp_video_file_path = temp_video_file.name + videos.append(fetch_and_read_video(args, temp_video_file)) + ele['video'] = temp_video_file_path + else: + video_key = 'video' + videos.append(fetch_and_read_video(args, ele[video_key])) + elif ele['type'] == 'image' or ele['type'] == 'image_url': + images.append(fetch_image(ele)) + + # move to the next content + index += 1 + + prompt = processor.apply_chat_template( + messages, + tokenize=tokenize, + add_generation_prompt=True, + add_vision_id=True, + ) + + audios = audios if len(audios) > 0 else None + images = images if len(images) > 0 else None + videos = videos if len(videos) > 0 else None + + multi_modal_data = {} + if audios: + multi_modal_data["audio"] = audios + if images: + multi_modal_data["image"] = images + if videos: + multi_modal_data["video"] = videos + + if isinstance(prompt, list) and isinstance(prompt[0], (list, str)): + prompt = prompt[0] + + if tokenize: + return OmniTokensPrompt( + prompt_token_ids=prompt, + multi_modal_data=multi_modal_data, + ) + else: + return TextPrompt( + prompt=prompt, + multi_modal_data=multi_modal_data, + ) + + +def make_text_prompt(args, prompt): + messages = [ + get_system_prompt(), + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + }, + ] + }, + ] + + prompt = make_inputs_qwen2_omni(args, messages, tokenize=args.tokenize) + return prompt + + +def make_audio_in_video_v2_prompt(args): + messages = [ + { + 'role': + 'system', + 'content': [{ + 'type': + 'text', + 'text': + 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.' + }] + }, + { + "role": + "user", + "content": [ + { + "type": + "video_url", + "video_url": + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/draw_small.mp4" + }, + ] + }, + ] + prompt = make_inputs_qwen2_omni( + args, + messages, + use_audio_in_video=True, + tokenize=args.tokenize, + ) + return prompt + + +def make_omni_prompt(args, prompt = None) -> Union[OmniTokensPrompt, List[OmniTokensPrompt]]: + if args.prompt_type == 'text': + prompt = make_text_prompt(args, prompt) + elif args.prompt_type == 'audio-in-video-v2': + prompt = make_audio_in_video_v2_prompt(args) + else: + raise ValueError(f'Unsupported prompt type: {args.prompt_type}') + return prompt \ No newline at end of file diff --git a/examples/omni/README.md b/examples/omni/README.md deleted file mode 100644 index 69274991ba6..00000000000 --- a/examples/omni/README.md +++ /dev/null @@ -1,31 +0,0 @@ -# Omni Examples - -Examples showcasing multi-stage AR + DiT pipelines powered by vLLM-omni. - -## AR ➜ DiT with Diffusers Backend - -```bash -# Optional: customise models -export AR_MODEL=Qwen/Qwen3-0.6B -export DIT_MODEL=./models/stable-diffusion-2-1/ - -# Run with defaults (20 steps @ 512x512) -python examples/omni/ar_dit_diffusers.py - -# Faster run -python examples/omni/ar_dit_diffusers.py --steps 14 --height 384 --width 384 --guidance 4.5 -``` - -Arguments: -- `--ar-model`: Override AR stage model (default `Qwen/Qwen3-0.6B`). -- `--dit-model`: Override DiT stage model (default env `DIT_MODEL`, else diffusers repo). -- `--steps`: Number of diffusion steps (default 20). -- `--guidance`: Guidance scale (default 5.0). -- `--height` / `--width`: Output resolution (default 512). -- `--seed`: Optional RNG seed. -- `--prompt`, `--temperature`, `--max-tokens`: Control AR stage generation. -- `--output`: Destination path for the generated image file. - -Generated images are saved next to the script as `omni_dit_output.png` by default. - -See the script for additional tweaks (scheduler, prompt chaining, etc.). diff --git a/examples/omni/ar_dit_diffusers.py b/examples/omni/ar_dit_diffusers.py deleted file mode 100644 index 5685b5dc3c4..00000000000 --- a/examples/omni/ar_dit_diffusers.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Run the AR → DiT (diffusers) pipeline using YAML configuration defaults.""" - -from __future__ import annotations - -import argparse -import os -from pathlib import Path -from typing import Dict, List - -import yaml - -from vllm_omni import ( - OmniLLM, - DiTConfig, - create_ar_stage_config, - create_dit_stage_config, -) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Run a two-stage AR → DiT pipeline with config-derived defaults." - ) - parser.add_argument( - "--config", - type=Path, - default=Path(__file__).resolve().parent / "configs" / "ar_dit_local.yaml", - help="Path to a YAML configuration describing the pipeline stages.", - ) - parser.add_argument( - "--prompt", - default="A scenic watercolor painting of a lighthouse at sunset", - help="Prompt passed to the AR stage.", - ) - parser.add_argument( - "--negative-prompt", - default=None, - help="Optional negative prompt forwarded to the diffusion stage.", - ) - parser.add_argument( - "--seed", - type=int, - default=None, - help="Optional seed for deterministic diffusion sampling.", - ) - parser.add_argument( - "--output", - type=Path, - default=Path("./omni_dit_output.png"), - help="Destination path for the generated image.", - ) - return parser.parse_args() - - -def _normalize_modalities(values: List[str] | None, default: List[str]) -> List[str]: - if not values: - return default - return list(values) - - -def _load_stage_configs_from_yaml(config_path: Path): - with config_path.open("r", encoding="utf-8") as f: - config = yaml.safe_load(f) or {} - - stages = config.get("stages", []) - if not stages: - raise ValueError(f"No stages defined in config: {config_path}") - - stage_configs = [] - for stage in stages: - stage_id = stage["stage_id"] - engine_type = stage["engine_type"] - default_stage_args = stage.get("default_stage_args") - - if engine_type == "AR": - stage_configs.append( - create_ar_stage_config( - stage_id=stage_id, - model_path=stage["model_path"], - input_modalities=_normalize_modalities( - stage.get("input_modalities"), ["text"] - ), - output_modalities=_normalize_modalities( - stage.get("output_modalities"), ["text"] - ), - default_stage_args=default_stage_args, - ) - ) - elif engine_type == "DiT": - dit_cfg_dict: Dict = dict(stage.get("dit_config", {})) - if "num_inference_steps" not in dit_cfg_dict: - raise ValueError( - "DiT stage requires 'num_inference_steps' in dit_config" - ) - dit_cfg = DiTConfig(**dit_cfg_dict) - - stage_configs.append( - create_dit_stage_config( - stage_id=stage_id, - model_path=stage["model_path"], - input_modalities=_normalize_modalities( - stage.get("input_modalities"), ["text"] - ), - output_modalities=_normalize_modalities( - stage.get("output_modalities"), ["image"] - ), - dit_config=dit_cfg, - default_stage_args=default_stage_args, - ) - ) - else: - raise ValueError(f"Unsupported engine_type '{engine_type}' in config") - - return stage_configs - - -def _apply_env_overrides(stage_configs): - ar_override = os.environ.get("AR_MODEL") - dit_override = os.environ.get("DIT_MODEL") - - for stage_config in stage_configs: - if ar_override and stage_config.engine_type == "AR": - stage_config.model_path = ar_override - if dit_override and stage_config.engine_type == "DiT": - stage_config.model_path = dit_override - - -def main(): - args = parse_args() - - stage_configs = _load_stage_configs_from_yaml(args.config) - _apply_env_overrides(stage_configs) - - omni = OmniLLM(stage_configs) - - stage_overrides: Dict[int, Dict[str, object]] = {} - if args.seed is not None or args.negative_prompt: - for stage_config in stage_configs: - if stage_config.engine_type != "DiT": - continue - override = stage_overrides.setdefault(stage_config.stage_id, {}) - if args.seed is not None: - override["seed"] = args.seed - if args.negative_prompt: - override["negative_prompt"] = args.negative_prompt - - outputs = omni.generate( - prompt=args.prompt, - stage_overrides=stage_overrides if stage_overrides else None, - ) - - image = None - for request_output in outputs: - for completion in getattr(request_output, "outputs", []) or []: - candidate = getattr(completion, "image", None) - if candidate is not None: - image = candidate - break - if image is not None: - break - - if image is None: - print("No image found in outputs.") - return - - out_path = args.output.resolve() - try: - image.save(out_path) - print(f"Saved image to {out_path}") - except Exception: - print("Generated output is not a PIL.Image; skipping save.") - - -if __name__ == "__main__": - main() diff --git a/examples/omni/configs/ar_dit_local.yaml b/examples/omni/configs/ar_dit_local.yaml deleted file mode 100644 index 5bae87e3c1e..00000000000 --- a/examples/omni/configs/ar_dit_local.yaml +++ /dev/null @@ -1,28 +0,0 @@ -# Example config for running OmniLLM with local AR + DiT models. -stages: - - stage_id: 0 - engine_type: AR - model_path: Qwen/Qwen3-0.6B - input_modalities: [text] - output_modalities: [text] - default_stage_args: - max_tokens: 16 - temperature: 0.7 - - stage_id: 1 - engine_type: DiT - model_path: stabilityai/stable-diffusion-2-1 - input_modalities: [text] - output_modalities: [image] - default_stage_args: - height: 512 - width: 512 - dit_config: - num_inference_steps: 20 - guidance_scale: 5.0 - use_diffusers: true - diffusers_pipeline: auto - height: 512 - width: 512 - device_config: - device: mps - dtype: float16 diff --git a/scripts/README.md b/scripts/README.md deleted file mode 100644 index 6f4e90bb5e5..00000000000 --- a/scripts/README.md +++ /dev/null @@ -1,109 +0,0 @@ -# vLLM-omni Testing Scripts - -This directory contains automated testing scripts for vLLM-omni serving functionality. - -## Scripts Overview - -### 1. `test_serving.sh` - Comprehensive Test Suite -Full-featured testing script that performs complete validation of the vLLM-omni serving functionality. - -**Features:** -- Model existence validation -- Environment setup verification -- Import functionality testing -- Server startup and health checks -- Text generation testing -- Performance benchmarking -- API client integration testing -- Comprehensive logging and error handling - -**Usage:** -```bash -# Use default model and port -./scripts/test_serving.sh - -# Use specific model and port -./scripts/test_serving.sh ./models/Qwen3-0.6B 8000 - -# Use HuggingFace model -./scripts/test_serving.sh Qwen/Qwen3-0.6B 8001 - -# Show help -./scripts/test_serving.sh --help -``` - -**Test Coverage:** -- ✅ Model loading and server startup -- ✅ Health and info endpoints -- ✅ Text generation functionality -- ✅ Performance metrics -- ✅ API client integration -- ✅ All imports working correctly - -### 2. `quick_test.sh` - Fast Validation -Lightweight script for quick validation after making changes. - -**Features:** -- Fast import testing -- Basic server startup -- Health endpoint validation -- Simple text generation test -- Retry mechanism for reliability - -**Usage:** -```bash -# Use default port (8000) -./scripts/quick_test.sh - -# Use specific port -./scripts/quick_test.sh 8001 -``` - -**Test Coverage:** -- ✅ Import functionality -- ✅ Server startup -- ✅ Health endpoint -- ✅ Basic text generation - -## Prerequisites - -1. **Conda Environment**: Ensure `vllm_omni` environment is activated -2. **Model Available**: Qwen3-0.6B model should be available in `./models/Qwen3-0.6B/` -3. **Dependencies**: All vLLM-omni dependencies should be installed - -## Environment Setup - -```bash -# Activate conda environment -conda activate vllm_omni - -# Verify installation -python -c "import vllm_omni; print('vLLM-omni ready')" -``` - -## Model Setup - -The scripts expect the Qwen3-0.6B model to be available. You can: - -1. **Use local model**: Place model in `./models/Qwen3-0.6B/` -2. **Use HuggingFace model**: Pass `Qwen/Qwen3-0.6B` as model path -3. **Download model**: Use the `download_models.py` script - -```bash -# Download model using the provided script -python scripts/download_models.py -``` - -## Output and Logging - -### Test Results -- **Success**: Green `[SUCCESS]` messages -- **Info**: Blue `[INFO]` messages -- **Warnings**: Yellow `[WARNING]` messages -- **Errors**: Red `[ERROR]` messages - -### Log Files -- `server.log`: Server startup and runtime logs -- `api_client_test.log`: API client test results -- `simple_usage_test.log`: Simple usage test results - diff --git a/scripts/test_serving.sh b/scripts/test_serving.sh deleted file mode 100755 index d3a8d6cda51..00000000000 --- a/scripts/test_serving.sh +++ /dev/null @@ -1,363 +0,0 @@ -#!/bin/bash - -# vLLM-omni Serving Functionality Test Script -# This script tests the complete serving functionality of vLLM-omni -# Usage: ./scripts/test_serving.sh [model_path] [port] - -set -e # Exit on any error - -# Configuration -MODEL_PATH=${1:-"./models/Qwen3-0.6B"} -PORT=${2:-8000} -HOST="localhost" -TIMEOUT=30 -SERVER_STARTUP_WAIT=15 - -# Colors for output -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# Logging functions -log_info() { - echo -e "${BLUE}[INFO]${NC} $1" -} - -log_success() { - echo -e "${GREEN}[SUCCESS]${NC} $1" -} - -log_warning() { - echo -e "${YELLOW}[WARNING]${NC} $1" -} - -log_error() { - echo -e "${RED}[ERROR]${NC} $1" -} - -# Cleanup function -cleanup() { - log_info "Cleaning up..." - if [ ! -z "$SERVER_PID" ]; then - log_info "Stopping server (PID: $SERVER_PID)..." - kill $SERVER_PID 2>/dev/null || true - wait $SERVER_PID 2>/dev/null || true - fi - # Kill any remaining vllm processes - pkill -f "vllm serve" 2>/dev/null || true - log_info "Cleanup completed" -} - -# Set up trap for cleanup on exit -trap cleanup EXIT - -# Check if model exists -check_model() { - log_info "Checking if model exists at: $MODEL_PATH" - if [ ! -d "$MODEL_PATH" ]; then - log_error "Model directory not found: $MODEL_PATH" - log_info "Available models:" - ls -la models/ 2>/dev/null || log_warning "No models directory found" - exit 1 - fi - log_success "Model found: $MODEL_PATH" -} - -# Check if conda environment is activated -check_environment() { - log_info "Checking conda environment..." - if [[ "$CONDA_DEFAULT_ENV" != "vllm_omni" ]]; then - log_warning "vllm_omni conda environment not activated" - log_info "Attempting to activate vllm_omni environment..." - source $(conda info --base)/etc/profile.d/conda.sh - conda activate vllm_omni - if [[ "$CONDA_DEFAULT_ENV" != "vllm_omni" ]]; then - log_error "Failed to activate vllm_omni environment" - exit 1 - fi - fi - log_success "Environment check passed: $CONDA_DEFAULT_ENV" -} - -# Test import functionality -test_imports() { - log_info "Testing imports..." - python -c " -import vllm_omni -from vllm_omni.entrypoints.omni_llm import OmniLLM, AsyncOmniLLM -print('✅ All imports successful') -" || { - log_error "Import test failed" - exit 1 - } - log_success "Import test passed" -} - -# Start the server -start_server() { - log_info "Starting vLLM-omni server on port $PORT..." - log_info "Command: vllm serve $MODEL_PATH --omni --port $PORT" - - # Start server in background - vllm serve "$MODEL_PATH" --omni --port "$PORT" > server.log 2>&1 & - SERVER_PID=$! - - log_info "Server started with PID: $SERVER_PID" - log_info "Waiting $SERVER_STARTUP_WAIT seconds for server to initialize..." - sleep $SERVER_STARTUP_WAIT - - # Check if server is still running - if ! kill -0 $SERVER_PID 2>/dev/null; then - log_error "Server failed to start" - log_info "Server logs:" - cat server.log - exit 1 - fi - - log_success "Server appears to be running" -} - -# Test health endpoint -test_health() { - log_info "Testing health endpoint..." - for attempt in {1..10}; do - local response=$(curl -s -w "%{http_code}" http://$HOST:$PORT/health 2>/dev/null || echo "000") - local http_code="${response: -3}" - local body="${response%???}" - - if [ "$http_code" = "200" ]; then - log_success "Health check passed: $body" - return 0 - fi - - log_warning "Health check attempt $attempt failed (HTTP $http_code). Retrying in 3s..." - sleep 3 - done - - log_error "Health check failed after multiple attempts" - return 1 -} - -# Test info endpoint -test_info() { - log_info "Testing info endpoint..." - local response=$(curl -s -w "%{http_code}" http://$HOST:$PORT/info 2>/dev/null || echo "000") - local http_code="${response: -3}" - local body="${response%???}" - - if [ "$http_code" = "200" ]; then - log_success "Info endpoint working" - echo "$body" | python -m json.tool 2>/dev/null || log_warning "Info response not valid JSON" - else - log_error "Info endpoint failed (HTTP $http_code): $body" - return 1 - fi -} - -# Test text generation -test_generation() { - log_info "Testing text generation..." - - # Test 1: Simple generation - local response=$(curl -s -X POST http://$HOST:$PORT/generate \ - -H "Content-Type: application/json" \ - -d '{"prompts": ["Test the server functionality"], "max_tokens": 20, "temperature": 0.7}' \ - 2>/dev/null || echo "{}") - - if echo "$response" | python -c "import sys, json; data=json.load(sys.stdin); exit(0 if 'outputs' in data and len(data['outputs']) > 0 else 1)" 2>/dev/null; then - log_success "Text generation test passed" - echo "$response" | python -c "import sys, json; data=json.load(sys.stdin); print('Generated text:', data['outputs'][0]['text'][:100] + '...' if len(data['outputs'][0]['text']) > 100 else data['outputs'][0]['text'])" 2>/dev/null - else - log_error "Text generation test failed" - echo "Response: $response" - return 1 - fi -} - -# Test API client example -test_api_client() { - log_info "Testing API client example..." - if [ -f "examples/basic/api_client.py" ]; then - python examples/basic/api_client.py > api_client_test.log 2>&1 || { - log_error "API client test failed" - log_info "API client logs:" - cat api_client_test.log - return 1 - } - log_success "API client test passed" - else - log_warning "API client example not found, skipping" - fi -} - -# Test AR → DiT pipeline example -test_ar_dit_pipeline() { - log_info "Testing AR → DiT diffusers pipeline example..." - - if ! python - 2>/dev/null <<'PY' -import importlib -for pkg in ("torch", "diffusers"): - if importlib.util.find_spec(pkg) is None: - raise ImportError(pkg) -PY - then - log_warning "Required packages for diffusers pipeline not found, skipping AR → DiT test" - return 0 - fi - - if [ ! -f "examples/omni/ar_dit_diffusers.py" ]; then - log_warning "AR → DiT example not found, skipping" - return 0 - fi - - local output_path="logs/ar_dit_pipeline.png" - local log_path="logs/ar_dit_pipeline.log" - rm -f "$output_path" - - local cmd=(python examples/omni/ar_dit_diffusers.py - --prompt "Smoke test prompt of a lighthouse" - --seed 0 - --output "$output_path") - - local env_prefix=() - if [ -n "$AR_MODEL" ]; then - env_prefix+=("AR_MODEL=$AR_MODEL") - fi - if [ -n "$DIT_MODEL" ]; then - env_prefix+=("DIT_MODEL=$DIT_MODEL") - fi - - if ! env "${env_prefix[@]}" "${cmd[@]}" >"$log_path" 2>&1; then - log_error "AR → DiT pipeline test failed" - log_info "AR → DiT logs:" - cat "$log_path" - return 1 - fi - - if [ ! -f "$output_path" ]; then - log_error "AR → DiT pipeline did not produce an image" - log_info "AR → DiT logs:" - cat "$log_path" - return 1 - fi - - log_success "AR → DiT pipeline test passed (output: $output_path)" -} - -# Test simple usage example -test_simple_usage() { - log_info "Testing simple usage example..." - if [ -f "examples/basic/simple_usage.py" ]; then - # Run with timeout to avoid hanging - timeout 60 python examples/basic/simple_usage.py > simple_usage_test.log 2>&1 || { - local exit_code=$? - if [ $exit_code -eq 124 ]; then - log_warning "Simple usage test timed out (60s), but this might be normal for model loading" - else - log_error "Simple usage test failed (exit code: $exit_code)" - log_info "Simple usage logs:" - cat simple_usage_test.log - return 1 - fi - } - log_success "Simple usage test passed" - else - log_warning "Simple usage example not found, skipping" - fi -} - -# Performance test -test_performance() { - log_info "Running performance test..." - - local start_time=$(date +%s) - local response=$(curl -s -X POST http://$HOST:$PORT/generate \ - -H "Content-Type: application/json" \ - -d '{"prompts": ["Performance test prompt"], "max_tokens": 50, "temperature": 0.7}' \ - 2>/dev/null) - local end_time=$(date +%s) - local duration=$((end_time - start_time)) - - if echo "$response" | python -c "import sys, json; data=json.load(sys.stdin); exit(0 if 'outputs' in data else 1)" 2>/dev/null; then - log_success "Performance test passed (${duration}s response time)" - else - log_error "Performance test failed" - return 1 - fi -} - -# Main test function -run_tests() { - log_info "Starting vLLM-omni serving functionality tests..." - log_info "Model: $MODEL_PATH" - log_info "Port: $PORT" - log_info "Host: $HOST" - echo "==========================================" - - # Run all tests - check_model - check_environment - test_imports - start_server - - # Wait a bit more for server to be fully ready - sleep 5 - - test_health - test_info - test_generation - test_performance - test_api_client - test_ar_dit_pipeline - - # Note: Simple usage test is commented out as it starts its own server - # test_simple_usage - - log_success "All tests completed successfully!" - echo "==========================================" - log_info "Test Summary:" - log_info "✅ Model loading and server startup" - log_info "✅ Health and info endpoints" - log_info "✅ Text generation functionality" - log_info "✅ Performance metrics" - log_info "✅ API client integration" - log_info "✅ AR → DiT diffusers pipeline" - log_info "✅ All imports working correctly" -} - -# Show usage -show_usage() { - echo "Usage: $0 [model_path] [port]" - echo "" - echo "Arguments:" - echo " model_path Path to the model directory (default: ./models/Qwen3-0.6B)" - echo " port Port to run the server on (default: 8000)" - echo "" - echo "Examples:" - echo " $0 # Use default model and port" - echo " $0 ./models/Qwen3-0.6B 8001 # Use specific model and port" - echo " $0 Qwen/Qwen3-0.6B 8000 # Use HuggingFace model" -} - -# Main execution -main() { - # Check for help flag - if [[ "$1" == "-h" || "$1" == "--help" ]]; then - show_usage - exit 0 - fi - - # Change to script directory - cd "$(dirname "$0")/.." - - # Create logs directory if it doesn't exist - mkdir -p logs - - # Run tests - run_tests -} - -# Run main function with all arguments -main "$@" diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index cdec3720174..00000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -Shared fixtures and configuration for vLLM-omni tests. -""" - -import pytest -import torch -from unittest.mock import Mock -from vllm_omni.config import OmniStageConfig, DiTConfig, DiTCacheConfig, create_ar_stage_config, create_dit_stage_config - - -@pytest.fixture(scope="session") -def device(): - """Get available device for testing.""" - return "cuda" if torch.cuda.is_available() else "cpu" - - -@pytest.fixture -def sample_ar_stage_config(): - """Sample AR stage configuration for testing.""" - return create_ar_stage_config( - stage_id=0, - model_path="test-ar-model", - input_modalities=["text"], - output_modalities=["text"] - ) - - -@pytest.fixture -def sample_dit_stage_config(): - """Sample DiT stage configuration for testing.""" - dit_config = DiTConfig( - model_type="dit", - scheduler_type="ddpm", - num_inference_steps=10, - guidance_scale=7.5 - ) - - return create_dit_stage_config( - stage_id=1, - model_path="test-dit-model", - input_modalities=["text"], - output_modalities=["image"], - dit_config=dit_config - ) - - -@pytest.fixture -def sample_stage_configs(sample_ar_stage_config, sample_dit_stage_config): - """Sample stage configurations for testing.""" - return [sample_ar_stage_config, sample_dit_stage_config] - - -@pytest.fixture -def mock_vllm_config(): - """Mock vLLM configuration.""" - config = Mock() - config.model = "test-model" - config.tensor_parallel_size = 1 - config.pipeline_parallel_size = 1 - return config - - -@pytest.fixture -def mock_dit_cache_config(): - """Mock DiT cache configuration.""" - from vllm_omni.config import DiTCacheTensor - - cache_tensors = [ - DiTCacheTensor( - name="test_tensor", - shape=[1, 512, 512], - dtype="float32", - persistent=True - ) - ] - - return DiTCacheConfig( - cache_tensors=cache_tensors, - max_cache_size=1024 * 1024 * 1024, # 1GB - cache_strategy="fifo", - enable_optimization=True - ) \ No newline at end of file diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py deleted file mode 100644 index 1d8229156e3..00000000000 --- a/tests/unit/test_config.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -Unit tests for configuration modules. -""" - -import pytest -from vllm_omni.config import ( - OmniStageConfig, - DiTConfig, - DiTCacheConfig, - DiTCacheTensor, - create_ar_stage_config, - create_dit_stage_config, -) - - -class TestOmniStageConfig: - def test_ar_stage_config_creation(self): - """Test AR stage configuration creation.""" - config = OmniStageConfig( - stage_id=0, - engine_type="AR", - model_path="test-model", - input_modalities=["text"], - output_modalities=["text"] - ) - assert config.engine_type == "AR" - assert config.stage_id == 0 - assert config.model_path == "test-model" - assert config.input_modalities == ["text"] - assert config.output_modalities == ["text"] - - def test_dit_stage_config_creation(self): - """Test DiT stage configuration creation.""" - dit_config = DiTConfig( - model_type="dit", - scheduler_type="ddpm", - num_inference_steps=50 - ) - config = OmniStageConfig( - stage_id=1, - engine_type="DiT", - model_path="test-dit-model", - input_modalities=["text"], - output_modalities=["image"], - dit_config=dit_config - ) - assert config.engine_type == "DiT" - assert config.dit_config is not None - assert config.dit_config.model_type == "dit" - - def test_invalid_engine_type(self): - """Test validation of engine type.""" - with pytest.raises(ValueError): - OmniStageConfig( - stage_id=0, - engine_type="INVALID", - model_path="test-model", - input_modalities=["text"], - output_modalities=["text"] - ) - - def test_empty_modalities(self): - """Test validation of empty modalities.""" - with pytest.raises(ValueError): - OmniStageConfig( - stage_id=0, - engine_type="AR", - model_path="test-model", - input_modalities=[], - output_modalities=["text"] - ) - - with pytest.raises(ValueError): - OmniStageConfig( - stage_id=0, - engine_type="AR", - model_path="test-model", - input_modalities=["text"], - output_modalities=[] - ) - - def test_dit_requires_config(self): - """Test that DiT engine requires dit_config.""" - with pytest.raises(ValueError): - OmniStageConfig( - stage_id=1, - engine_type="DiT", - model_path="test-dit-model", - input_modalities=["text"], - output_modalities=["image"] - ) - - -class TestDiTConfig: - def test_dit_config_defaults(self): - """Test DiT configuration with defaults.""" - config = DiTConfig( - model_type="dit", - scheduler_type="ddpm", - num_inference_steps=50 - ) - assert config.use_diffusers is False - assert config.diffusers_pipeline is None - assert config.guidance_scale == 7.5 - assert config.height == 512 - assert config.width == 512 - - def test_diffusers_config(self): - """Test DiT configuration with diffusers.""" - config = DiTConfig( - model_type="dit", - scheduler_type="ddpm", - num_inference_steps=50, - use_diffusers=True, - diffusers_pipeline="stable-diffusion" - ) - assert config.use_diffusers is True - assert config.diffusers_pipeline == "stable-diffusion" - - -class TestDiTCacheConfig: - def test_cache_config_creation(self): - """Test cache configuration creation.""" - cache_tensors = [ - DiTCacheTensor( - name="test_tensor", - shape=[1, 512, 512], - dtype="float32", - persistent=True - ) - ] - - config = DiTCacheConfig( - cache_tensors=cache_tensors, - max_cache_size=1024 * 1024 * 1024, - cache_strategy="fifo" - ) - - assert len(config.cache_tensors) == 1 - assert config.max_cache_size == 1024 * 1024 * 1024 - assert config.cache_strategy == "fifo" - assert config.enable_optimization is True - - -class TestHelperFunctions: - def test_create_ar_stage_config(self): - """Test create_ar_stage_config helper function.""" - config = create_ar_stage_config( - stage_id=0, - model_path="test-model" - ) - - assert config.stage_id == 0 - assert config.engine_type == "AR" - assert config.model_path == "test-model" - assert config.input_modalities == ["text"] - assert config.output_modalities == ["text"] - - def test_create_dit_stage_config(self): - """Test create_dit_stage_config helper function.""" - dit_config = DiTConfig( - model_type="dit", - scheduler_type="ddpm", - num_inference_steps=50 - ) - - config = create_dit_stage_config( - stage_id=1, - model_path="test-dit-model", - dit_config=dit_config - ) - - assert config.stage_id == 1 - assert config.engine_type == "DiT" - assert config.model_path == "test-dit-model" - assert config.input_modalities == ["text"] - assert config.output_modalities == ["image"] - assert config.dit_config is not None - diff --git a/vllm_omni/__init__.py b/vllm_omni/__init__.py index fb544b0faad..0758904cc4e 100644 --- a/vllm_omni/__init__.py +++ b/vllm_omni/__init__.py @@ -14,13 +14,10 @@ __email__ = "hsliuustc@gmail.com" # Main entry points -from .entrypoints.omni_llm import OmniLLM, AsyncOmniLLM +from . import patch +from .entrypoints.omni_lm import OmniLM from .config import ( - OmniStageConfig, - DiTConfig, - DiTCacheConfig, - create_ar_stage_config, - create_dit_stage_config, + OmniModelConfig, ) __all__ = [ @@ -30,15 +27,10 @@ "__email__", # Main components - "OmniLLM", - "AsyncOmniLLM", + "OmniLM", # Configuration - "OmniStageConfig", - "DiTConfig", - "DiTCacheConfig", - "create_ar_stage_config", - "create_dit_stage_config", + "OmniModelConfig", # All other components are available through their respective modules # processors.*, schedulers.*, executors.*, etc. diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py index e48a874a7ce..a14afbced5e 100644 --- a/vllm_omni/config/__init__.py +++ b/vllm_omni/config/__init__.py @@ -2,20 +2,33 @@ Configuration module for vLLM-omni. """ -from .stage_config import ( - OmniStageConfig, - DiTConfig, - DiTCacheConfig, - DiTCacheTensor, - create_ar_stage_config, - create_dit_stage_config, -) +from vllm.config import ModelConfig +from typing import Optional +from pydantic.dataclasses import dataclass +from pydantic import ConfigDict +from vllm.config import config +import vllm_omni.model_executor.models as me_models + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class OmniModelConfig(ModelConfig): + """Configuration for Omni models, extending the base ModelConfig.""" + + stage_id: int = 0 + model_stage: str = "thinker" + model_arch: str = "Qwen2_5OmniForConditionalGeneration" + engine_output_type: Optional[str] = None + + @property + def registry(self): + return me_models.OmniModelRegistry + + @property + def architectures(self) -> list[str]: + return [self.model_arch] + __all__ = [ - "OmniStageConfig", - "DiTConfig", - "DiTCacheConfig", - "DiTCacheTensor", - "create_ar_stage_config", - "create_dit_stage_config", + "OmniModelConfig", ] \ No newline at end of file diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py deleted file mode 100644 index 8cca98c63e6..00000000000 --- a/vllm_omni/config/stage_config.py +++ /dev/null @@ -1,166 +0,0 @@ -""" -Stage configuration for vLLM-omni multi-stage processing. -""" - -from dataclasses import dataclass -from typing import List, Optional, Type, Literal, Any, Dict -from vllm.config import ( - VllmConfig, - DeviceConfig, - LoadConfig, - ModelConfig, - CompilationConfig, -) -from vllm.executor.executor_base import ExecutorBase as Executor - -@dataclass -class DiTCacheTensor: - """Configuration for DiT cache tensors.""" - name: str - shape: List[int] - dtype: str = "float32" - persistent: bool = True - - -@dataclass -class DiTCacheConfig: - """Configuration for DiT caching system.""" - cache_tensors: List[DiTCacheTensor] - max_cache_size: int = 1024 * 1024 * 1024 # 1GB - cache_strategy: str = "fifo" # fifo, lru, lfu - enable_optimization: bool = True - cache_compression: bool = False - - -@dataclass -class DiTConfig: - """Configuration for DiT (Diffusion Transformer) stages.""" - - num_inference_steps: int - """Number of diffusion inference steps.""" - - model_config: Optional[ModelConfig] = None - """Model configuration.""" - scheduler_config: Optional[Any] = None - """Scheduler configuration.""" - device_config: Optional[DeviceConfig] = None - """Device configuration.""" - load_config: Optional[LoadConfig] = None - """Model loading configuration.""" - compilation_config: Optional[CompilationConfig] = None - """Compilation configuration.""" - dit_cache_config: Optional[DiTCacheConfig] = None - """DiT cache configuration.""" - guidance_scale: float = 7.5 - use_diffusers: bool = False - diffusers_pipeline: Optional[str] = None - height: int = 512 - width: int = 512 - batch_size: int = 1 - - -@dataclass -class OmniStageConfig: - """Configuration for a processing stage in vLLM-omni.""" - - stage_id: int - engine_type: Literal["AR", "DiT"] - model_path: str - input_modalities: List[str] - output_modalities: List[str] - vllm_config: Optional[VllmConfig] = None - dit_config: Optional[DiTConfig] = None - executor_class: Type[Executor] = None # Will be set based on engine_type - default_stage_args: Optional[Dict[str, Any]] = None - stage_output: Optional[Any] = None - - def __post_init__(self): - """Validate configuration after initialization.""" - if self.engine_type not in ["AR", "DiT"]: - raise ValueError(f"Invalid engine_type: {self.engine_type}. Must be 'AR' or 'DiT'") - - if self.engine_type == "DiT" and self.dit_config is None: - raise ValueError("DiT engine requires dit_config") - - if not self.input_modalities: - raise ValueError("input_modalities cannot be empty") - - if not self.output_modalities: - raise ValueError("output_modalities cannot be empty") - - def get_executor_class(self) -> Type[Executor]: - """Get the appropriate executor class for this stage.""" - if self.executor_class is not None: - return self.executor_class - - if self.engine_type == "AR": - from vllm.executor.uniproc_executor import UniProcExecutor - return UniProcExecutor - elif self.engine_type == "DiT": - if self.dit_config and self.dit_config.use_diffusers: - from vllm_omni.executor.diffusers_executor import DiffusersPipelineExecutor - return DiffusersPipelineExecutor - else: - from vllm.executor.uniproc_executor import UniProcExecutor - return UniProcExecutor - - raise ValueError(f"No executor class available for engine_type: {self.engine_type}") - - -def create_ar_stage_config( - stage_id: int, - model_path: str, - input_modalities: List[str] = None, - output_modalities: List[str] = None, - vllm_config: Optional[VllmConfig] = None, - executor_class: Optional[Type[Executor]] = None, - default_stage_args: Optional[Dict[str, Any]] = None, -) -> OmniStageConfig: - """Create a configuration for an AR (Autoregressive) stage.""" - if input_modalities is None: - input_modalities = ["text"] - if output_modalities is None: - output_modalities = ["text"] - - # For now, we'll create minimal configs - # vllm_config and executor_class will be handled by the LLM class - return OmniStageConfig( - stage_id=stage_id, - engine_type="AR", - model_path=model_path, - input_modalities=input_modalities, - output_modalities=output_modalities, - vllm_config=vllm_config, - executor_class=executor_class, - default_stage_args=default_stage_args, - ) - - -def create_dit_stage_config( - stage_id: int, - model_path: str, - input_modalities: Optional[List[str]] = None, - output_modalities: Optional[List[str]] = None, - dit_config: Optional[DiTConfig] = None, - executor_class: Optional[Type[Executor]] = None, - default_stage_args: Optional[Dict[str, Any]] = None, -) -> OmniStageConfig: - """Create a configuration for a DiT (Diffusion Transformer) stage.""" - if input_modalities is None: - input_modalities = ["text"] - if output_modalities is None: - output_modalities = ["image"] - - if dit_config is None: - dit_config = DiTConfig(num_inference_steps=50) - - return OmniStageConfig( - stage_id=stage_id, - engine_type="DiT", - model_path=model_path, - input_modalities=input_modalities, - output_modalities=output_modalities, - dit_config=dit_config, - executor_class=executor_class, - default_stage_args=default_stage_args, - ) diff --git a/vllm_omni/core/__init__.py b/vllm_omni/core/__init__.py index e69de29bb2d..07a546bd026 100644 --- a/vllm_omni/core/__init__.py +++ b/vllm_omni/core/__init__.py @@ -0,0 +1,10 @@ +""" +Core components for vLLM-omni. +""" + +from .dit_cache_manager import DiTCacheManager + +__all__ = [ + "DiTCacheManager", +] + diff --git a/vllm_omni/core/dit_cache_manager.py b/vllm_omni/core/dit_cache_manager.py index d2c800e76f5..d08c13a1f44 100644 --- a/vllm_omni/core/dit_cache_manager.py +++ b/vllm_omni/core/dit_cache_manager.py @@ -5,174 +5,8 @@ to optimize inference performance and memory usage. """ -import time -import torch -from typing import Dict, List, Optional, Any -from dataclasses import dataclass -from ..config import DiTCacheConfig, DiTCacheTensor - - class DiTCacheManager: """Manages DiT-specific caching for optimized inference.""" - def __init__(self, config: DiTCacheConfig): - self.config = config - self.cache_tensors: Dict[str, DiTCacheTensor] = {} - self.cache_groups: List[DiTCacheTensor] = config.cache_tensors - self.max_cache_size = config.max_cache_size - self.current_cache_size = 0 - self.cache_hits = 0 - self.cache_misses = 0 - self.cache_timeout = 3600 # 1 hour default timeout - - def allocate_cache(self, request_id: str, size: int) -> torch.Tensor: - """Allocate cache for a specific request.""" - # Check if we have enough space - if self.current_cache_size + size > self.max_cache_size: - self._evict_cache() - - # Create tensor - tensor = torch.zeros(size, dtype=torch.float32) - - # Store in cache - cache_tensor = DiTCacheTensor( - name=request_id, - tensor=tensor, - timestamp=time.time(), - access_count=0, - size_bytes=size * 4 # Assuming float32 - ) - - self.cache_tensors[request_id] = cache_tensor - self.current_cache_size += cache_tensor.size_bytes - - return tensor - - def get_cache(self, request_id: str) -> Optional[torch.Tensor]: - """Get cached tensor for a request.""" - if request_id in self.cache_tensors: - cache_tensor = self.cache_tensors[request_id] - cache_tensor.access_count += 1 - self.cache_hits += 1 - return cache_tensor.tensor - else: - self.cache_misses += 1 - return None - - def store_cache(self, request_id: str, tensor: torch.Tensor) -> None: - """Store a tensor in the cache.""" - if request_id in self.cache_tensors: - # Update existing cache - old_tensor = self.cache_tensors[request_id] - self.current_cache_size -= old_tensor.size_bytes - - new_size_bytes = tensor.numel() * 4 # Assuming float32 - self.current_cache_size += new_size_bytes - - self.cache_tensors[request_id] = DiTCacheTensor( - name=request_id, - tensor=tensor, - timestamp=time.time(), - access_count=old_tensor.access_count, - size_bytes=new_size_bytes - ) - else: - # Create new cache entry - new_size_bytes = tensor.numel() * 4 - if self.current_cache_size + new_size_bytes > self.max_cache_size: - self._evict_cache() - - self.cache_tensors[request_id] = DiTCacheTensor( - name=request_id, - tensor=tensor, - timestamp=time.time(), - access_count=0, - size_bytes=new_size_bytes - ) - self.current_cache_size += new_size_bytes - - def release_cache(self, request_id: str) -> None: - """Release cache for a request.""" - if request_id in self.cache_tensors: - cache_tensor = self.cache_tensors[request_id] - self.current_cache_size -= cache_tensor.size_bytes - del self.cache_tensors[request_id] - - def clear_expired_cache(self) -> None: - """Clear expired cache entries.""" - current_time = time.time() - expired_keys = [] - - for request_id, cache_tensor in self.cache_tensors.items(): - if current_time - cache_tensor.timestamp > self.cache_timeout: - expired_keys.append(request_id) - - for request_id in expired_keys: - self.release_cache(request_id) - - def _evict_cache(self) -> None: - """Evict cache entries based on strategy.""" - if self.config.cache_strategy == "fifo": - self._evict_fifo() - elif self.config.cache_strategy == "lru": - self._evict_lru() - elif self.config.cache_strategy == "lfu": - self._evict_lfu() - else: - self._evict_fifo() # Default to FIFO - - def _evict_fifo(self) -> None: - """Evict cache entries using FIFO strategy.""" - if not self.cache_tensors: - return - - # Find oldest entry - oldest_key = min(self.cache_tensors.keys(), - key=lambda k: self.cache_tensors[k].timestamp) - self.release_cache(oldest_key) - - def _evict_lru(self) -> None: - """Evict cache entries using LRU strategy.""" - if not self.cache_tensors: - return - - # Find least recently used entry - lru_key = min(self.cache_tensors.keys(), - key=lambda k: self.cache_tensors[k].access_count) - self.release_cache(lru_key) - - def _evict_lfu(self) -> None: - """Evict cache entries using LFU strategy.""" - if not self.cache_tensors: - return - - # Find least frequently used entry - lfu_key = min(self.cache_tensors.keys(), - key=lambda k: self.cache_tensors[k].access_count) - self.release_cache(lfu_key) - - def get_statistics(self) -> Dict[str, Any]: - """Get cache statistics.""" - total_requests = self.cache_hits + self.cache_misses - hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0 - - return { - "cache_size": self.current_cache_size, - "max_cache_size": self.max_cache_size, - "cache_utilization": self.current_cache_size / self.max_cache_size, - "cache_hits": self.cache_hits, - "cache_misses": self.cache_misses, - "hit_rate": hit_rate, - "num_cached_tensors": len(self.cache_tensors) - } - - def clear_all_cache(self) -> None: - """Clear all cache entries.""" - self.cache_tensors.clear() - self.current_cache_size = 0 - self.cache_hits = 0 - self.cache_misses = 0 - - def set_cache_timeout(self, timeout: float) -> None: - """Set cache timeout in seconds.""" - self.cache_timeout = timeout \ No newline at end of file + def __init__(self, config): + pass \ No newline at end of file diff --git a/vllm_omni/core/sched/__init__.py b/vllm_omni/core/sched/__init__.py index e69de29bb2d..3ef88c25db4 100644 --- a/vllm_omni/core/sched/__init__.py +++ b/vllm_omni/core/sched/__init__.py @@ -0,0 +1,14 @@ +""" +Scheduling components for vLLM-omni. +""" + +from .scheduler import OmniScheduler +from .diffusion_scheduler import DiffusionScheduler +from .output import OmniNewRequestData + +__all__ = [ + "OmniScheduler", + "DiffusionScheduler", + "OmniNewRequestData", +] + diff --git a/vllm_omni/core/sched/diffusion_scheduler.py b/vllm_omni/core/sched/diffusion_scheduler.py new file mode 100644 index 00000000000..55d096ac664 --- /dev/null +++ b/vllm_omni/core/sched/diffusion_scheduler.py @@ -0,0 +1,318 @@ +from vllm.v1.core.sched.scheduler import SchedulerOutput, EngineCoreOutputs, Request, RequestStatus, SpecDecodingStats, defaultdict, Optional +from vllm.v1.core.sched.request_queue import create_request_queue +from vllm.v1.engine import EngineCoreEventType +from vllm.distributed.kv_events import KVEventBatch +import time +from vllm_omni.core.sched.scheduler import OmniScheduler +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.engine import OmniEngineCoreOutput +from vllm_omni.core.sched.output import OmniNewRequestData + + +class DiffusionScheduler(OmniScheduler): + def schedule(self) -> SchedulerOutput: + """扩散快速通道: + - 一次性喂入该请求的全部输入 token(若为 0,则分配 1 个占位 token)。 + - 若无法一次性满足 token 预算,则退回上游 vLLM 的默认调度。 + """ + + # 选出零 prompt 且使用 pooling(扩散结果经 pooler_output 回传)的请求 + token_budget = self.max_num_scheduled_tokens + capacity = self.max_num_running_reqs - len(self.running) + scheduled_timestamp = time.monotonic() + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + + req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} + num_scheduled_tokens: dict[str, int] = {} + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + scheduled_encoder_inputs: dict[str, list[int]] = {} + structured_output_request_ids: dict[str, int] = {} + + # 临时队列:保持等待队列顺序,不破坏非扩散请求 + skipped_waiting_requests = create_request_queue(self.policy) + + # 快速通道挑选并调度(所有请求都视为扩散请求,不依赖 pooling_params) + while self.waiting and token_budget > 0 and capacity > 0: + request = self.waiting.peek_request() + # 统一按扩散处理。若未来需要条件开关,可接入配置或请求标记。 + is_diffusion = True + if not is_diffusion: + # 暂存到跳过队列,稍后归还到等待队列头部 + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # 一次性为该请求分配全部输入 token(若为 0,则分配 1 个占位 token) + required_tokens = max(getattr(request, "num_prompt_tokens", 0), 1) + if required_tokens > token_budget: + # 无足够预算一次性完成该请求的输入处理,停止快速通道尝试 + break + num_new_tokens = required_tokens + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, + ) + if new_blocks is None: + # 无法分配(显存紧张等),停止快速通道尝试,回退到默认调度 + # 将当前 request 放回等待队列头 + # 注意:此处不改变原队列顺序 + break + + # 正式调度该请求 + request = self.waiting.pop_request() + self.running.append(request) + request.status = RequestStatus.RUNNING + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + + req_to_new_block_ids[request.request_id] = new_blocks.get_block_ids() + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + capacity -= 1 + scheduled_new_reqs.append(request) + + # 归还被跳过的等待请求 + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) + + # 若快速通道未调度任何请求,则回退到原始调度逻辑 + if not num_scheduled_tokens: + return super().schedule() + + # 计算公共前缀块(与 v1 对齐) + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + grammar_bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + + # 组装 SchedulerOutput + new_reqs_data = [ + OmniNewRequestData.from_request(req, req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_block_ids, + ) + + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + finished_req_ids=self.finished_req_ids, + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + ) + + # KVTransfer:封装元信息 + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + # 发布 KV 事件(与 v1 对齐) + events = self.kv_cache_manager.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + # 更新内部状态(推进 num_computed_tokens,释放 encoder 输入等) + self._update_after_schedule(scheduler_output) + return scheduler_output + """ + Scheduler for the diffusion model. + This scheduler is modified to stop the request immediately for the diffusion model. + This is because the diffusion model can generate the final image/audio in one step. + Note: This is just a minimal modification to the original scheduler, and there should be some further efforts to optimize the scheduler. + The original scheduler is still used for the AR model. + """ + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: OmniModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + """Update the scheduler state based on the model runner output. + + This method is modified to stop the request immediately for the diffusion model. + """ + sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits + multimodal_outputs = model_runner_output.multimodal_outputs + + outputs: dict[int, list[OmniEngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: Optional[SpecDecodingStats] = None + + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[ + req_index] if sampled_token_ids else [] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens, where is given by: + # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). + num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - + len(generated_token_ids)) + request.num_computed_tokens -= num_tokens_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=len(scheduled_spec_token_ids), + num_accepted_tokens=len(generated_token_ids) - 1) + + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None + status_before_stop = request.status + pooler_output = None + if pooler_outputs: + pooler_output = pooler_outputs[req_index] + + # 扩散请求:单步完成,直接标记完成并释放资源 + request.status = RequestStatus.FINISHED_STOPPED + # 可选:标注停止原因,便于前端区分(不影响协议) + request.stop_reason = request.stop_reason or "diffusion_done" + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) + + # Extract sample logprobs if needed. + if request.sampling_params is not None \ + and request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and self.structured_output_manager.should_advance( + request): + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # check above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) + + # spec_token_ids comes from the model runner output + if num_nans_in_logits is not None and req_id in num_nans_in_logits: + request.num_nans_in_logits = num_nans_in_logits[req_id] + + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + if self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + # Needs to happen after new_token_ids are accepted. + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids[req_index]) + else: + request.spec_token_ids = spec_token_ids[req_index] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None \ + or kv_transfer_params: + + # Add EngineCoreOutput for this Request. + outputs[request.client_index].append( + OmniEngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + pooling_output=pooler_output, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + num_cached_tokens=request.num_cached_tokens, + multimodal_outputs=multimodal_outputs, + output_type=getattr(self.vllm_config.model_config, "engine_output_type", None), + )) + + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = [ + req for req in self.running if req not in stopped_running_reqs + ] + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) + + # KV Connector: update state for finished KV Transfers. + if model_runner_output.kv_connector_output: + self._update_from_kv_xfer_finished( + model_runner_output.kv_connector_output) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if engine_core_outputs: + # Return stats to only one of the front-ends. + next(iter(engine_core_outputs.values())).scheduler_stats = ( + self.make_stats(spec_decoding_stats)) + + return engine_core_outputs \ No newline at end of file diff --git a/vllm_omni/core/sched/output.py b/vllm_omni/core/sched/output.py new file mode 100644 index 00000000000..bfb395fcc65 --- /dev/null +++ b/vllm_omni/core/sched/output.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from vllm.v1.core.sched.output import NewRequestData +from vllm.v1.request import Request +from typing import Optional + +from vllm_omni.engine import PromptEmbedsPayload, AdditionalInformationPayload + + + +@dataclass +class OmniNewRequestData(NewRequestData): + # Optional serialized prompt embeddings + prompt_embeds: Optional[PromptEmbedsPayload] = None + # Optional serialized additional information + additional_information: Optional[AdditionalInformationPayload] = None + + @classmethod + def from_request( + cls, + request: Request, + block_ids: tuple[list[int], ...], + ) -> NewRequestData: + return cls( + req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, + mm_positions=request.mm_positions, + sampling_params=request.sampling_params, + pooling_params=request.pooling_params, + block_ids=block_ids, + num_computed_tokens=request.num_computed_tokens, + lora_request=request.lora_request, + prompt_embeds=request.prompt_embeds, + additional_information=request.additional_information, + ) \ No newline at end of file diff --git a/vllm_omni/core/sched/scheduler.py b/vllm_omni/core/sched/scheduler.py index 5bbbfeb05c9..2c66a3a6bca 100644 --- a/vllm_omni/core/sched/scheduler.py +++ b/vllm_omni/core/sched/scheduler.py @@ -1,40 +1,230 @@ +from __future__ import annotations -from vllm_omni.request import OmniRequest -from typing import List -from threading import Lock, Condition -from vllm_omni.config import OmniConfig +from collections import defaultdict +from typing import Optional -from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.core.sched import SchedulerInterface -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.core.sched.output import ( + SchedulerOutput, ) +from vllm.v1.core.sched.utils import check_stop +from vllm.v1.engine import EngineCoreOutputs +from vllm.v1.request import Request, RequestStatus +from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler +from vllm_omni.engine import OmniEngineCoreOutput +from vllm_omni.outputs import OmniModelRunnerOutput -class OmniScheduler(SchedulerInterface): + +class OmniScheduler(VLLMScheduler): """ OmniScheduler: Scheduler for vLLM-omni multimodal processing. This scheduler extends vLLM's scheduler to support multimodal and non-autoregressive processing with additional fields and methods specific to vLLM-omni. """ - def __init__(self, - omni_config: OmniConfig, - kv_cache_config: KVCacheConfig, - structured_output_manager: StructuredOutputManager, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - include_finished_set: bool = False, - log_stats: bool = False, - ): - super().__init__( - vllm_config=omni_config.vllm_config, - kv_cache_config=kv_cache_config, - multimodal_registry=mm_registry, - structured_output_manager=structured_output_manager, - include_finished_set=include_finished_set, - log_stats=log_stats, - ) - self.omni_config = omni_config - def schedule(self, requests: List[OmniRequest]) -> List[OmniRequest]: - # TODO: Implement scheduling logic - pass \ No newline at end of file + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: OmniModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits + + outputs: dict[int, list[OmniEngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: Optional[SpecDecodingStats] = None + + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[ + req_index] if sampled_token_ids else [] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens, where is given by: + # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). + num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - + len(generated_token_ids)) + request.num_computed_tokens -= num_tokens_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=len(scheduled_spec_token_ids), + num_accepted_tokens=len(generated_token_ids) - 1) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None + status_before_stop = request.status + + # Check for stop and update request status. + if new_token_ids: + new_token_ids, stopped = self._update_request_with_output( + request, new_token_ids) + + # Stop checking for pooler models. + pooler_output = None + if pooler_outputs: + pooler_output = pooler_outputs[req_index] + stopped = check_stop(request, self.max_model_len, + pooler_output) + + if stopped: + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) + + # Extract sample logprobs if needed. + if request.sampling_params is not None \ + and request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and self.structured_output_manager.should_advance( + request): + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # check above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) + + # spec_token_ids comes from the model runner output + if num_nans_in_logits is not None and req_id in num_nans_in_logits: + request.num_nans_in_logits = num_nans_in_logits[req_id] + + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + if self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + # Needs to happen after new_token_ids are accepted. + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids[req_index]) + else: + request.spec_token_ids = spec_token_ids[req_index] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None \ + or kv_transfer_params: + # Add EngineCoreOutput for this Request. + outputs[request.client_index].append( + OmniEngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + pooling_output=pooler_output, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + num_cached_tokens=request.num_cached_tokens, + output_type=getattr(self.vllm_config.model_config, "engine_output_type", None), + )) + + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = [ + req for req in self.running if req not in stopped_running_reqs + ] + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) + + # KV Connector: update state for finished KV Transfers. + if model_runner_output.kv_connector_output: + self._update_from_kv_xfer_finished( + model_runner_output.kv_connector_output) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if engine_core_outputs: + # Return stats to only one of the front-ends. + next(iter(engine_core_outputs.values())).scheduler_stats = ( + self.make_stats(spec_decoding_stats)) + + return engine_core_outputs + + # Ensure scheduled_new_reqs carry omni-specific payloads (e.g., additional_information) + def schedule(self) -> SchedulerOutput: # type: ignore[override] + scheduler_output = super().schedule() + try: + # Late import to avoid circulars in some launch modes + from .output import OmniNewRequestData + # Rewrap base NewRequestData entries with OmniNewRequestData, enriching with request-level payloads + new_list = [] + for nr in scheduler_output.scheduled_new_reqs: + req_id = getattr(nr, "req_id", None) + request = self.requests.get(req_id) if req_id else None + # Build omni entry preserving all base fields + omni_nr = OmniNewRequestData( + req_id=nr.req_id, + prompt_token_ids=nr.prompt_token_ids, + mm_inputs=nr.mm_inputs, + mm_hashes=nr.mm_hashes, + mm_positions=nr.mm_positions, + sampling_params=nr.sampling_params, + pooling_params=nr.pooling_params, + block_ids=nr.block_ids, + num_computed_tokens=nr.num_computed_tokens, + lora_request=nr.lora_request, + # Enrich with omni payloads from the live request object + prompt_embeds=getattr(request, "prompt_embeds", None) if request else None, + additional_information=getattr(request, "additional_information", None) if request else None, + ) + new_list.append(omni_nr) + + scheduler_output.scheduled_new_reqs = new_list # type: ignore[assignment] + except Exception: + # If anything goes wrong, leave the original output unchanged + pass + + return scheduler_output \ No newline at end of file diff --git a/vllm_omni/distributed/__init__.py b/vllm_omni/distributed/__init__.py index e69de29bb2d..a3bd035ad5d 100644 --- a/vllm_omni/distributed/__init__.py +++ b/vllm_omni/distributed/__init__.py @@ -0,0 +1,8 @@ +""" +Distributed components for vLLM-omni. +""" + +# Currently empty, placeholder for future distributed components + +__all__ = [] + diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index 406c26b3994..62cb9f66f6f 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -1,9 +1,65 @@ """ Engine components for vLLM-omni. """ +import msgspec +from typing import Any -from .output_processor import MultimodalOutputProcessor +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from typing import Optional +import torch -__all__ = [ - "MultimodalOutputProcessor", -] + +class OmniEngineCoreOutput(EngineCoreOutput): + #multimodal outputs + multimodal_outputs: Optional[dict[str, torch.Tensor]] = None + + # Output data type hint (e.g., "text", "image", "text+image", "latent"). + output_type: Optional[str] = None + + +class PromptEmbedsPayload(msgspec.Struct): + """Serialized prompt embeddings payload for direct transfer. + + data: raw bytes of the tensor in row-major order + shape: [seq_len, hidden_size] + dtype: torch dtype name (e.g., "float16", "float32") + """ + + data: bytes + shape: list[int] + dtype: str + + +class AdditionalInformationEntry(msgspec.Struct): + """One entry of additional_information. + + Two supported forms are encoded: + - tensor: data/shape/dtype + - list: a Python list (msgspec-serializable) + Exactly one of (tensor_data, list_data) should be non-None. + """ + + # Tensor form + tensor_data: Optional[bytes] = None + tensor_shape: Optional[list[int]] = None + tensor_dtype: Optional[str] = None + + # List form + list_data: Optional[list[Any]] = None + + +class AdditionalInformationPayload(msgspec.Struct): + """Serialized dictionary payload for additional_information. + + Keys are strings; values are encoded as AdditionalInformationEntry. + """ + + entries: dict[str, AdditionalInformationEntry] + + +class OmniEngineCoreRequest(EngineCoreRequest): + # Optional prompt embeddings (direct-transfer version) + prompt_embeds: Optional[PromptEmbedsPayload] = None + # Optional additional information dictionary (serialized) + additional_information: Optional[AdditionalInformationPayload] = None + \ No newline at end of file diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py new file mode 100644 index 00000000000..52253c11873 --- /dev/null +++ b/vllm_omni/engine/arg_utils.py @@ -0,0 +1,50 @@ +from vllm.engine.arg_utils import EngineArgs +from typing import Literal, Optional +from dataclasses import dataclass +from vllm_omni.config import OmniModelConfig +from vllm.utils import FlexibleArgumentParser + + +@dataclass +class OmniEngineArgs(EngineArgs): + stage_id: int = 0 + model_stage: str = "thinker" + model_arch: str = "Qwen2_5OmniForConditionalGeneration" + engine_output_type: Optional[str] = None + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Shared CLI arguments for vLLM engine.""" + parser.add_argument( + "--engine-output-type", + type=str, + default=EngineArgs.engine_output_type, + help=( + "Declare EngineCoreOutput.output_type (e.g., 'text', 'image', " + "'text+image', 'latent'). This will be written into " + "model_config.engine_output_type for schedulers to use." + ), + ) + parser.add_argument("--model-stage", type=str, default=EngineArgs.model_stage, + help="Declare model stage (e.g., 'thinker', 'talker', 'token2wav'). This will be written into model_config.model_stage for schedulers to use.") + return parser + + def create_model_config(self) -> OmniModelConfig: + # First, get the base ModelConfig from the parent class + base_config = super().create_model_config() + + # Create OmniModelConfig by copying all base config attributes + # and adding the new omni-specific fields + config_dict = base_config.__dict__.copy() + + # Add the new omni-specific fields + config_dict['stage_id'] = self.stage_id + config_dict['model_stage'] = self.model_stage + config_dict['model_arch'] = self.model_arch + config_dict['engine_output_type'] = self.engine_output_type + + # Create and return the OmniModelConfig instance + omni_config = OmniModelConfig(**config_dict) + omni_config.hf_config.architectures = omni_config.architectures + + return omni_config \ No newline at end of file diff --git a/vllm_omni/engine/diffusion_engine.py b/vllm_omni/engine/diffusion_engine.py deleted file mode 100644 index ef8a5b51d0a..00000000000 --- a/vllm_omni/engine/diffusion_engine.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Diffusion engine adapters for diffusers-backed DiT stages.""" - -from __future__ import annotations - -from typing import Any, Optional - -from ..config import DiTConfig -from ..worker.gpu_diffusion_worker import ( - DiffusionGPUWorker, - DiffusionRunnerOutput, -) - - -class DiffusionEngine: - """Base diffusion engine storing shared DiT configuration.""" - - def __init__( - self, - dit_config: DiTConfig, - model_path: Optional[str] = None, - log_stats: bool = False, - multiprocess_mode: bool = False, - ) -> None: - self.dit_config = dit_config - - model_cfg = getattr(dit_config, "model_config", None) - config_model_path = None - if model_cfg is not None: - if isinstance(model_cfg, dict): - config_model_path = ( - model_cfg.get("model") - or model_cfg.get("model_path") - ) - else: - config_model_path = ( - getattr(model_cfg, "model", None) - or getattr(model_cfg, "model_path", None) - ) - - self.model_path = model_path or config_model_path - self.model_config = dit_config.model_config - self.cache_config = dit_config.dit_cache_config - - self.height = dit_config.height - self.width = dit_config.width - self.num_inference_steps = dit_config.num_inference_steps - self.guidance_scale = dit_config.guidance_scale - - self.log_stats = log_stats - self.multiprocess_mode = multiprocess_mode - - def generate(self, *args, **kwargs): # pragma: no cover - placeholder - raise NotImplementedError("This method should be implemented in subclasses.") - - -class DiffusersPipelineEngine(DiffusionEngine): - """Adapter that invokes a diffusers pipeline via DiffusionGPUWorker.""" - - def __init__( - self, - dit_config: DiTConfig, - model_path: Optional[str] = None, - log_stats: bool = False, - multiprocess_mode: bool = False, - ) -> None: - super().__init__(dit_config, model_path, log_stats, multiprocess_mode) - - resolved_model_path = self.model_path - if resolved_model_path is None: - raise ValueError("Model path must be provided for DiffusersPipelineEngine.") - - device_cfg = getattr(dit_config, "device_config", None) - model_cfg = getattr(dit_config, "model_config", None) - - device = None - dtype = None - - if device_cfg: - if isinstance(device_cfg, dict): - device = device_cfg.get("device") - dtype = device_cfg.get("dtype") - else: - device = getattr(device_cfg, "device", None) - dtype = getattr(device_cfg, "dtype", None) - - if dtype is None and model_cfg: - if isinstance(model_cfg, dict): - dtype = model_cfg.get("dtype") - else: - dtype = getattr(model_cfg, "dtype", None) - - self.worker = DiffusionGPUWorker( - resolved_model_path, - pipeline_name=dit_config.diffusers_pipeline, - device=device, - dtype=dtype, - ) - - def generate( - self, - prompt: str, - *, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: Optional[int] = None, - guidance_scale: Optional[float] = None, - negative_prompt: Optional[str] = None, - seed: Optional[int] = None, - image: Optional[Any] = None, - ) -> DiffusionRunnerOutput: - height = int(height) if height is not None else self.height - width = int(width) if width is not None else self.width - num_steps = ( - int(num_inference_steps) - if num_inference_steps is not None - else self.num_inference_steps - ) - guidance = ( - float(guidance_scale) - if guidance_scale is not None - else self.guidance_scale - ) - - return self.worker.generate( - prompt=prompt, - height=height, - width=width, - num_inference_steps=num_steps, - guidance_scale=guidance, - negative_prompt=negative_prompt, - seed=seed, - image=image, - ) diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index 8f5fcb3571a..4807b148829 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -1,215 +1,410 @@ -""" -Output processing for multimodal outputs in vLLM-omni. -""" +from typing import Dict, Callable, Optional, Any, Union -from typing import List, Dict, Any, Optional, Callable, Union -from vllm.outputs import RequestOutput, CompletionOutput -from vllm.v1.outputs import ModelRunnerOutput as EngineCoreOutput +import torch +from vllm.v1.engine.output_processor import OutputProcessor as VLLMOutputProcessor +from vllm.v1.engine.output_processor import OutputProcessorOutput, RequestState, RequestOutputCollector +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.v1.engine import FinishReason +from vllm.v1.metrics.stats import IterationStats +from vllm.sampling_params import RequestOutputKind +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import IncrementalDetokenizer +from vllm.v1.engine.logprobs import LogprobsProcessor +from vllm.v1.engine.parallel_sampling import ParentRequest +from vllm.logger import init_logger +from vllm_omni.engine import OmniEngineCoreOutput -class MultimodalOutputProcessor: - """Handles multimodal output processing for vLLM-omni.""" - - def __init__(self): - self.output_handlers: Dict[str, Callable] = { - "image": self._process_image_output, - "text+image": self._process_text_image_output, - "latents": self._process_latents_output, - "text": self._process_text_output, - "pooling": self._process_pooling_output, - } - - def process_output(self, engine_core_output: Any) -> List[RequestOutput]: - """Process engine core output and return formatted RequestOutput.""" - if engine_core_output is None: - return [] - - # If it's already a RequestOutput, return as is - if isinstance(engine_core_output, RequestOutput): - return [engine_core_output] - - # If it's a list of RequestOutputs, return as is - if isinstance(engine_core_output, list): - return engine_core_output - - # Otherwise, process based on output type - output_type = self._detect_output_type(engine_core_output) - handler = self.output_handlers.get(output_type, self._process_pooling_output) - - return handler(engine_core_output) - - def _build_request_output( + +logger = init_logger(__name__) + + +class OmniRequestState(RequestState): + + def __init__( self, - source: Any, - completion_outputs: List[CompletionOutput], - ) -> RequestOutput: - """Helper to construct RequestOutput with safe defaults.""" - - return RequestOutput( - request_id=getattr(source, "request_id", "unknown"), - prompt=getattr(source, "prompt", ""), - prompt_token_ids=getattr(source, "prompt_token_ids", []), - prompt_logprobs=getattr(source, "prompt_logprobs", None), - outputs=completion_outputs, - finished=getattr(source, "finished", True), + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.mm_type: Optional[str] = None + self.mm_accumulated: Optional[torch.Tensor] = None + + @classmethod + def from_new_request( + cls, + tokenizer: AnyTokenizer, + request: EngineCoreRequest, + prompt: Optional[str], + parent_req: Optional[ParentRequest], + request_index: int, + queue: Optional[Any], + log_stats: bool, + ) -> "OmniRequestState": + + if sampling_params := request.sampling_params: + if not sampling_params.detokenize: + tokenizer = None + output_kind = sampling_params.output_kind + logprobs_processor = LogprobsProcessor.from_new_request( + tokenizer=tokenizer, + request=request, + ) + detokenizer = IncrementalDetokenizer.from_new_request( + tokenizer=tokenizer, + request=request, + ) + max_tokens_param = sampling_params.max_tokens + else: + logprobs_processor = None + detokenizer = None + max_tokens_param = None + assert request.pooling_params is not None + output_kind = request.pooling_params.output_kind + + return cls( + request_id=request.request_id, + parent_req=parent_req, + request_index=request_index, + lora_name=(request.lora_request.name + if request.lora_request is not None else None), + output_kind=output_kind, + prompt=prompt, + prompt_token_ids=request.prompt_token_ids, + logprobs_processor=logprobs_processor, + detokenizer=detokenizer, + max_tokens_param=max_tokens_param, + arrival_time=request.arrival_time, + queue=queue, + log_stats=log_stats, ) - def _detect_output_type(self, output: Any) -> str: - """Detect the type of output based on its content.""" - if hasattr(output, 'output_type'): - return output.output_type - - # Check for image-related attributes - if hasattr(output, 'image') or hasattr(output, 'images'): - if hasattr(output, 'text') or hasattr(output, 'texts'): - return "text+image" + def add_multimodal_tensor(self, tensor: Optional[torch.Tensor], + mm_type: Optional[str]) -> None: + if tensor is None: + return + try: + if mm_type: + self.mm_type = (mm_type or "").lower() + t = tensor.detach() + try: + t = t.to("cpu") + except Exception: + pass + if self.mm_accumulated is None: + self.mm_accumulated = t else: - return "image" - - # Check for latent-related attributes - if hasattr(output, 'latents') or hasattr(output, 'latent_representation'): - return "latents" - - # Check for pooling output - if hasattr(output, 'pooler_output') and output.pooler_output is not None: - return "pooling" - - # Default to text - return "text" - - def _process_text_output(self, output: Any) -> List[RequestOutput]: - """Process text output.""" - if isinstance(output, RequestOutput): - return [output] - - # Create a mock RequestOutput for text - completion_output = CompletionOutput( - index=0, - text=getattr(output, 'text', ''), - token_ids=getattr(output, 'token_ids', []), - cumulative_logprob=getattr(output, 'cumulative_logprob', 0.0), - logprobs=getattr(output, 'logprobs', None), - finish_reason=getattr(output, 'finish_reason', 'length') - ) - - return [self._build_request_output(output, [completion_output])] - - def _process_image_output(self, output: Any) -> List[RequestOutput]: - """Process image output.""" - # For image outputs, we need to create a special RequestOutput - # that can handle image data - - # Extract image data - image_data = getattr(output, 'image', None) - if image_data is None: - image_data = getattr(output, 'images', [None])[0] - - # Create a completion output with image data - completion_output = CompletionOutput( - index=0, - text="", # No text for pure image output - token_ids=[], - cumulative_logprob=0.0, - logprobs=None, - finish_reason="stop" - ) - - # Add image data to the completion output - completion_output.image = image_data - - return [ - self._build_request_output(output, [completion_output]) - ] + self.mm_accumulated = torch.cat([self.mm_accumulated, t], dim=0) + except Exception: + pass + + # Override: do not route to pooling-only path; always create completion + # outputs, and attach pooling_result into the CompletionOutput. + def make_request_output( + self, + new_token_ids: list[int], + pooling_output: Optional[torch.Tensor], + finish_reason: Optional[FinishReason], + stop_reason: Optional[Union[int, str]], + kv_transfer_params: Optional[dict[str, Any]] = None, + num_cached_tokens: Optional[int] = None, + ) -> Optional[Any]: + finished = finish_reason is not None + final_only = self.output_kind == RequestOutputKind.FINAL_ONLY + + if not finished and final_only: + return None + + if num_cached_tokens is not None: + # Keep num_cached_tokens in RequestOutput for compatibility + try: + self.num_cached_tokens = num_cached_tokens # type: ignore[attr-defined] + except Exception: + pass + + request_id = self.request_id + output = self._new_completion_output(new_token_ids, finish_reason, + stop_reason) + + if self.parent_req is None: + outputs = [output] + else: + request_id, outputs, finished = self.parent_req.get_outputs( + request_id, output) + if not outputs: + return None + + return self._new_request_output(request_id, outputs, finished, + kv_transfer_params) + + def _new_completion_output( + self, + token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Optional[Union[int, str]] + ) -> Any: + # Reuse base text/logprobs logic, then annotate with pooling_result. + base_output = super()._new_completion_output(token_ids, finish_reason, + stop_reason) + try: + if self.mm_accumulated is not None: + tensor = self.mm_accumulated + try: + tensor = tensor.detach().to("cpu") + except Exception: + pass + # Attach on the completion output for downstream consumers. + if not hasattr(base_output, "multimodal_output"): + setattr(base_output, "multimodal_output", {}) + setattr(base_output, "multimodal_output", {self.mm_type: tensor}) + except Exception as e: + logger.warning("Error in _new_completion_output", e) + pass + return base_output + + +class MultimodalOutputProcessor(VLLMOutputProcessor): + """Handles multimodal output processing by normalizing EngineCoreOutput + before delegating to the base vLLM OutputProcessor. + + Strategy: + - Route by EngineCoreOutput.output_type when present + ("image", "text+image", "latents", "text"). + - Fallback to pooling/text heuristics when output_type is absent. + - Mutate EngineCoreOutput in-place to ensure vLLM's base processor can + produce the correct RequestOutput/PoolingRequestOutput. + - Allow custom per-modality handlers via register_handler(). + """ + def __init__(self, tokenizer: TokenizerGroup, log_stats: bool): + super().__init__(tokenizer=tokenizer, log_stats=log_stats) + self.output_handlers: Dict[str, Callable[[OmniEngineCoreOutput], None]] = {} + self._reqid_to_mm_type: Dict[str, str] = {} + self.request_states: dict[str, OmniRequestState] = {} - def _process_text_image_output(self, output: Any) -> List[RequestOutput]: - """Process combined text and image output.""" - # Extract text and image data - text_data = getattr(output, 'text', '') - image_data = getattr(output, 'image', None) - - if image_data is None: - image_data = getattr(output, 'images', [None])[0] - - # Create a completion output with both text and image - completion_output = CompletionOutput( - index=0, - text=text_data, - token_ids=getattr(output, 'token_ids', []), - cumulative_logprob=getattr(output, 'cumulative_logprob', 0.0), - logprobs=getattr(output, 'logprobs', None), - finish_reason="stop" - ) - - # Add image data to the completion output - completion_output.image = image_data - - return [ - self._build_request_output(output, [completion_output]) - ] + def register_handler(self, modality: str, + handler: Callable[[OmniEngineCoreOutput], None]) -> None: + self.output_handlers[modality.lower()] = handler - def _process_latents_output(self, output: Any) -> List[RequestOutput]: - """Process latent representation output.""" - # Extract latent data - latent_data = getattr(output, 'latents', None) - if latent_data is None: - latent_data = getattr(output, 'latent_representation', None) - - # Create a completion output with latent data - completion_output = CompletionOutput( - index=0, - text="", # No text for latent output - token_ids=[], - cumulative_logprob=0.0, - logprobs=None, - finish_reason="stop" - ) - - # Add latent data to the completion output - completion_output.latents = latent_data - - return [ - self._build_request_output(output, [completion_output]) - ] + def add_request( + self, + request: EngineCoreRequest, + prompt: Optional[str], + parent_req: Optional[ParentRequest] = None, + request_index: int = 0, + queue: Optional[RequestOutputCollector] = None, + ) -> None: + request_id = request.request_id + if request_id in self.request_states: + raise ValueError(f"Request id {request_id} already running.") + + tokenizer = None if not self.tokenizer else \ + self.tokenizer.get_lora_tokenizer(request.lora_request) + + req_state = OmniRequestState.from_new_request(tokenizer=tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) + self.request_states[request_id] = req_state + self.lora_states.add_request(req_state) + if parent_req: + self.parent_requests[parent_req.request_id] = parent_req - def _process_pooling_output(self, output: Any) -> List[RequestOutput]: - """Process pooling output (hidden states, embeddings, etc.).""" - # Extract pooling data - pooling_data = getattr(output, 'pooler_output', None) - if pooling_data is None: - pooling_data = getattr(output, 'hidden_states', None) - - # Create a completion output with pooling data - completion_output = CompletionOutput( - index=0, - text="", # No text for pooling output - token_ids=[], - cumulative_logprob=0.0, - logprobs=None, - finish_reason="stop" + def process_outputs( + self, + engine_core_outputs: list[OmniEngineCoreOutput], + engine_core_timestamp: Optional[float] = None, + iteration_stats: Optional[IterationStats] = None, + ) -> OutputProcessorOutput: + self._reqid_to_mm_type.clear() + for eco in engine_core_outputs: + mm_type = (getattr(eco, "output_type", None) or "").lower() + if mm_type: + self._reqid_to_mm_type[eco.request_id] = mm_type + self._route_and_normalize(eco) + + # Build RequestOutputs without delegating to base, so we can keep ids + request_outputs: list[Any] = [] + reqs_to_abort: list[str] = [] + for eco in engine_core_outputs: + req_id = eco.request_id + req_state = self.request_states.get(req_id) + if req_state is None: + continue + + # 1) Stats + self._update_stats_from_output(req_state, eco, + engine_core_timestamp, + iteration_stats) + + new_token_ids = eco.new_token_ids + pooling_output = eco.pooling_output + finish_reason = eco.finish_reason + stop_reason = eco.stop_reason + kv_transfer_params = eco.kv_transfer_params + num_cached_tokens = eco.num_cached_tokens + req_state.is_prefilling = False + + # 2) Detokenize and logprobs when text path + assert req_state.detokenizer is not None + assert req_state.logprobs_processor is not None + stop_string = req_state.detokenizer.update( + new_token_ids, finish_reason == FinishReason.STOP) + if stop_string: + finish_reason = FinishReason.STOP + stop_reason = stop_string + req_state.logprobs_processor.update_from_output(eco) + + # 2.5) Accumulate multimodal tensors in RequestState + try: + mm_type = (getattr(eco, "output_type", None) or "").lower() + if pooling_output is not None and isinstance(req_state, OmniRequestState): + req_state.add_multimodal_tensor(pooling_output, mm_type) + except Exception: + pass + + # 3) Create RequestOutput objects, forcing combined mode to keep ids + pooling_for_make = pooling_output + if pooling_output is not None and new_token_ids: + # Do not consume pooling path now; keep ids and attach mm later + pooling_for_make = None + + ro = req_state.make_request_output(new_token_ids, pooling_for_make, + finish_reason, stop_reason, + kv_transfer_params, + num_cached_tokens) + if ro: + # Attach accumulated multimodal payload if any + try: + if isinstance(req_state, OmniRequestState) and req_state.mm_accumulated is not None: + mm_key = req_state.mm_type or "latents" + if not hasattr(ro, "multimodal_output"): + setattr(ro, "multimodal_output", {}) + ro.multimodal_output[mm_key] = req_state.mm_accumulated + except Exception as e: + logger.warning("Error in process_outputs", e) + pass + if req_state.queue is not None: + req_state.queue.put(ro) + else: + request_outputs.append(ro) + + # 4) Free completed + if finish_reason is not None: + self.request_states.pop(req_id) + parent_req = req_state.parent_req + if parent_req and not parent_req.child_requests: + self.parent_requests.pop(parent_req.request_id, None) + if not eco.finished: + reqs_to_abort.append(req_id) + self._update_stats_from_finished(req_state, finish_reason, + iteration_stats) + # Cleanup per-request mm state + if isinstance(req_state, OmniRequestState): + req_state.mm_accumulated = None + req_state.mm_type = None + + return OutputProcessorOutput( + request_outputs=request_outputs, + reqs_to_abort=reqs_to_abort, ) - - # Add pooling data to the completion output - completion_output.pooler_output = pooling_data - - return [ - self._build_request_output(output, [completion_output]) - ] - def process_outputs(self, engine_core_outputs: List[EngineCoreOutput], **kwargs) -> List[RequestOutput]: - """Process multiple engine core outputs.""" - all_outputs = [] - - for engine_core_output in engine_core_outputs: - outputs = self.process_output(engine_core_output) - all_outputs.extend(outputs) - - return all_outputs - - def add_output_handler(self, output_type: str, handler: Callable) -> None: - """Add a custom output handler for a specific output type.""" - self.output_handlers[output_type] = handler - - def remove_output_handler(self, output_type: str) -> None: - """Remove an output handler for a specific output type.""" + # ---- routing helpers ---- + def _route_and_normalize(self, eco: OmniEngineCoreOutput) -> None: + output_type = (getattr(eco, "output_type", None) or "").lower() + + # Custom handler first (if registered) if output_type in self.output_handlers: - del self.output_handlers[output_type] + try: + self.output_handlers[output_type](eco) + # Fall through to default fixups in case the handler left gaps + except Exception: + pass + + if output_type == "image": + self._process_image_output(eco) + elif output_type in ("text+image", "text,image", "image+text"): + self._process_text_image_output(eco) + elif output_type in ("latents", "latent"): + self._process_latents_output(eco) + elif output_type in ("audio", "speech"): + self._process_audio_output(eco) + elif output_type == "text": + self._process_text_output(eco) + else: + # Fallback heuristic + if eco.pooling_output is not None: + self._process_pooling_output(eco) + else: + self._process_text_output(eco) + + # ---- modality processors ---- + def _process_image_output(self, eco: OmniEngineCoreOutput) -> None: + """Ensure image tensors are surfaced via pooling_output for vLLM.""" + if eco.pooling_output is None: + tensor = self._extract_from_multimodal_outputs( + eco, keys=("image", "images", "pixel_values", "pixels")) + if tensor is not None: + eco.pooling_output = tensor + + def _process_text_image_output(self, eco: OmniEngineCoreOutput) -> None: + """Allow text+image outputs. Text path stays as new_token_ids; + image/latents route via pooling_output.""" + # Preserve text tokens as-is; ensure pooling_output carries image/latents + if eco.pooling_output is None: + tensor = self._extract_from_multimodal_outputs( + eco, keys=("image", "images", "pixel_values", "pixels", + "latent", "latents", "z")) + if tensor is not None: + eco.pooling_output = tensor + + def _process_latents_output(self, eco: OmniEngineCoreOutput) -> None: + """Ensure latent tensors are surfaced via pooling_output.""" + if eco.pooling_output is None: + tensor = self._extract_from_multimodal_outputs( + eco, keys=("latent", "latents", "z", "posterior")) + if tensor is not None: + eco.pooling_output = tensor + + def _process_audio_output(self, eco: OmniEngineCoreOutput) -> None: + """Ensure audio tensors are surfaced via pooling_output.""" + if eco.pooling_output is None: + tensor = self._extract_from_multimodal_outputs( + eco, keys=("audio", "audios", "wav", "waveform", + "audio_pcm", "pcm")) + if tensor is not None: + eco.pooling_output = tensor + + def _process_text_output(self, eco: OmniEngineCoreOutput) -> None: + """No-op; base processor will detokenize new_token_ids → text.""" + return + + def _process_pooling_output(self, eco: OmniEngineCoreOutput) -> None: + """Optional sanity checks for pooling tensor.""" + if eco.pooling_output is None: + return + if not isinstance(eco.pooling_output, torch.Tensor): + # Best-effort: convert to tensor if it's a list/ndarray-like + try: + eco.pooling_output = torch.as_tensor(eco.pooling_output) + except Exception: + pass + + def _extract_from_multimodal_outputs( + self, eco: OmniEngineCoreOutput, keys: tuple[str, ...] + ) -> Optional[torch.Tensor]: + mm = getattr(eco, "multimodal_outputs", None) + if not isinstance(mm, dict): + return None + for k in keys: + v = mm.get(k) + if isinstance(v, torch.Tensor): + return v + # Try the first tensor in the dict as a fallback + for v in mm.values(): + if isinstance(v, torch.Tensor): + return v + return None \ No newline at end of file diff --git a/vllm_omni/engine/processor.py b/vllm_omni/engine/processor.py index e69de29bb2d..1c5891265b3 100644 --- a/vllm_omni/engine/processor.py +++ b/vllm_omni/engine/processor.py @@ -0,0 +1,216 @@ +import time +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union + +from vllm.inputs import ProcessorInputs, PromptType +from vllm.inputs.parse import split_enc_dec_inputs +from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.inputs import PlaceholderRange +from vllm.multimodal.utils import merge_and_sort_multimodal_metadata +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams +from vllm.config import VllmConfig +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm_omni.inputs.preprocess import OmniInputPreprocessor + +from vllm.v1.engine.processor import Processor +from vllm_omni.engine import PromptEmbedsPayload, AdditionalInformationPayload, AdditionalInformationEntry, OmniEngineCoreRequest + + +class OmniProcessor(Processor): + def __init__(self, + vllm_config: VllmConfig, + tokenizer: TokenizerGroup, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + super().__init__(vllm_config, tokenizer, mm_registry) + self.input_preprocessor = OmniInputPreprocessor(self.model_config, + self.tokenizer, + mm_registry) + + def process_inputs( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + data_parallel_rank: Optional[int] = None, + ) -> tuple[Optional[str], OmniEngineCoreRequest]: + + # TODO(woosuk): Support pooling models. + # TODO(woosuk): Support encoder-decoder models. + self._validate_lora(lora_request) + self._validate_params(params, lora_request) + if trace_headers is not None: + raise ValueError("V1 does not support tracing yet.") + + data_parallel_size = self.vllm_config.parallel_config.data_parallel_size + if data_parallel_rank is not None and not (0 <= data_parallel_rank < + data_parallel_size): + raise ValueError(f"data_parallel_rank {data_parallel_rank} " + f"is out of range [0, {data_parallel_size}).") + + if arrival_time is None: + arrival_time = time.time() + + # Process inputs, which includes: + # 1. Tokenize text prompt, with LoRA request if one exists. + # 2. For multimodal models with a merged preprocessor, preprocess + # multimodal data and expand prompt token ids accordingly. + return_mm_hashes = (self.model_config.processor_return_mm_hashes + or bool(self.cache_config.enable_prefix_caching)) + processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( + prompt, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + from vllm.platforms import current_platform + current_platform.validate_request( + prompt=prompt, + params=params, + processed_inputs=processed_inputs, + ) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + self._validate_model_inputs(processed_inputs, lora_request) + + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + + # TODO: Impl encoder-decoder + if encoder_inputs is not None: + raise NotImplementedError + + sampling_params = None + pooling_params = None + if isinstance(params, SamplingParams): + # TODO: can we avoid cloning here in multiproc case? + sampling_params = params.clone() + # If unset max tokens, then generate up to the max_model_len. + if sampling_params.max_tokens is None: + sampling_params.max_tokens = ( + self.model_config.max_model_len - + len(decoder_inputs["prompt_token_ids"])) + sampling_params.update_from_generation_config( + self.generation_config_fields, eos_token_id) + if self.tokenizer is not None: + sampling_params.update_from_tokenizer( + self.tokenizer.get_lora_tokenizer(lora_request)) + else: + pooling_params = params.clone() + + # Multimodal related. + sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None + sorted_mm_positions: Optional[list[PlaceholderRange]] = None + sorted_mm_hashes: Optional[list[str]] = None + if decoder_inputs["type"] == "multimodal": + decoder_mm_inputs = decoder_inputs["mm_kwargs"] + + # Merge and flatten multimodal placeholders, hashes and inputs + # from dictionaries to lists, and sort them by each item's position + # in the input sequence. + ( + sorted_item_modalities, + sorted_mm_positions, + sorted_mm_hashes, + ) = merge_and_sort_multimodal_metadata( + decoder_inputs["mm_placeholders"], + decoder_inputs["mm_hashes"] if return_mm_hashes else None, + ) + + # The output of merged multi-modal processor (`decoder_mm_inputs`) + # is a single MultiModalKwargs for all items from all modalities. + # This code flattens kwargs for individual items in a list and + # sorts them by each item's position in the input sequence if there + # are multiple modalities. + unique_modalities = set(sorted_item_modalities) + if len(unique_modalities) > 1: + orig_sorted_mm_inputs = [] + used_indices = {modality: 0 for modality in unique_modalities} + + for modality in sorted_item_modalities: + items = decoder_mm_inputs.get_items(modality) + item = items[used_indices[modality]] + + orig_sorted_mm_inputs.append( + MultiModalKwargs.from_items([item])) + used_indices[modality] += 1 + else: + orig_sorted_mm_inputs = [ + MultiModalKwargs.from_items([item]) for item in + decoder_mm_inputs.get_items(sorted_item_modalities[0]) + ] + + if sorted_mm_hashes is not None: + sorted_mm_inputs = self.mm_input_cache_client.get_and_update( + orig_sorted_mm_inputs, sorted_mm_hashes) + else: + sorted_mm_inputs = orig_sorted_mm_inputs + + # Serialize prompt_embeds and additional_information if provided (direct-transfer path) + prompt_embeds_payload: Optional[PromptEmbedsPayload] = None + additional_information_payload: Optional[AdditionalInformationPayload] = None + if "prompt_embeds" in decoder_inputs: # type: ignore[operator] + import numpy as np + import torch + pe: torch.Tensor = decoder_inputs["prompt_embeds"] # type: ignore[index] + if pe.ndim != 2: + raise ValueError( + "prompt_embeds must be of shape (seq_len, hidden_size)") + # Move to CPU and ensure contiguous memory for stable serialization + pe_cpu = pe.detach().to("cpu").contiguous() + seq_len, hidden_size = pe_cpu.shape + dtype_str = str(pe_cpu.dtype).replace("torch.", "") + data_bytes = pe_cpu.numpy().tobytes() + prompt_embeds_payload = PromptEmbedsPayload( + data=data_bytes, + shape=[int(seq_len), int(hidden_size)], + dtype=dtype_str, + ) + if "additional_information" in decoder_inputs: # type: ignore[operator] + import numpy as np + import torch + raw_info: dict[str, Any] = decoder_inputs["additional_information"] # type: ignore[index] + entries: dict[str, AdditionalInformationEntry] = {} + for key, value in raw_info.items(): + if isinstance(value, torch.Tensor): + v_cpu = value.detach().to("cpu").contiguous() + dtype_str = str(v_cpu.dtype).replace("torch.", "") + data_bytes = v_cpu.numpy().tobytes() + entry = AdditionalInformationEntry( + tensor_data=data_bytes, + tensor_shape=[int(x) for x in list(v_cpu.shape)], + tensor_dtype=dtype_str, + ) + elif isinstance(value, list): + entry = AdditionalInformationEntry(list_data=value) + else: + raise ValueError( + "additional_information values must be Tensor or list") + entries[key] = entry + additional_information_payload = AdditionalInformationPayload( + entries=entries) + + return decoder_inputs.get("prompt"), OmniEngineCoreRequest( + request_id=request_id, + prompt_token_ids=decoder_inputs["prompt_token_ids"], + mm_inputs=sorted_mm_inputs, + mm_hashes=sorted_mm_hashes, + mm_placeholders=sorted_mm_positions, + sampling_params=sampling_params, + pooling_params=pooling_params, + eos_token_id=eos_token_id, + arrival_time=arrival_time, + lora_request=lora_request, + cache_salt=decoder_inputs.get("cache_salt"), + priority=priority, + data_parallel_rank=data_parallel_rank, + prompt_embeds=prompt_embeds_payload, + additional_information=additional_information_payload, + ) \ No newline at end of file diff --git a/vllm_omni/entrypoints/__init__.py b/vllm_omni/entrypoints/__init__.py index e69de29bb2d..1068cc8ef12 100644 --- a/vllm_omni/entrypoints/__init__.py +++ b/vllm_omni/entrypoints/__init__.py @@ -0,0 +1,12 @@ +""" +Entrypoints for vLLM-omni. +""" + +from .omni_lm import OmniLLM +from .utils import load_stage_configs_from_yaml + +__all__ = [ + "OmniLLM", + "load_stage_configs_from_yaml", +] + diff --git a/vllm_omni/entrypoints/api_server.py b/vllm_omni/entrypoints/api_server.py deleted file mode 100644 index af14fbf08db..00000000000 --- a/vllm_omni/entrypoints/api_server.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -API server for vLLM-omni. -""" - -import asyncio -from typing import Dict, Any, List -from fastapi import FastAPI, HTTPException -from fastapi.responses import JSONResponse -from pydantic import BaseModel -import uvicorn - -from .omni_llm import AsyncOmniLLM - - -class GenerateRequest(BaseModel): - """Request model for generation.""" - prompts: List[str] - max_tokens: int = 100 - temperature: float = 0.7 - stage_args: List[Dict[str, Any]] = None - - -class GenerateResponse(BaseModel): - """Response model for generation.""" - outputs: List[Dict[str, Any]] - stage_outputs: List[Dict[str, Any]] = None - - -app = FastAPI( - title="vLLM-omni API", - description="Multi-modality models inference and serving", - version="0.1.0" -) - -# Global omni_llm instance -omni_llm: AsyncOmniLLM = None - - -@app.on_event("startup") -async def startup_event(): - """Initialize the omni_llm instance on startup.""" - global omni_llm - # This will be set by the run_server function - pass - - -@app.post("/generate", response_model=GenerateResponse) -async def generate(request: GenerateRequest): - """Generate text or multimodal content.""" - try: - if omni_llm is None: - raise HTTPException(status_code=500, detail="OmniLLM not initialized") - - # Prepare stage arguments - if request.stage_args is None: - # Create default stage arguments - one per stage config - # For now, we'll process all prompts in the first stage - stage_args = [{ - "prompt": " ".join(request.prompts) if request.prompts else "", - "max_tokens": request.max_tokens, - "temperature": request.temperature - }] - else: - stage_args = request.stage_args - - # Generate using omni_llm - try: - outputs = await omni_llm.generate_async(stage_args) - except Exception as e: - # Fallback to synchronous generation - outputs = omni_llm.generate(stage_args) - - # Convert outputs to response format - response_outputs = [] - for output in outputs: - if hasattr(output, 'outputs') and output.outputs: - for out in output.outputs: - response_outputs.append({ - "text": getattr(out, 'text', ''), - "finished": getattr(out, 'finish_reason', 'length') != 'length', - "tokens": getattr(out, 'token_ids', []) - }) - else: - response_outputs.append({ - "text": "", - "finished": True, - "tokens": [] - }) - - return GenerateResponse( - outputs=response_outputs, - stage_outputs=[{"stage": i, "output": "processed"} for i in range(len(stage_args))] - ) - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/health") -async def health_check(): - """Health check endpoint.""" - return {"status": "healthy", "service": "vllm-omni"} - - -@app.get("/info") -async def get_info(): - """Get information about the service.""" - if omni_llm is None: - return {"error": "OmniLLM not initialized"} - - return { - "service": "vllm-omni", - "version": "0.1.0", - "num_stages": omni_llm.get_num_stages() if hasattr(omni_llm, 'get_num_stages') else 0, - "stage_configs": [ - { - "stage_id": config.stage_id, - "engine_type": config.engine_type, - "model_path": config.model_path, - "input_modalities": config.input_modalities, - "output_modalities": config.output_modalities - } - for config in omni_llm.stage_configs - ] - } - - -async def run_server(omni_llm_instance: AsyncOmniLLM, host: str = "0.0.0.0", port: int = 8000): - """Run the API server.""" - global omni_llm - omni_llm = omni_llm_instance - - config = uvicorn.Config( - app=app, - host=host, - port=port, - log_level="info" - ) - - server = uvicorn.Server(config) - await server.serve() - - -if __name__ == "__main__": - # This is for testing purposes - asyncio.run(run_server(None)) diff --git a/vllm_omni/entrypoints/cli/__init__.py b/vllm_omni/entrypoints/cli/__init__.py index 2df4defa767..1efadca0f9d 100644 --- a/vllm_omni/entrypoints/cli/__init__.py +++ b/vllm_omni/entrypoints/cli/__init__.py @@ -1,5 +1,6 @@ """CLI helpers for vLLM-omni entrypoints.""" from .serve import OmniServeCommand +from .main import main -__all__ = ["OmniServeCommand"] +__all__ = ["OmniServeCommand", "main"] diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py deleted file mode 100644 index 8223c6f20e9..00000000000 --- a/vllm_omni/entrypoints/omni_llm.py +++ /dev/null @@ -1,576 +0,0 @@ -""" -Core OmniLLM and AsyncOmniLLM classes for multi-stage processing. -""" - -from typing import List, Dict, Any, Optional, Union, Callable -from vllm.entrypoints.llm import LLM -from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.engine.llm_engine import LLMEngine -from vllm.outputs import RequestOutput, LoRARequest - -from ..config import OmniStageConfig -from .stage_manager import StageManager -from ..engine.output_processor import MultimodalOutputProcessor -from ..engine.diffusion_engine import DiffusersPipelineEngine - - - -class OmniLLM(LLM): - """Extended LLM supporting multiple engines and stage-based processing.""" - - def __init__( - self, - stage_configs: List[OmniStageConfig], - log_stats: bool = False, - **kwargs - ): - # Use the first stage's model as the default model for LLM - default_model = stage_configs[0].model_path if stage_configs else "test-model" - - # Track whether we should launch engines in multiprocess mode. - self.multiprocess_mode = kwargs.pop("multiprocess_mode", False) - - # Fix configuration validation issues - # Ensure max_num_batched_tokens is at least as large as max_model_len - if 'max_model_len' in kwargs and 'max_num_batched_tokens' in kwargs: - if kwargs['max_num_batched_tokens'] < kwargs['max_model_len']: - kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] - elif 'max_model_len' in kwargs: - # If max_model_len is set but max_num_batched_tokens is not, set it to max_model_len - kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] - else: - # Set reasonable defaults to avoid validation errors - kwargs['max_model_len'] = 2048 - kwargs['max_num_batched_tokens'] = 2048 - - super().__init__(model=default_model, **kwargs) - self.stage_configs = stage_configs - self.log_stats = log_stats - self.stage_manager = StageManager(stage_configs, log_stats) - self.output_processor = MultimodalOutputProcessor() - self._initialize_stage_engines() - - def _initialize_stage_engines(self) -> None: - """Initialize LLMEngine instances for each stage.""" - self.stage_manager.initialize_engines() - - def generate( - self, - stage_args_list: Optional[List[Dict[str, Any]]] = None, - use_tqdm: Union[bool, Callable[..., Any]] = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - priority: Optional[List[int]] = None, - *, - prompt: Optional[str] = None, - stage_overrides: Optional[Dict[int, Dict[str, Any]]] = None, - ) -> List[RequestOutput]: - """Main generation interface - orchestrates multi-stage processing.""" - if stage_args_list is None: - if prompt is None: - raise ValueError( - "prompt must be provided when stage_args_list is not supplied" - ) - stage_args_list = self._build_stage_args_from_config( - prompt, stage_overrides or {} - ) - - if len(stage_args_list) != len(self.stage_configs): - raise ValueError( - f"Number of stage arguments ({len(stage_args_list)}) must match " - f"number of stage configs ({len(self.stage_configs)})" - ) - - # Process through each stage sequentially - current_output = None - - for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): - stage_engine = self.stage_manager.get_engine(i) - - # Prepare input for this stage - processed_input = self._process_stage_inputs( - stage_config, stage_args or {}, current_output - ) - - # Execute stage - stage_output = self._execute_stage( - stage_engine, processed_input, lora_request, priority, stage_config - ) - - # Update for next stage - current_output = stage_output - stage_config.stage_output = stage_output - - # Process final output - final_output = self.output_processor.process_output(current_output) - return final_output - - def _build_stage_args_from_config( - self, - prompt: str, - stage_overrides: Dict[int, Dict[str, Any]], - ) -> List[Dict[str, Any]]: - """Derive per-stage argument dictionaries from configuration defaults.""" - stage_args: List[Dict[str, Any]] = [] - for stage_config in self.stage_configs: - combined: Dict[str, Any] = dict(stage_config.default_stage_args or {}) - override = stage_overrides.get(stage_config.stage_id) - if override: - combined.update(override) - if stage_config.engine_type == "AR": - combined["prompt"] = prompt - stage_args.append(combined) - return stage_args - - def _process_stage_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Prepare input for specific stage.""" - if stage_config.engine_type == "AR": - return self._process_ar_inputs(stage_config, stage_args, previous_output) - elif stage_config.engine_type == "DiT": - return self._process_dit_inputs(stage_config, stage_args, previous_output) - else: - raise NotImplementedError(f"Unknown engine type: {stage_config.engine_type}") - - def _process_ar_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Process inputs for AR stage.""" - combined = dict(stage_config.default_stage_args or {}) - combined.update(stage_args) - combined.setdefault("prompt", "") - combined.setdefault("max_tokens", 100) - combined.setdefault("temperature", 0.7) - - # If we have previous output (e.g., from a previous AR stage), - # we might want to use it as context - if previous_output is not None: - # Extract text from previous output if available - if hasattr(previous_output, 'outputs') and previous_output.outputs: - last_output = previous_output.outputs[-1] - if hasattr(last_output, 'text'): - combined["prompt"] = last_output.text + " " + combined["prompt"] - - return combined - - def _process_dit_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Process inputs for DiT stage.""" - combined = dict(stage_config.default_stage_args or {}) - combined.update(stage_args) - - dit = stage_config.dit_config - if dit is not None: - combined.setdefault("height", getattr(dit, "height", 512)) - combined.setdefault("width", getattr(dit, "width", 512)) - combined.setdefault( - "num_inference_steps", getattr(dit, "num_inference_steps", 50) - ) - combined.setdefault( - "guidance_scale", getattr(dit, "guidance_scale", 7.5) - ) - else: - combined.setdefault("height", 512) - combined.setdefault("width", 512) - combined.setdefault("num_inference_steps", 50) - combined.setdefault("guidance_scale", 7.5) - - # Handle image inputs if present - if "image" in stage_args: - # For now, we'll pass the image path directly - # In a full implementation, this would involve VAE encoding - combined["image"] = stage_args["image"] - - # If we have previous output from an AR stage, we might want to use it - if previous_output is not None: - # Extract text from previous AR output - if hasattr(previous_output, 'outputs') and previous_output.outputs: - last_output = previous_output.outputs[-1] - if hasattr(last_output, 'text'): - combined["prompt"] = last_output.text - - combined.setdefault("prompt", stage_args.get("prompt", "")) - - return combined - - def _execute_stage( - self, - stage_engine: Optional[Union[LLMEngine, DiffusersPipelineEngine]], - processed_input: Dict[str, Any], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - priority: Optional[List[int]] = None, - stage_config: Optional[OmniStageConfig] = None, - ) -> Any: - """Execute a single stage.""" - # DiT via diffusers backend - if stage_config and stage_config.engine_type == "DiT": - dit = stage_config.dit_config - if dit and getattr(dit, "use_diffusers", False): - # Lazy-init executor per stage - if not hasattr(self, "_dit_engines"): - self._dit_engines = {} - exec_inst = self._dit_engines.get(stage_config.stage_id) - if exec_inst is None: - exec_inst = DiffusersPipelineEngine( - dit_config=dit, - model_path=stage_config.model_path, - log_stats=self.log_stats, - multiprocess_mode=self.multiprocess_mode, - ) - - self._dit_engines[stage_config.stage_id] = exec_inst - - return exec_inst.generate( - prompt=processed_input.get("prompt", ""), - height=processed_input.get("height", getattr(dit, "height", 512)), - width=processed_input.get("width", getattr(dit, "width", 512)), - num_inference_steps=processed_input.get( - "num_inference_steps", getattr(dit, "num_inference_steps", 30) - ), - guidance_scale=processed_input.get( - "guidance_scale", getattr(dit, "guidance_scale", 5.0) - ), - negative_prompt=processed_input.get("negative_prompt"), - seed=processed_input.get("seed"), - image=processed_input.get("image"), - ) - - # Use the parent LLM's generate method for AR text generation - prompt = processed_input.get("prompt", "") - max_tokens = processed_input.get("max_tokens", 100) - temperature = processed_input.get("temperature", 0.7) - - # Generate using the base LLM class - from vllm.sampling_params import SamplingParams - - sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=temperature, - top_p=processed_input.get("top_p", 1.0), - frequency_penalty=processed_input.get("frequency_penalty", 0.0), - presence_penalty=processed_input.get("presence_penalty", 0.0), - stop=processed_input.get("stop", None) - ) - - # Use the parent class's generate method - outputs = super().generate([prompt], sampling_params) - - # Return the first output (we're processing one prompt at a time) - if outputs: - return outputs[0] - else: - # Fallback to mock output if generation fails - from vllm.outputs import RequestOutput, CompletionOutput - - mock_output = CompletionOutput( - index=0, - text="Generation failed", - token_ids=[], - cumulative_logprob=0.0, - logprobs=None, - finish_reason="error" - ) - - return RequestOutput( - request_id="fallback_request", - prompt=prompt, - prompt_token_ids=[], - prompt_logprobs=None, - outputs=[mock_output], - finished=True - ) - - -class AsyncOmniLLM(LLM): - """Extended LLM class supporting multiple engines and stage-based processing.""" - - def __init__( - self, - stage_configs: List[OmniStageConfig], - log_stats: bool = False, - **kwargs - ): - # Use the first stage's model for the base LLM - if stage_configs and stage_configs[0].model_path: - model = stage_configs[0].model_path - else: - model = "Qwen/Qwen3-0.6B" - - # Fix configuration validation issues - # Ensure max_num_batched_tokens is at least as large as max_model_len - if 'max_model_len' in kwargs and 'max_num_batched_tokens' in kwargs: - if kwargs['max_num_batched_tokens'] < kwargs['max_model_len']: - kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] - elif 'max_model_len' in kwargs: - # If max_model_len is set but max_num_batched_tokens is not, set it to max_model_len - kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] - else: - # Set reasonable defaults to avoid validation errors - kwargs['max_model_len'] = 2048 - kwargs['max_num_batched_tokens'] = 2048 - - super().__init__(model=model, **kwargs) - self.stage_configs = stage_configs - self.log_stats = log_stats - self.stage_manager = StageManager(stage_configs, log_stats) - self.output_processor = MultimodalOutputProcessor() - - def _initialize_async_stage_engines(self) -> None: - """Initialize AsyncLLM instances for each stage.""" - self.stage_manager.initialize_async_engines() - - async def generate_async( - self, - stage_args_list: Optional[List[Dict[str, Any]]] = None, - use_tqdm: Union[bool, Callable[..., Any]] = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - priority: Optional[List[int]] = None, - *, - prompt: Optional[str] = None, - stage_overrides: Optional[Dict[int, Dict[str, Any]]] = None, - ) -> List[RequestOutput]: - """Async generation interface - orchestrates multi-stage processing.""" - - if stage_args_list is None: - if prompt is None: - raise ValueError( - "prompt must be provided when stage_args_list is not supplied" - ) - stage_args_list = self._build_stage_args_from_config( - prompt, stage_overrides or {} - ) - - if len(stage_args_list) != len(self.stage_configs): - raise ValueError( - f"Number of stage arguments ({len(stage_args_list)}) must match " - f"number of stage configs ({len(self.stage_configs)})" - ) - - # Process through each stage sequentially - current_output = None - - for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): - stage_engine = self.stage_manager.get_async_engine(i) - - # Prepare input for this stage - processed_input = self._process_stage_inputs( - stage_config, stage_args or {}, current_output - ) - - # Execute stage asynchronously - stage_output = await self._execute_stage_async( - stage_engine, processed_input, lora_request, priority, stage_config - ) - - # Update for next stage - current_output = stage_output - stage_config.stage_output = stage_output - - # Process final output - final_output = self.output_processor.process_output(current_output) - return final_output - - def _build_stage_args_from_config( - self, - prompt: str, - stage_overrides: Dict[int, Dict[str, Any]], - ) -> List[Dict[str, Any]]: - stage_args: List[Dict[str, Any]] = [] - for stage_config in self.stage_configs: - combined: Dict[str, Any] = dict(stage_config.default_stage_args or {}) - override = stage_overrides.get(stage_config.stage_id) - if override: - combined.update(override) - if stage_config.engine_type == "AR": - combined["prompt"] = prompt - stage_args.append(combined) - return stage_args - - def _process_stage_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Prepare input for specific stage (same as OmniLLM).""" - if stage_config.engine_type == "AR": - return self._process_ar_inputs(stage_config, stage_args, previous_output) - elif stage_config.engine_type == "DiT": - return self._process_dit_inputs(stage_config, stage_args, previous_output) - else: - raise NotImplementedError(f"Unknown engine type: {stage_config.engine_type}") - - def _process_ar_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Process inputs for AR stage (same as OmniLLM).""" - combined = dict(stage_config.default_stage_args or {}) - combined.update(stage_args) - combined.setdefault("prompt", "") - combined.setdefault("max_tokens", 100) - combined.setdefault("temperature", 0.7) - - if previous_output is not None: - if hasattr(previous_output, 'outputs') and previous_output.outputs: - last_output = previous_output.outputs[-1] - if hasattr(last_output, 'text'): - combined["prompt"] = last_output.text + " " + combined["prompt"] - - return combined - - def _process_dit_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Process inputs for DiT stage (same as OmniLLM).""" - combined = dict(stage_config.default_stage_args or {}) - combined.update(stage_args) - - dit = stage_config.dit_config - if dit is not None: - combined.setdefault("height", getattr(dit, "height", 512)) - combined.setdefault("width", getattr(dit, "width", 512)) - combined.setdefault( - "num_inference_steps", getattr(dit, "num_inference_steps", 50) - ) - combined.setdefault( - "guidance_scale", getattr(dit, "guidance_scale", 7.5) - ) - else: - combined.setdefault("height", 512) - combined.setdefault("width", 512) - combined.setdefault("num_inference_steps", 50) - combined.setdefault("guidance_scale", 7.5) - - if "image" in stage_args: - combined["image"] = stage_args["image"] - - if previous_output is not None: - if hasattr(previous_output, 'outputs') and previous_output.outputs: - last_output = previous_output.outputs[-1] - if hasattr(last_output, 'text'): - combined["prompt"] = last_output.text - - combined.setdefault("prompt", stage_args.get("prompt", "")) - - return combined - - async def _execute_stage_async( - self, - stage_engine: AsyncLLM, - processed_input: Dict[str, Any], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - priority: Optional[List[int]] = None, - stage_config: Optional[OmniStageConfig] = None, - ) -> Any: - """Execute a single stage asynchronously.""" - # DiT via diffusers backend (sync call inside async for MVP) - if stage_config and stage_config.engine_type == "DiT": - dit = stage_config.dit_config - if dit and getattr(dit, "use_diffusers", False): - if not hasattr(self, "_dit_engines"): - self._dit_engines = {} - exec_inst = self._dit_engines.get(stage_config.stage_id) - if exec_inst is None: - from vllm_omni.engine.diffusion_engine import ( - DiffusersPipelineEngine, - ) - - pipeline_name = getattr(dit, "diffusers_pipeline", None) - device_cfg = getattr(dit, "device_config", None) - model_cfg = getattr(dit, "model_config", None) - if isinstance(device_cfg, dict): - device = device_cfg.get("device") - dtype = device_cfg.get("dtype") - else: - device = getattr(device_cfg, "device", None) - dtype = getattr(device_cfg, "dtype", None) - - if dtype is None: - if isinstance(model_cfg, dict): - dtype = model_cfg.get("dtype") - else: - dtype = getattr(model_cfg, "dtype", None) - - exec_inst = DiffusersPipelineEngine( - model_path=stage_config.model_path, - pipeline_name=pipeline_name, - device=device, - dtype=dtype, - ) - self._dit_engines[stage_config.stage_id] = exec_inst - - return exec_inst.generate( - prompt=processed_input.get("prompt", ""), - height=processed_input.get("height", getattr(dit, "height", 512)), - width=processed_input.get("width", getattr(dit, "width", 512)), - num_inference_steps=processed_input.get( - "num_inference_steps", getattr(dit, "num_inference_steps", 30) - ), - guidance_scale=processed_input.get( - "guidance_scale", getattr(dit, "guidance_scale", 5.0) - ), - negative_prompt=processed_input.get("negative_prompt"), - seed=processed_input.get("seed"), - image=processed_input.get("image"), - ) - - # Use the parent LLM's generate method for AR text generation - prompt = processed_input.get("prompt", "") - max_tokens = processed_input.get("max_tokens", 100) - temperature = processed_input.get("temperature", 0.7) - - # Generate using the base LLM class - from vllm.sampling_params import SamplingParams - - sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=temperature, - top_p=processed_input.get("top_p", 1.0), - frequency_penalty=processed_input.get("frequency_penalty", 0.0), - presence_penalty=processed_input.get("presence_penalty", 0.0), - stop=processed_input.get("stop", None) - ) - - # Use the parent class's generate method - outputs = super().generate([prompt], sampling_params) - - # Return the first output (we're processing one prompt at a time) - if outputs: - return outputs[0] - else: - # Fallback to mock output if generation fails - from vllm.outputs import RequestOutput, CompletionOutput - - mock_output = CompletionOutput( - index=0, - text="Generation failed", - token_ids=[], - cumulative_logprob=0.0, - logprobs=None, - finish_reason="error" - ) - - return RequestOutput( - request_id="fallback_request", - prompt=prompt, - prompt_token_ids=[], - prompt_logprobs=None, - outputs=[mock_output], - finished=True - ) diff --git a/vllm_omni/entrypoints/omni_lm.py b/vllm_omni/entrypoints/omni_lm.py new file mode 100644 index 00000000000..dfdf3aeb8e0 --- /dev/null +++ b/vllm_omni/entrypoints/omni_lm.py @@ -0,0 +1,155 @@ +from typing import Union, Sequence, Optional, Any +import cloudpickle +from pydantic import ValidationError + +from vllm.inputs import PromptType +from vllm.sampling_params import SamplingParams +from vllm.entrypoints.llm import LLM + +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.engine.arg_utils import HfOverrides +from vllm.usage.usage_lib import UsageContext +from vllm.config import CompilationConfig, is_init_field +from vllm.utils import Counter +from vllm.logger import init_logger +import vllm.envs as envs + +from vllm_omni.entrypoints.utils import load_stage_configs_from_model +from vllm_omni.entrypoints.stage_manager import Stage +from vllm_omni.engine.arg_utils import OmniEngineArgs +from vllm_omni.engine.output_processor import MultimodalOutputProcessor +from vllm_omni.engine.processor import OmniProcessor +from vllm_omni.outputs import OmniRequestOutput + + +logger = init_logger(__name__) + + +class OmniLM: + def __init__(self, model: str, stage_configs = None, log_stats: bool = False, **kwargs): + if stage_configs is None: + self.initialize_stage_configs(model) + else: + self.stage_configs = stage_configs + + self.stage_list = [] + self.initialize_stages(model) + + def initialize_stage_configs(self, model: str): + self.stage_configs = load_stage_configs_from_model(model) + + def initialize_stages(self, model: str): + for stage_config in self.stage_configs: + stage = Stage(stage_config) + omni_llm = OmniLLM(model=model, **stage_config.engine_args) + stage.set_engine(omni_llm) + self.stage_list.append(stage) + + def generate( + self, + prompts: Union[PromptType, Sequence[PromptType]], + sampling_params_list: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + ) -> list[OmniRequestOutput]: + """Generate text outputs for the given prompts.""" + final_outputs: list[OmniRequestOutput] = [] + for stage_id, stage in enumerate(self.stage_list): + if stage_id > 0: + engine_inputs = stage.process_engine_inputs(self.stage_list, prompts) + else: + engine_inputs = prompts + engine_outputs = self._run_generation(stage, sampling_params_list[stage_id], engine_inputs) + stage.set_engine_outputs(engine_outputs) + if hasattr(stage, 'final_output') and stage.final_output: + final_outputs.append(OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs)) + return final_outputs + + def _run_generation(self, stage: Stage, sampling_params: SamplingParams, prompts: Union[PromptType, Sequence[PromptType]]): + engine_outputs = [] + for ro in stage.engine.generate(prompts, sampling_params): + engine_outputs.append(ro) + return engine_outputs + + +class OmniLLM(LLM): + def __init__(self, + model: str, + compilation_config: Optional[Union[int, dict[str, Any], + CompilationConfig]] = None, + hf_overrides: Optional[HfOverrides] = None, + **kwargs): + """LLM constructor.""" + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + if "worker_cls" in kwargs: + worker_cls = kwargs["worker_cls"] + # if the worker_cls is not qualified string name, + # we serialize it using cloudpickle to avoid pickling issues + if isinstance(worker_cls, type): + kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) + + if "kv_transfer_config" in kwargs and isinstance( + kwargs["kv_transfer_config"], dict): + from vllm.config import KVTransferConfig + raw_config_dict = kwargs["kv_transfer_config"] + try: + kwargs["kv_transfer_config"] = KVTransferConfig( + **raw_config_dict) + except ValidationError as e: + logger.error( + "Failed to convert 'kv_transfer_config' dict to " + "KVTransferConfig object. Dict: %s. Error: %s", + raw_config_dict, e) + # Consider re-raising a more specific vLLM error or ValueError + # to provide better context to the user. + raise ValueError( + f"Invalid 'kv_transfer_config' provided: {e}") from e + + if hf_overrides is None: + hf_overrides = {} + + if compilation_config is not None: + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig( + level=compilation_config) + elif isinstance(compilation_config, dict): + predicate = lambda x: is_init_field(CompilationConfig, x[0]) + compilation_config_instance = CompilationConfig( + **dict(filter(predicate, compilation_config.items()))) + else: + compilation_config_instance = compilation_config + else: + compilation_config_instance = CompilationConfig() + + engine_args = OmniEngineArgs( + model=model, + hf_overrides=hf_overrides, + compilation_config=compilation_config_instance, + **kwargs, + ) + + # Create the Engine (autoselects V0 vs V1) + self.llm_engine = LLMEngine.from_engine_args( + engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + self.llm_engine.output_processor = MultimodalOutputProcessor(tokenizer=self.llm_engine.tokenizer, + log_stats=self.llm_engine.log_stats) + self.llm_engine.processor = OmniProcessor(vllm_config=self.llm_engine.vllm_config, + tokenizer=self.llm_engine.tokenizer) + self.engine_class = type(self.llm_engine) + + self.request_counter = Counter() + self.default_sampling_params: Union[dict[str, Any], None] = None + + if envs.VLLM_USE_V1: + supported_tasks = self.llm_engine \ + .get_supported_tasks() # type: ignore + else: + supported_tasks = self.llm_engine.model_config.supported_tasks + + logger.info("Supported_tasks: %s", supported_tasks) + + self.supported_tasks = supported_tasks \ No newline at end of file diff --git a/vllm_omni/entrypoints/stage_manager.py b/vllm_omni/entrypoints/stage_manager.py index 22a03d51c2e..cc75e0d0584 100644 --- a/vllm_omni/entrypoints/stage_manager.py +++ b/vllm_omni/entrypoints/stage_manager.py @@ -2,86 +2,79 @@ Stage manager for orchestrating multiple engines in vLLM-omni. """ -from typing import List, Optional, Union +import importlib +from typing import List, Union from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.async_llm import AsyncLLM -from ..engine.diffusion_engine import DiffusionEngine -from ..config import OmniStageConfig +from vllm_omni.engine import OmniEngineCoreOutput +from vllm.inputs import TextPrompt +from vllm_omni.inputs.data import OmniTokensPrompt -class StageManager: - """Manages multiple stage engines for multi-stage processing.""" - - def __init__(self, stage_configs: List[OmniStageConfig], log_stats: bool = False): - self.stage_configs = stage_configs - self.log_stats = log_stats - self.engine_list: List[Union[LLMEngine, DiffusionEngine]] = [] - self.async_engine_list: List[AsyncLLM] = [] - self._initialized = False - self._async_initialized = False - - def initialize_engines(self) -> None: - """Initialize LLMEngine instances for each stage.""" - if self._initialized: - return - - # For now, create placeholder engines - # In a full implementation, this would create actual engines - for stage_config in self.stage_configs: - # Placeholder - would create actual engine here - self.engine_list.append(None) - - self._initialized = True - - def initialize_async_engines(self) -> None: - """Initialize AsyncLLM instances for each stage.""" - if self._async_initialized: - return - - # For now, create placeholder engines - # In a full implementation, this would create actual engines - for stage_config in self.stage_configs: - # Placeholder - would create actual engine here - self.async_engine_list.append(None) - - self._async_initialized = True - - def get_engine(self, stage_id: int) -> LLMEngine: - """Get the engine for a specific stage.""" - if not self._initialized: - self.initialize_engines() - - if stage_id >= len(self.engine_list): - raise IndexError(f"Stage {stage_id} not found. Available stages: 0-{len(self.engine_list)-1}") - - return self.engine_list[stage_id] - - def get_async_engine(self, stage_id: int) -> AsyncLLM: - """Get the async engine for a specific stage.""" - if not self._async_initialized: - self.initialize_async_engines() - - if stage_id >= len(self.async_engine_list): - raise IndexError(f"Async stage {stage_id} not found. Available stages: 0-{len(self.async_engine_list)-1}") - - return self.async_engine_list[stage_id] - - def get_stage_config(self, stage_id: int) -> OmniStageConfig: - """Get the configuration for a specific stage.""" - if stage_id >= len(self.stage_configs): - raise IndexError(f"Stage config {stage_id} not found. Available stages: 0-{len(self.stage_configs)-1}") - - return self.stage_configs[stage_id] +class Stage: + def __init__(self, stage_config): + self.stage_config = stage_config + self.engine = None + self.async_engine = None + self.stage_id = stage_config.stage_id + self.engine_args = stage_config.engine_args + self.model_stage = stage_config.engine_args.model_stage + if hasattr(stage_config, 'engine_input_source'): + self.engine_input_source = stage_config.engine_input_source + else: + self.engine_input_source = [] + self.engine_output_type = stage_config.engine_args.engine_output_type + self.engine_outputs = None + if hasattr(stage_config, 'custom_process_input_func'): + # Import the module specified in the config (already a full module path) + module_path, func_name = stage_config.custom_process_input_func.rsplit('.', 1) + module = importlib.import_module(module_path) + self.custom_process_input_func = getattr(module, func_name) + else: + self.custom_process_input_func = None + + if hasattr(stage_config, 'final_output'): + self.final_output = stage_config.final_output + else: + self.final_output = False + + if hasattr(stage_config, 'final_output_type'): + self.final_output_type = stage_config.final_output_type + else: + self.final_output_type = None + + def set_engine(self, engine: LLMEngine) -> None: + """Initialize the engine for the stage.""" + self.engine = engine + + def set_async_engine(self, async_engine: AsyncLLM) -> None: + """Initialize the async engine for the stage.""" + self.async_engine = async_engine - def get_num_stages(self) -> int: - """Get the number of stages.""" - return len(self.stage_configs) + def set_engine_outputs(self, engine_outputs: OmniEngineCoreOutput) -> None: + """Set the engine output for the stage.""" + self.engine_outputs = engine_outputs - def cleanup(self) -> None: - """Clean up resources.""" - # Clean up engines if needed - self.engine_list.clear() - self.async_engine_list.clear() - self._initialized = False - self._async_initialized = False + def process_engine_inputs(self, stage_list, prompt: Union[OmniTokensPrompt, TextPrompt] = None) -> List[Union[OmniTokensPrompt, TextPrompt]]: + """Process the engine input for the stage.""" + if self.custom_process_input_func is None: + engine_inputs = [] + if len(self.engine_input_source) == 0: + raise ValueError("engine_input_source is empty") + source_stage_id = self.engine_input_source[0] + source_outputs = stage_list[source_stage_id].engine_outputs + multi_modal_data = {source_output.request_id: + prompt.get('multi_modal_data', None) for source_output, prompt in zip(source_outputs, prompt)} + + for source_output in source_outputs: + engine_input = OmniTokensPrompt( + prompt_token_ids = source_output.outputs[0].token_ids, + multi_modal_data=multi_modal_data[source_output.request_id] if multi_modal_data else None, + ) + engine_inputs.append(engine_input) + return engine_inputs + + else: + engine_input_source = self.engine_input_source + return self.custom_process_input_func(stage_list, engine_input_source, prompt) diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py new file mode 100644 index 00000000000..42a98f1648f --- /dev/null +++ b/vllm_omni/entrypoints/utils.py @@ -0,0 +1,25 @@ +import os +from pathlib import Path +from omegaconf import OmegaConf +from vllm.transformers_utils.config import get_config + +# Get the project root directory (2 levels up from this file) +PROJECT_ROOT = Path(__file__).parent.parent.parent + + +def load_stage_configs_from_model(model: str): + """Load stage configs from model.""" + hf_config = get_config(model, trust_remote_code=True) + model_type = hf_config.model_type + stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" + stage_config_path = PROJECT_ROOT / stage_config_file + if not os.path.exists(stage_config_path): + raise FileNotFoundError(f"Stage config file {stage_config_path} not found") + stage_configs = load_stage_configs_from_yaml(config_path=str(stage_config_path)) + return stage_configs + + +def load_stage_configs_from_yaml(config_path: str): + """Load stage configs from yaml file.""" + config_data = OmegaConf.load(config_path) + return config_data.stage_args \ No newline at end of file diff --git a/vllm_omni/inputs/__init__.py b/vllm_omni/inputs/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py new file mode 100644 index 00000000000..5682b9dabb6 --- /dev/null +++ b/vllm_omni/inputs/data.py @@ -0,0 +1,56 @@ +from vllm.inputs.data import TokensPrompt, EmbedsPrompt, TokenInputs +from typing import Any, NotRequired, Optional +import torch + + +class OmniTokensPrompt(TokensPrompt): + prompt_embeds: NotRequired[torch.Tensor] + """The embeddings of the prompt.""" + + # New: optional additional information dictionary + # Values may be torch.Tensor or list + additional_information: NotRequired[dict[str, Any]] + + +class OmniTokenInputs(TokenInputs): + # New: optional prompt embeddings aligned with token ids + prompt_embeds: NotRequired[torch.Tensor] + + # New: optional additional information dictionary + # Values may be torch.Tensor or list + additional_information: NotRequired[dict[str, Any]] + + +class OmniEmbedsPrompt(EmbedsPrompt): + # New: optional prompt embeddings aligned with token ids + prompt_embeds: NotRequired[torch.Tensor] + + # New: optional additional information dictionary + # Values may be torch.Tensor or list + additional_information: NotRequired[dict[str, Any]] + + +def token_inputs_omni( + prompt_token_ids: list[int], + token_type_ids: Optional[list[int]] = None, + prompt: Optional[str] = None, + cache_salt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + additional_information: Optional[dict[str, Any]] = None, +) -> OmniTokenInputs: + """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional + values.""" + inputs = OmniTokenInputs(type="token", prompt_token_ids=prompt_token_ids) + + if prompt is not None: + inputs["prompt"] = prompt + if token_type_ids is not None: + inputs["token_type_ids"] = token_type_ids + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + if prompt_embeds is not None: + inputs["prompt_embeds"] = prompt_embeds + if additional_information is not None: + inputs["additional_information"] = additional_information + + return inputs \ No newline at end of file diff --git a/vllm_omni/inputs/parse.py b/vllm_omni/inputs/parse.py new file mode 100644 index 00000000000..a7c72b9a971 --- /dev/null +++ b/vllm_omni/inputs/parse.py @@ -0,0 +1,21 @@ +from vllm.inputs.parse import ParsedSingletonPrompt, ParsedStrPrompt, ParsedEmbedsPrompt, ParsedTokensPrompt, ParsedTextPrompt +from vllm.inputs.data import SingletonPrompt + + +def parse_singleton_prompt_omni(prompt: SingletonPrompt) -> ParsedSingletonPrompt: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + # Type ignores are because mypy does not correctly infer the TypedDicts + # Pyright does succeed. + # 优先 tokens:当 tokens 与 embeds 同在时,保留两者并走 tokens 路径 + if "prompt_token_ids" in prompt: + return ParsedTokensPrompt( + type="tokens", content=prompt) # type: ignore[typeddict-item] + elif "prompt_embeds" in prompt: + return ParsedEmbedsPrompt( + type="embeds", content=prompt) # type: ignore[typeddict-item] + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) + raise TypeError( + "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") \ No newline at end of file diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py new file mode 100644 index 00000000000..88e97b58af9 --- /dev/null +++ b/vllm_omni/inputs/preprocess.py @@ -0,0 +1,166 @@ +from typing import Optional, Any, Union +from typing_extensions import assert_never + +from vllm.lora.request import LoRARequest +from vllm.inputs.preprocess import InputPreprocessor +from vllm.inputs.data import TokensPrompt, SingletonPrompt, SingletonInputs, TextPrompt +from vllm.multimodal.inputs import MultiModalInputs +from vllm_omni.inputs.data import OmniTokenInputs, token_inputs_omni +from vllm_omni.inputs.parse import parse_singleton_prompt_omni + + +class OmniInputPreprocessor(InputPreprocessor): + def _process_tokens( + self, + parsed_content: TokensPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[OmniTokenInputs, MultiModalInputs]: + prompt_token_ids = parsed_content["prompt_token_ids"] + token_type_ids = parsed_content.get("token_type_ids") + prompt_embeds = parsed_content.get("prompt_embeds") + additional_information = parsed_content.get("additional_information") + + inputs: Union[OmniTokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_token_ids, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + inputs = token_inputs_omni( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + prompt_embeds=prompt_embeds, + additional_information=additional_information, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + def _prompt_to_llm_inputs( + self, + prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> SingletonInputs: + """ + Extract the singleton inputs from a prompt. + + Arguments: + + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + * return_mm_hashes: whether to return multimodal hashes + + Returns: + + * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance + """ + parsed = parse_singleton_prompt_omni(prompt) + + if parsed["type"] == "embeds": + return self._process_embeds(parsed["content"]) + if parsed["type"] == "tokens": + return self._process_tokens( + parsed["content"], + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "text": + return self._process_text( + parsed["content"], + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "str": + return self._process_text( + TextPrompt(prompt=parsed["content"]), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + assert_never(parsed) + + async def _process_tokens_async( + self, + parsed_content: TokensPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[OmniTokenInputs, MultiModalInputs]: + prompt_token_ids = parsed_content["prompt_token_ids"] + token_type_ids = parsed_content.get("token_type_ids") + prompt_embeds = parsed_content.get("prompt_embeds") + additional_information = parsed_content.get("additional_information") + + inputs: Union[OmniTokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = await self._process_multimodal_async( + prompt_token_ids, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + inputs = token_inputs_omni( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + prompt_embeds=prompt_embeds, + additional_information=additional_information, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + async def _prompt_to_llm_inputs_async( + self, + prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> SingletonInputs: + """ + Async version of + [`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs]. + """ + parsed = parse_singleton_prompt_omni(prompt) + + if parsed["type"] == "embeds": + return await self._process_embeds_async(parsed["content"]) + if parsed["type"] == "tokens": + return await self._process_tokens_async( + parsed["content"], + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "text": + return await self._process_text_async( + parsed["content"], + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "str": + return await self._process_text_async( + TextPrompt(prompt=parsed["content"]), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + assert_never(parsed) \ No newline at end of file diff --git a/vllm_omni/model_executor/layers/__init__.py b/vllm_omni/model_executor/layers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/model_executor/layers/mrope.py b/vllm_omni/model_executor/layers/mrope.py new file mode 100644 index 00000000000..372fc9440bc --- /dev/null +++ b/vllm_omni/model_executor/layers/mrope.py @@ -0,0 +1,702 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools +from typing import Optional, Union + +import numpy as np +import torch +from transformers import PretrainedConfig + +from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding +from vllm.model_executor.layers.rotary_embedding.common import apply_rotary_emb_dispatch + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + +class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[list[int]] = None, + ) -> None: + # In Qwen2.5-VL, the maximum index value is related to the duration of + # the input video. We enlarge max_position_embeddings to 4 times to get + # a larger the cos and sin cache. + self.cache_max_position_num = max_position_embeddings * 4 + super().__init__(head_size, rotary_dim, self.cache_max_position_num, + base, is_neox_style, dtype) + + self.mrope_section = mrope_section + if self.mrope_section: + print("Warning: mrope_section check is not disabled in Qwen2.5-Omni, this may cause errors, and should be removed in the future.") + # assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat([ + m[i] + for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) + ], + dim=-1) + sin = torch.cat([ + m[i] + for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) + ], + dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + # query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, + # self.is_neox_style) + query_rot = _apply_rotary_emb(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + # key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, + # self.is_neox_style) + key_rot = _apply_rotary_emb(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + @classmethod + def get_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[list[list[int]], int]: + """Get mrope input positions and delta value.""" + + image_grid_thw = [] if image_grid_thw is None else image_grid_thw + video_grid_thw = [] if video_grid_thw is None else video_grid_thw + second_per_grid_ts = [] if second_per_grid_ts is None else \ + second_per_grid_ts + + llm_positions, mrope_position_delta = \ + cls.get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + return llm_positions.tolist(), mrope_position_delta + + @classmethod + def get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + from vllm.transformers_utils.config import thinker_uses_mrope + if thinker_uses_mrope(hf_config): + return cls._omni_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + elif hf_config.model_type in ["glm4v", "glm4v_moe"]: + return cls._glm4v_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) + else: + return cls._vl_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + ) + + @classmethod + def _glm4v_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for GLM4V.""" + + image_token_id = hf_config.image_token_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view( + 1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view( + 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + return llm_positions, mrope_position_delta + + @classmethod + def _vl_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, + "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * + tokens_per_second).long().flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @classmethod + def _omni_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [ + audio_token_id, video_token_id, image_token_id + ]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and \ + src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and \ + src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], + dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len( + t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * + vision_ntoken_per_chunk) + vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_chunk, + grid_hs, grid_ws).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len) * [audio_token_id]) + audio_start_idx = start_idx if len( + audio_llm_pos_ids_list + ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + if min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = (torch.arange( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len)).expand(3, -1) + + audio_start_idx).split(1, + dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend( + (pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand( + 3, -1) + llm_pos_ids_list[-1].max() + 1).split( + 1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, + dim=1).max() + 1 - len(src_item) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @staticmethod + def _get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, + ) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand( + len(t_index), -1, llm_grid_w).flatten()) + w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( + len(t_index), llm_grid_h, -1).flatten()) + t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( + -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + @staticmethod + def _split_list_into_ranges(lst: torch.Tensor, + interval: int) -> list[list[int]]: + ranges: list[list[int]] = [[] + for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges + + @staticmethod + def get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> list[list[int]]: + return [ + list( + range(context_len + mrope_position_delta, + seq_len + mrope_position_delta)) for _ in range(3) + ] + + @staticmethod + def get_next_input_positions_tensor(out: np.ndarray, out_offset: int, + mrope_position_delta: int, + context_len: int, num_new_tokens: int): + + values = np.arange(mrope_position_delta + context_len, + mrope_position_delta + context_len + num_new_tokens, + dtype=out.dtype) + out[:, out_offset:out_offset + num_new_tokens] = values + + @classmethod + def omni_get_updates_use_audio_in_video( + cls, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[list[int], torch.Tensor], + video_second_per_grid_t: float, + ) -> list[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * video_second_per_grid_t * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( + spatial_merge_size**2) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, + audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates diff --git a/vllm_omni/model_executor/models/__init__.py b/vllm_omni/model_executor/models/__init__.py new file mode 100644 index 00000000000..27e699e1228 --- /dev/null +++ b/vllm_omni/model_executor/models/__init__.py @@ -0,0 +1 @@ +from .registry import OmniModelRegistry \ No newline at end of file diff --git a/vllm_omni/model_executor/models/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni.py new file mode 100644 index 00000000000..2fd286c5f6c --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni.py @@ -0,0 +1,743 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2.5-Omni model (merged thinker, talker and token2wav dit).""" + +from functools import cached_property +from typing import Iterable, List, Optional, Set, Tuple, Union, NamedTuple, Dict + +import os +import glob +import numpy as np +import torch +import torch.nn as nn +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniConfig, Qwen2_5OmniThinkerConfig, Qwen2_5OmniTalkerConfig) + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniConditionalGenerationMixin, + Qwen2_5OmniThinkerMultiModalProcessor, + Qwen2_5OmniThinkerProcessingInfo, + Qwen2_5OmniThinkerDummyInputsBuilder) +# from vllm.model_executor.models.qwen2_code2wav_dit import Qwen2Code2wav +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + +from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, + maybe_prefix) +from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf +from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights + +class OmniOutput(NamedTuple): + """Output from the merged Omni model containing both text and audio.""" + text_hidden_states: torch.Tensor + multimodal_outputs: dict = {} + intermediate_tensors: Optional[IntermediateTensors] = None + +logger = init_logger(__name__) + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5OmniThinkerMultiModalProcessor, + info=Qwen2_5OmniThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, +) +class Qwen2_5OmniForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP, + Qwen2_5OmniConditionalGenerationMixin): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.have_multimodal_outputs = True + config: Qwen2_5OmniConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + # keep vllm_config for later submodule init + self.vllm_config = vllm_config + + # Initialize thinker components + thinker_config: Qwen2_5OmniThinkerConfig = config.thinker_config + self.thinker_config = thinker_config + self.multimodal_config = multimodal_config + + # Initialize talker components + talker_config: Qwen2_5OmniTalkerConfig = config.talker_config + self.talker_config = talker_config + + + + self.model_stage = vllm_config.model_config.model_stage + if self.model_stage=="thinker": + # Initialize thinker model (multimodal processing) + self.thinker = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "thinker"), + hf_config=thinker_config, + # Use registry architecture key + architectures=["Qwen2_5OmniThinkerModel"], + ) + self.model = self.thinker + self.talker = None + self.token2wav = None + + elif self.model_stage=="talker": + self.thinker = None + # Initialize talker model wrapper (handles projection + LM) + self.talker = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "talker"), + hf_config=talker_config, + # Use registry architecture key + architectures=["Qwen2_5OmniTalkerModel"], + ) + self.talker.init_multi_modal(thinker_config) + self.model=self.talker + self.token2wav = None + self._init_special_tokens_embeddings() + + elif self.model_stage=="code2wav": + self.thinker = None + self.talker = None + # Initialize token2wav (code->mel->wav) like thinker/talker + self.token2wav_config = getattr(config, 'token2wav_config', None) + self.token2wav = None + if self.token2wav_config is not None: + self.token2wav = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "token2wav"), + hf_config=self.token2wav_config, + architectures=["Qwen2_5OmniToken2WavModel"], + ) + # voice resources (loaded on demand) + self._token2wav_conds: Dict[str, torch.Tensor] = {} + self._token2wav_ref_mels: Dict[str, torch.Tensor] = {} + self.model = self.token2wav + else: + raise ValueError("Invalid model stage") + + # Set up intermediate tensors + self.make_empty_intermediate_tensors = ( + self.thinker.make_empty_intermediate_tensors) if self.model_stage=="thinker" else lambda: None + + self.thinker_output_token_ids = torch.empty(0, dtype=torch.long, device="cuda:0") + self.thinker_hidden_states = torch.empty(0, dtype=torch.long, device="cuda:0") + self.prev_inputs = torch.empty(0, dtype=torch.long, device="cuda:0") + + + # -------------------- Device utilities -------------------- + @staticmethod + def _module_device(module: nn.Module) -> torch.device: + try: + return next(module.parameters()).device + except StopIteration: + # No parameters; fall back to buffers or cpu + for _, buf in module.named_buffers(recurse=True): + return buf.device + return torch.device("cpu") + + def move_submodules_to_devices( + self, + *, + thinker_device: Optional[Union[str, torch.device]] = None, + talker_device: Optional[Union[str, torch.device]] = None, + token2wav_device: Optional[Union[str, torch.device]] = None, + ) -> None: + """Optionally move thinker/talker/token2wav to different devices. + + Example: + model.move_submodules_to_devices( + thinker_device='cuda:0', + talker_device='cuda:1', + token2wav_device='cpu', + ) + """ + if thinker_device is not None and self.thinker is not None: + self.thinker.to(thinker_device) + if talker_device is not None and self.talker is not None: + self.talker.to(talker_device) + if token2wav_device is not None and self.token2wav is not None: + self.token2wav.to(token2wav_device) + + @cached_property + def sampler(self): + if hasattr(self.model, "sampler"): + return self.model.sampler + return get_sampler() + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + ) -> torch.Tensor: + if self.model_stage == "code2wav": + return torch.zeros_like(input_ids).reshape(-1, 1).repeat(1, self.vllm_config.model_config.get_hidden_size()) + return self.model.get_input_embeddings( + input_ids, multimodal_embeddings) + + + def get_multimodal_embeddings(self, **kwargs): + # Delegate to thinker model for multimodal processing + return self.model.get_multimodal_embeddings(**kwargs) + + def last_index_of(self, list, value): + return len(list) - 1 - list[::-1].index(value) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + generate_audio: bool = True, + voice_type: str = "Chelsie", + codec: Optional[torch.Tensor] = None, + sampling_metadata: Optional[SamplingMetadata] = None, + logits_index: Optional[int] = None, + sampler = None, + additional_information: Optional[dict[str, object]] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors, OmniOutput]: + """ + Workflow: + 1) Thinker: multimodal understanding → text hidden states. + 2) If audio requested and codec not provided, use talker to derive codec. + 3) If audio requested (or codec provided), use token2wav to synthesize waveform. + 4) Return text hidden states (and audio when applicable). + """ + if self.model_stage=="thinker": + # Normalize to batched inputs if caller provides 1D/2D unbatched tensors + added_batch_dim = False + if input_ids is not None and input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + added_batch_dim = True + if positions is not None and positions.ndim == 1: + positions = positions.unsqueeze(0) + added_batch_dim = True + if inputs_embeds is not None and inputs_embeds.ndim == 2: + inputs_embeds = inputs_embeds.unsqueeze(0) + added_batch_dim = True + thinker_dev = self._module_device(self.thinker) + + #if input_ids is None, set it to an zero tenser, in the length of the same as the embedding seq length + if input_ids is None: + input_ids = torch.zeros(inputs_embeds.shape[1], dtype=torch.long, device=thinker_dev).unsqueeze(0) #(1, 0) + added_batch_dim = True + + # 1) Thinker (ensure inputs on thinker's device) + if input_ids is not None and input_ids.device != thinker_dev: + input_ids = input_ids.to(thinker_dev) + if positions is not None and positions.device != thinker_dev: + positions = positions.to(thinker_dev) + if inputs_embeds is not None and inputs_embeds.device != thinker_dev: + inputs_embeds = inputs_embeds.to(thinker_dev) + # Run thinker + thinker_output = self.thinker( + input_ids=input_ids, + positions=positions[0], + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + if isinstance(thinker_output, tuple): + embeds, text_hidden_states = thinker_output + else: + text_hidden_states = thinker_output + + # Text-only path + return OmniOutput( + text_hidden_states=text_hidden_states.squeeze(0) if added_batch_dim else text_hidden_states, + multimodal_outputs=None + ) + + # 2) Talker (if codec not provided) + if self.model_stage=="talker": + + if input_ids is None and additional_information is None: + input_ids = torch.zeros(inputs_embeds.shape[0], dtype=torch.long, device=inputs_embeds.device) + additional_information = {} + self.thinker_reply_part = torch.zeros_like(inputs_embeds) + is_profile = True + else: + is_profile = False + + if input_ids is not None and additional_information is not None and not is_profile: + # read from additional_information dict + thinker_result = None + if additional_information is not None and isinstance(additional_information, dict): + thinker_result = additional_information.get("thinker_result") + prompt_embeds = additional_information.get("prompt_embeds") + prompt_token_ids = additional_information.get("prompt_token_ids") + thinker_output_token_ids = additional_information.get("thinker_output_token_ids") + else: + thinker_result = torch.zeros_like(inputs_embeds) + prompt_embeds = torch.zeros_like(inputs_embeds) + prompt_token_ids = torch.zeros(inputs_embeds.shape[0], dtype=torch.int64, device=inputs_embeds.device) + thinker_output_token_ids = torch.zeros(inputs_embeds.shape[0], dtype=torch.int64, device=inputs_embeds.device) + + if thinker_result is None: + thinker_result = torch.zeros_like(inputs_embeds) + self.thinker_reply_part = thinker_result.squeeze(0) + if self.thinker_reply_part.shape[1] > 1: + self.thinker_reply_part = self.thinker_reply_part[1:, :] + input_ids, inputs_embeds = self._thinker_to_talker_prefill( + voice_type=voice_type, + output_prompt_embeds=thinker_result, + output_token_ids = thinker_output_token_ids, + thinker_prompt_embeds=prompt_embeds, + prompt_token_ids = prompt_token_ids, + ) + elif not is_profile: + input_ids, inputs_embeds = \ + self._thinker_to_talker_decode_one_step( + output_prompt_embeds=self.thinker_reply_part[:1] if self.thinker_reply_part.shape[0]>=1 else torch.zeros(1, self.thinker_reply_part.shape[1]).cuda().to(torch.bfloat16)+(-1.25*2**(-123)), + output_token_ids=input_ids, + ) + + if self.thinker_reply_part.shape[0] >=1: + self.thinker_reply_part = self.thinker_reply_part[1:, :] + + with torch.inference_mode(): + talker_hidden = self.talker( + input_ids=input_ids, + positions=positions[0], + inputs_embeds=inputs_embeds, + + ) + + return OmniOutput( + text_hidden_states=talker_hidden, + multimodal_outputs=None + ) + + if self.model_stage=="code2wav": + code = input_ids if input_ids is not None else torch.zeros(inputs_embeds.shape[0], dtype=torch.long, device=inputs_embeds.device) + audio_tensor = self.generate_audio(code[:-1] if code[-1]==8294 else code, voice_type) + # print("Currently, for debug, we return the audio tensor directly") + return OmniOutput( + text_hidden_states = None, + multimodal_outputs = { + "audio": audio_tensor + } + ) + + return OmniOutput( + text_hidden_states=torch.cat( + [ + torch.zeros([inputs_embeds.shape[0],896], dtype=torch.bfloat16).cuda(), + self.talker.thinker_to_talker_proj(self.talker.get_input_embeddings(torch.tensor([8294,8293]).to(torch.bfloat16).cuda()))[0] + ], + dim=0), + multimodal_outputs=None + ) + + def generate_audio(self, code, voice_type): + # 使用 Token2Wav 的分块接口进行端到端流式合成 + token2wav_dev = self._module_device(self.token2wav) + if isinstance(code, torch.Tensor): + code_tensor = code.to(dtype=torch.long, device=token2wav_dev) + else: + code_tensor = torch.as_tensor(code, dtype=torch.long, device=token2wav_dev) + if code_tensor.ndim == 2 and code_tensor.shape[0] == 1: + code_tensor = code_tensor.squeeze(0) + + audio_tensor = self._codec_to_audio(code_tensor, voice_type) + + return audio_tensor + + def _load_model_embedding( + self, + kind: str, # thinker or talker + ) -> torch.nn.Embedding: + + if kind == 'thinker': + return self.thinker.language_model.model.embed_tokens if self.thinker is not None else torch.load("thinker_embedding.pt", weights_only=False) + elif kind == 'talker': + return self.talker.language_model.model.embed_tokens if self.talker is not None else torch.load("talker_embedding.pt", weights_only=False) + else: + raise ValueError("invalid kind") + + def _init_special_tokens_embeddings( + self, + ): + # thinker and talker embeddings + self.thinker_embedding = self._load_model_embedding('thinker') + self.talker_embedding = self._load_model_embedding('talker') + + # embed_text_bos_token + self.tts_text_spk_token_ids = { + # M02:我是个会说标准普通话、带部分北方口音的男声 + 'm02': 151870, + 'Ethan': 151870, + + # F030:我是你的二次元虚拟女友 + 'f030': 151872, + 'Chelsie': 151872, + } + self.default_tts_text_spk_type = list( + self.tts_text_spk_token_ids.keys())[0] + self.tts_text_spk_token_ids['prefix_caching'] = 151870 + + talker_hf_config = self.talker_config + if hasattr(talker_hf_config, 'talker_config'): + talker_hf_config = talker_hf_config.talker_config + + self.embed_text_bos_token = self.thinker_embedding( + torch.tensor( + [talker_hf_config.tts_text_start_token_id], + dtype=torch.long, + device="cuda:0", + )) + self.embed_text_spk_tokens = { + key: + self.thinker_embedding( + torch.tensor( + [value], + dtype=torch.long, + device="cuda:0", + )) + for key, value in self.tts_text_spk_token_ids.items() + } + self.embed_text_eos_token = self.thinker_embedding( + torch.tensor( + [talker_hf_config.tts_text_end_token_id], + dtype=torch.long, + device="cuda:0", + )) + self.embed_text_pad_token = self.thinker_embedding( + torch.tensor( + [talker_hf_config.tts_text_pad_token_id], + dtype=torch.long, + device="cuda:0", + )) + self.embed_codec_bos_token = self.talker_embedding( + torch.tensor( + [talker_hf_config.tts_codec_start_token_id], + dtype=torch.long, + device="cuda:0", + )) + self.embed_codec_eos_token = self.talker_embedding( + torch.tensor( + [talker_hf_config.tts_codec_end_token_id], + dtype=torch.long, + device="cuda:0", + )) + self.embed_codec_pad_token = self.talker_embedding( + torch.tensor( + [talker_hf_config.tts_codec_pad_token_id], + dtype=torch.long, + device="cuda:0", + )) + return set(["thinker_embedding.weight", "talker_embedding.weight"]) + + def _get_embed_text_spk_token(self, voice_type: str): + if voice_type not in self.embed_text_spk_tokens: + return self.embed_text_bos_token + return self.embed_text_spk_tokens[voice_type] + + def _get_text_spk_token_id(self, voice_type: str): + talker_hf_config = self.talker_config + if hasattr(talker_hf_config, 'talker_config'): + talker_hf_config = talker_hf_config.talker_config + + if voice_type not in self.tts_text_spk_token_ids: + return talker_hf_config.tts_text_start_token_id + return self.tts_text_spk_token_ids[voice_type] + + def _thinker_to_talker_prefill( + self, + voice_type: str, + output_prompt_embeds, + output_token_ids, + thinker_prompt_embeds, + prompt_token_ids + ): + + talker_hf_config = self.talker_config + if hasattr(talker_hf_config, 'talker_config'): + talker_hf_config = talker_hf_config.talker_config + + # if len(output.outputs[0].token_ids) == 2: + # issue request + prompt_embeds = torch.cat([ + thinker_prompt_embeds, + self._get_embed_text_spk_token(voice_type) + + self.embed_codec_pad_token, + output_prompt_embeds[:1] + self.embed_codec_bos_token, + ], + dim=0) + + prompt_token_ids_processed = prompt_token_ids + [ + talker_hf_config.tts_codec_pad_token_id, + output_token_ids[0], + + ] + input_tokens_len = len(prompt_token_ids_processed) + # the code below is from model runner in Qwen, may need to further discuss later + if input_tokens_len > 2: + prompt_token_ids_processed = ( + [self.talker_config.tts_codec_mask_token_id] * + (input_tokens_len - 2) + [ + self.talker_config.tts_codec_pad_token_id, + self.talker_config.tts_codec_start_token_id + ]) + else: + prompt_token_ids_processed = [ + self.talker_config.tts_codec_pad_token_id, + self.talker_config.tts_codec_start_token_id, + ][-input_tokens_len:] + if isinstance(prompt_token_ids_processed,list): + prompt_token_ids_processed = torch.Tensor(prompt_token_ids_processed).to(torch.int64).cuda() + return prompt_token_ids_processed, prompt_embeds + + def _thinker_to_talker_decode_one_step( + self, + output_prompt_embeds, + output_token_ids, + ): + processed_output_token_embeds = output_prompt_embeds + self.talker.get_input_embeddings(output_token_ids) #for decode + return output_token_ids, processed_output_token_embeds + + def compute_logits( + self, + hidden_states: Union[torch.Tensor, OmniOutput], + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + # Handle OmniOutput type + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + + # Use thinker model for logits computation + return self.model.compute_logits(hidden_states, sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + # Use thinker model for sampling + return self.model.sample(logits, sampling_metadata) + + def generate_speech(self, text_tokens: torch.Tensor, voice_type: str = "default"): + """ + Generate speech from text tokens using the talker and token2wav models. + This method is kept for backward compatibility and direct speech generation. + + Args: + text_tokens: Text tokens from thinker model + voice_type: Voice type for speech generation + + Returns: + Audio tensor + """ + # Generate codec tokens using talker model + talker_output = self.talker( + input_ids=None, + positions=None, + inputs_embeds=text_tokens + ) + + # Convert talker output to codec tokens + codec_tokens = self._convert_to_codec_tokens(talker_output) + + # Generate audio using token2wav model + return self._codec_to_audio(codec_tokens, voice_type=voice_type) + + + def _convert_to_codec_tokens(self, talker_output: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: + """ + 参考 HF:使用 talker 的 codec 头得到 logits,抑制 BOS,再贪心选取当前步的下一个 codec token。 + """ + with torch.inference_mode(): + logits = self.talker.compute_logits(talker_output, None) + if logits is None: + return torch.zeros((talker_output.size(0), 0), dtype=torch.long, device=talker_output.device) + + # 仅抑制 codec_bos,与 HF generate 的 suppress_tokens 行为一致 + bos_id = None + if hasattr(self, 'talker_config') and hasattr(self.talker_config, 'tts_codec_start_token_id'): + bos_id = int(getattr(self.talker_config, 'tts_codec_start_token_id')) + if bos_id is not None: + logits[..., bos_id] = -1e9 + + # 取最后一步位置的分布并贪心选取 + next_id = self.talker.sample(logits, sampling_metadata).sampled_token_ids + return next_id.to(dtype=torch.long) + + def _init_token2wav_model(self): + """Initialize speaker resources if provided; model is constructed in __init__.""" + if self.token2wav is None or self.token2wav_config is None: + return + device = 'cuda' if torch.cuda.is_available() else 'cpu' + # optional speaker resources + conds = getattr(self.token2wav_config, 'conds', None) + ref_mels = getattr(self.token2wav_config, 'ref_mels', None) + if isinstance(conds, dict) and isinstance(ref_mels, dict): + self._token2wav_conds = {k: torch.as_tensor(v, device=device) for k, v in conds.items()} + self._token2wav_ref_mels = {k: torch.as_tensor(v, device=device) for k, v in ref_mels.items()} + # legacy: load from directory if provided + model_path = "/workspace/model_ckpt/Qwen2.5-Omni-7B"#getattr(self.token2wav_config, 'model_path', None) + if isinstance(model_path, str) and os.path.isdir(model_path): + spk_pt = os.path.join(model_path, 'spk_dict.pt') + if os.path.exists(spk_pt): + data = torch.load(spk_pt, map_location=device) + for key, value in data.items(): + self._token2wav_conds[key] = value["cond"].to(device) + self._token2wav_ref_mels[key] = value["ref_mel"].to(device) + else: + # legacy npy inputs + for f in sorted(glob.glob(os.path.join(model_path, 'inputs', '*spk_emb.npy'))): + key = os.path.basename(f).split('_')[0].lower() + self._token2wav_conds[key] = torch.as_tensor(np.load(f), device=device) + for f in sorted(glob.glob(os.path.join(model_path, 'inputs', '*ref_mel.npy'))): + key = os.path.basename(f).split('_')[0].lower() + self._token2wav_ref_mels[key] = torch.as_tensor(np.load(f), device=device) + + def _codec_to_audio(self, codec_tokens: torch.Tensor, voice_type: str = "default") -> Optional[torch.Tensor]: + if self.token2wav is None: + self._init_token2wav_model() + if self.token2wav is None: + return None + # Normalize voice type + voice = (voice_type or 'default') + # Resolve cond / ref_mel if provided + cond = None + ref_mel = None + if voice in self._token2wav_conds and voice in self._token2wav_ref_mels: + cond = self._token2wav_conds[voice] + ref_mel = self._token2wav_ref_mels[voice] + # Fallback: create dummy cond/ref_mel if not provided + token2wav_dev = self._module_device(self.token2wav) + if cond is None: + cond = torch.zeros((1, self.token2wav_config.dit_config.enc_emb_dim), device=token2wav_dev, dtype=torch.float32) + if ref_mel is None: + ref_mel = torch.zeros((1, 300, self.token2wav_config.dit_config.mel_dim), device=token2wav_dev, dtype=torch.float32) + + # Ensure codec is (1, T) long tensor on correct device + if isinstance(codec_tokens, torch.Tensor): + codec = codec_tokens.to(dtype=torch.long, device=token2wav_dev) + if codec.ndim == 1: + codec = codec.unsqueeze(0) + else: + codec = torch.as_tensor(codec_tokens, dtype=torch.long, device=token2wav_dev).unsqueeze(0) + + # Streaming with chunked process and boundary alignment (rely on token2wav.process_chunk) + factor = getattr(self.token2wav.token2wav.factor, 'factor', 2) + chunk_size = 48 + mel_dim = getattr(self.token2wav.token2wav.code2wav_dit_model, 'mel_dim', self.token2wav_config.dit_config.mel_dim) + total_mel = int(codec.shape[1] * factor) + steps = 10 + + # Prepare initial noise for the whole sequence + y_all = torch.randn((1, total_mel, mel_dim), dtype=ref_mel.dtype, device=token2wav_dev) + + logger.info(f"Currently, we do not use the chunked process, we only use the token2wav.process_chunk for the whole sequence.\ + The stream mode will be implemented in the future.") + + chunk_ends = [] + for i in range(codec.shape[1]): + chunk_code_length = i * 2 - 24 + finished = i==(codec.shape[1]-1) + if (chunk_code_length > 0 and + chunk_code_length % chunk_size == 0) or finished: + chunk_ends.append(i) + + # Number of chunks in mel domain + prev_generated = None + wav_chunks: list = [] + prev_id = 0 + + with torch.inference_mode(): + for n,i in enumerate([0]): + finished = (i == codec.shape[1] - 1) + _, audio_chunk = self.token2wav.process_chunk( + conditioning=cond, + reference_mel=ref_mel, + codec_all=codec, + y_all=y_all, + i=n, + steps=steps, + prev_generated=prev_generated if prev_generated is not None else [], + finished=True, + ) + prev_generated = audio_chunk + wav_chunks.append(audio_chunk.detach().cpu().numpy()) + prev_id = i + + if len(wav_chunks) == 0: + return torch.zeros(0, device=token2wav_dev) + + waveform = np.concatenate(wav_chunks) + return torch.as_tensor(waveform, device=token2wav_dev) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + """Load weights for all components of the omni model.""" + loaded_weights = set() + thinker_weights = [] + talker_weights = [] + token2wav_weights = [] + for k, v in weights: + if k.startswith('thinker.'): + thinker_weights.append((k, v)) + elif k.startswith('talker.'): + talker_weights.append((k, v)) + elif k.startswith('token2wav.'): + token2wav_weights.append((k, v)) + else: + raise ValueError(f"Unknown weight prefix: {k}") + + # Load thinker weights + if self.thinker: + if thinker_weights: + thinker_loaded = self.thinker.load_weights(thinker_weights) + else: + thinker_loaded = set([k for k,v in thinker_weights]) + thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, 'thinker') + loaded_weights.update(thinker_loaded) + torch.save(self.thinker.language_model.model.embed_tokens, "thinker_embedding.pt") + + + # Load talker weights + if talker_weights and self.talker is not None: + # Map talker weights to appropriate components + talker_loaded = self.talker.load_weights(talker_weights) + talker_loaded = add_prefix_to_loaded_weights(talker_loaded, 'talker') + loaded_weights.update(talker_loaded) + torch.save(self.talker.language_model.model.embed_tokens, "talker_embedding.pt") + loaded_weights.update(self._init_special_tokens_embeddings()) + + + # Load token2wav weights (if any) + if token2wav_weights and self.token2wav is not None: + self._init_token2wav_model() + hf_model_folder = download_weights_from_hf(self.vllm_config.model_config.model, + self.vllm_config.load_config.download_dir, allow_patterns=["*.safetensors", "*.bin", "*.pt"]) + t2w_loaded = self.token2wav.load_weights(token2wav_weights, os.path.join(hf_model_folder, "spk_dict.pt")) + t2w_loaded = add_prefix_to_loaded_weights(t2w_loaded, 'token2wav') + loaded_weights.update(t2w_loaded) + + return loaded_weights \ No newline at end of file diff --git a/vllm_omni/model_executor/models/qwen2_5_omni_talker.py b/vllm_omni/model_executor/models/qwen2_5_omni_talker.py new file mode 100644 index 00000000000..8e36e416ff7 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni_talker.py @@ -0,0 +1,269 @@ +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" +from functools import cached_property +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import Qwen2_5OmniTalkerConfig +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import Qwen2_5OmniAudioEncoder + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniConditionalGenerationMixin, + Qwen2_5OmniThinkerMultiModalProcessor, + Qwen2_5OmniThinkerProcessingInfo, + Qwen2_5OmniThinkerDummyInputsBuilder) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, + maybe_prefix, + merge_multimodal_embeddings) +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionTransformer + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5OmniThinkerMultiModalProcessor, + info=Qwen2_5OmniThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, +) +class Qwen2_5OmniTalkerForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP, + Qwen2_5OmniConditionalGenerationMixin): + logger = init_logger(__name__) + # Align to thinker-style static mapper for clarity + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # text LM head/body in talker + "talker.codec_head.": "language_model.lm_head.", + "talker.model.": "language_model.model.", + # projection weights + "talker.thinker_to_talker_proj.": "thinker_to_talker_proj.", + # fallback root + "talker.": "", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: Qwen2_5OmniTalkerConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.prefix = prefix + self.quant_config = quant_config + + if hasattr(config, "talker_config"): + self.config = config.talker_config + vllm_config.model_config.hf_text_config = \ + vllm_config.model_config.hf_config.talker_config + else: + self.config = config + + self.thinker_to_talker_proj = ColumnParallelLinear( + self.config.embedding_size, + self.config.hidden_size, + bias=True, + gather_output=True, + skip_bias_add=False, + quant_config=quant_config, + ) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + hf_config=getattr(self.config, 'text_config', self.config), + architectures=["Qwen2ForCausalLM_old"], + ) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def init_multi_modal(self, thinker_config): + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=self.quant_config, + prefix=maybe_prefix(self.prefix, "visual"), + ) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + # def get_input_embeddings( + # self, + # input_ids: torch.Tensor, + # ) -> torch.Tensor: + # return self.language_model.get_input_embeddings(input_ids) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + + # TODO (ywang96): support overlapping modalitiy embeddings so that + # `use_audio_in_video` will work on V1. + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, [ + self.config.image_token_index, + self.config.video_token_index, + self.config.audio_token_index + ]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor = None, + positions: torch.Tensor = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + assert input_ids is not None or inputs_embeds is not None, "input_ids or inputs_embeds must be provided" + forward_context: ForwardContext = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + # for profile_run: + inputs_embeds = self.get_input_embeddings(input_ids) + # else: + # # in the decoding stage, the "input_embeds" means "thinker_reply_part" + # if attn_metadata.prefill_metadata is None: + # inputs_embeds += self.get_input_embeddings(input_ids) + + input_ids = None + + # projection + inputs_embeds, _ = self.thinker_to_talker_proj(inputs_embeds) + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=['thinker.', 'token2wav.'], + ) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + # Log load summary + try: + total_bytes = 0 + for name, param in self.named_parameters(): + if param is not None and param.data is not None: + total_bytes += param.data.numel() * param.data.element_size() + device = next(self.parameters()).device + self.logger.info( + "[Model Loaded] name=%s, success=%s, size=%.2f MB, device=%s", + self.__class__.__name__, True, total_bytes / (1024**2), str(device)) + except Exception: + pass + multi_model_weights = set() + for name, param in self.visual.named_parameters(): + multi_model_weights.add("visual."+name) + for name, param in self.audio_tower.named_parameters(): + multi_model_weights.add("audio_tower."+name) + loaded.update(multi_model_weights) + return loaded + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) + if input_key in ("input_audio_features" + ) and "audio" not in mm_input_by_modality: + mm_input_by_modality[ + "audio"] = self._parse_and_validate_audio_input(**kwargs) + return mm_input_by_modality + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += audio_embeddings + return multimodal_embeddings \ No newline at end of file diff --git a/vllm_omni/model_executor/models/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni_thinker.py new file mode 100644 index 00000000000..a3b5aed47e9 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni_thinker.py @@ -0,0 +1,913 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2.5-Omni model (thinker part).""" + +from collections.abc import Iterable, Mapping, Sequence +from copy import copy +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from transformers.feature_extraction_utils import BatchFeature +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniConfig, Qwen2_5OmniThinkerConfig) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoder) +from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import ( + Qwen2_5OmniProcessor) +from transformers.models.whisper import WhisperFeatureExtractor + +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, + Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) +from vllm.model_executor.models.qwen2_audio import ( + Qwen2AudioInputs, Qwen2AudioProcessingInfo, + _get_feat_extract_output_lengths) +from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, + ModalityDataItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + PlaceholderFeaturesInfo, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens + +from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +from vllm_omni.model_executor.layers.mrope import MRotaryEmbedding + +try: + import flash_attn +except (ImportError, ModuleNotFoundError): + flash_attn = None + +logger = init_logger(__name__) + + +def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): + audio_feature_lengths = hf_inputs.get("audio_feature_lengths", + torch.empty((0, ))) + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + input_audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_feature_lengths, dim=1), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + audio_feature_lengths=MultiModalFieldConfig.batched("audio"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + ) + + +class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): + + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={ + "input_audio_features", "audio_feature_lengths" + }, + fields_factory=_qwen2_5_omni_thinker_field_config, + ) + + return super()._parse_audio_data(data) + + +class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo, + Qwen2_5_VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config + + def get_hf_processor(self, **kwargs: object) -> Qwen2_5OmniProcessor: + return self.ctx.get_hf_processor( + Qwen2_5OmniProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + + def get_feature_extractor(self, **kwargs: object): + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None, "image": None, "video": None} + + +class Qwen2_5OmniThinkerDummyInputsBuilder( + BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + hf_processor = self.info.get_hf_processor() + + audio_token: str = hf_processor.audio_token + image_token: str = hf_processor.image_token + video_token: str = hf_processor.video_token + + return (audio_token * num_audios + image_token * num_images + + video_token * num_videos) + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + feature_extractor = self.info.get_feature_extractor() + + target_audio_length = min( + feature_extractor.chunk_length, + 30, + ) * feature_extractor.sampling_rate + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + mm_data = { + "audio": + self._get_dummy_audios(length=target_audio_length, + num_audios=num_audios), + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos(width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos), + } + + return mm_data + + +class Qwen2_5OmniThinkerMultiModalProcessor( + BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return Qwen2_5OmniThinkerMultiModalDataParser( + target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + # NOTE: WhisperFeatureExtractor cannot handle empty list of audios + if audios: + # NOTE: Qwen2.5-Omni processor accept "audio" + mm_data["audio"] = audios + mm_kwargs = dict(**mm_kwargs, ) + + hf_inputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + input_features = hf_inputs.pop('input_features', None) + feature_attention_mask = hf_inputs.get('feature_attention_mask', None) + if ('input_audio_features' not in hf_inputs + and input_features is not None): + if feature_attention_mask is not None: + input_features = input_features.permute( + 0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + hf_inputs['input_audio_features'] = input_features + if ('audio_feature_lengths' not in hf_inputs + and feature_attention_mask is not None): + hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) + return hf_inputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _qwen2_5_omni_thinker_field_config(hf_inputs) + + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + prompt_ids: list[int], + mm_kwargs: MultiModalKwargs, + is_update_applied: bool, + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + """ + Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. + """ + unbound_prompt_updates = self._get_prompt_updates( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates) + + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + + if is_update_applied: + mm_placeholders = self._find_mm_placeholders( + mm_prompt_updates, + prompt_ids, + mm_item_counts, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + use_audio_in_video=use_audio_in_video) + + tokenizer = self.info.get_tokenizer() + prompt = decode_tokens(tokenizer, prompt_ids) + else: + ( + prompt_ids, + prompt, + mm_placeholders, + ) = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + mm_item_counts, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + use_audio_in_video=use_audio_in_video) + + tokenizer = self.info.get_tokenizer() + prompt = decode_tokens(tokenizer, prompt_ids) + + if use_audio_in_video: + mm_kwargs["use_audio_in_video"] = True + + return prompt_ids, prompt, mm_placeholders + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + vocab = tokenizer.get_vocab() + + audio_token = processor.audio_token + image_token = processor.image_token + video_token = processor.video_token + audio_token_id = vocab[audio_token] + image_token_id = vocab[image_token] + video_token_id = vocab[video_token] + + audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths") + feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + if audio_feature_lengths is None and feature_attention_mask is None: + audio_output_lengths = [] + elif audio_feature_lengths is not None: + _, audio_output_lens = _get_feat_extract_output_lengths( + audio_feature_lengths) + audio_output_lengths = audio_output_lens.tolist() + elif feature_attention_mask is not None: + assert isinstance(feature_attention_mask, torch.Tensor) + _, audio_output_lens = _get_feat_extract_output_lengths( + feature_attention_mask.sum(-1)) + audio_output_lengths = audio_output_lens.tolist() + + # number of audios read from video. + audio_in_video_item_idx = 0 + + def get_replacement_qwen2_audio(item_idx: int): + item_idx += audio_in_video_item_idx + + num_features = audio_output_lengths[item_idx] + if num_features == 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + raise ValueError( + f"The audio {audio} (len={len(audio)}) is too short " + "to be represented inside the model") + + return [audio_token_id] * num_features + + def get_replacement_qwen2_vision(item_idx: int, modality: str): + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + merge_length = image_processor.merge_size**2 + + token_id = image_token_id if modality == "image" else video_token_id + return [token_id] * (int(grid_thw.prod()) // merge_length) + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + thinker_config = self.info.get_hf_config() + + def get_replacement_qwen2_use_audio_in_video(item_idx: int): + nonlocal audio_in_video_item_idx + + audio_num_features = audio_output_lengths[audio_in_video_item_idx + + item_idx] + video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + + audio_in_video_item_idx += 1 + + second_per_grid_ts = hf_processor_mm_kwargs.get( + "second_per_grid_ts", None) + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[item_idx] + else: + video_second_per_grid_t = 1.0 + + return MRotaryEmbedding.omni_get_updates_use_audio_in_video( + thinker_config=thinker_config, + audio_len=audio_num_features, + video_grid_thw=video_grid_thw, + video_second_per_grid_t=video_second_per_grid_t, + ) + + video_replacement_fn = ( + get_replacement_qwen2_use_audio_in_video if use_audio_in_video else + partial(get_replacement_qwen2_vision, modality="video")) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_qwen2_audio, + ), + PromptReplacement( + modality="image", + target=image_token, + replacement=partial(get_replacement_qwen2_vision, + modality="image"), + ), + PromptReplacement( + modality="video", + target=video_token, + replacement=video_replacement_fn, + ), + ] + + def _apply_hf_processor_main( + self, + prompt: Union[str, list[int]], + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + *, + enable_hf_prompt_update: bool, + ) -> tuple[list[int], BatchFeature, bool]: + """ + Qwen2.5-Omni reimplements this function to handle text only. + """ + if isinstance(prompt, str): + if enable_hf_prompt_update: + return self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + tokenizer = self.info.get_tokenizer() + prompt_ids = encode_tokens(tokenizer, prompt) + else: + prompt_ids = self._apply_hf_processor_tokens_only(prompt) + + mm_processed_data = self._apply_hf_processor_mm_only( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + + return prompt_ids, mm_processed_data, False + + def _apply_hf_processor_mm_only( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> BatchFeature: + """ + Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. + """ + mm_counts = mm_items.get_all_counts() + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + if use_audio_in_video and "video" in mm_counts: + assert "audio" in mm_counts + mm_counts["audio"] -= mm_counts["video"] + + _, mm_processed_data, _ = self._apply_hf_processor_text_mm( + prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + + return mm_processed_data + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_item_counts: Mapping[str, int], + use_audio_in_video: bool = False, + ) -> None: + if use_audio_in_video: + mm_item_counts = copy(mm_item_counts) + if "video" in mm_item_counts: + assert "audio" in mm_item_counts + mm_item_counts["audio"] -= mm_item_counts["video"] + super()._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + +class Qwen2_5OmniConditionalGenerationMixin: + + def _validate_and_reshape_mm_tensor(self, + mm_input: object, + name: str, + dim: int = 0) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + return torch.concat(list(mm_input), dim=dim) + else: + return torch.concat(mm_input, dim=dim) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[Qwen2AudioInputs]: + input_audio_features = kwargs.pop('input_audio_features', None) + audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) + feature_attention_mask = kwargs.pop('feature_attention_mask', None) + if input_audio_features is None: + return None + input_audio_features = self._validate_and_reshape_mm_tensor( + input_audio_features, 'input_audio_features', dim=1) + if feature_attention_mask is not None: + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, 'feature_attention_mask') + if not isinstance(input_audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_audio_features)}") + return Qwen2AudioInputs(input_features=input_audio_features, + audio_feature_lengths=audio_feature_lengths, + feature_attention_mask=feature_attention_mask) + + def _parse_and_validate_image_input( + self, + **kwargs: dict[str, Any], + ) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, + **kwargs: dict[str, Any], + ) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + def _process_audio_input( + self, + audio_input: Qwen2AudioInputs, + audio_hashes: list[str] = None, + cached_audio_features: torch.Tensor = None, + ) -> torch.Tensor: + + input_features = audio_input["input_features"] + audio_feature_lengths = audio_input["audio_feature_lengths"] + if input_features.ndim == 3: + assert input_features.shape[0] == 1 + input_features = input_features.squeeze(0) + if audio_feature_lengths.ndim == 2: + assert audio_feature_lengths.shape[ + 0] == 1 or audio_feature_lengths.shape[1] == 1 + if audio_feature_lengths.shape[0] == 1: + audio_feature_lengths = audio_feature_lengths.squeeze(0) + else: + audio_feature_lengths = audio_feature_lengths.squeeze(1) + + audio_feat_lengths, audio_output_lengths = ( + self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths)) + + audio_outputs = self.audio_tower( + input_features.to(self.audio_tower.dtype), + feature_lens=audio_feature_lengths, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + return audio_features.split(audio_output_lengths.tolist()) + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"].type(self.visual.dtype) + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs, + video_hashes: list[str] = None, + cached_video_embeds: torch.Tensor = None) -> torch.Tensor: + if video_input["type"] == "video_embeds": + return video_input["video_embeds"].type(self.visual.dtype) + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5OmniThinkerMultiModalProcessor, + info=Qwen2_5OmniThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, +) +class Qwen2_5OmniThinkerForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, + Qwen2_5OmniConditionalGenerationMixin): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "thinker.lm_head.": "language_model.lm_head.", + "thinker.model.": "language_model.model.", + "thinker.": "", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|vision_start|><|IMAGE|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|VIDEO|><|vision_end|>" + if modality.startswith("audio"): + return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>" + + raise ValueError("Only image, video or audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + thinker_config: Qwen2_5OmniThinkerConfig = ( + vllm_config.model_config.hf_config) + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = thinker_config + self.multimodal_config = multimodal_config + + # force "use_flash_attention_2=True" to audio tower to align + # the results. + if flash_attn is not None: + audio_config = thinker_config.audio_config + audio_config._attn_implementation_autoset = True + audio_config._attn_implementation = "flash_attention_2" + else: + logger.warning( + "flash_attn is not available, the model may not yield the " + "exactly same result as the transformers implementation " + "in the audio tower part.") + + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + self.quant_config = quant_config + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + hf_config=thinker_config.text_config, + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) + if input_key in ("input_audio_features" + ) and "audio" not in mm_input_by_modality: + mm_input_by_modality[ + "audio"] = self._parse_and_validate_audio_input(**kwargs) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += audio_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + + # TODO (ywang96): support overlapping modalitiy embeddings so that + # `use_audio_in_video` will work on V1. + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, [ + self.config.image_token_index, + self.config.video_token_index, + self.config.audio_token_index + ]) + return inputs_embeds + + def get_multimodal_embeddings_v0( + self, **kwargs: object) -> Optional[NestedTensors]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if audio_input is None and image_input is None and video_input is None: + return None + + multimodal_embeddings: list[tuple[NestedTensors, str]] = [] + + if audio_input is not None: + audio_embeds = self._process_audio_input(audio_input) + multimodal_embeddings.append((audio_embeds, "audio")) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + multimodal_embeddings.append((image_embeds, "image")) + if video_input is not None: + video_embeds = self._process_video_input(video_input) + multimodal_embeddings.append((video_embeds, "video")) + return multimodal_embeddings + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + for embeddings, modality in multimodal_embeddings: + if modality == "audio": + placeholder_token_id = self.config.audio_token_index + if modality == "image": + placeholder_token_id = self.config.image_token_index + if modality == "video": + placeholder_token_id = self.config.video_token_index + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, embeddings, placeholder_token_id) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings( + input_ids, multimodal_embeddings) + text_inputs_embeds = self.get_input_embeddings(input_ids, [(torch.zeros_like(embeddings), modality) for embeddings, modality in multimodal_embeddings] if multimodal_embeddings is not None else None) + input_ids = None + else: + text_inputs_embeds = inputs_embeds + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return text_inputs_embeds, hidden_states.unsqueeze(0) #(1, S, D) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["talker.", "token2wav."], + ) + loaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + # Log load summary + try: + total_bytes = 0 + for name, param in self.named_parameters(): + if param is not None and param.data is not None: + total_bytes += param.data.numel() * param.data.element_size() + device = next(self.parameters()).device + logger.info( + "[Model Loaded] name=%s, success=%s, size=%.2f MB, device=%s", + self.__class__.__name__, True, total_bytes / (1024**2), str(device)) + except Exception: + pass + return loaded_weights + diff --git a/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py b/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py new file mode 100644 index 00000000000..f6bbcfb2935 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_5_omni_token2wav.py @@ -0,0 +1,1693 @@ +############################ +# Start Token2Wav # +############################ + +from typing import Iterable, Optional, Set, Tuple, List, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +from torch.nn import Parameter + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader as _Vllm_AutoWeightsLoader, + WeightsMapper as _Vllm_WeightsMapper, + init_vllm_registered_model as _Vllm_init_vllm_registered_model, + maybe_prefix as _Vllm_maybe_prefix, +) + +# Bring in HF base classes, configs and utilities used below +from transformers.utils import ModelOutput +from transformers.utils.logging import get_logger as _hf_get_logger +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniPreTrainedModel, + Qwen2_5OmniPreTrainedModelForConditionalGeneration, +) +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniDiTConfig, + Qwen2_5OmniBigVGANConfig, + Qwen2_5OmniToken2WavConfig, +) + +# Provide a no-op auto_docstring decorator to satisfy annotations if missing +def auto_docstring(func=None, **_kwargs): + if func is None: + def wrapper(f): + return f + return wrapper + return func + +# HF logger alias +logger = _hf_get_logger(__name__) + + +# Using custom RoPE, will use LlamaRotaryEmbedding next version +class Qwen2_5OmniDiTRotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x): + batch_size, seq_len = x.shape[0], x.shape[1] + t = torch.arange(seq_len, device=x.device) + device_type = x.device.type + device_type = device_type if device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() + freqs = torch.stack((freqs, freqs), dim=-1) + freqs = freqs.reshape(*freqs.shape[:-2], -1) + freqs = freqs.repeat(batch_size, *([1] * freqs.dim())) + cos = freqs.cos() + sin = freqs.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class TimeDelayNetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + output = torch.cat(outputs, dim=1) + return output + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=se_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d( + in_channels=se_channels, + out_channels=out_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + + return hidden_states * hidden_states_mean + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + """ + + def __init__(self, channels, attention_channels=128): + super().__init__() + + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def _length_to_mask(self, length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + """ + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + def _compute_statistics(self, x, m, dim=2): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + + # Make binary mask of shape [N, 1, L] + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + total = mask.sum(dim=2, keepdim=True) + + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + + # Apply layers + attention = self.conv(self.tanh(self.tdnn(attention))) + + # Filter out zero-paddings + attention = attention.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SqueezeExcitationBlock. + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + + return hidden_state + residual + + +class ECAPA_TimeDelayNet(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). + """ + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], + config.enc_channels[-1], + config.enc_kernel_sizes[-1], + config.enc_dilations[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + config.enc_channels[-1], + attention_channels=config.enc_attention_channels, + ) + + # Final linear transformation + self.fc = nn.Conv1d( + in_channels=config.enc_channels[-1] * 2, + out_channels=config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + # Minimize transpose for efficiency + hidden_states = hidden_states.transpose(1, 2) + + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + + # Multi-layer feature aggregation + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + + # Attentive Statistical Pooling + hidden_states = self.asp(hidden_states) + + # Final linear transformation + hidden_states = self.fc(hidden_states) + + hidden_states = hidden_states.squeeze(-1) + return hidden_states + + +class DiTInputEmbedding(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + self.proj = nn.Linear( + config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim, + config.hidden_size, + ) + self.spk_encoder = ECAPA_TimeDelayNet(config) + + def forward( + self, + hidden_states: torch.Tensor, + speaker_embedding: torch.Tensor, + condition_vector: torch.Tensor, + code_embed: torch.Tensor, + drop_audio_cond: Optional[bool] = False, + code_embed_uncond: Optional[bool] = None, + apply_cfg: Optional[bool] = True, + ): + if apply_cfg: + hidden_states = torch.cat([hidden_states, hidden_states], dim=0) + speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0) + condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0) + code_embed = torch.cat([code_embed, code_embed_uncond], dim=0) + elif drop_audio_cond: # cfg for cond audio + condition_vector = torch.zeros_like(condition_vector) + speaker_embedding = torch.zeros_like(speaker_embedding) + condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1) + hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1)) + + return hidden_states + + +# Transformer backbone using DiT blocks +class DiTCodecEmbedding(nn.Module): + def __init__(self, codec_num_embeds, codec_dim, repeats): + super().__init__() + self.repeats = repeats + self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim) + + def forward(self, code, drop_code=False): + if drop_code: + code = torch.zeros_like(code) + code_embed = self.codec_embed(code) + + code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1) + return code_embed + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation +class Qwen2_5_OmniAdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation +class Qwen2_5_OmniAdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + return hidden_states + + +# FeedForward +class DiTMLP(nn.Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + + self.ff = nn.ModuleList( + [ + nn.Linear(dim, inner_dim), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + ] + ) + + def forward(self, hidden_states): + for layer in self.ff: + hidden_states = layer(hidden_states) + return hidden_states + + +# Modified from Llama with a different rotate function, will fixed in next release +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + def rotate_half_codec(x): + # x = rearrange(x, "... (d r) -> ... d r", r=2) + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.reshape(*x.shape[:-2], -1) + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half_codec(q) * sin) + k_embed = (k * cos) + (rotate_half_codec(k) * sin) + return q_embed, k_embed + + +class DiTAttention(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + + self.config = config + self.dim = config.hidden_size + self.heads = config.num_attention_heads + self.inner_dim = config.head_dim * config.num_attention_heads + self.dropout = config.dropout + self.is_causal = False + + self.to_q = nn.Linear(config.hidden_size, self.inner_dim) + self.to_k = nn.Linear(config.hidden_size, self.inner_dim) + self.to_v = nn.Linear(config.hidden_size, self.inner_dim) + + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + + def forward( + self, + hidden_states, # noised input x + position_embeddings=None, # rotary position embedding for x + attention_mask=None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + # Due to training process, only first head is applied with RoPE, will be fixed at next release + cos, sin = position_embeddings + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_weights, _ = attention_interface( + self, + query, + key, + value, + attention_mask=attention_mask, + is_causal=False, + ) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.to(query.dtype) + + # linear proj + attention_output = self.to_out[0](attention_weights) + attention_output = self.to_out[1](attention_output) + + return attention_output + + +# time step conditioning embedding +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, hidden_states, scale=1000): + device = hidden_states.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb.type_as(hidden_states) + + +class DiTTimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)]) + + def forward(self, timestep): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + for layer in self.time_mlp: + time_hidden = layer(time_hidden) # b d + return time_hidden + + +class DiTDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0): + super().__init__() + self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size) + + self.attn = DiTAttention(config) + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout) + + def forward( + self, hidden_states, timestep, position_embeddings=None, block_diff=None + ): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep) + + # attention + attn_output = self.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=(block_diff >= -float(self.look_backward_block)) + & (block_diff <= float(self.look_ahead_block)), + ) + + # process attention output for input x + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + return hidden_states + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://huggingface.co/papers/2006.08195 + """ + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + def forward(self, hidden_states): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + alpha = torch.exp(alpha) + beta = torch.exp(beta) + hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(hidden_states * alpha), 2 + ) + + return hidden_states + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + """Generates a 1D Kaiser-windowed sinc filter. + + Args: + cutoff (float): Normalized cutoff frequency (0 to 0.5). + half_width (float): Transition bandwidth. + kernel_size (int): Number of filter taps. + + Returns: + torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter. + """ + is_even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # Compute Kaiser window parameters + delta_f = 4 * half_width + attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + + if attenuation > 50.0: + beta = 0.1102 * (attenuation - 8.7) + elif attenuation >= 21.0: + beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0) + else: + beta = 0.0 + + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + + # Compute time indices + if is_even: + time_indices = torch.arange(-half_size, half_size) + 0.5 + else: + time_indices = torch.arange(kernel_size) - half_size + + # Compute sinc filter + if cutoff == 0: + return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + + sinc_filter = torch.sinc(2 * cutoff * time_indices) + normalized_filter = 2 * cutoff * kaiser_window * sinc_filter + + # Normalize to ensure sum = 1 (avoid leakage of constant component) + normalized_filter /= normalized_filter.sum() + + return normalized_filter.view(1, 1, kernel_size) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + hidden_states_dtype = hidden_states.dtype + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate").to(self.filter.dtype) + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels + ).to(hidden_states_dtype) + hidden_states = hidden_states[..., self.pad_left : -self.pad_right] + + return hidden_states + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + + if cutoff < 0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = ratio + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + hidden_states_dtype = hidden_states.dtype + hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate").to(self.filter.dtype) + out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels).to(hidden_states_dtype) + return out + + +class TorchActivation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + if not callable(activation): + raise TypeError("Activation function must be callable") + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, hidden_states): + hidden_states = self.upsample(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class AMPBlock(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3, 5), + ): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=self._get_padding(kernel_size, dilation[0]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=self._get_padding(kernel_size, dilation[1]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=self._get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + ] + ) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + self.activations = nn.ModuleList( + [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)] + ) + + def _get_padding(self, kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + def forward(self, hidden_states): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2): + residual = hidden_states + hidden_states = act1(hidden_states) + hidden_states = conv1(hidden_states) + hidden_states = act2(hidden_states) + hidden_states = conv2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavBigVGAN model. Which take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniBigVGANConfig + + def __init__(self, config: Qwen2_5OmniBigVGANConfig): + super().__init__(config) + self.num_residual_blocks = len(config.resblock_kernel_sizes) + self.num_upsample_layers = len(config.upsample_rates) + + self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3) + + # Removing extra ModuleList breaks official state dict + ups = [ + nn.ModuleList( + [ + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**layer_idx), + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ] + ) + for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)) + ] + self.ups = nn.ModuleList(ups) + + self.resblocks = nn.ModuleList( + [ + AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation) + for layer_idx in range(self.num_upsample_layers) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes) + ] + ) + + self.activation_post = TorchActivation1d( + activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers)) + ) + self.conv_post = nn.Conv1d( + config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False + ) + + def normalize_spectrogram(self, spectrogram, max_value, min_db): + return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value) + + def amplitude_to_db(self, amplitude, min_db_level): + min_level = torch.exp( + torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype) + ) + return 20 * torch.log10(torch.clamp(amplitude, min=min_level)) + + def process_mel_spectrogram(self, mel_spectrogram): + amplitude_spectrum = torch.exp(mel_spectrogram) + decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20 + return self.normalize_spectrogram(decibel_spectrum, 1, -115) + + def forward(self, mel_spectrogram): + processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram) + hidden_representation = self.conv_pre(processed_spectrogram) + + for layer_index in range(self.num_upsample_layers): + hidden_representation = self.ups[layer_index][0](hidden_representation) + residual_output = sum( + self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + for block_index in range(self.num_residual_blocks) + ) + residual_output = residual_output / self.num_residual_blocks + hidden_representation = residual_output + + hidden_representation = self.activation_post(hidden_representation) + output_waveform = self.conv_post(hidden_representation) + return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze().cpu() + + +class RungeKutta4ODESolver: + def __init__(self, function, initial_value): + self.function = function + self.initial_value = initial_value + + self._one_third = 1 / 3 + self._two_thirds = 2 / 3 + + def _rk4_step(self, function, time_start, time_step, time_end, value_start, function_value_start=None): + k1 = function_value_start if function_value_start is not None else function(time_start, value_start) + k2 = function(time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third) + k3 = function(time_start + time_step * self._two_thirds, value_start + time_step * (k2 - k1 * self._one_third)) + k4 = function(time_end, value_start + time_step * (k1 - k2 + k3)) + return (k1 + 3 * (k2 + k3) + k4) * time_step / 8 + + def _compute_step(self, function, time_start, time_step, time_end, value_start): + function_value_start = function(time_start, value_start) + return self._rk4_step( + function, time_start, time_step, time_end, value_start, function_value_start=function_value_start + ), function_value_start + + def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): + if time_point == time_start: + return value_start + if time_point == time_end: + return value_end + weight = (time_point - time_start) / (time_end - time_start) + return value_start + weight * (value_end - value_start) + + def integrate(self, time_points): + solution = torch.empty( + len(time_points), + *self.initial_value.shape, + dtype=self.initial_value.dtype, + device=self.initial_value.device, + ) + solution[0] = self.initial_value + + current_index = 1 + current_value = self.initial_value + for time_start, time_end in zip(time_points[:-1], time_points[1:]): + time_step = time_end - time_start + delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) + next_value = current_value + delta_value + + while current_index < len(time_points) and time_end >= time_points[current_index]: + solution[current_index] = self._linear_interpolation( + time_start, time_end, current_value, next_value, time_points[current_index] + ) + current_index += 1 + + current_value = next_value + + return solution + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavDiT model. Which take speech tokens as input and predict mel spectrogram. + """ +) +class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniDiTConfig + _no_split_modules = ["DiTDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__(config) + self.mel_dim = config.mel_dim + self.repeats = config.repeats + self.time_embed = DiTTimestepEmbedding(config.hidden_size) + + self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats) + self.input_embed = DiTInputEmbedding(config) + + self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim) + + self.hidden_size = config.hidden_size + self.layers = config.num_hidden_layers + self.block_size = config.block_size + self.num_attention_heads = config.num_attention_heads + + self.transformer_blocks = nn.ModuleList() + for i in range(config.num_hidden_layers): + self.transformer_blocks.append( + DiTDecoderLayer( + config, + look_ahead_block=1 if i in config.look_ahead_layers else 0, + look_backward_block=1 if i in config.look_backward_layers else 0, + ) + ) + + self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = nn.Linear(config.hidden_size, config.mel_dim) + + def _create_block_diff(self, hidden_states): + batch, seq_len = hidden_states.shape[0], hidden_states.shape[1] + block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length] + + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + block_diff = block_j - block_i # (n, n) + + return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len) + + def forward( + self, + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + batch_size = hidden_states.shape[0] + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Compute embeddings + time_embedding = self.time_embed(time_step) + text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code) + text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + + hidden_states = self.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_unconditioned, + apply_cfg=apply_cfg, + ) + + # Compute positional encodings + position_embeddings = self.rotary_embed(hidden_states) + blockwise_difference = self._create_block_diff(hidden_states) + + # Transformer blocks + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block( + hidden_states, + time_embedding, + position_embeddings=position_embeddings, + block_diff=blockwise_difference, + ) + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + + return output + + @torch.no_grad() + def sample( + self, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ): + noise_initialization = torch.randn([1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype) + maximum_duration = quantized_code.shape[1] * self.repeats + initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device) + batch_size = reference_mel_spectrogram.shape[0] + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) + + if batch_size != 1: + raise ValueError("Only batch size = 1 is currently supported") + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + solution_trajectory = ode_solver.integrate(time_embedding) + + generated_waveform = solution_trajectory[-1] + generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) + return generated_mel_spectrogram + + @torch.no_grad() + def fast_block_sample( + self, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, + y0: torch.Tensor, + num_steps=10, + guidance_scale=0.5, + sway_coefficient: Optional[float] = -1.0, + ): + """ + Block-wise ODE sampling starting from provided initial state y0. + + Args: + conditioning_vector: (B, enc_emb_dim) + reference_mel_spectrogram: (B, T_ref, mel_dim) + quantized_code: (B, T_code) + y0: (B, T_target, mel_dim) initial state for ODE + Returns: + mel: (B, mel_dim, T_target) + """ + initial_state = y0.to(quantized_code.device) + batch_size = reference_mel_spectrogram.shape[0] + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, initial_state.shape[1], 1) + + if batch_size != 1: + raise ValueError("Only batch size = 1 is currently supported") + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + solution_trajectory = ode_solver.integrate(time_embedding) + + generated_waveform = solution_trajectory[-1] + generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) + return generated_mel_spectrogram + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel): + config: Qwen2_5OmniToken2WavConfig + base_model_prefix = "model" + _no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"] + + def __init__(self, config: Qwen2_5OmniToken2WavConfig): + super().__init__(config) + attn_impl = config._attn_implementation + if config._attn_implementation == "flash_attention_2": + logger.warning_once( + "Qwen2_5OmniToken2WavModel must inference with fp32, but flash_attention_2 only supports fp16 and bf16, " + "attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa." + ) + attn_impl = "sdpa" + elif config._attn_implementation == "eager": + logger.warning_once( + "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa" + ) + attn_impl = "sdpa" + self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config( + config.dit_config, attn_implementation=attn_impl + ) + self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config( + config.bigvgan_config, attn_implementation=attn_impl + ) + + # Streaming-related parameters aligned with Qwen2Code2wav + self.factor = self.code2wav_dit_model.repeats # 50Hz=2, 200Hz=4 + # default bs_mel depends on factor + self.bs_mel = 24 if self.factor == 2 else 32 + self.bs_codec = self.bs_mel // self.factor + self.past_cache_size = self.bs_mel * self.factor + self.future_cache_size = self.bs_mel * 1 + self.batched_chunk = 3 + self.chunk_size = self.bs_mel * self.batched_chunk + self.future_size = 20 if self.factor == 2 else 13 + + # codec embedding size for masking EOS out-of-range + try: + self.codec_embed_size = self.code2wav_dit_model.text_embed.codec_embed.weight.size(0) + except Exception: + self.codec_embed_size = -1 + + # vocoder hop length inferred from upsample rates + try: + ups = self.code2wav_bigvgan_model.config.upsample_rates + hop = 1 + for r in ups: + hop *= int(r) + self.vocoder_hop = int(hop) + except Exception: + # fallback to commonly used value + self.vocoder_hop = 240 + + def forward( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generates a waveform from input code and conditioning parameters.""" + + mel_spectrogram = self.code2wav_dit_model.sample( + conditioning, + reference_mel, + code, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + ).to(self.code2wav_bigvgan_model.dtype) + + waveform = self.code2wav_bigvgan_model(mel_spectrogram).to(self.dtype) + + return waveform + + # ============== Chunked processing helpers (compat with qwen2_code2wav_dit) ============== + @torch.inference_mode() + def process_chunk_dit_batch( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + code: torch.Tensor, + y0: torch.Tensor, + steps: int, + ) -> torch.Tensor: + """Block-wise DiT: generate mel from initial state y0 for the given code slice.""" + # prevent codec out-of-range (eos) + if self.codec_embed_size > 0: + code = code.clone() + code[code >= self.codec_embed_size] = 0 + mel = self.code2wav_dit_model.fast_block_sample( + conditioning_vector=conditioning, + reference_mel_spectrogram=reference_mel, + quantized_code=code, + y0=y0, + num_steps=steps, + ) + return mel.to(self.code2wav_bigvgan_model.dtype) + + @torch.inference_mode() + def process_chunk_bigvgan_batch(self, mel_batch: torch.Tensor) -> torch.Tensor: + """Vocoder batch: mel -> waveform.""" + return self.code2wav_bigvgan_model(mel_batch) + + @torch.inference_mode() + def process_little_chunk( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + codec_all: torch.Tensor, + y_all: torch.Tensor, + i: int, + steps: int, + prev_generated: torch.Tensor, + finished: bool = False, + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + """Streaming per small chunk: returns (mel_or_None, audio_slice).""" + start_index = max(i * self.chunk_size - self.past_cache_size, 0) + end_index = min((i + 1) * self.chunk_size + self.future_cache_size, + codec_all.shape[1] * self.factor) + + y0 = y_all[:, start_index:end_index].reshape(1, -1, self.code2wav_dit_model.mel_dim).contiguous() + codec = codec_all[:, start_index // self.factor:end_index // self.factor].reshape(1, -1).contiguous() + + # generate mel for current window (B, mel_dim, T) + generated = self.process_chunk_dit_batch( + conditioning=conditioning, + reference_mel=reference_mel, + code=codec, + y0=y0, + steps=steps, + ) + + # splice and vocode with 50Hz-style rules + return self._process_chunk_for_50hz( + i=i, + start_index=start_index, + end_index=end_index, + finished=finished, + prev_generated=prev_generated, + generated=generated, + ) + + @torch.inference_mode() + def process_chunk( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + codec_all: torch.Tensor, + y_all: torch.Tensor, + i: int, + steps: int, + prev_generated: Union[torch.Tensor, List[torch.Tensor]], + finished: bool = False, + ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], torch.Tensor]: + """High-level chunk API aligning to qwen2_code2wav_dit signature.""" + if not isinstance(prev_generated, torch.Tensor): + prev_generated = prev_generated[0] if len(prev_generated) > 0 else None + _mel, audio = self.process_little_chunk( + conditioning=conditioning, + reference_mel=reference_mel, + codec_all=codec_all, + y_all=y_all, + i=i, + steps=steps, + prev_generated=prev_generated, + finished=finished, + ) + return _mel if _mel is not None else prev_generated, audio + + @torch.inference_mode() + def _process_chunk_for_50hz( + self, + i: int, + start_index: int, + end_index: int, + finished: bool, + prev_generated: Optional[torch.Tensor], + generated: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Align mel and audio boundaries for 50Hz-like streaming. + + Shapes: + - generated: (B, mel_dim, T_window) + - prev_generated: (B, mel_dim, T_prev) + Returns: + - mel_chunk: (B, mel_dim, T_chunk) + - audio_slice: (T_audio_chunk,) + """ + # Normalize dtype + generated = generated.to(torch.float32) + if i == 0: + mel = generated[:, :, :self.chunk_size] + elif finished: + mel_trim = generated[:, :, self.past_cache_size:] + mel = torch.cat([prev_generated[:, :, -self.future_size * 2:], mel_trim], dim=2) + else: + if start_index == 0: + mel_trim = generated[:, :, i * self.chunk_size:-self.future_cache_size] + else: + mel_trim = generated[:, :, self.past_cache_size:-self.future_cache_size] + mel = torch.cat([prev_generated[:, :, -self.future_size * 2:], mel_trim], dim=2) + + audio = self.code2wav_bigvgan_model(mel) + if i == 0: + audio_output = audio[:-self.future_size * self.vocoder_hop] + elif finished: + audio_output = audio[self.future_size * self.vocoder_hop:] + else: + audio_output = audio[self.future_size * self.vocoder_hop:-self.future_size * self.vocoder_hop] + return mel, audio_output + + +# ================= vLLM-style wrapper for Token2Wav ================= + +class Qwen2_5OmniToken2WavForConditionalGenerationVLLM(nn.Module, SupportsPP): + logger = init_logger(__name__) + + # Map HF weights -> vLLM module names + hf_to_vllm_mapper = _Vllm_WeightsMapper( + orig_to_new_prefix={ + # HF root is 'model.' + "model.": "token2wav_model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + # Expect hf_config to be Token2Wav config + self.config = vllm_config.model_config.hf_config + + # Initialize underlying HF Token2Wav model via registry + self.token2wav = _Vllm_init_vllm_registered_model( + vllm_config=vllm_config, + prefix=_Vllm_maybe_prefix(prefix, "token2wav_model"), + hf_config=self.config, + architectures=["Qwen2_5OmniToken2WavDiTModel"], + ) + + # Provide placeholder to align with vLLM runner expectations + def _empty_intermediate_tensors(): + return None + self.make_empty_intermediate_tensors = _empty_intermediate_tensors + + def get_language_model(self) -> torch.nn.Module: + return self.token2wav + + @property + def sampler(self): + # Token2Wav does not use sampler; return vLLM default for API parity + return get_sampler() + + def forward( + self, + code: torch.Tensor, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + num_steps: int = 10, + guidance_scale: float = 0.5, + sway_coefficient: float = -1.0, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs, + ) -> torch.Tensor: + # Delegate to HF token2wav model + return self.token2wav( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + **kwargs, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + # Token2Wav outputs waveform; logits are not applicable + return hidden_states + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return None + + def load_weights_without_buffers(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + loader = _Vllm_AutoWeightsLoader(self) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + # Log load summary + try: + total_bytes = 0 + for _, param in self.named_parameters(): + if param is not None and param.data is not None: + total_bytes += param.data.numel() * param.data.element_size() + device = next(self.parameters()).device + self.logger.info( + "[Model Loaded] name=%s, success=%s, size=%.2f MB, device=%s", + self.__class__.__name__, True, total_bytes / (1024**2), str(device)) + except Exception: + pass + return loaded + + def find_all_registers(model, prefix=''): + """ + Find all registered buffers in a PyTorch model. + + Args: + model: PyTorch model (nn.Module) + prefix: prefix for nested modules (used in recursion) + + Returns: + dict: Dictionary with buffer names as keys and their properties as values + """ + registers = {} + + # Get all named buffers + for name, buf in model.named_buffers(): + if name in model.state_dict(): + registers[name] = { + 'name': name, + 'buffer': buf + } + return registers + + + #remove buffers from the weights and reload them after loading weights + def remove_buffers_from_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], buffers: dict): + weights_to_load = [] + for key, value in weights: + if key in buffers: + buffers[key]['buffer'] = value + continue + weights_to_load.append((key, value)) + return weights_to_load + + def reload_buffers_to_model(self, buffers: dict): + """ + reload stored buffers from weights to model + """ + loaded_buffers = set() + for name, buf_val in self.named_buffers(): + if name in buffers: + buf_val.copy_(buffers[name]['buffer']) + loaded_buffers.add(name) + return loaded_buffers + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], spk_dict_path: str) -> Set[str]: + buffers = self.find_all_registers(self) + weights_to_load = self.remove_buffers_from_weights(weights, buffers) + loaded = self.load_weights_without_buffers(weights_to_load) + loaded_buffers = self.reload_buffers_to_model(buffers) + #merge loaded and loaded_buffers + loaded.update(loaded_buffers) + self.spk_dict = torch.load(spk_dict_path) + return loaded + + # ============== Optional chunked helpers for API parity ============== + @torch.inference_mode() + def process_chunk_dit_batch( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + code: torch.Tensor, + y0: torch.Tensor, + steps: int, + ) -> torch.Tensor: + return self.token2wav( + code=code, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=steps, + ) + + @torch.inference_mode() + def process_chunk_bigvgan_batch(self, mel_batch: torch.Tensor) -> Optional[torch.Tensor]: + # BigVGAN is not part of this wrapper; return None for parity. + return None + + @torch.inference_mode() + def process_little_chunk( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + codec_all: torch.Tensor, + y_all: torch.Tensor, + i: int, + steps: int, + prev_generated: torch.Tensor, + finished: bool = False, + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + mel = self.token2wav( + code=codec_all, + conditioning=conditioning, + reference_mel=reference_mel, + num_steps=steps, + ) + return None, mel + + @torch.inference_mode() + def process_chunk( + self, + conditioning: torch.Tensor, + reference_mel: torch.Tensor, + codec_all: torch.Tensor, + y_all: torch.Tensor, + i: int, + steps: int, + prev_generated: Union[torch.Tensor, List[torch.Tensor]], + finished: bool = False, + ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], torch.Tensor]: + _mel, out = self.process_little_chunk( + conditioning=conditioning, + reference_mel=reference_mel, + codec_all=codec_all, + y_all=y_all, + i=i, + steps=steps, + prev_generated=prev_generated if isinstance(prev_generated, torch.Tensor) else None, + finished=finished, + ) + return _mel if _mel is not None else prev_generated, out \ No newline at end of file diff --git a/vllm_omni/model_executor/models/qwen2_old.py b/vllm_omni/model_executor/models/qwen2_old.py new file mode 100644 index 00000000000..9e33268dfc7 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen2_old.py @@ -0,0 +1,564 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.v1.pool.metadata import PoolingMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Qwen2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2Attention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + head_dim: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + head_dim=getattr(config, "head_dim", None), + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class Qwen2Model(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = {} is less than " + "`num_hidden_layers` = {}. Please open an issue " + "to discuss this feature.".format( + config.max_window_layers, + config.num_hidden_layers, + )) + + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + getattr(config, "embedding_size", config.hidden_size), + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # Use the provided decoder layer type or default to Qwen2DecoderLayer + decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + pooler_config = vllm_config.model_config.pooler_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + # TODO: Replace this model class with as_embedding_model( + # Qwen2ForCausalLM) after changing the default pooling method + if pooler_config.pooling_type is None: + logger.warning( + "This embedding model will default to last-token pooling in " + "an upcoming version. To avoid breaking changes, you should " + "pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`" + " explicitly.") + + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.MEAN, + normalize=True, + softmax=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + return self.model(input_ids, positions, intermediate_tensors) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + self.model.load_weights(weights) \ No newline at end of file diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py new file mode 100644 index 00000000000..f04b70a67a0 --- /dev/null +++ b/vllm_omni/model_executor/models/registry.py @@ -0,0 +1,37 @@ +from vllm.model_executor.models.registry import ModelRegistry, _VLLM_MODELS, _ModelRegistry, _LazyRegisteredModel + + +_OMNI_MODELS = { + "Qwen2_5OmniForConditionalGeneration": ( + "qwen2_5_omni", + "Qwen2_5OmniForConditionalGeneration", + ), + "Qwen2_5OmniThinkerModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 + "Qwen2_5OmniTalkerModel": ("qwen2_5_omni_talker", "Qwen2_5OmniTalkerForConditionalGeneration"), # noqa: E501 + "Qwen2_5OmniToken2WavModel": ("qwen2_5_omni_token2wav", "Qwen2_5OmniToken2WavForConditionalGenerationVLLM"), + "Qwen2_5OmniToken2WavDiTModel": ("qwen2_5_omni_token2wav", "Qwen2_5OmniToken2WavModel"), + "Qwen2ForCausalLM_old": ("qwen2_old", "Qwen2ForCausalLM") # need to discuss +} + +_VLLM_OMNI_MODELS = { + **_VLLM_MODELS, + **_OMNI_MODELS, +} + + +OmniModelRegistry = _ModelRegistry({ + **{ + model_arch: _LazyRegisteredModel( + module_name=f"vllm.model_executor.models.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() + }, + **{ + model_arch: _LazyRegisteredModel( + module_name=f"vllm_omni.model_executor.models.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_relname, cls_name) in _OMNI_MODELS.items() + } +}) \ No newline at end of file diff --git a/vllm_omni/model_executor/models/utils.py b/vllm_omni/model_executor/models/utils.py new file mode 100644 index 00000000000..7e915d9373d --- /dev/null +++ b/vllm_omni/model_executor/models/utils.py @@ -0,0 +1,8 @@ +from vllm.model_executor.models.utils import maybe_prefix + + +def add_prefix_to_loaded_weights(weights: set[str], prefix: str) -> set[str]: + """ + Add a prefix to the names of the loaded weights. + """ + return {maybe_prefix(prefix, name) for name in weights} \ No newline at end of file diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml new file mode 100644 index 00000000000..aa38bae5654 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml @@ -0,0 +1,41 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. +stage_args: + - stage_id: 0 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_cls: vllm_omni.worker.AR_gpu_worker.ARGPUWorker + scheduler_cls: vllm_omni.core.sched.scheduler.OmniScheduler + gpu_memory_utilization: 0.32 + enforce_eager: true # need to discuss + trust_remote_code: true + engine_output_type: latent # change the param name,such as pooling_output + final_output: true + final_output_type: text + - stage_id: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_cls: vllm_omni.worker.AR_gpu_worker.ARGPUWorker + scheduler_cls: vllm_omni.core.sched.scheduler.OmniScheduler + gpu_memory_utilization: 0.32 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + + - stage_id: 2 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_cls: vllm_omni.worker.diffusion_gpu_worker.DiffusionGPUWorker + scheduler_cls: vllm_omni.core.sched.diffusion_scheduler.DiffusionScheduler + gpu_memory_utilization: 0.3 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio \ No newline at end of file diff --git a/vllm_omni/model_executor/stage_input_processors/__init__.py b/vllm_omni/model_executor/stage_input_processors/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py new file mode 100644 index 00000000000..a854fc5627b --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -0,0 +1,32 @@ +from vllm.inputs import TextPrompt +from typing import Union +import torch +from vllm_omni.inputs.data import OmniTokensPrompt + +def thinker2talker(stage_list, engine_input_source, prompt: Union[OmniTokensPrompt, TextPrompt] = None): + source_stage_id = engine_input_source[0] + thinker_outputs = stage_list[source_stage_id].engine_outputs + talker_inputs = [] + multi_modal_data = {thinker_output.request_id: + prompt.get('multi_modal_data', None) for thinker_output, prompt in zip(thinker_outputs, prompt)} + + for i, thinker_output in enumerate(thinker_outputs): + output = thinker_output.outputs[0] + prompt_token_ids = thinker_output.prompt_token_ids + thinker_output_ids = output.token_ids + prompt_token_ids_len = len(prompt_token_ids) + thinker_hidden_states = output.multimodal_output["latent"].clone().detach().cuda() + talker_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * (len(prompt_token_ids) + 2), # add 2 for codec pad and start token + additional_information={ + "thinker_result": thinker_hidden_states[prompt_token_ids_len:].to(torch.float32), + "prompt_embeds": thinker_hidden_states[:prompt_token_ids_len].to(torch.float32), + "prompt_token_ids": prompt_token_ids, + "thinker_output_token_ids": thinker_output_ids, + }, + multi_modal_data=multi_modal_data[thinker_output.request_id] if multi_modal_data is not None else None, + mm_processor_kwargs=None, + ) + ) + return talker_inputs \ No newline at end of file diff --git a/vllm_omni/multimodal/__init__.py b/vllm_omni/multimodal/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py new file mode 100644 index 00000000000..5d289184775 --- /dev/null +++ b/vllm_omni/outputs.py @@ -0,0 +1,15 @@ +from vllm.v1.outputs import ModelRunnerOutput +from vllm.outputs import RequestOutput +from typing import Optional +from dataclasses import dataclass +import torch + + +class OmniModelRunnerOutput(ModelRunnerOutput): + multimodal_outputs: Optional[dict[str, torch.Tensor]] = None + +@dataclass +class OmniRequestOutput(RequestOutput): + stage_id: int + final_output_type: str + request_output: RequestOutput \ No newline at end of file diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py new file mode 100644 index 00000000000..783796ca76e --- /dev/null +++ b/vllm_omni/patch.py @@ -0,0 +1,21 @@ +import sys + +from vllm.inputs.data import TokensPrompt as _OriginalTokensPrompt +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding as _OriginalMRotaryEmbedding +from vllm.v1.request import Request as _OriginalRequest +from vllm.v1.engine import EngineCoreRequest as _OriginalEngineCoreRequest + +from vllm_omni.inputs.data import OmniTokensPrompt +from vllm_omni.model_executor.layers.mrope import MRotaryEmbedding +from vllm_omni.request import OmniRequest +from vllm_omni.engine import OmniEngineCoreRequest + +for module_name, module in sys.modules.items(): + if hasattr(module, 'TokensPrompt') and module.TokensPrompt == _OriginalTokensPrompt: + module.TokensPrompt = OmniTokensPrompt + if hasattr(module, 'MRotaryEmbedding') and module.MRotaryEmbedding == _OriginalMRotaryEmbedding: + module.MRotaryEmbedding = MRotaryEmbedding + if hasattr(module, 'Request') and module.Request == _OriginalRequest: + module.Request = OmniRequest + if hasattr(module, 'EngineCoreRequest') and module.EngineCoreRequest == _OriginalEngineCoreRequest: + module.EngineCoreRequest = OmniEngineCoreRequest \ No newline at end of file diff --git a/vllm_omni/request.py b/vllm_omni/request.py index 4bdc62b2533..b0da21e6398 100644 --- a/vllm_omni/request.py +++ b/vllm_omni/request.py @@ -1,299 +1,47 @@ -""" -OmniRequest: Extended request class for vLLM-omni multimodal processing. +from vllm.v1.request import Request +from vllm.v1.structured_output.request import StructuredOutputRequest +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.utils import is_list_of +from typing import Optional +from vllm_omni.engine import PromptEmbedsPayload, AdditionalInformationPayload, OmniEngineCoreRequest -This class extends vLLM's Request to support multimodal and non-autoregressive -processing with additional fields and methods specific to vLLM-omni. -""" -import enum -import time -from typing import Optional, Dict, Any, List, Union -from dataclasses import dataclass, field -from vllm.v1.request import Request as vLLMRequest - - -class RequestType(enum.Enum): - """Types of requests supported by vLLM-omni.""" - TEXT = "text" - IMAGE = "image" - AUDIO = "audio" - VIDEO = "video" - MULTIMODAL = "multimodal" - DIFFUSION = "diffusion" - - -class ProcessingStage(enum.Enum): - """Processing stages for multi-stage models.""" - PREPROCESSING = "preprocessing" - AR_GENERATION = "ar_generation" - DIFFUSION_GENERATION = "diffusion_generation" - POSTPROCESSING = "postprocessing" - COMPLETED = "completed" - - -@dataclass -class MultimodalData: - """Container for multimodal input data.""" - data_type: str # "image", "audio", "video", etc. - data: Any # The actual data (numpy array, bytes, etc.) - metadata: Dict[str, Any] = field(default_factory=dict) - processed: bool = False - - -@dataclass -class DiffusionParams: - """Parameters specific to diffusion model generation.""" - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - seed: Optional[int] = None - scheduler: str = "ddpm" - strength: float = 1.0 - eta: float = 0.0 - - -class OmniRequest(vLLMRequest): - """ - Extended request class for vLLM-omni multimodal and non-autoregressive processing. - - This class extends vLLM's Request with additional fields and methods to support: - - Multimodal input processing - - Non-autoregressive generation (diffusion models) - - Multi-stage processing pipelines - - Enhanced caching for different model types - """ - +class OmniRequest(Request): def __init__( - self, - request_id: str, - prompt: Optional[str] = None, - prompt_token_ids: Optional[List[int]] = None, - sampling_params: Optional[Any] = None, - arrival_time: Optional[float] = None, - lora_request: Optional[Any] = None, - multi_modal_data: Optional[Dict[str, Any]] = None, - multi_modal_placeholders: Optional[Dict[str, str]] = None, - priority: int = 0, - # vLLM-omni specific parameters - request_type: RequestType = RequestType.TEXT, - processing_stage: ProcessingStage = ProcessingStage.PREPROCESSING, - multimodal_inputs: Optional[List[MultimodalData]] = None, - diffusion_params: Optional[DiffusionParams] = None, - output_format: str = "text", - cache_key: Optional[str] = None, - stage_configs: Optional[Dict[str, Any]] = None, - **kwargs - ): - # Initialize parent class - super().__init__( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - arrival_time=arrival_time or time.time(), - lora_request=lora_request, - multi_modal_data=multi_modal_data, - multi_modal_placeholders=multi_modal_placeholders, - priority=priority, - **kwargs - ) - - # vLLM-omni specific attributes - self.request_type = request_type - self.processing_stage = processing_stage - self.multimodal_inputs = multimodal_inputs or [] - self.diffusion_params = diffusion_params or DiffusionParams() - self.output_format = output_format - self.cache_key = cache_key - self.stage_configs = stage_configs or {} - - # Processing state - self.current_stage = 0 - self.stage_results = {} - self.hidden_states = None - self.intermediate_outputs = [] - - # Timing and metrics - self.stage_timings = {} - self.total_processing_time = 0.0 - - # Error handling - self.errors = [] - self.retry_count = 0 - self.max_retries = 3 - - def add_multimodal_input(self, data: MultimodalData) -> None: - """Add a multimodal input to the request.""" - self.multimodal_inputs.append(data) - - def get_multimodal_inputs_by_type(self, data_type: str) -> List[MultimodalData]: - """Get all multimodal inputs of a specific type.""" - return [inp for inp in self.multimodal_inputs if inp.data_type == data_type] - - def update_processing_stage(self, stage: ProcessingStage) -> None: - """Update the current processing stage.""" - self.processing_stage = stage - self.stage_timings[stage.value] = time.time() - - def add_stage_result(self, stage: str, result: Any) -> None: - """Add a result from a processing stage.""" - self.stage_results[stage] = result - self.intermediate_outputs.append({ - 'stage': stage, - 'result': result, - 'timestamp': time.time() - }) - - def get_stage_result(self, stage: str) -> Any: - """Get a result from a specific processing stage.""" - return self.stage_results.get(stage) - - def set_hidden_states(self, hidden_states: Any) -> None: - """Set hidden states for the request.""" - self.hidden_states = hidden_states - - def get_hidden_states(self) -> Any: - """Get hidden states for the request.""" - return self.hidden_states - - def add_error(self, error: str) -> None: - """Add an error to the request.""" - self.errors.append({ - 'error': error, - 'timestamp': time.time(), - 'stage': self.processing_stage.value - }) - - def has_errors(self) -> bool: - """Check if the request has any errors.""" - return len(self.errors) > 0 - - def can_retry(self) -> bool: - """Check if the request can be retried.""" - return self.retry_count < self.max_retries - - def increment_retry(self) -> None: - """Increment the retry count.""" - self.retry_count += 1 - - def generate_cache_key(self) -> str: - """Generate a cache key for the request.""" - if self.cache_key: - return self.cache_key - - # Generate cache key based on request content - key_parts = [ - self.request_id, - str(self.request_type.value), - str(self.prompt_token_ids) if self.prompt_token_ids else str(self.prompt), - str(self.diffusion_params.num_inference_steps) if self.diffusion_params else "0", - str(len(self.multimodal_inputs)) - ] - return "_".join(key_parts) - - def get_processing_time(self) -> float: - """Get the total processing time for the request.""" - if self.stage_timings: - return time.time() - min(self.stage_timings.values()) - return 0.0 - - def is_completed(self) -> bool: - """Check if the request is completed.""" - return self.processing_stage == ProcessingStage.COMPLETED - - def to_dict(self) -> Dict[str, Any]: - """Convert the request to a dictionary for serialization.""" - return { - 'request_id': self.request_id, - 'request_type': self.request_type.value, - 'processing_stage': self.processing_stage.value, - 'prompt': self.prompt, - 'prompt_token_ids': self.prompt_token_ids, - 'output_format': self.output_format, - 'multimodal_inputs_count': len(self.multimodal_inputs), - 'diffusion_params': self.diffusion_params.__dict__ if self.diffusion_params else None, - 'stage_results': list(self.stage_results.keys()), - 'has_errors': self.has_errors(), - 'retry_count': self.retry_count, - 'processing_time': self.get_processing_time() - } - - def __repr__(self) -> str: - """String representation of the request.""" - return (f"OmniRequest(id={self.request_id}, " - f"type={self.request_type.value}, " - f"stage={self.processing_stage.value}, " - f"multimodal_inputs={len(self.multimodal_inputs)})") - - -# Factory functions for creating different types of requests -def create_text_request( - request_id: str, - prompt: str, - sampling_params: Optional[Any] = None, - **kwargs -) -> OmniRequest: - """Create a text-only request.""" - return OmniRequest( - request_id=request_id, - prompt=prompt, - request_type=RequestType.TEXT, - sampling_params=sampling_params, - **kwargs - ) - - -def create_image_request( - request_id: str, - prompt: str, - image_data: Any, - diffusion_params: Optional[DiffusionParams] = None, - **kwargs -) -> OmniRequest: - """Create an image generation request.""" - multimodal_input = MultimodalData( - data_type="image", - data=image_data, - metadata={"is_input": True} - ) - - return OmniRequest( - request_id=request_id, - prompt=prompt, - request_type=RequestType.IMAGE, - multimodal_inputs=[multimodal_input], - diffusion_params=diffusion_params, - output_format="image", - **kwargs - ) - - -def create_multimodal_request( - request_id: str, - prompt: str, - multimodal_inputs: List[MultimodalData], - **kwargs -) -> OmniRequest: - """Create a multimodal request.""" - return OmniRequest( - request_id=request_id, - prompt=prompt, - request_type=RequestType.MULTIMODAL, - multimodal_inputs=multimodal_inputs, - **kwargs - ) - - -def create_diffusion_request( - request_id: str, - prompt: str, - diffusion_params: DiffusionParams, - **kwargs -) -> OmniRequest: - """Create a diffusion model request.""" - return OmniRequest( - request_id=request_id, - prompt=prompt, - request_type=RequestType.DIFFUSION, - diffusion_params=diffusion_params, - **kwargs - ) \ No newline at end of file + self, + prompt_embeds: Optional[PromptEmbedsPayload] = None, + additional_information: Optional[AdditionalInformationPayload] = None, + *args, **kwargs): + super().__init__(*args, **kwargs) + # Serialized prompt embeddings payload (optional) + self.prompt_embeds: Optional[PromptEmbedsPayload] = prompt_embeds + # Serialized additional information payload (optional) + self.additional_information: Optional[AdditionalInformationPayload] = additional_information + + @classmethod + def from_engine_core_request(cls, request: OmniEngineCoreRequest) -> "Request": + if request.mm_inputs is not None: + assert isinstance(request.mm_inputs, list) + assert is_list_of(request.mm_inputs, MultiModalKwargs), ( + "mm_inputs was not updated in EngineCore.add_request") + + return cls( + request_id=request.request_id, + client_index=request.client_index, + prompt_token_ids=request.prompt_token_ids, + multi_modal_inputs=request.mm_inputs, + multi_modal_hashes=request.mm_hashes, + multi_modal_placeholders=request.mm_placeholders, + sampling_params=request.sampling_params, + pooling_params=request.pooling_params, + eos_token_id=request.eos_token_id, + arrival_time=request.arrival_time, + lora_request=request.lora_request, + structured_output_request=StructuredOutputRequest( + sampling_params=request.sampling_params) \ + if request.sampling_params else None, + cache_salt=request.cache_salt, + priority=request.priority, + prompt_embeds=request.prompt_embeds, + additional_information=request.additional_information, + ) \ No newline at end of file diff --git a/vllm_omni/sample/__init__.py b/vllm_omni/sample/__init__.py index e69de29bb2d..2b582946f3d 100644 --- a/vllm_omni/sample/__init__.py +++ b/vllm_omni/sample/__init__.py @@ -0,0 +1,8 @@ +""" +Sample components for vLLM-omni. +""" + +# Currently empty, placeholder for sample code and examples + +__all__ = [] + diff --git a/vllm_omni/utils/__init__.py b/vllm_omni/utils/__init__.py index e69de29bb2d..f3ef4214800 100644 --- a/vllm_omni/utils/__init__.py +++ b/vllm_omni/utils/__init__.py @@ -0,0 +1,8 @@ +""" +Utility functions for vLLM-omni. +""" + +# Currently empty, placeholder for future utility functions + +__all__ = [] + diff --git a/vllm_omni/worker/AR_gpu_model_runner.py b/vllm_omni/worker/AR_gpu_model_runner.py new file mode 100644 index 00000000000..c7b88daec89 --- /dev/null +++ b/vllm_omni/worker/AR_gpu_model_runner.py @@ -0,0 +1,578 @@ +"""AR GPU Model Runner for vLLM-omni. + +Exposes per-request hidden representations via ModelRunnerOutput.pooler_output +and also outputs sampled tokens. +""" + +from __future__ import annotations + +from typing import Optional, Union, Any, List +import numpy as np + +import torch + +from vllm import envs +from vllm.v1.worker.gpu_model_runner import ( + EMPTY_MODEL_RUNNER_OUTPUT, + IntermediateTensors, + get_pp_group, + get_tp_group, + has_kv_transfer_group, + set_forward_context, +) +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.spec_decode.eagle import EagleProposer + +from vllm_omni.engine import PromptEmbedsPayload, AdditionalInformationPayload +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + + +class ARModelRunner(OmniGPUModelRunner): + """Autoregressive GPU model runner that returns hidden states per request. + + This runner follows the same preparation and forward path as GPUModelRunner + (inputs assembly, multi-modal handling, TP/PP/DP integration, CUDA graphs), + and additionally performs lightweight sampling so that sampled tokens are + available in outputs. Hidden representations are taken at the same indices + that GPUModelRunner would use for sampling/logits (i.e. `logits_indices`). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[OmniModelRunnerOutput, IntermediateTensors]: + # Update internal state with the new schedule + self._update_states(scheduler_output) + + # Decode per-request prompt_embeds / additional_hidden_states payloads (if present) into CPU tensors + try: + new_reqs = getattr(scheduler_output, "scheduled_new_reqs", []) + if new_reqs: + import numpy as np + import torch + for nr in new_reqs: + req_id = getattr(nr, "req_id", None) or getattr(nr, "request_id", None) + if req_id is None: + continue + # prompt_embeds + payload_pe = getattr(nr, "prompt_embeds", None) + if payload_pe is not None: + if isinstance(payload_pe, torch.Tensor): + pe_cpu = payload_pe.detach().to("cpu").contiguous() + elif isinstance(payload_pe, PromptEmbedsPayload): + dt = np.dtype(getattr(payload_pe, "dtype", "float32")) + arr = np.frombuffer(payload_pe.data, dtype=dt) + arr = arr.reshape(payload_pe.shape) + pe_cpu = torch.from_numpy(arr) + else: + pe_cpu = None + if pe_cpu is not None and req_id in self.requests: + setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu) + # additional_information + payload_info = getattr(nr, "additional_information", None) + if payload_info is not None: + info_dict = {} + if isinstance(payload_info, dict): + # Already decoded + info_dict = payload_info + elif isinstance(payload_info, AdditionalInformationPayload): + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr) + else: + info_dict[k] = entry.list_data + if info_dict and req_id in self.requests: + setattr(self.requests[req_id], "additional_information_cpu", info_dict) + except Exception: + pass + + # If there's no work to do, either return empty output or kv-only path + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) + + # Prepare decoder inputs and attention metadata + (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata, num_scheduled_tokens_np, + spec_decode_common_attn_metadata) = self._prepare_inputs( + scheduler_output) + + # Determine number of input tokens for this iteration + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if (self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + else: + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if (self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1): + from vllm.utils import round_up # lazy local import + num_input_tokens = round_up(num_scheduled_tokens, tp_size) + else: + num_input_tokens = num_scheduled_tokens + + # DP padding + num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) + num_input_tokens += num_pad + + # Multimodal handling (encode and gather embeddings if needed) + if self.is_multimodal_model: + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + + # Always assemble inputs_embeds on first PP rank; overlay per-request prompt_embeds and collect additional_hidden_states for prefill only + if get_pp_group().is_first_rank: + inputs_embeds_scheduled = self.model.get_input_embeddings( + input_ids=self.input_ids[:num_scheduled_tokens], + multimodal_embeddings=(mm_embeds or None) + if self.is_multimodal_model else None, + ) + + # Copy into persistent buffer to enable CUDA Graph capture + self.inputs_embeds[:num_scheduled_tokens].copy_( + inputs_embeds_scheduled) + + # Reset per-step additional information collector + if hasattr(self, "_forward_additional_information"): + self._forward_additional_information = None + + # Overlay custom prompt_embeds per request for the prompt portion; collect additional_information (tensor/list) for prefill portion only + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + addi_cpu = getattr(req_state, "additional_information_cpu", None) + num_computed_tokens = int( + self.input_batch.num_computed_tokens_cpu[req_index]) + prompt_len = len(req_state.prompt_token_ids) + prompt_remaining = max(0, prompt_len - num_computed_tokens) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + overlay_len = min(sched_tokens, prompt_remaining) + if overlay_len <= 0: + continue + if pe_cpu is not None: + src = pe_cpu[num_computed_tokens: + num_computed_tokens + overlay_len].to( + dtype=self.dtype, + device=self.device, + non_blocking=True) + start_offset = int(self.query_start_loc_cpu[req_index]) + self.inputs_embeds[start_offset:start_offset + overlay_len] \ + .copy_(src) + # For additional_information: handle arbitrary keys + if addi_cpu is not None and isinstance(addi_cpu, dict): + # Lazy init collector dict + if not hasattr(self, "_forward_additional_information") or \ + self._forward_additional_information is None: + self._forward_additional_information = {} + # Process tensors (slice by scheduled prompt range) and lists (append per-request) + for k, v in addi_cpu.items(): + if isinstance(v, torch.Tensor): + # Slice along token dimension for prefill part + try: + seg = v[num_computed_tokens: + num_computed_tokens + overlay_len].to( + dtype=self.dtype, + device=self.device, + non_blocking=True) + except Exception: + # Fallback: move whole tensor if slicing fails + seg = v.to(dtype=self.dtype, + device=self.device, + non_blocking=True) + prev_val = self._forward_additional_information.get(k) + self._forward_additional_information[k] = ( + torch.cat([prev_val, seg], dim=0) + if isinstance(prev_val, torch.Tensor) else seg.clone()) + elif isinstance(v, list): + prev_val = self._forward_additional_information.get(k) + if prev_val is None: + self._forward_additional_information[k] = [v] + elif isinstance(prev_val, list): + self._forward_additional_information[k].append(v) + else: + # Mixed types: wrap existing into list + self._forward_additional_information[k] = [prev_val, v] + + input_ids = self.input_ids[:num_input_tokens] # preserved for APIs + inputs_embeds = self.inputs_embeds[:num_input_tokens] + model_mm_kwargs = (self._extract_mm_kwargs(scheduler_output) + if self.is_multimodal_model else {}) + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + model_mm_kwargs = {} + + # Positions/mRoPE + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + else: + positions = self.positions[:num_input_tokens] + + # Handle pipeline-parallel intermediate tensors + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True) + + # Some attention backends only support CUDA Graphs in pure decode. + skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + + # Forward pass + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs, + ), self.maybe_get_kv_connector_output( + scheduler_output) as kv_connector_output: + + model_kwargs_extra = {} + # Only pass additional_information for the prefill part + if hasattr(self, "_forward_additional_information") and \ + self._forward_additional_information is not None and \ + isinstance(self._forward_additional_information, dict): + model_kwargs_extra["additional_information"] = self._forward_additional_information + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_mm_kwargs, + device=self.device, + ), + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + **model_kwargs_extra, + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, _aux_hidden_states = model_output + else: + hidden_states = model_output + + text_hidden_states, multimodal_outputs = ( + self.extract_multimodal_outputs(hidden_states)) + + # Mid PP stages return intermediate tensors unmodified + if not get_pp_group().is_last_rank: + assert isinstance(text_hidden_states, IntermediateTensors) + text_hidden_states.kv_connector_output = kv_connector_output + return text_hidden_states + + # Broadcast PP output for external_launcher (torchrun) + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 + if not get_pp_group().is_last_rank: + assert isinstance(text_hidden_states, IntermediateTensors) + if not broadcast_pp_output: + text_hidden_states.kv_connector_output = kv_connector_output + return text_hidden_states + get_pp_group().send_tensor_dict(text_hidden_states.tensors, + all_gather_group=get_tp_group()) + logits = None + else: + if self.input_batch.pooling_params: + return self._pool(text_hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np, kv_connector_output) + + sample_hidden_states = text_hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed (with spec decode) + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + + num_nans_in_logits = {} + if envs.VLLM_COMPUTE_NANS_IN_LOGITS: + num_nans_in_logits = self._get_nans_in_logits(logits) + + # Handle partial prefill: discard sampled tokens and rewind RNG + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + discard_sampled_tokens_req_indices.append(i) + + # Move CPU sync parts + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Prompt logprobs if needed + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + text_hidden_states[:num_scheduled_tokens], + scheduler_output, + ) + + # Parse valid sampled tokens + sampled_token_ids = sampler_output.sampled_token_ids + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + + # Cache sampled tokens + for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + if not sampled_ids: + continue + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + # Speculative decoding draft tokens if configured + if not self.speculative_config: + spec_token_ids = None + else: + assert spec_decode_common_attn_metadata is not None + spec_token_ids = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + sampling_metadata, + text_hidden_states, + sample_hidden_states, + _aux_hidden_states if '_aux_hidden_states' in locals() else None, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + + # Convert to per-request tensors on CPU + pooler_output: list[Optional[torch.Tensor]] = [] + prev_logits_index = 0 + for logits_index in logits_indices: + pooler_output.append(text_hidden_states[prev_logits_index:logits_index+1]) + prev_logits_index = logits_index + 1 + + + self.eplb_step() + + return OmniModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None, + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits, + ) + + @torch.inference_mode() + def extract_multimodal_outputs(self, hidden_states: Union[torch.Tensor, List[torch.Tensor]]) -> dict: + if hasattr(self.model, "have_multimodal_outputs") and self.model.have_multimodal_outputs: + text_hidden_states = hidden_states.text_hidden_states + multimodal_outputs = hidden_states.multimodal_outputs + + elif isinstance(hidden_states, torch.Tensor): + text_hidden_states = hidden_states + multimodal_outputs = {} + elif isinstance(hidden_states, List): + text_hidden_states = hidden_states[0] + multimodal_outputs = {} + else: + raise ValueError(f"Invalid hidden states type: {type(hidden_states)}") + return text_hidden_states, multimodal_outputs + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + capture_attn_cudagraph: bool = False, + skip_eplb: bool = False, + is_profile: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens += num_pad + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + + attn_metadata: Optional[dict[str, Any]] = None + if capture_attn_cudagraph: + attn_metadata = {} + + # Make sure max_model_len is used at the graph capture time. + self.seq_lens_np[:num_reqs] = self.max_model_len + self.seq_lens_np[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id].get_device_tensor()[:num_reqs], + slot_mapping=self.input_batch. + block_table[kv_cache_group_id].slot_mapping[:num_tokens], + causal=True) + + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_metadata_i = attn_group.metadata_builder\ + .build_for_cudagraph_capture(common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + with self.maybe_dummy_run_with_lora(self.lora_config, + num_scheduled_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + model_mm_kwargs = {} + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, None, False) + + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): + outputs = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_mm_kwargs, + device=self.device, + ), + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs + + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + self.drafter.dummy_run(num_tokens) + + # This is necessary to avoid blocking DP. + # For dummy runs, we typically skip EPLB since we don't have any real + # requests to process. + # However, in DP settings, there may be cases when some DP ranks do + # not have any requests to process, so they're executing dummy batches. + # In such cases, we still have to trigger EPLB to make sure + # ranks execute the rearrangement in synchronization. + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + return hidden_states, hidden_states[logit_indices] + diff --git a/vllm_omni/worker/AR_gpu_worker.py b/vllm_omni/worker/AR_gpu_worker.py new file mode 100644 index 00000000000..07e08088423 --- /dev/null +++ b/vllm_omni/worker/AR_gpu_worker.py @@ -0,0 +1,67 @@ +from vllm.v1.worker.gpu_worker import Worker as GPUWorker +import torch +import os +import gc +from vllm.utils import GiB_bytes, MemorySnapshot +from vllm.platforms import current_platform +from vllm.worker.worker import _check_if_gpu_supports_dtype +from vllm.v1.worker.gpu_worker import init_worker_distributed_environment +from vllm.model_executor import set_random_seed +from vllm.v1.utils import report_usage_stats + +from vllm_omni.worker.AR_gpu_model_runner import ARModelRunner + + +class ARGPUWorker(GPUWorker): + def init_device(self): + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + current_platform.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + + # take current memory snapshot + self.init_snapshot = MemorySnapshot() + self.requested_memory = (self.init_snapshot.total_memory * + self.cache_config.gpu_memory_utilization) + if self.init_snapshot.free_memory < self.requested_memory: + GiB = lambda b: round(b / GiB_bytes, 2) + raise ValueError( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup " + f"is less than desired GPU memory utilization " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " + f"utilization or reduce GPU memory used by other processes." + ) + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend) + # Set random seed. + set_random_seed(self.model_config.seed) + + # Construct the model runner + self.model_runner: ARModelRunner = ARModelRunner( + self.vllm_config, self.device) + + if self.rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) \ No newline at end of file diff --git a/vllm_omni/worker/__init__.py b/vllm_omni/worker/__init__.py index e69de29bb2d..cc05bcd4bf7 100644 --- a/vllm_omni/worker/__init__.py +++ b/vllm_omni/worker/__init__.py @@ -0,0 +1,9 @@ +"""Worker modules for vLLM-omni.""" + +from .AR_gpu_worker import ARGPUWorker +from .diffusion_gpu_worker import DiffusionGPUWorker + +__all__ = [ + "ARGPUWorker", + "DiffusionGPUWorker", +] diff --git a/vllm_omni/worker/diffusion_gpu_worker.py b/vllm_omni/worker/diffusion_gpu_worker.py new file mode 100644 index 00000000000..a9767f01374 --- /dev/null +++ b/vllm_omni/worker/diffusion_gpu_worker.py @@ -0,0 +1,67 @@ +from vllm.v1.worker.gpu_worker import Worker as GPUWorker +import torch +import os +import gc +from vllm.utils import GiB_bytes, MemorySnapshot +from vllm.platforms import current_platform +from vllm.worker.worker import _check_if_gpu_supports_dtype +from vllm.v1.worker.gpu_worker import init_worker_distributed_environment +from vllm.model_executor import set_random_seed +from vllm.v1.utils import report_usage_stats + +from vllm_omni.worker.diffusion_model_runner import DiffusionModelRunner + + +class DiffusionGPUWorker(GPUWorker): + def init_device(self): + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + current_platform.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + + # take current memory snapshot + self.init_snapshot = MemorySnapshot() + self.requested_memory = (self.init_snapshot.total_memory * + self.cache_config.gpu_memory_utilization) + if self.init_snapshot.free_memory < self.requested_memory: + GiB = lambda b: round(b / GiB_bytes, 2) + raise ValueError( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup " + f"is less than desired GPU memory utilization " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " + f"utilization or reduce GPU memory used by other processes." + ) + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend) + # Set random seed. + set_random_seed(self.model_config.seed) + + # Construct the model runner + self.model_runner: DiffusionModelRunner = DiffusionModelRunner( + self.vllm_config, self.device) + + if self.rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) \ No newline at end of file diff --git a/vllm_omni/worker/diffusion_model_runner.py b/vllm_omni/worker/diffusion_model_runner.py new file mode 100644 index 00000000000..f582f0449b8 --- /dev/null +++ b/vllm_omni/worker/diffusion_model_runner.py @@ -0,0 +1,376 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Optional, Union, List +import numpy as np +import gc +import logging + +import torch + +from vllm.v1.worker.gpu_model_runner import ( + GPUModelRunner, + EMPTY_MODEL_RUNNER_OUTPUT, + IntermediateTensors, + get_pp_group, + has_kv_transfer_group, + set_forward_context, +) +from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs +from vllm.multimodal.inputs import MultiModalKwargs + +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.attention.backends.utils import CommonAttentionMetadata + +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + + +logger = logging.getLogger(__name__) + + +class DiffusionModelRunner(OmniGPUModelRunner): + """Diffusion model runner for vLLM-omni (non-autoregressive). + + - Reuses GPUModelRunner preparation, multimodal handling, and TP/PP/DP glue. + - Does not compute logits or perform token sampling. + - Executes diffusion process and returns tensors via `pooler_output`. + """ + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[OmniModelRunnerOutput, IntermediateTensors]: + self._update_states(scheduler_output) + + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) + + # Prepare decoder inputs and attention metadata (for batch/order mapping) + (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata, num_scheduled_tokens_np, + spec_decode_common_attn_metadata) = self._prepare_inputs( + scheduler_output) + + # Input token count for this iteration (not used by diffusion, but + # retained to keep DP padding/ordering consistent) + num_input_tokens = scheduler_output.total_num_scheduled_tokens + num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) + num_input_tokens += num_pad + + # Multimodal conditioning (e.g., text/audio/video encoders) + if self.is_multimodal_model: + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + + # Build inputs to mirror AR runner: input_ids/positions/embeds + if self.is_multimodal_model and get_pp_group().is_first_rank: + inputs_embeds_scheduled = self.model.get_input_embeddings( + input_ids=self.input_ids[:scheduler_output.total_num_scheduled_tokens], + multimodal_embeddings=mm_embeds or None, + ) + self.inputs_embeds[:scheduler_output.total_num_scheduled_tokens].copy_( + inputs_embeds_scheduled) + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = self.inputs_embeds[:num_input_tokens] + model_mm_kwargs = self._extract_mm_kwargs(scheduler_output) + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + model_mm_kwargs = {} + + # Positions (mrope or standard) + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + else: + positions = self.positions[:num_input_tokens] + + # Intermediate tensors sync for PP (if any) + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True) + + # Set forward context mainly for resource management and kv connector + skip_cuda_graphs = True # diffusion path does not rely on cuda graphs here + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs, + ), self.maybe_get_kv_connector_output( + scheduler_output) as kv_connector_output: + + if not get_pp_group().is_last_rank: + # For non-last PP stages, pass through intermediate tensors. + assert intermediate_tensors is not None + intermediate_tensors.kv_connector_output = kv_connector_output + return intermediate_tensors + + outputs = self._run_diffusion( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + multimodal_kwargs=model_mm_kwargs, + logits_indices=logits_indices, + ) + _, multimodal_outputs = ( + self.extract_multimodal_outputs(outputs)) + + # Ensure one tensor per request, map to CPU for output struct + pooler_output: List[Optional[torch.Tensor]] = [] + if isinstance(multimodal_outputs, torch.Tensor): + # If model returned a single stacked tensor, split by requests + assert outputs.shape[0] == self.input_batch.num_reqs + for i in range(self.input_batch.num_reqs): + pooler_output.append(outputs[i].detach().cpu()) + elif isinstance(multimodal_outputs, list): + for out in outputs: + pooler_output.append(out.detach().cpu() if out is not None else None) + elif isinstance(multimodal_outputs, dict): + for out in multimodal_outputs.values(): + pooler_output.append(out.detach().cpu() if out is not None else None) + else: + raise RuntimeError("Unsupported diffusion output type") + + self.eplb_step() + + return OmniModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + num_nans_in_logits={}, + ) + + def _run_diffusion( + self, + *, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], + multimodal_kwargs: dict, + logits_indices: torch.Tensor, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + """Runs the diffusion process and returns per-request tensors. + + Tries model interfaces in the following order for maximal compatibility: + 1) model.sample(condition=..., **kwargs) + 2) model.forward(condition=..., **kwargs) + 3) model.diffuse(condition=..., **kwargs) + """ + # Keep inputs identical to AR runner + kwargs = dict( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs(multimodal_kwargs, device=self.device), + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + ) + + # For Qwen 2.5 Omni's current implementation, we only support the forward method + if hasattr(self.model, "forward"): + return self.model.forward(**kwargs) + + # if hasattr(self.model, "sample"): + # return self.model.sample(**kwargs) + # if hasattr(self.model, "forward"): + # return self.model.forward(**kwargs) + # if hasattr(self.model, "diffuse"): + # return self.model.diffuse(**kwargs) + + raise RuntimeError( + "The loaded model does not expose diffusion interfaces 'sample', " + "'forward', or 'diffuse'. Please implement one of them or adapt the runner.") + + + + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + capture_attn_cudagraph: bool = False, + skip_eplb: bool = False, + is_profile: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens += num_pad + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + + attn_metadata: Optional[dict[str, dict]] = None + if capture_attn_cudagraph: + attn_metadata = {} + + # Make sure max_model_len is used at the graph capture time. + self.seq_lens_np[:num_reqs] = self.max_model_len + self.seq_lens_np[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id].get_device_tensor()[:num_reqs], + slot_mapping=self.input_batch.block_table[ + kv_cache_group_id].slot_mapping[:num_tokens], + causal=True) + + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_metadata_i = attn_group.metadata_builder \ + .build_for_cudagraph_capture(common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + with self.maybe_dummy_run_with_lora(self.lora_config, + num_scheduled_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + model_mm_kwargs = {} + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, None, False) + + # Diffusion path: avoid CUDA graphs; we only use context for resource wiring + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): + outputs = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_mm_kwargs, + device=self.device, + ), + sampler=None, + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs + + # Extract multimodal outputs if present; we ignore them here because + # dummy run returns tensors only. The actual diffusion runner returns + # multimodal outputs via pooler_output in execute_model. + text_hidden_states, _ = self.extract_multimodal_outputs(hidden_states) + + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + return text_hidden_states, None + + @torch.inference_mode() + def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: + logger.warning("Dummy sampler run is not implemented for diffusion model") + return None + + def profile_run(self) -> None: + # Profile with multimodal encoder & encoder cache, similar to base but + # without any logits/sampler warming. + if self.is_multimodal_model: + mm_budget = self.mm_budget + assert mm_budget is not None + + # TODO: handle encoder-decoder models once supported. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + ( + dummy_modality, + max_tokens, + ) = mm_budget.get_modality_with_max_tokens() + ( + max_mm_items_per_prompt, + max_mm_items_per_batch, + ) = mm_budget.get_max_items(dummy_modality, max_tokens) + + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) + + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) + + hidden_states, _ = self._dummy_run(self.max_num_tokens, is_profile=True) + if get_pp_group().is_last_rank: + pass # No sampler/pooler warmup for diffusion + self._sync_device() + del hidden_states + self.encoder_cache.clear() + gc.collect() diff --git a/vllm_omni/worker/gpu_diffusion_model_runner.py b/vllm_omni/worker/gpu_diffusion_model_runner.py deleted file mode 100644 index e464922945b..00000000000 --- a/vllm_omni/worker/gpu_diffusion_model_runner.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Model runner that wraps a diffusers text-to-image pipeline.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Optional - -@dataclass -class DiffusionRunnerOutput: - """Structured output returned by the diffusion runner.""" - - prompt: str - images: list - request_id: str = "diffusion_request" - finished: bool = True - output_type: str = "image" - - -class DiffusionModelRunner: - """Lightweight runner that loads and executes diffusers pipelines.""" - - def __init__( - self, - model_path: str, - *, - pipeline_name: Optional[str] = None, - device: Optional[str] = None, - dtype: Optional[str] = None, - ) -> None: - self.model_path = model_path - self.pipeline_name = pipeline_name or "auto" - self.device, self.torch_dtype = self._resolve_device_and_dtype(device, dtype) - self._pipeline = self._load_pipeline() - - def _resolve_device_and_dtype( - self, - device: Optional[str], - dtype: Optional[str], - ): - import torch - - resolved_device = device - if resolved_device is None: - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - resolved_device = "mps" - elif torch.cuda.is_available(): - resolved_device = "cuda" - else: - resolved_device = "cpu" - - if dtype is not None: - resolved_dtype = getattr(torch, dtype, None) - else: - if resolved_device in {"cuda", "mps"}: - resolved_dtype = torch.float16 - else: - resolved_dtype = torch.float32 - - return resolved_device, resolved_dtype - - def _load_pipeline(self): - import diffusers - - if self.pipeline_name != "auto": - PipelineCls = getattr(diffusers, self.pipeline_name) - else: - PipelineCls = diffusers.AutoPipelineForText2Image - - pipeline = PipelineCls.from_pretrained( - self.model_path, - torch_dtype=self.torch_dtype, - ) - - try: - pipeline = pipeline.to(self.device) - except Exception: - pipeline = pipeline.to("cpu") - - self._apply_optimisations(pipeline) - return pipeline - - @staticmethod - def _apply_optimisations(pipeline) -> None: - # Enable attention slicing / VAE tiling where available to reduce memory. - try: - if hasattr(pipeline, "enable_attention_slicing"): - pipeline.enable_attention_slicing() - except Exception: - pass - try: - vae = getattr(pipeline, "vae", None) - if vae and hasattr(vae, "enable_tiling"): - vae.enable_tiling() - except Exception: - pass - - def generate( - self, - prompt: str, - *, - height: int = 512, - width: int = 512, - num_inference_steps: int = 30, - guidance_scale: float = 5.0, - negative_prompt: Optional[str] = None, - seed: Optional[int] = None, - image: Optional[Any] = None, - ) -> DiffusionRunnerOutput: - import torch - - generator = None - if seed is not None: - try: - generator = torch.Generator(device=self.device).manual_seed(int(seed)) - except Exception: - # Some backends (e.g., CPU without full support) may not accept the device. - generator = torch.Generator().manual_seed(int(seed)) - - output = self._pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - num_inference_steps=int(num_inference_steps), - guidance_scale=float(guidance_scale), - height=int(height), - width=int(width), - generator=generator, - image=image, - ) - - return DiffusionRunnerOutput( - prompt=prompt, - images=getattr(output, "images", []) or [], - ) diff --git a/vllm_omni/worker/gpu_diffusion_worker.py b/vllm_omni/worker/gpu_diffusion_worker.py deleted file mode 100644 index 74f2474c4c9..00000000000 --- a/vllm_omni/worker/gpu_diffusion_worker.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Worker abstraction for executing diffusion model runners.""" - -from __future__ import annotations - -from typing import Any, Optional - -from vllm.v1.worker.worker_base import WorkerBase - -from .gpu_diffusion_model_runner import DiffusionModelRunner, DiffusionRunnerOutput - - -class DiffusionGPUWorker(WorkerBase): - """Minimal worker wrapping DiffusionModelRunner. - - The worker interface mirrors the shape expected by vLLM executors: it owns - a model runner instance and exposes a `generate` method that can be called - by an executor implementation. - """ - - def __init__( - self, - model_path: str, - *, - pipeline_name: Optional[str] = None, - device: Optional[str] = None, - dtype: Optional[str] = None, - ) -> None: - self.model_runner = DiffusionModelRunner( - model_path, - pipeline_name=pipeline_name, - device=device, - dtype=dtype, - ) - - def generate( - self, - prompt: str, - *, - height: int = 512, - width: int = 512, - num_inference_steps: int = 30, - guidance_scale: float = 5.0, - negative_prompt: Optional[str] = None, - seed: Optional[int] = None, - image: Optional[Any] = None, - ) -> DiffusionRunnerOutput: - return self.model_runner.generate( - prompt=prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - negative_prompt=negative_prompt, - seed=seed, - image=image, - ) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py new file mode 100644 index 00000000000..acbda86482f --- /dev/null +++ b/vllm_omni/worker/gpu_model_runner.py @@ -0,0 +1,758 @@ + +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import numpy as np +import torch + +import vllm.envs as envs +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces_base import VllmModelForPooling +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.sequence import IntermediateTensors +from vllm.utils import LazyLoader +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.sampling_params import SamplingType +from vllm.distributed.kv_transfer import has_kv_transfer_group +from vllm.utils import round_up +from vllm.v1.spec_decode.eagle import EagleProposer + +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.model_executor.layers.mrope import MRotaryEmbedding + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + xgr_torch_compile = LazyLoader( + "xgr_torch_compile", globals(), + "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") + +logger = init_logger(__name__) + + +class OmniGPUModelRunner(GPUModelRunner): + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + req_ids_to_add: list[str] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + if sampling_params and \ + sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + if pooling_params: + assert (task := pooling_params.task) is not None, ( + "You did not set `task` in the API") + + model = cast(VllmModelForPooling, self.model) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + + # If prompt embeddings are provided, decode and attach to inter_data + try: + if getattr(new_req_data, "prompt_embeds", None) is not None: + payload = new_req_data.prompt_embeds + import numpy as np + dtype = getattr(np, payload.dtype) + arr = np.frombuffer(payload.data, dtype=dtype) + arr = arr.reshape(payload.shape) + pe_cpu = torch.from_numpy(arr) + # Store temporarily on CPU; later moved to device in builder + setattr(self.requests[req_id], "prompt_embeds_cpu", pe_cpu) + # Also replace payload with Tensor for user visibility in scheduler_output + try: + new_req_data.prompt_embeds = pe_cpu # type: ignore[assignment] + except Exception: + pass + except Exception: + pass + # Decode additional_information payloads (dictionary) + try: + if getattr(new_req_data, "additional_information", None) is not None: + payload_info = new_req_data.additional_information + info_dict = {} + if isinstance(payload_info, dict): + info_dict = payload_info + else: + from vllm.v1.engine import AdditionalInformationPayload + if isinstance(payload_info, AdditionalInformationPayload): + import numpy as np + for k, entry in payload_info.entries.items(): + if entry.tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(entry.tensor_data, dtype=dt) + arr = arr.reshape(entry.tensor_shape) + info_dict[k] = torch.from_numpy(arr) + else: + info_dict[k] = entry.list_data + if info_dict: + setattr(self.requests[req_id], "additional_information_cpu", info_dict) + except Exception: + pass + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend( + mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend( + mm_input["video_grid_thw"].tolist()) + if mm_input.get("second_per_grid_ts") is not None: + second_per_grid_ts.extend( + mm_input["second_per_grid_ts"]) + if mm_input.get("audio_feature_lengths") is not None: + audio_feature_lengths.extend( + mm_input["audio_feature_lengths"]) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + hf_config = self.model_config.hf_config + + self.requests[req_id].mrope_positions, \ + self.requests[req_id].mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + req_ids_to_add.append(req_id) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = (num_computed_tokens + len(new_token_ids) - + req_state.num_tokens) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:]) + + # Update the block IDs. + if not resumed_from_preemption: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) + else: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + req_ids_to_add.append(req_id) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens_no_spec[ + req_index] = end_token_index + self.input_batch.num_tokens[req_index] = end_token_index + + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + if spec_token_ids: + num_spec_tokens = len(spec_token_ids) + start_index = self.input_batch.num_tokens_no_spec[req_index] + end_token_index = start_index + num_spec_tokens + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec tokens. + self.input_batch.num_tokens[req_index] += num_spec_tokens + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state) + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[OmniModelRunnerOutput, IntermediateTensors]: + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) + + # Prepare the decoder inputs. + (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata, num_scheduled_tokens_np, + spec_decode_common_attn_metadata) = ( + self._prepare_inputs(scheduler_output)) + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if (self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + else: + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + num_input_tokens = round_up(num_scheduled_tokens, tp_size) + else: + num_input_tokens = num_scheduled_tokens + + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) + num_input_tokens += num_pad + + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + + if get_pp_group().is_first_rank: + # Always prepare inputs_embeds on the first PP rank so both + # multimodal and text-only models can consume embeddings. + inputs_embeds_scheduled = self.model.get_input_embeddings( + input_ids=self.input_ids[:num_scheduled_tokens], + multimodal_embeddings=(mm_embeds or None) + if self.is_multimodal_model else None, + ) + + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_scheduled_tokens].copy_( + inputs_embeds_scheduled) + + # Overlay per-request custom prompt_embeds for the prompt portion + for req_index, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + if pe_cpu is None: + continue + num_computed_tokens = int( + self.input_batch.num_computed_tokens_cpu[req_index]) + prompt_len = len(req_state.prompt_token_ids) + prompt_remaining = max(0, prompt_len - num_computed_tokens) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + overlay_len = min(sched_tokens, prompt_remaining) + if overlay_len <= 0: + continue + src = pe_cpu[num_computed_tokens: + num_computed_tokens + overlay_len].to( + dtype=self.dtype, + device=self.device, + non_blocking=True) + start_offset = int(self.query_start_loc_cpu[req_index]) + self.inputs_embeds[start_offset:start_offset + overlay_len]\ + .copy_(src) + + input_ids = self.input_ids[:num_input_tokens] # preserved for APIs + inputs_embeds = self.inputs_embeds[:num_input_tokens] + model_mm_kwargs = (self._extract_mm_kwargs(scheduler_output) + if self.is_multimodal_model else {}) + else: + # For non-first PP ranks, use token ids as usual; embeddings are + # only consumed on the first rank. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + model_mm_kwargs = {} + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + else: + positions = self.positions[:num_input_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True) + + # Some attention backends only support CUDA Graphs in pure decode. + # If attention doesn't support CUDA Graphs for this batch, but we + # compiled with full CUDA graphs, we have to skip them entirely. + skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + + # Run the model. + # Use persistent buffers for CUDA graphs. + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs, + ), self.maybe_get_kv_connector_output( + scheduler_output) as kv_connector_output: + + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_mm_kwargs, + device=self.device, + ), + sampling_metadata=self.input_batch.sampling_metadata, + logits_index=logits_indices, + sampler=self.sampler, + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None + + text_hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + assert isinstance(text_hidden_states, IntermediateTensors) + if not broadcast_pp_output: + text_hidden_states.kv_connector_output = kv_connector_output + return text_hidden_states + get_pp_group().send_tensor_dict(text_hidden_states.tensors, + all_gather_group=get_tp_group()) + logits = None + else: + if self.input_batch.pooling_params: + return self._pool(text_hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np, kv_connector_output) + + sample_hidden_states = text_hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + + num_nans_in_logits = {} + if envs.VLLM_COMPUTE_NANS_IN_LOGITS: + num_nans_in_logits = self._get_nans_in_logits(logits) + + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + # Ignore the sampled token for partial prefills. + # Rewind the generator state as if the token was not sampled. + # This relies on cuda-specific torch-internal impl details + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + # Record the index of the request that should not be sampled, + # so that we could clear the sampled tokens before returning. + discard_sampled_tokens_req_indices.append(i) + + # NOTE: GPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + text_hidden_states[:num_scheduled_tokens], + scheduler_output, + ) + + # Get the valid generated tokens. + import os + sampled_token_ids = sampler_output.sampled_token_ids if os.environ.get("model_stage") != "code2wav" else torch.tensor([[8294]]).to(torch.int32).cuda() + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + if not self.speculative_config: + # Speculative decoding is not enabled. + spec_token_ids = None + else: + assert spec_decode_common_attn_metadata is not None + spec_token_ids = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + sampling_metadata, + text_hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + + self.eplb_step() + + return OmniModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits, + multimodal_outputs=multimodal_outputs, + ) + + @torch.inference_mode() + def extract_multimodal_outputs(self, hidden_states: torch.Tensor) -> dict: + if hasattr(self.model, "have_multimodal_outputs") and self.model.have_multimodal_outputs: + text_hidden_states = hidden_states.text_hidden_states + multimodal_outputs = hidden_states.multimodal_outputs + + else: + text_hidden_states = hidden_states + multimodal_outputs = {} + return text_hidden_states, multimodal_outputs + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + capture_attn_cudagraph: bool = False, + skip_eplb: bool = False, + is_profile: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens += num_pad + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + + attn_metadata: Optional[dict[str, Any]] = None + if capture_attn_cudagraph: + attn_metadata = {} + + # Make sure max_model_len is used at the graph capture time. + self.seq_lens_np[:num_reqs] = self.max_model_len + self.seq_lens_np[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id].get_device_tensor()[:num_reqs], + slot_mapping=self.input_batch. + block_table[kv_cache_group_id].slot_mapping[:num_tokens], + causal=True) + + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_metadata_i = attn_group.metadata_builder\ + .build_for_cudagraph_capture(common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + with self.maybe_dummy_run_with_lora(self.lora_config, + num_scheduled_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + model_mm_kwargs = {} + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, None, False) + + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): + outputs = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_mm_kwargs, + device=self.device, + ), + sampler = None + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs + + logger.warning(f"Multimodal outputs are not returned in the dummy run, need to double check the implementation!") + text_hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + + + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + self.drafter.dummy_run(num_tokens) + + # This is necessary to avoid blocking DP. + # For dummy runs, we typically skip EPLB since we don't have any real + # requests to process. + # However, in DP settings, there may be cases when some DP ranks do + # not have any requests to process, so they're executing dummy batches. + # In such cases, we still have to trigger EPLB to make sure + # ranks execute the rearrangement in synchronization. + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + return text_hidden_states, text_hidden_states[logit_indices] \ No newline at end of file diff --git a/vllm_omni/worker/model_runner.py b/vllm_omni/worker/model_runner.py new file mode 100644 index 00000000000..ca7b78d98ca --- /dev/null +++ b/vllm_omni/worker/model_runner.py @@ -0,0 +1,199 @@ +from dataclasses import dataclass +from typing import Optional, Dict, Any, List +import itertools + +import torch + +from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder + +from vllm.utils import flatten_2d_lists, async_tensor_h2d +from vllm.lora.layers import LoRAMapping +from vllm.sequence import MultiModalKwargs + + +@dataclass(frozen=True) +class OmniModelInputForGPU(ModelInputForGPU): + # New: optional additional information dict for current scheduled tokens + additional_information: Optional[Dict[str, Any]] = None + + +class OmniModelInputForGPUBuilder(ModelInputForGPUBuilder): + def build(self) -> OmniModelInputForGPU: + """Finalize the builder intermediate data and + create on-device tensors. + """ + # Combine and flatten intermediate data. + input_tokens = list[int]() + inputs_embeds_list = list[torch.Tensor]() + token_types = list[int]() + for inter_data in self.inter_data_list: + for cur_input_tokens in inter_data.input_tokens: + input_tokens.extend(cur_input_tokens) + for cur_token_types in inter_data.token_types: + token_types.extend(cur_token_types) + if inter_data.inputs_embeds is not None: + inputs_embeds_list.append( + inter_data.inputs_embeds.to( + dtype=self.runner.model_config.dtype, + device=self.runner.device)) + # Support v1 direct-transfer prompt embeds attached on request state + if inter_data.inputs_embeds is None: + try: + # Locate req_id -> request state to fetch decoded CPU embeds + req_id = inter_data.request_id # type: ignore[attr-defined] + req_state = getattr(self.runner, "requests", {}).get(req_id) + pe_cpu = getattr(req_state, "prompt_embeds_cpu", None) + if pe_cpu is not None: + inter_data.inputs_embeds = pe_cpu.to( + dtype=self.runner.model_config.dtype, + device=self.runner.device) + inputs_embeds_list.append(inter_data.inputs_embeds) + except Exception: + pass + inputs_embeds: Optional[torch.Tensor] + if len(inputs_embeds_list) == 0: + inputs_embeds = None + else: + inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to( + dtype=self.runner.model_config.dtype, + device=self.runner.device) + assert len(inputs_embeds) == len(input_tokens) + + if not input_tokens and inputs_embeds is None: + # This may happen when all prefill requests hit + # prefix caching and there is no decode request. + return self.model_input_cls() + + mrope_input_positions: Optional[List[List[int]]] = None + if any(inter_data.mrope_input_positions is not None + for inter_data in self.inter_data_list): + mrope_input_positions = [[] for _ in range(3)] + for idx in range(3): + for inter_data in self.inter_data_list: + msections = inter_data.mrope_input_positions + if msections is None: + for _seq_input_positions in inter_data.input_positions: + mrope_input_positions[idx].extend( + _seq_input_positions) + else: + for _seq_mrope_input_positions in msections: + mrope_input_positions[idx].extend( + _seq_mrope_input_positions[idx]) + input_positions = None + else: + input_positions = [] + for inter_data in self.inter_data_list: + for cur_input_positions in inter_data.input_positions: + input_positions.extend(cur_input_positions) + + seq_lens = [] + query_lens = [] + max_decode_seq_len = 0 + max_encoder_seq_len = 0 + for inter_data in self.inter_data_list: + seq_lens.extend(inter_data.seq_lens) + query_lens.extend(inter_data.query_lens) + if not inter_data.is_prompt: + max_decode_seq_len = max(max_decode_seq_len, + max(inter_data.seq_lens)) + if self.runner.model_config.is_encoder_decoder: + max_encoder_seq_len = max(max_encoder_seq_len, + inter_data.encoder_seq_len) + + # Mapping from request IDs to sequence IDs. Used for Jamba models + # that manages the cache by itself. + request_ids_to_seq_ids = { + data.request_id: data.seq_ids + for data in self.inter_data_list + } + + cuda_graph_pad_size = self._get_cuda_graph_pad_size( + num_seqs=len(seq_lens), + max_decode_seq_len=max_decode_seq_len, + max_encoder_seq_len=max_encoder_seq_len) + + batch_size = len(input_tokens) + if cuda_graph_pad_size != -1: + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + batch_size += cuda_graph_pad_size + + # Tokens and positions. + if cuda_graph_pad_size: + input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) + assert self.runner.device is not None + input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, + self.runner.device, + self.runner.pin_memory) + + token_types_tensor = async_tensor_h2d(token_types, torch.long, + self.runner.device, + self.runner.pin_memory) \ + if token_types else None + + if mrope_input_positions is not None: + for idx in range(3): + mrope_input_positions[idx].extend( + itertools.repeat(0, cuda_graph_pad_size)) + input_positions_tensor = async_tensor_h2d(mrope_input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory) + else: + input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) + input_positions_tensor = async_tensor_h2d(input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory) + # Sequence and query lengths. + if cuda_graph_pad_size: + seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) + + # Attention metadata. + attn_metadata = self.attn_metadata_builder.build( + seq_lens, query_lens, cuda_graph_pad_size, batch_size) + + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(r for data in self.inter_data_list + for r in data.lora_requests) + lora_index_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_index_mapping) + for inter_data in self.inter_data_list + ]) + if cuda_graph_pad_size: + lora_index_mapping.extend( + itertools.repeat(0, cuda_graph_pad_size)) + lora_prompt_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_prompt_mapping) + for inter_data in self.inter_data_list + ]) + + lora_mapping = LoRAMapping( + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=not self.decode_only)) + + # Multi-modal data. + multi_modal_kwargs_list = [ + data.multi_modal_kwargs for data in self.inter_data_list + if data.multi_modal_kwargs is not None + ] + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + + return self.model_input_cls( + input_tokens=input_tokens_tensor, + inputs_embeds=inputs_embeds, + input_positions=input_positions_tensor, + token_types=token_types_tensor, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + request_ids_to_seq_ids=request_ids_to_seq_ids, + finished_requests_ids=self.finished_requests_ids) \ No newline at end of file