diff --git a/README.md b/README.md index 75b1bcf23b4..13da76012f3 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,63 @@ # vLLM-omni: Multi-modal Extension for vLLM vLLM-omni is designed to extend vLLM capabilities to support multi-modality model inference and serving, particularly focusing on non-autoregressive structures and non-textual outputs. + +## 🎯 Overview + +Traditional vLLM systems are limited to text-based, autoregressive generation. vLLM-omni addresses this limitation by enabling support for: + +- **Multi-modal Models**: Text, image, video, audio, and sensor data processing +- **Non-autoregressive Architectures**: Diffusion Transformers (DiT) and other parallel generation models +- **Heterogeneous Outputs**: Beyond traditional text generation to structured, binary, and streaming outputs + +## πŸ—οΈ Architecture + +vLLM-omni is built on a modular architecture that extends vLLM's core functionality: + + +## πŸš€ Key Features + +### Multi-Engine Support + +- **Autoregressive Engine**: Traditional text generation with enhanced KV-caching +- **Diffusion Engine**: Support for DiT models and iterative generation +- **Hybrid Engine**: Combined AR+DiT processing pipelines + +### Modality Processing + +- **Text**: Advanced tokenization and embedding generation +- **Image**: Vision encoder integration (CLIP, etc.) +- **Audio**: Speech processing and audio embedding +- **Video**: Frame-by-frame and temporal processing +- **Sensor**: IoT and sensor data interpretation + +### Output Formats + +- **Structured Data**: JSON, XML, and custom formats +- **Binary Outputs**: Images, audio, and video generation +- **Streaming**: Real-time progressive generation +- **Multipart**: Combined multi-modal responses + +## πŸ“‹ Supported Models + +### AR + Diffusion Transformer (DiT) Models +- Qwen-Image (Image generation and editing) +- Qwen-omni (Thinker-Talker-Codec structure) +- Custom DiT and hiybrid architectures + +## πŸ› οΈ Installation + +### Prerequisites + +- Python 3.8+ +- PyTorch 2.0+ +- CUDA 11.8+ (for GPU acceleration) + +### Install from Source + +```bash +git clone https://github.com/your-org/vllm-omni.git +pip install -r requirements.txt +cd vllm-omni +pip install -e . +``` diff --git a/docs/PRD.md b/docs/PRD.md new file mode 100644 index 00000000000..7b785041950 --- /dev/null +++ b/docs/PRD.md @@ -0,0 +1,179 @@ +# vLLM-omni Product Requirements Document (PRD) + +## 1. Product Overview + +### 1.1 Product Name +vLLM-omni: Multi-modality models inference and serving with non-autoregressive structures + +### 1.2 Product Vision +Extend vLLM beyond traditional text-based, autoregressive generation to support multi-modality models with non-autoregressive structures and non-textual outputs while maintaining vLLM's proven architecture and performance. + +### 1.3 Target Users +- AI researchers working with multimodal models +- ML engineers building production inference systems +- Developers integrating DiT (Diffusion Transformer) models +- Organizations requiring efficient multimodal model serving + +## 2. Core Requirements + +### 2.1 Functional Requirements + +#### 2.1.1 Multi-Stage Processing +- **REQ-001**: Support stage-based model processing where each stage can use different engine types (AR/DiT) +- **REQ-002**: Enable sequential processing through multiple stages with data flow between stages +- **REQ-003**: Support both autoregressive (AR) and diffusion (DiT) model stages + +#### 2.1.2 vLLM Compatibility +- **REQ-004**: Maintain full compatibility with vLLM V1 architecture (AsyncLLM and EngineCore patterns) +- **REQ-005**: Support existing vLLM CLI commands with `--omni` flag extension +- **REQ-006**: Reuse vLLM's multiprocess worker architecture for scalability + +#### 2.1.3 Multimodal Support +- **REQ-007**: Support text, image, and latent space inputs/outputs +- **REQ-008**: Enable image-to-image, text-to-image, and text-to-text generation +- **REQ-009**: Support hidden state passing between AR and DiT stages + +#### 2.1.4 CLI and API +- **REQ-010**: Provide CLI command: `vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8000` +- **REQ-011**: Support both online (AsyncOmniLLM) and offline (OmniLLM) inference modes +- **REQ-012**: Maintain vLLM's existing API compatibility + +### 2.2 Non-Functional Requirements + +#### 2.2.1 Performance +- **REQ-013**: Maintain vLLM's inference performance for AR stages +- **REQ-014**: Optimize DiT stage processing with caching mechanisms +- **REQ-015**: Support distributed inference across multiple GPUs + +#### 2.2.2 Scalability +- **REQ-016**: Support horizontal scaling through vLLM's worker process pattern +- **REQ-017**: Enable efficient memory management for large multimodal models +- **REQ-018**: Support batch processing for multiple requests + +#### 2.2.3 Extensibility +- **REQ-019**: Easy integration of new modalities and model architectures +- **REQ-020**: Pluggable scheduler and executor components +- **REQ-021**: Support for future non-autoregressive model types + +## 3. Technical Architecture + +### 3.1 Core Components + +#### 3.1.1 Entry Points +- **OmniServeCommand**: CLI wrapper that intercepts vLLM commands with `--omni` flag +- **OmniLLM**: Offline inference class supporting multi-stage processing +- **AsyncOmniLLM**: Online inference class with asynchronous processing + +#### 3.1.2 Stage Management +- **OmniStageConfig**: Configuration for each processing stage +- **Stage Engine List**: Multiple AsyncLLM instances for each stage +- **Stage I/O Management**: Data flow between stages + +#### 3.1.3 Engine Components +- **EngineCore**: Reused from vLLM (no changes needed) +- **OmniDiffusionScheduler**: New scheduler for DiT models +- **DiTCacheManager**: Caching system for DiT optimization +- **MultiprocExecutor**: Reused from vLLM for DiT without diffusers +- **DiffusersPipelineExecutor**: New executor for diffusers integration + +#### 3.1.4 Model Runners +- **OmniDiffusionModelRunner**: Handles DiT model execution +- **OmniARModelRunner**: Handles AR model execution with hidden state output +- **ModelRunnerOutput**: Extended to support multimodal outputs via pooler_output + +#### 3.1.5 Output Processing +- **MultimodalOutputProcessor**: Handles final multimodal output processing +- **RequestState**: Extended to support pooling outputs +- **Output Handlers**: Type-specific output processing + +### 3.2 Data Flow +``` +API Server β†’ OmniLLM/AsyncOmniLLM β†’ LLMEngine/AsyncLLM β†’ Engine Core +β†’ Scheduler (AR/DiT) β†’ Executor (AR/DiT) β†’ Worker (AR/DiT) +β†’ ModelRunner (AR/DiT) β†’ RequestState β†’ OutputProcessor β†’ Final Output +``` + +## 4. Implementation Phases + +### Phase 1: Foundation (Weeks 1-2) +- Package structure and dependencies +- Basic OmniLLM and AsyncOmniLLM classes +- Stage configuration system +- CLI integration + +### Phase 2: Core Components (Weeks 3-4) +- DiT scheduler implementation +- Model runners for AR and DiT +- Basic output processing + +### Phase 3: Advanced Features (Weeks 5-6) +- Caching system implementation +- Multimodal output processing +- Request state management + +### Phase 4: Integration & Testing (Weeks 7-8) +- End-to-end integration +- Comprehensive testing +- Performance optimization +- Documentation + +## 5. Success Criteria + +### 5.1 Functional Success +- [ ] Successfully run `vllm serve model --omni` command +- [ ] Process multi-stage ARβ†’DiT pipelines +- [ ] Generate multimodal outputs (text + image) +- [ ] Maintain vLLM API compatibility + +### 5.2 Performance Success +- [ ] AR stage performance within 5% of native vLLM +- [ ] DiT stage processing with reasonable latency +- [ ] Memory usage comparable to vLLM for equivalent models + +### 5.3 Quality Success +- [ ] 90%+ test coverage +- [ ] All integration tests passing +- [ ] Documentation complete and accurate + +## 6. Risk Assessment + +### 6.1 Technical Risks +- **High**: vLLM API changes breaking compatibility +- **Medium**: DiT model integration complexity +- **Low**: Performance overhead from multi-stage processing + +### 6.2 Mitigation Strategies +- Regular vLLM compatibility testing +- Incremental DiT integration with fallback options +- Performance benchmarking at each stage + +## 7. Dependencies + +### 7.1 External Dependencies +- vLLM >= 0.10.2 +- PyTorch >= 2.7 +- Transformers >= 4.30.0 +- FastAPI, Uvicorn for API serving +- Ray for distributed computing + +### 7.2 Optional Dependencies +- xDiT for DiT acceleration +- Cache-DiT for advanced caching +- Diffusers for pipeline-based DiT models + +## 8. Future Roadmap + +### 8.1 Short-term (3 months) +- Additional DiT model support +- Performance optimizations +- Enhanced caching strategies + +### 8.2 Medium-term (6 months) +- Support for video generation models +- Advanced scheduling strategies +- Multi-GPU DiT optimization + +### 8.3 Long-term (12 months) +- Custom model architecture support +- Advanced multimodal fusion +- Production deployment tools diff --git a/docs/api/API_DESIGN_TEMPLATE.md b/docs/api/API_DESIGN_TEMPLATE.md new file mode 100644 index 00000000000..eefaf9334cf --- /dev/null +++ b/docs/api/API_DESIGN_TEMPLATE.md @@ -0,0 +1,336 @@ +# API Design Template for vLLM-omni Modules + +This template provides a standardized structure for designing APIs for core, engine, executor, and worker modules in vLLM-omni. Use this template to ensure consistency and completeness across all modules. + +## Template Structure + +### 1. Module Overview +- **Purpose**: What this module does +- **Responsibilities**: Key responsibilities of the module +- **Dependencies**: What other modules this depends on +- **Integration Points**: How it integrates with other modules + +### 2. Core Classes/Interfaces +- **Base Classes**: Abstract base classes and interfaces +- **Implementation Classes**: Concrete implementations +- **Data Structures**: Key data structures and models + +### 3. Public API Methods +- **Initialization**: Constructor and setup methods +- **Core Operations**: Main functionality methods +- **Configuration**: Configuration and parameter methods +- **Lifecycle Management**: Start, stop, cleanup methods +- **Monitoring**: Status, metrics, and debugging methods + +### 4. Configuration +- **Configuration Classes**: Dataclasses or config objects +- **Required Parameters**: Must-have configuration +- **Optional Parameters**: Optional configuration with defaults +- **Validation**: Parameter validation rules + +### 5. Error Handling +- **Custom Exceptions**: Module-specific exceptions +- **Error Codes**: Standardized error codes +- **Recovery Strategies**: How to handle and recover from errors + +### 6. Examples +- **Basic Usage**: Simple usage examples +- **Advanced Usage**: Complex scenarios +- **Integration Examples**: How to use with other modules + +--- + +## Example: Core Module Template + +### 1. Module Overview + +**Purpose**: The core module provides fundamental scheduling and caching functionality for vLLM-omni. + +**Responsibilities**: +- Request scheduling and prioritization +- DiT cache management +- Resource allocation and coordination +- Inter-module communication + +**Dependencies**: +- `vllm_omni.request` - Request handling +- `vllm_omni.config` - Configuration management +- `vllm_omni.utils` - Utility functions + +**Integration Points**: +- Receives requests from entrypoints +- Coordinates with engine modules +- Manages worker allocation +- Provides status to monitoring systems + +### 2. Core Classes/Interfaces + +```python +from abc import ABC, abstractmethod +from typing import List, Optional, Dict, Any +from dataclasses import dataclass +from enum import Enum + +class DiTCacheManager: + """Manages DiT cache for diffusion models.""" + + def __init__(self, config: DiTCacheConfig): + self.config = config + self.cache = {} + self.cache_stats = {} + + async def get_cache(self, cache_key: str) -> Optional[Any]: + """Retrieve cached data.""" + pass + + async def set_cache(self, cache_key: str, data: Any) -> None: + """Store data in cache.""" + pass + + async def invalidate_cache(self, cache_key: str) -> None: + """Invalidate cached data.""" + pass +``` + +### 3. Public API Methods + +#### Initialization +```python +class OmniScheduler(): + def __init__(self, config: SchedulerConfig): + """Initialize the core scheduler.""" + self.config = config + self.scheduler = self._create_scheduler() + self.cache_manager = DiTCacheManager(config.dit_cache_config) + self._running = False + + def _create_scheduler(self) -> BaseScheduler: + """Factory method to create appropriate scheduler.""" + if self.config.scheduler_type == SchedulerType.FIFO: + return FIFOScheduler(self.config) + elif self.config.scheduler_type == SchedulerType.PRIORITY: + return PriorityScheduler(self.config) + # ... other scheduler types +``` + +#### Core Operations +```python + async def schedule(self, request: OmniRequest) -> bool: + """ + Schedule a request for processing. + + Args: + request: The request to schedule + + Returns: + bool: True if successfully scheduled, False otherwise + + Raises: + SchedulerError: If scheduling fails + QueueFullError: If queue is at capacity + """ + pass +``` + +#### Configuration +```python + def update_config(self, new_config: SchedulerConfig) -> None: + """Update scheduler configuration.""" + pass + + def get_config(self) -> SchedulerConfig: + """Get current configuration.""" + pass +``` + +#### Lifecycle Management +```python + async def start(self) -> None: + """Start the scheduler.""" + self._running = True + # Start background tasks + + async def stop(self) -> None: + """Stop the scheduler gracefully.""" + self._running = False + # Cleanup resources + + async def shutdown(self) -> None: + """Force shutdown the scheduler.""" + # Immediate cleanup +``` + +#### Monitoring +```python + def get_status(self) -> Dict[str, Any]: + """Get current scheduler status.""" + return { + "running": self._running, + "queue_size": self.scheduler.queue_size(), + "processed_requests": self.scheduler.processed_count(), + "cache_hit_rate": self.cache_manager.get_hit_rate() + } + + def get_metrics(self) -> Dict[str, float]: + """Get performance metrics.""" + pass +``` + +### 4. Configuration + +```python +@dataclass +class OmniConfig: + """Configuration for the core module.""" + + # Scheduler configuration + scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) + + # Cache configuration + dit_cache: DiTCacheConfig = field(default_factory=DiTCacheConfig) + + # Resource limits + max_memory_gb: float = 16.0 + max_gpu_utilization: float = 0.8 + + # Timeouts + request_timeout: int = 300 + worker_timeout: int = 60 + + def validate(self) -> None: + """Validate configuration parameters.""" + if self.max_memory_gb <= 0: + raise ValueError("max_memory_gb must be positive") + if not 0 < self.max_gpu_utilization <= 1: + raise ValueError("max_gpu_utilization must be between 0 and 1") +``` + +### 5. Error Handling + +```python +class CoreModuleError(Exception): + """Base exception for core module errors.""" + pass + +class SchedulerError(CoreModuleError): + """Scheduler-related errors.""" + pass + +class QueueFullError(SchedulerError): + """Queue is at capacity.""" + pass + +class CacheError(CoreModuleError): + """Cache-related errors.""" + pass + +class ResourceError(CoreModuleError): + """Resource allocation errors.""" + pass +``` + +### 6. Examples + +#### Basic Usage +```python +from vllm_omni.core import CoreScheduler, CoreModuleConfig +from vllm_omni.request import create_text_request + +# Create configuration +config = CoreModuleConfig( + scheduler=SchedulerConfig(scheduler_type=SchedulerType.FIFO), + max_memory_gb=8.0 +) + +# Initialize scheduler +scheduler = CoreScheduler(config) +await scheduler.start() + +# Create and schedule a request +request = create_text_request( + request_id="req_001", + prompt="Hello, world!", + sampling_params=sampling_params +) + +success = await scheduler.schedule_request(request) +if success: + result = await scheduler.process_request(request) + print(f"Result: {result}") +``` + +#### Advanced Usage +```python +# Custom scheduler with priority handling +config = CoreModuleConfig( + scheduler=SchedulerConfig( + scheduler_type=SchedulerType.PRIORITY, + priority_weights={"high": 2.0, "normal": 1.0, "low": 0.5} + ) +) + +scheduler = CoreScheduler(config) +await scheduler.start() + +# Monitor scheduler status +status = scheduler.get_status() +print(f"Queue size: {status['queue_size']}") +print(f"Cache hit rate: {status['cache_hit_rate']:.2%}") +``` + +--- + +## Module-Specific Guidelines + +### Core Module +- Focus on scheduling, caching, and resource management +- Provide clear interfaces for other modules +- Handle concurrency and thread safety +- Implement comprehensive monitoring + +### Engine Module +- Handle model loading and inference +- Support both AR and diffusion models +- Provide unified interface for different model types +- Implement efficient memory management + +### Executor Module +- Coordinate between different engines +- Handle request routing and load balancing +- Manage execution pipelines +- Provide error recovery mechanisms + +### Worker Module +- Handle actual model execution +- Manage GPU resources +- Implement batching strategies +- Provide performance optimization + +--- + +## Checklist for API Design + +- [ ] **Clear Purpose**: Module purpose is well-defined +- [ ] **Complete Interface**: All public methods are documented +- [ ] **Error Handling**: Comprehensive error handling strategy +- [ ] **Configuration**: Flexible configuration system +- [ ] **Examples**: Basic and advanced usage examples +- [ ] **Type Hints**: All methods have proper type hints +- [ ] **Documentation**: Comprehensive docstrings +- [ ] **Testing**: Unit tests for all public methods +- [ ] **Integration**: Clear integration points with other modules +- [ ] **Performance**: Performance considerations documented + +--- + +## Submission Guidelines + +1. Create a new file: `docs/api/[module_name]_api.md` +2. Follow the template structure exactly +3. Include all required sections +4. Provide working code examples +5. Ensure all methods have proper docstrings +6. Include error handling strategies +7. Submit for review before implementation + +This template ensures consistency and completeness across all vLLM-omni modules. diff --git a/docs/api/README.md b/docs/api/README.md new file mode 100644 index 00000000000..681bbba81d6 --- /dev/null +++ b/docs/api/README.md @@ -0,0 +1,235 @@ +# vLLM-omni API Documentation + +This directory contains comprehensive API design documentation for all core modules in vLLM-omni. These templates provide a standardized structure for designing and implementing the core, engine, executor, and worker modules. + +## πŸ“‹ API Design Templates + +### 1. [API Design Template](API_DESIGN_TEMPLATE.md) +**Master template** for creating API documentation for any module in vLLM-omni. Use this as a starting point for new modules. + +### 2. [Core Module API](core_module_api.md) +**Core module** provides fundamental scheduling, caching, and resource management functionality. + +**Key Components:** +- Request scheduling and prioritization +- DiT cache management for diffusion models +- Resource allocation and coordination +- Inter-module communication + +### 3. [Engine Module API](engine_module_api.md) +**Engine module** handles model loading, inference execution, and output processing. + +**Key Components:** +- Model loading and initialization +- Inference execution for AR and diffusion models +- Input preprocessing for various modalities +- Output postprocessing and formatting + +### 4. [Executor Module API](executor_module_api.md) +**Executor module** coordinates and manages request execution across different engines and workers. + +**Key Components:** +- Request routing and load balancing +- Execution pipeline coordination +- Worker management and task distribution +- Error handling and recovery + +### 5. [Worker Module API](worker_module_api.md) +**Worker module** provides the actual execution environment for model inference. + +**Key Components:** +- Model execution and inference +- GPU resource management +- Request batching and processing +- Performance optimization + +## 🎯 How to Use These Templates + +### For New Module Development + +1. **Start with the Master Template** + - Copy `API_DESIGN_TEMPLATE.md` + - Rename to `[module_name]_api.md` + - Follow the structure exactly + +2. **Fill in Module-Specific Details** + - Update the module overview + - Define core classes and interfaces + - Specify public API methods + - Add configuration options + - Define error handling strategy + - Provide usage examples + +3. **Review and Validate** + - Ensure all sections are complete + - Verify code examples work + - Check for consistency with other modules + - Validate configuration options + +### For Existing Module Updates + +1. **Update API Documentation** + - Modify existing API files + - Add new methods and classes + - Update examples and configuration + - Maintain backward compatibility notes + +2. **Version Control** + - Track changes in git + - Use clear commit messages + - Tag major API changes + +## πŸ“š Template Structure + +Each API template follows this standardized structure: + +### 1. Module Overview +- **Purpose**: What the module does +- **Responsibilities**: Key responsibilities +- **Dependencies**: Required modules +- **Integration Points**: How it connects with other modules + +### 2. Core Classes/Interfaces +- **Base Classes**: Abstract base classes +- **Implementation Classes**: Concrete implementations +- **Data Structures**: Key data models + +### 3. Public API Methods +- **Initialization**: Constructor and setup +- **Core Operations**: Main functionality +- **Configuration**: Settings management +- **Lifecycle Management**: Start/stop/cleanup +- **Monitoring**: Status and metrics + +### 4. Configuration +- **Configuration Classes**: Dataclasses for settings +- **Required Parameters**: Must-have settings +- **Optional Parameters**: Optional settings with defaults +- **Validation**: Parameter validation rules + +### 5. Error Handling +- **Custom Exceptions**: Module-specific errors +- **Error Codes**: Standardized error codes +- **Recovery Strategies**: Error recovery approaches + +### 6. Examples +- **Basic Usage**: Simple usage examples +- **Advanced Usage**: Complex scenarios +- **Integration Examples**: Multi-module usage + +## πŸ”§ Implementation Guidelines + +### Code Standards +- Use type hints for all methods +- Include comprehensive docstrings +- Follow PEP 8 style guidelines +- Use async/await for I/O operations + +### Error Handling +- Define specific exception types +- Provide meaningful error messages +- Include error recovery strategies +- Log errors appropriately + +### Configuration +- Use dataclasses for configuration +- Provide sensible defaults +- Include validation methods +- Support environment variables + +### Testing +- Write unit tests for all public methods +- Include integration tests +- Test error conditions +- Validate configuration options + +## πŸš€ Getting Started + +### For Developers + +1. **Choose a Module** + - Pick the module you want to work on + - Read the corresponding API documentation + - Understand the responsibilities and interfaces + +2. **Set Up Development Environment** + ```bash + # Clone the repository + git clone https://github.com/hsliuustc0106/vllm-omni.git + cd vllm-omni + + # Install dependencies + pip install -r requirements-dev.txt + + # Set up pre-commit hooks + pre-commit install + ``` + +3. **Start Implementation** + - Create the module directory + - Implement base classes first + - Add concrete implementations + - Write tests as you go + +### For Contributors + +1. **Review API Documentation** + - Read through all module APIs + - Understand the overall architecture + - Identify areas for improvement + +2. **Propose Changes** + - Create issues for API changes + - Discuss in pull requests + - Update documentation accordingly + +3. **Submit Contributions** + - Follow the coding standards + - Include tests for new features + - Update documentation + - Submit pull requests + +## πŸ“– Additional Resources + +### Architecture Overview +- [vLLM-omni Architecture Design](../architecture/vLLM-omni%20arch%20design%20doc.md) +- [Component Categorization](../../ARCHITECTURE_CATEGORIZATION.md) + +### Development Guides +- [Development Setup](../../README.md#development) +- [Testing Guidelines](../../tests/README.md) +- [Contributing Guidelines](../../CONTRIBUTING.md) + +### Examples +- [Basic Examples](../../examples/basic/) +- [Advanced Examples](../../examples/advanced/) +- [Multimodal Examples](../../examples/multimodal/) + +## 🀝 Contributing + +We welcome contributions to improve the API documentation and implementation. Please: + +1. **Follow the Template Structure** + - Use the master template as a guide + - Maintain consistency across modules + - Include all required sections + +2. **Provide Working Examples** + - Test all code examples + - Include both basic and advanced usage + - Show integration patterns + +3. **Keep Documentation Updated** + - Update docs when APIs change + - Version control documentation changes + - Maintain backward compatibility notes + +## πŸ“ Notes + +- All API documentation should be kept up-to-date with implementation +- Code examples should be tested and working +- Configuration options should be validated +- Error handling should be comprehensive +- Performance considerations should be documented + +For questions or suggestions about the API documentation, please open an issue or start a discussion in the repository. diff --git a/docs/architecture/detailed arch design.md b/docs/architecture/detailed arch design.md new file mode 100644 index 00000000000..c442fc67d28 --- /dev/null +++ b/docs/architecture/detailed arch design.md @@ -0,0 +1,425 @@ +# vLLM-omni Software Design Document + +## Overview + +vLLM-omni is a multi-modality extension for vLLM that supports non-autoregressive structures and non-textual outputs. This document outlines the key software abstractions, APIs, and dependencies for the system, designed to maximize reuse of vLLM's proven architecture. + +## Architecture Principles + +1. **vLLM V1 Compatibility**: Built on vLLM's Engine V1 architecture with AsyncLLM and EngineCore patterns +2. **Stage-based Processing**: Models are divided into multiple stages, each processed by different Engine Cores +3. **Multiple Engine Core Support**: Each stage can use either AR Engine Core (reusing vLLM) or Diffusion Engine Core (new DiT support), or other new Cores +4. **Worker Process Pattern**: Follows vLLM's multiprocess worker architecture for scalability +5. **Extensibility**: Easy integration of new modalities, model architectures, and output formats + +## Key Data Flow + +API Server --> OmniLLM/AsyncOmniLLM (New, including multi engines) --> LLMEngine/AsyncLLM --> Engine Core + --> Scheduler (New one for DiT) --> Executor (New one for diffusers) --> Worker (New one for DiT) + --> ModelRunner (New one for AR hiddenstate, New one for DiT) --> RequestState --> OutputProcessoer (New one for final multimodal output) + +## Core Components (Ordered by Data Flow) + +### 1. Installation +```bash +pip install vllm +pip install vllm-omni +``` + +### 2. Online inference launch (Entry Point) + +Keep consistency with the vllm main branch +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8000 +``` + +#### Design Architecture +```mermaid +graph TD + A[vllm serve model --omni] --> B[vLLM-omni CLI Wrapper] + B --> C{Detect --omni flag} + C -->|Yes| D[Parse omniConfig] + C -->|No| E[Forward to vLLM CLI] + D --> F[Initialize AsyncOmniLLM] + F --> G[Start Omni Server] + G --> H[Multi-stage Processing] + E --> I[Standard vLLM Pipeline] + + subgraph "vLLM-omni Components" + F + G + H + end + + subgraph "vLLM Components" + I + end +``` + +#### Omni Serve Command Implementation +File: vllm_omni/entrypoints/cli/main.py +```python +import sys +import argparse +from typing import List, Optional +from vllm_omni.entrypoints.omni import OmniServeCommand + +def main(): + """Main CLI entry point that intercepts vLLM commands""" + # Check if --omni flag is present + if "--omni" in sys.argv: + # Remove --omni flag and process with vLLM-omni + omni_args = [arg for arg in sys.argv[1:] if arg != "--omni"] + omni_serve = OmniServeCommand() + omni_serve.run(omni_args) + else: + # Forward to original vLLM CLI + from vllm.entrypoints.cli.main import main as vllm_main + vllm_main() + +if __name__ == "__main__": + main() +``` + +#### vLLM Plugin Integration +File: vllm_omni/pyproject.toml +```python +[project.scripts] +# Override vLLM's CLI entry point +vllm = "vllm_omni.cli:main" + +# Add entry point for vLLM integration +[project.entry-points."vllm.plugins"] +omni = "vllm_omni.plugin:OmniPlugin" +``` + +### 3. Offline inference launch (Entry Point) +Design an upper level class to incorporate multi Engines, each engine has a engine call. +```python +from vllm.entrypoints.llm import LLM +from vllm.v1.engine.llm_engine import LLMEngine + +class OmniLLM(LLM): + """Extended LLM supporting multiple engines and stage-based processing""" + + def __init__(self, stage_configs: List[OmniStageConfig]): + super().__init__() + self.stage_configs = stage_configs + self.engine_list = [] # List of AsyncLLM instances for each stage + self.output_processor = MultimodalOutputProcessor() + self._initialize_stage_engines() + + def _initialize_stage_engines(self) -> None: + """Initialize LLMEngine instances for each stage""" + for stage_config in self.stage_configs: + stage_llm = LLMEngine.from_vllm_config( + vllm_config=stage_config.vllm_config, + executor_class=self.executor_class, + log_stats=self.log_stats + ) + self.engine_list.append(stage_llm) + + def generate( + self, + stage_args_list: List[stage_args], + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + priority: Optional[list[int]] = None, + ) -> list[RequestOutput]: + """Main generation interface - orchestrates multi-stage processing""" + + # Process through each stage sequentially + for i, stage_config in enumerate(self.stage_configs): + stage_engine = self.engine_list[i] + + # Prepare input for this stage + stage_args = stage_args_list[i] + prompt_str, engine_request, tokenization_kwargs = self._process_stage_inputs(stage_config, **stage_args) + + # Add inputs to Engine + stage_engine.add_request(requesy_id, prompt_str, tokenization_kwargs) + # Run Engine + stage_output = stage_engine.step() + + # Update stage input and output of each stage for later usage + self._update_stage_io(stage_output, stage_config) + + + # Process final output + final_output = self.output_processor.process_output( + self.stage_configs[-1].stage_output + ) + return final_output + + def _process_stage_inputs(stage_config, **stage_args) -> Any: + """Prepare input for specific stage""" + if stage_config.engine_type == "AR": + return self._process_ar_inputs(**stage_args) + elif stage_config.engine_type == "DiT": + return self._process_dit_inputs(**stage_args) + else + raise NotImplementedError + + def _process_dit_inputs(**stage_args)-> Any: + image_latent = self.vae.encode(**stage_args) + + return image_latent +``` + +### 4. Stage config for setting of different model stages +```python +@dataclass +class OmniStageConfig: + """Configuration for a processing stage""" + stage_id: int + engine_type: str # "AR" or "DiT" + model_path: str + input_modalities: List[str] + output_modalities: List[str] + vllm_config: Optional[VllmConfig] = None # For engine config of corresponding stage + executor_class: type[Executor] # For execute class config of corresponding stage + dit_config: Optional[DiTConfig] = None # For diffusion stages + cache_config: Optional[DiTCacheConfig] = None +``` +For AR stage, the setting is: +```python +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.executor.multiproc_executor import MultiprocExecutor +ar_stage_config = OmniStageConfig() +ar_stage_config.vllm_config.scheduler_config.scheduler_cls = Scheduler # original vllm scheduler for AR +ar_stage_config.executor_class = MultiprocExecutor # original vllm executor for AR +``` +For DiT stage, the setting is: +```python +from vllm.v1.executor.multiproc_executor import MultiprocExecutor +dit_stage_config = OmniStageConfig() +dit_stage_config.vllm_config.scheduler_config.scheduler_cls = DiffusionScheduler # New scheduler for DiT + +# For diffusion models without using diffusers +dit_stage_config.executor_class = MultiprocExecutor + +# For diffusion models using diffusers +dit_stage_config.executor_class = DiffusersPipelineExecutor +``` + +### 5. Online inference main class + +Similar to OmniLLM in offline inference, add some asynchronous processing, referring to AsyncLLM +```python +from vllm.v1.engine.async_llm improt AsyncLLM + + +class AsyncOmniLLM(AsyncLLM): + """Extended AsyncLLM supporting multiple engines and stage-based processing""" + + def __init__(self, stage_configs: List[OmniStageConfig]): + super().__init__() +``` + +### 6. Engine Core +No need to change. The specific executor, scheduler and output type can be transferred into it with new configs. +```python +# class: EngineCore +# EngineCore.step (simplified) +class EngineCore: + def step(self): + scheduler_output = scheduler.schedule() + model_output = executor.execute_model(scheduler_output) + engine_outputs = scheduler.update_from_output( + scheduler_output, model_output + ) + return engine_outputs +``` + +### 7. Scheduler +Create a new Diffusion Scheduler for DiT, which is inherited from original vllm AR scheduler. At first, no complicated strategy. +Just process the request first in first out. Then create a class of DiT Cache Manager for optimization in the future. +```python +from vllm.v1.core.sched.scheduler import SchedulerInterface + + +class OmniDiffusionScheduler(SchedulerInterface): + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + dit_cache_config: DiTCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + + self.dit_cache_manager = DiTCacheManager(dit_cache_config) +``` +For DiT Cache Manager (Can refer to xDiT): +```python +class DiTCacheManager: + """Manages DiT-specific caching""" + + def __init__(self, config: DiTCacheConfig): + self.cache_tensors: Dict[str, torch.Tensor] = {} + self.cache_groups: List[DiTCacheTensor] = config.dit_cache_tensors + + def allocate_cache(self, request_id: str, size: int) -> torch.Tensor + def get_cache(self, request_id: str) -> Optional[torch.Tensor] + def release_cache(self, request_id: str) -> None + def clear_expired_cache(self) -> None +``` + +### 8. Executor for DiT without diffusers + +No need to change. Just import original AR executor. + + + +### 9. Worker + +#### Inheritance strategy +Prefer reusing the mature GPU Worker end-to-end. Worker class is selected by configuration (vllm_config). Do not add a new executor-specific worker binding. If customization is needed, override only `init_device` to construct the `OmniDiffusionModelRunner`; all other behaviors (device init details, profiling, sleep/wake, PP/TP comms, execute path) remain from `vllm/v1/worker/gpu_worker.py::Worker`. + + +Inherited and overridden +```python +# Optional: only if you need to plug a custom OmniDiffusionModelRunner. +from vllm.v1.worker.gpu_worker import Worker as GPUWorker +from vllm.v1.worker.diffusion_model_runner import OmniDiffusionModelRunner + +class Worker(GPUWorker): + def init_device(self) -> None: + #those related to device check and init + ... + + self.model_runner = OmniDiffusionModelRunner(self.vllm_config, ...) +``` + +### 10. Model Runner + +### Function map (Model Runner) +#### 1) Inherited and overridden +Those parts relied to the KV Cache will be omitted if we do not register the model to the vllm config. The engine core will view it as do not require KV Cache, and handle it properly + +Reuse `vllm/v1/outputs.py::ModelRunnerOutput`: +- OmniDiffusionModelRunner: Use the `pooler_output=[Tensor,...]` to return multi modal tensors +- OmniARModelRunner: Use the `pooler_output=[Tensor,...]` to return hidden states. +```python +from typing import Optional, Union +import torch +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.outputs import ModelRunnerOutput + + +class OmniDiffusionModelRunner(GPUModelRunner): + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + ... + return ModelRunnerOutput( + req_ids=[...], + req_id_to_index={...}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[Tensor,...], # Return Hidden states + kv_connector_output=None, + num_nans_in_logits=None, + )# return multi modal tensors via pooler_output=[Tensor,...] + + +class OmniARModelRunner(GPUModelRunner): + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + ... + return ModelRunnerOutput( + req_ids=[...], + req_id_to_index={...}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[Tensor,...], # Return Hidden states + kv_connector_output=None, + num_nans_in_logits=None, + )# return hidden states via pooler_output=[Tensor,...] +``` + +### 11. RequestState Role and Usage +```bash +# RequestState serves as the per-request state tracker in OutputProcessor: +# - Maintains request-specific state (tokens, logprobs, detokenizer, etc.) +# - Converts EngineCoreOutput β†’ RequestOutput/PoolingRequestOutput +# - Manages request lifecycle from registration to completion +AsyncLLM.add_request() +↓ +OutputProcessor.add_request() +↓ +RequestState.from_new_request() β†’ εˆ›ε»Ίθ―·ζ±‚ηŠΆζ€ + +AsyncLLM.__init__() / AsyncLLM.generate() / AsyncLLM.encode() β†’ εˆ›ε»ΊδΈ€δΈͺBackground loop 持续从EngineCoreθŽ·ε–θΎ“ε‡Ί +↓ +OutputProcessor.process_outputs() β†’ ζ›΄ζ–°ηŠΆζ€εΉΆε€„η†θΎ“ε‡Ί +↓ +RequestState.make_request_output() β†’ θ½¬ζ’δΈΊζœ€η»ˆθΎ“ε‡Ί,格式为RequestOutputζˆ–θ€…PoolingRequestOutput +↓ +RequestOutputCollector.put() β†’ ζŽ¨ι€εˆ°ι˜Ÿεˆ—οΌˆAsyncLLMοΌ‰ +``` + +Need to add implementation for an existing method +```python +class RequestState: + def __init__(self, request_id: str, parent_req: Optional[ParentRequest], + request_index: int, lora_name: Optional[str], + output_kind: RequestOutputKind, prompt: Optional[str], + prompt_token_ids: list[int], logprobs_processor: Optional[LogprobsProcessor], + detokenizer: Optional[IncrementalDetokenizer], + max_tokens_param: Optional[int], arrival_time: float, + queue: Optional[RequestOutputCollector], log_stats: bool): + + def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput: + """Create PoolingOutput for multimodal/pooling requests""" + return PoolingOutput(data=pooling_output) +``` + + +### 12. OutputProcessor + +For hidden state output, original OutputProcessor already support. + +We only need to add new one for final multimodal output. +```python +class MultimodalOutputProcessor(OutputProcessor): + """Handles multimodal output processing""" + + def __init__(self): + self.output_handlers: Dict[str, OutputProcessor] = {} + + def process_outputs(self, engine_core_outputs: list[EngineCoreOutput], ...): + for engine_core_output in engine_core_outputs: + # Option 1: Use output_type field (if EngineCoreOutput is extended) + if engine_core_output.output_type == "image": + self._process_image_output(engine_core_output) + elif engine_core_output.output_type == "text+image": + self._process_text_image_output(engine_core_output) + elif engine_core_output.output_type == "latents": + self._process_latents_output(engine_core_output) + elif engine_core_output.output_type == "text": + self._process_text_output(engine_core_output) + else: + # Fallback: use existing pooling_output logic + if engine_core_output.pooling_output is not None: + self._process_pooling_output(engine_core_output) + else: + self._process_text_output(engine_core_output) +``` + + + diff --git a/docs/architecture/vLLM-omni arch design doc.md b/docs/architecture/high level arch design.md similarity index 100% rename from docs/architecture/vLLM-omni arch design doc.md rename to docs/architecture/high level arch design.md diff --git a/docs/architecture/implementation_architecture.md b/docs/architecture/implementation_architecture.md new file mode 100644 index 00000000000..60af881e0d5 --- /dev/null +++ b/docs/architecture/implementation_architecture.md @@ -0,0 +1,374 @@ +# vLLM-omni Implementation Architecture + +## 1. Package Structure + +``` +vllm_omni/ +β”œβ”€β”€ __init__.py # Main package exports +β”œβ”€β”€ config/ # Configuration management +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ stage_config.py # OmniStageConfig implementation +β”‚ β”œβ”€β”€ dit_config.py # DiT-specific configurations +β”‚ └── cache_config.py # Caching configurations +β”œβ”€β”€ core/ # Core processing components +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ omni_llm.py # OmniLLM and AsyncOmniLLM +β”‚ β”œβ”€β”€ stage_manager.py # Multi-stage orchestration +β”‚ β”œβ”€β”€ dit_cache_manager.py # DiT caching system +β”‚ └── sched/ # Schedulers +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ scheduler.py # Base scheduler interface +β”‚ └── diffusion_scheduler.py # DiT scheduler +β”œβ”€β”€ engine/ # Engine components +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ processor.py # Output processing +β”‚ └── output_processor.py # Multimodal output handling +β”œβ”€β”€ executor/ # Executor implementations +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ base_executor.py # Base executor interface +β”‚ └── diffusers_executor.py # Diffusers pipeline executor +β”œβ”€β”€ model_executor/ # Model execution components +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ ar_model_runner.py # OmniARModelRunner +β”‚ └── dit_model_runner.py # OmniDiffusionModelRunner +β”œβ”€β”€ worker/ # Worker implementations +β”‚ β”œβ”€β”€ __init__.py +β”‚ └── omni_worker.py # Extended worker for DiT +β”œβ”€β”€ entrypoints/ # Entry points and CLI +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ omni.py # OmniServeCommand +β”‚ └── cli/ # CLI integration +β”‚ β”œβ”€β”€ __init__.py +β”‚ └── main.py # CLI main entry point +β”œβ”€β”€ request.py # Request handling +β”œβ”€β”€ dit_cache_interface.py # DiT cache interface +└── utils/ # Utility functions + β”œβ”€β”€ __init__.py + β”œβ”€β”€ multimodal.py # Multimodal utilities + └── vae.py # VAE utilities for image processing +``` + +## 2. Core Module Dependencies + +### 2.1 vLLM Integration Points + +```python +# Key vLLM imports and extensions +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.core.sched.scheduler import SchedulerInterface +from vllm.v1.executor.multiproc_executor import MultiprocExecutor +from vllm.v1.worker.gpu_worker import Worker as GPUWorker +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.outputs import ModelRunnerOutput +``` + +### 2.2 Internal Dependencies + +```python +# Internal module dependencies +vllm_omni.config β†’ vllm_omni.core +vllm_omni.core β†’ vllm_omni.engine, vllm_omni.executor, vllm_omni.worker +vllm_omni.engine β†’ vllm_omni.model_executor +vllm_omni.entrypoints β†’ vllm_omni.core +``` + +## 3. Configuration System + +### 3.1 Stage Configuration + +```python +# vllm_omni/config/stage_config.py +@dataclass +class OmniStageConfig: + stage_id: int + engine_type: Literal["AR", "DiT"] + model_path: str + input_modalities: List[str] + output_modalities: List[str] + vllm_config: Optional[VllmConfig] = None + executor_class: type[Executor] = MultiprocExecutor + dit_config: Optional[DiTConfig] = None + cache_config: Optional[DiTCacheConfig] = None + stage_output: Optional[Any] = None +``` + +### 3.2 DiT Configuration + +```python +# vllm_omni/config/dit_config.py +@dataclass +class DiTConfig: + model_type: str + scheduler_type: str + num_inference_steps: int + guidance_scale: float + use_diffusers: bool = False + diffusers_pipeline: Optional[str] = None +``` + +### 3.3 Cache Configuration + +```python +# vllm_omni/config/cache_config.py +@dataclass +class DiTCacheConfig: + cache_tensors: List[DiTCacheTensor] + max_cache_size: int + cache_strategy: str = "fifo" + enable_optimization: bool = True +``` + +## 4. Core Implementation Details + +### 4.1 OmniLLM Implementation + +```python +# vllm_omni/core/omni_llm.py +class OmniLLM(LLM): + def __init__(self, stage_configs: List[OmniStageConfig]): + super().__init__() + self.stage_configs = stage_configs + self.engine_list: List[LLMEngine] = [] + self.output_processor = MultimodalOutputProcessor() + self._initialize_stage_engines() + + def _initialize_stage_engines(self) -> None: + """Initialize LLMEngine instances for each stage""" + for stage_config in self.stage_configs: + if stage_config.engine_type == "AR": + engine = self._create_ar_engine(stage_config) + elif stage_config.engine_type == "DiT": + engine = self._create_dit_engine(stage_config) + self.engine_list.append(engine) + + def generate(self, stage_args_list: List[Dict], **kwargs) -> List[RequestOutput]: + """Main generation interface - orchestrates multi-stage processing""" + current_output = None + + for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): + stage_engine = self.engine_list[i] + + # Prepare input for this stage + processed_input = self._process_stage_inputs(stage_config, stage_args, current_output) + + # Execute stage + stage_output = self._execute_stage(stage_engine, processed_input) + + # Update for next stage + current_output = stage_output + stage_config.stage_output = stage_output + + # Process final output + return self.output_processor.process_output(current_output) +``` + +### 4.2 AsyncOmniLLM Implementation + +```python +# vllm_omni/core/omni_llm.py +class AsyncOmniLLM(AsyncLLM): + def __init__(self, stage_configs: List[OmniStageConfig]): + super().__init__() + self.stage_configs = stage_configs + self.async_engine_list: List[AsyncLLM] = [] + self.output_processor = MultimodalOutputProcessor() + self._initialize_async_stage_engines() + + async def generate_async(self, stage_args_list: List[Dict], **kwargs) -> List[RequestOutput]: + """Async generation interface""" + # Similar to OmniLLM but with async/await patterns + pass +``` + +### 4.3 DiT Scheduler Implementation + +```python +# vllm_omni/core/sched/diffusion_scheduler.py +class OmniDiffusionScheduler(SchedulerInterface): + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + dit_cache_config: DiTCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + super().__init__(vllm_config, kv_cache_config, structured_output_manager, + mm_registry, include_finished_set, log_stats) + self.dit_cache_manager = DiTCacheManager(dit_cache_config) + + def schedule(self) -> SchedulerOutput: + """Schedule DiT requests with caching optimization""" + # Implement DiT-specific scheduling logic + pass +``` + +### 4.4 Model Runners + +```python +# vllm_omni/model_executor/dit_model_runner.py +class OmniDiffusionModelRunner(GPUModelRunner): + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + # DiT model execution logic + return ModelRunnerOutput( + req_ids=[...], + req_id_to_index={...}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[tensor, ...], # DiT output tensors + kv_connector_output=None, + num_nans_in_logits=None, + ) + +# vllm_omni/model_executor/ar_model_runner.py +class OmniARModelRunner(GPUModelRunner): + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + # AR model execution with hidden state output + return ModelRunnerOutput( + req_ids=[...], + req_id_to_index={...}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[hidden_states, ...], # Hidden states + kv_connector_output=None, + num_nans_in_logits=None, + ) +``` + +### 4.5 Output Processing + +```python +# vllm_omni/engine/output_processor.py +class MultimodalOutputProcessor(OutputProcessor): + def __init__(self): + self.output_handlers: Dict[str, Callable] = { + "image": self._process_image_output, + "text+image": self._process_text_image_output, + "latents": self._process_latents_output, + "text": self._process_text_output, + } + + def process_outputs(self, engine_core_outputs: List[EngineCoreOutput], ...): + """Process multimodal outputs based on type""" + for engine_core_output in engine_core_outputs: + output_type = self._detect_output_type(engine_core_output) + handler = self.output_handlers.get(output_type, self._process_pooling_output) + handler(engine_core_output) +``` + +## 5. CLI Integration + +### 5.1 Entry Point Override + +```python +# vllm_omni/entrypoints/cli/main.py +def main(): + """Main CLI entry point that intercepts vLLM commands""" + if "--omni" in sys.argv: + omni_args = [arg for arg in sys.argv[1:] if arg != "--omni"] + omni_serve = OmniServeCommand() + omni_serve.run(omni_args) + else: + from vllm.entrypoints.cli.main import main as vllm_main + vllm_main() +``` + +### 5.2 Package Configuration + +```toml +# pyproject.toml updates +[project.scripts] +vllm = "vllm_omni.entrypoints.cli.main:main" +vllm-omni = "vllm_omni.entrypoints.cli.main:main" + +[project.entry-points."vllm.plugins"] +omni = "vllm_omni.plugin:OmniPlugin" +``` + +## 6. Testing Strategy + +### 6.1 Unit Tests +- Individual component testing +- Mock vLLM dependencies +- Configuration validation + +### 6.2 Integration Tests +- End-to-end pipeline testing +- vLLM compatibility testing +- Multi-stage processing validation + +### 6.3 Performance Tests +- Benchmarking against native vLLM +- Memory usage profiling +- Latency measurements + +## 7. Installation and Setup + +### 7.1 Package Installation +```bash +pip install vllm>=0.2.0 +pip install vllm-omni +``` + +### 7.2 Development Setup +```bash +git clone https://github.com/hsliuustc0106/vllm-omni +cd vllm-omni +pip install -e ".[dev]" +``` + +### 7.3 Usage +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8000 +``` + +## 8. Memory Management + +### 8.1 DiT Cache Management +- Tensor caching for intermediate results +- Memory pooling for efficient allocation +- Cache eviction strategies + +### 8.2 Multi-Stage Memory +- Inter-stage data passing optimization +- Memory sharing between stages +- Garbage collection optimization + +## 9. Error Handling + +### 9.1 Stage Failure Handling +- Graceful degradation on stage failures +- Error propagation and reporting +- Recovery mechanisms + +### 9.2 vLLM Compatibility +- Version compatibility checks +- API change detection +- Fallback mechanisms + +## 10. Monitoring and Logging + +### 10.1 Performance Metrics +- Stage execution times +- Memory usage per stage +- Throughput measurements + +### 10.2 Debug Information +- Request tracing across stages +- Cache hit/miss ratios +- Error logging and reporting diff --git a/docs/implementation_summary.md b/docs/implementation_summary.md new file mode 100644 index 00000000000..204a798c896 --- /dev/null +++ b/docs/implementation_summary.md @@ -0,0 +1,293 @@ +# vLLM-omni Implementation Summary + +## Overview + +This document provides a comprehensive summary of the vLLM-omni implementation plan and current progress. The implementation follows a structured approach with PRD, architecture design, test design, and code implementation phases. + +## Implementation Status + +### βœ… Completed Phases + +#### 1. Product Requirements Document (PRD) +- **File**: `docs/PRD.md` +- **Status**: Complete +- **Contents**: + - Product vision and target users + - Functional and non-functional requirements + - Technical architecture overview + - Implementation phases and success criteria + - Risk assessment and mitigation strategies + +#### 2. Architecture Design +- **File**: `docs/architecture/implementation_architecture.md` +- **Status**: Complete +- **Contents**: + - Detailed package structure + - Core module dependencies + - Configuration system design + - Implementation details for all components + - Memory management and error handling strategies + +#### 3. Test Design +- **File**: `docs/testing/test_design.md` +- **Status**: Complete +- **Contents**: + - Comprehensive testing strategy + - Unit, integration, and E2E test specifications + - Performance and compatibility testing + - Test configuration and execution guidelines + +#### 4. Package Setup +- **Status**: Complete +- **Components**: + - Updated `pyproject.toml` with vLLM integration + - Package structure and dependencies + - CLI entry point configuration + - vLLM plugin system setup + +#### 5. Core Modules +- **Status**: Complete +- **Components**: + - `vllm_omni/config/`: Configuration management + - `vllm_omni/core/omni_llm.py`: OmniLLM and AsyncOmniLLM classes + - `vllm_omni/core/stage_manager.py`: Multi-stage orchestration + - `vllm_omni/core/dit_cache_manager.py`: DiT caching system + - `vllm_omni/core/sched/diffusion_scheduler.py`: DiT scheduler + - `vllm_omni/engine/output_processor.py`: Multimodal output processing + +### 🚧 In Progress / Pending + +#### 6. Scheduler and Executor Components +- **Status**: Partially Complete +- **Completed**: + - DiT scheduler implementation + - DiT cache manager +- **Pending**: + - Executor implementations + - Integration with vLLM's executor system + +#### 7. Model Runners +- **Status**: Pending +- **Components**: + - `OmniDiffusionModelRunner`: DiT model execution + - `OmniARModelRunner`: AR model execution with hidden states + - Integration with vLLM's model runner system + +#### 8. Output Processing +- **Status**: Partially Complete +- **Completed**: + - Basic multimodal output processor +- **Pending**: + - RequestState extensions + - Advanced output handling + +#### 9. CLI Integration +- **Status**: Complete +- **Components**: + - CLI entry point with `--omni` flag support + - OmniServeCommand implementation + - API server for serving + - vLLM plugin system + +#### 10. Testing and Validation +- **Status**: Partially Complete +- **Completed**: + - Basic test structure + - Configuration tests +- **Pending**: + - Integration tests + - E2E tests + - Performance benchmarks + +## Key Implementation Details + +### Package Structure +``` +vllm_omni/ +β”œβ”€β”€ config/ # Configuration management +β”‚ β”œβ”€β”€ stage_config.py # OmniStageConfig, DiTConfig, etc. +β”‚ └── __init__.py +β”œβ”€β”€ core/ # Core processing components +β”‚ β”œβ”€β”€ omni_llm.py # OmniLLM and AsyncOmniLLM +β”‚ β”œβ”€β”€ stage_manager.py # Multi-stage orchestration +β”‚ β”œβ”€β”€ dit_cache_manager.py # DiT caching system +β”‚ └── sched/ # Schedulers +β”‚ └── diffusion_scheduler.py +β”œβ”€β”€ engine/ # Engine components +β”‚ └── output_processor.py # Multimodal output handling +β”œβ”€β”€ entrypoints/ # Entry points and CLI +β”‚ β”œβ”€β”€ cli/main.py # CLI main entry point +β”‚ β”œβ”€β”€ omni.py # OmniServeCommand +β”‚ └── api_server.py # API server +β”œβ”€β”€ plugin.py # vLLM plugin system +└── __init__.py # Main package exports +``` + +### Core Classes + +#### OmniLLM +- **Purpose**: Offline multi-stage processing +- **Features**: + - Sequential stage processing + - AR and DiT stage support + - Input/output processing between stages + - Integration with vLLM's LLMEngine + +#### AsyncOmniLLM +- **Purpose**: Online multi-stage processing +- **Features**: + - Asynchronous stage processing + - Real-time request handling + - Integration with vLLM's AsyncLLM + +#### StageManager +- **Purpose**: Orchestrate multiple stage engines +- **Features**: + - Engine creation and management + - Stage configuration handling + - Resource cleanup + +#### DiTCacheManager +- **Purpose**: Optimize DiT inference with caching +- **Features**: + - Tensor caching with multiple strategies (FIFO, LRU, LFU) + - Memory management + - Cache statistics and monitoring + +### Configuration System + +#### OmniStageConfig +- **Purpose**: Configure individual processing stages +- **Features**: + - Support for AR and DiT engine types + - Modality specification + - Executor class selection + - DiT-specific configuration + +#### DiTConfig +- **Purpose**: Configure DiT-specific parameters +- **Features**: + - Model type and scheduler configuration + - Inference parameters (steps, guidance scale) + - Diffusers integration support + +### CLI Integration + +#### Command Line Interface +```bash +# Basic usage +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8000 + +# Multi-stage configuration +vllm serve model --omni --ar-stage text-model --dit-stage image-model + +# DiT-specific options +vllm serve model --omni --dit-steps 50 --dit-guidance-scale 7.5 +``` + +#### API Server +- **Endpoints**: + - `POST /generate`: Generate text or multimodal content + - `GET /health`: Health check + - `GET /info`: Service information +- **Features**: + - FastAPI-based REST API + - Async request handling + - Multi-stage processing support + +## Installation and Usage + +### Installation +```bash +# Install vLLM first +pip install vllm>=0.2.0 + +# Install vLLM-omni +pip install vllm-omni + +# Or install from source +git clone https://github.com/hsliuustc0106/vllm-omni +cd vllm-omni +pip install -e ".[dev]" +``` + +### Basic Usage +```python +from vllm_omni import OmniLLM, create_ar_stage_config, create_dit_stage_config + +# Create stage configurations +ar_config = create_ar_stage_config( + stage_id=0, + model_path="microsoft/DialoGPT-small", + input_modalities=["text"], + output_modalities=["text"] +) + +dit_config = create_dit_stage_config( + stage_id=1, + model_path="stabilityai/stable-diffusion-2-1", + input_modalities=["text"], + output_modalities=["image"] +) + +# Create OmniLLM instance +omni_llm = OmniLLM([ar_config, dit_config]) + +# Generate with multi-stage processing +stage_args = [ + {"prompt": "A beautiful landscape"}, + {"prompt": "A beautiful landscape"} +] + +outputs = omni_llm.generate(stage_args) +``` + +## Next Steps + +### Immediate Priorities +1. **Complete Model Runners**: Implement OmniDiffusionModelRunner and OmniARModelRunner +2. **Executor Integration**: Complete executor implementations and vLLM integration +3. **Testing**: Implement comprehensive test suite +4. **Documentation**: Create user guides and API documentation + +### Medium-term Goals +1. **Performance Optimization**: Optimize memory usage and inference speed +2. **Advanced Features**: Implement advanced scheduling and caching strategies +3. **Model Support**: Add support for more model architectures +4. **Production Readiness**: Add monitoring, logging, and deployment tools + +### Long-term Vision +1. **Multi-GPU Support**: Optimize for distributed inference +2. **Custom Architectures**: Support for custom model architectures +3. **Advanced Multimodal**: Enhanced multimodal fusion capabilities +4. **Ecosystem Integration**: Integration with popular ML frameworks + +## Technical Considerations + +### vLLM Compatibility +- Maintains compatibility with vLLM V1 architecture +- Reuses proven components (scheduler, executor, worker patterns) +- Minimal modifications to existing vLLM codebase + +### Memory Management +- Efficient caching system for DiT models +- Memory sharing between stages +- Garbage collection optimization + +### Extensibility +- Plugin-based architecture +- Easy integration of new modalities +- Configurable stage processing + +### Performance +- Optimized for both AR and DiT models +- Efficient multi-stage processing +- Scalable to multiple GPUs + +## Conclusion + +The vLLM-omni implementation provides a solid foundation for multi-modality model inference with non-autoregressive structures. The current implementation covers the core functionality needed for basic multi-stage processing, with a clear path forward for advanced features and optimizations. + +The modular architecture ensures maintainability and extensibility, while the vLLM integration provides a proven foundation for production deployment. The comprehensive testing strategy and documentation ensure reliability and ease of use. + +This implementation represents a significant step forward in making advanced multimodal AI models accessible and efficient for production use cases. diff --git a/docs/testing/test_design.md b/docs/testing/test_design.md new file mode 100644 index 00000000000..6206c1a35e9 --- /dev/null +++ b/docs/testing/test_design.md @@ -0,0 +1,759 @@ +# vLLM-omni Test Design Document + +## 1. Testing Strategy Overview + +### 1.1 Testing Pyramid +``` + E2E Tests (10%) + / \ + Integration Tests (30%) + / \ + Unit Tests (60%) +``` + +### 1.2 Test Categories +- **Unit Tests**: Individual component testing with mocked dependencies +- **Integration Tests**: Component interaction testing with real vLLM integration +- **End-to-End Tests**: Full pipeline testing with real models +- **Performance Tests**: Benchmarking and profiling +- **Compatibility Tests**: vLLM version compatibility validation + +## 2. Test Structure + +``` +tests/ +β”œβ”€β”€ __init__.py +β”œβ”€β”€ conftest.py # Shared fixtures and configuration +β”œβ”€β”€ unit/ # Unit tests (60%) +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ test_config/ # Configuration testing +β”‚ β”‚ β”œβ”€β”€ test_stage_config.py +β”‚ β”‚ β”œβ”€β”€ test_dit_config.py +β”‚ β”‚ └── test_cache_config.py +β”‚ β”œβ”€β”€ test_core/ # Core component testing +β”‚ β”‚ β”œβ”€β”€ test_omni_llm.py +β”‚ β”‚ β”œβ”€β”€ test_async_omni_llm.py +β”‚ β”‚ β”œβ”€β”€ test_stage_manager.py +β”‚ β”‚ └── test_dit_cache_manager.py +β”‚ β”œβ”€β”€ test_scheduler/ # Scheduler testing +β”‚ β”‚ β”œβ”€β”€ test_diffusion_scheduler.py +β”‚ β”‚ └── test_scheduler_interface.py +β”‚ β”œβ”€β”€ test_executor/ # Executor testing +β”‚ β”‚ β”œβ”€β”€ test_base_executor.py +β”‚ β”‚ └── test_diffusers_executor.py +β”‚ β”œβ”€β”€ test_model_executor/ # Model runner testing +β”‚ β”‚ β”œβ”€β”€ test_ar_model_runner.py +β”‚ β”‚ └── test_dit_model_runner.py +β”‚ β”œβ”€β”€ test_engine/ # Engine testing +β”‚ β”‚ β”œβ”€β”€ test_output_processor.py +β”‚ β”‚ └── test_multimodal_processor.py +β”‚ β”œβ”€β”€ test_worker/ # Worker testing +β”‚ β”‚ └── test_omni_worker.py +β”‚ └── test_utils/ # Utility testing +β”‚ β”œβ”€β”€ test_multimodal.py +β”‚ └── test_vae.py +β”œβ”€β”€ integration/ # Integration tests (30%) +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ test_vllm_integration.py +β”‚ β”œβ”€β”€ test_stage_processing.py +β”‚ β”œβ”€β”€ test_cli_integration.py +β”‚ └── test_api_compatibility.py +β”œβ”€β”€ e2e/ # End-to-end tests (10%) +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ test_full_pipeline.py +β”‚ β”œβ”€β”€ test_multimodal_generation.py +β”‚ └── test_performance.py +β”œβ”€β”€ benchmarks/ # Performance benchmarks +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ test_memory_usage.py +β”‚ β”œβ”€β”€ test_latency.py +β”‚ └── test_throughput.py +└── fixtures/ # Test data and fixtures + β”œβ”€β”€ __init__.py + β”œβ”€β”€ sample_models/ + β”œβ”€β”€ test_images/ + └── test_configs/ +``` + +## 3. Unit Test Specifications + +### 3.1 Configuration Tests + +#### test_stage_config.py +```python +import pytest +from vllm_omni.config.stage_config import OmniStageConfig, DiTConfig, DiTCacheConfig + +class TestOmniStageConfig: + def test_ar_stage_config_creation(self): + """Test AR stage configuration creation""" + config = OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-model", + input_modalities=["text"], + output_modalities=["text"] + ) + assert config.engine_type == "AR" + assert config.stage_id == 0 + + def test_dit_stage_config_creation(self): + """Test DiT stage configuration creation""" + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50 + ) + config = OmniStageConfig( + stage_id=1, + engine_type="DiT", + model_path="test-dit-model", + input_modalities=["text"], + output_modalities=["image"], + dit_config=dit_config + ) + assert config.engine_type == "DiT" + assert config.dit_config is not None + + def test_invalid_engine_type(self): + """Test validation of engine type""" + with pytest.raises(ValueError): + OmniStageConfig( + stage_id=0, + engine_type="INVALID", + model_path="test-model", + input_modalities=["text"], + output_modalities=["text"] + ) +``` + +#### test_dit_config.py +```python +import pytest +from vllm_omni.config.dit_config import DiTConfig + +class TestDiTConfig: + def test_dit_config_defaults(self): + """Test DiT configuration with defaults""" + config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50 + ) + assert config.use_diffusers is False + assert config.diffusers_pipeline is None + + def test_diffusers_config(self): + """Test DiT configuration with diffusers""" + config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50, + use_diffusers=True, + diffusers_pipeline="stable-diffusion" + ) + assert config.use_diffusers is True + assert config.diffusers_pipeline == "stable-diffusion" +``` + +### 3.2 Core Component Tests + +#### test_omni_llm.py +```python +import pytest +from unittest.mock import Mock, patch +from vllm_omni.core.omni_llm import OmniLLM +from vllm_omni.config.stage_config import OmniStageConfig + +class TestOmniLLM: + @pytest.fixture + def mock_stage_configs(self): + """Create mock stage configurations""" + return [ + OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-ar-model", + input_modalities=["text"], + output_modalities=["text"] + ), + OmniStageConfig( + stage_id=1, + engine_type="DiT", + model_path="test-dit-model", + input_modalities=["text"], + output_modalities=["image"] + ) + ] + + @patch('vllm_omni.core.omni_llm.LLMEngine') + def test_omni_llm_initialization(self, mock_llm_engine, mock_stage_configs): + """Test OmniLLM initialization""" + omni_llm = OmniLLM(mock_stage_configs) + assert len(omni_llm.engine_list) == 2 + assert omni_llm.stage_configs == mock_stage_configs + + @patch('vllm_omni.core.omni_llm.LLMEngine') + def test_stage_engine_creation(self, mock_llm_engine, mock_stage_configs): + """Test stage engine creation""" + omni_llm = OmniLLM(mock_stage_configs) + # Verify engines are created for each stage + assert mock_llm_engine.from_vllm_config.call_count == 2 + + def test_process_stage_inputs_ar(self, mock_stage_configs): + """Test AR stage input processing""" + omni_llm = OmniLLM(mock_stage_configs) + ar_config = mock_stage_configs[0] + stage_args = {"prompt": "test prompt"} + + result = omni_llm._process_stage_inputs(ar_config, stage_args, None) + assert "prompt" in result + + def test_process_stage_inputs_dit(self, mock_stage_configs): + """Test DiT stage input processing""" + omni_llm = OmniLLM(mock_stage_configs) + dit_config = mock_stage_configs[1] + stage_args = {"image": "test_image.jpg"} + + with patch.object(omni_llm, 'vae') as mock_vae: + mock_vae.encode.return_value = "encoded_image" + result = omni_llm._process_stage_inputs(dit_config, stage_args, None) + assert "encoded_image" in result +``` + +#### test_async_omni_llm.py +```python +import pytest +import asyncio +from unittest.mock import Mock, patch, AsyncMock +from vllm_omni.core.omni_llm import AsyncOmniLLM + +class TestAsyncOmniLLM: + @pytest.fixture + def mock_stage_configs(self): + """Create mock stage configurations""" + return [ + OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-ar-model", + input_modalities=["text"], + output_modalities=["text"] + ) + ] + + @pytest.mark.asyncio + async def test_async_generation(self, mock_stage_configs): + """Test async generation""" + with patch('vllm_omni.core.omni_llm.AsyncLLM') as mock_async_llm: + async_omni_llm = AsyncOmniLLM(mock_stage_configs) + + # Mock the generate_async method + mock_async_llm.return_value.generate_async = AsyncMock(return_value=[]) + + result = await async_omni_llm.generate_async([{"prompt": "test"}]) + assert result == [] +``` + +### 3.3 Scheduler Tests + +#### test_diffusion_scheduler.py +```python +import pytest +from unittest.mock import Mock, patch +from vllm_omni.core.sched.diffusion_scheduler import OmniDiffusionScheduler + +class TestOmniDiffusionScheduler: + @pytest.fixture + def mock_configs(self): + """Create mock configurations""" + return { + 'vllm_config': Mock(), + 'kv_cache_config': Mock(), + 'dit_cache_config': Mock(), + 'structured_output_manager': Mock(), + 'mm_registry': Mock(), + 'include_finished_set': False, + 'log_stats': False + } + + def test_scheduler_initialization(self, mock_configs): + """Test scheduler initialization""" + scheduler = OmniDiffusionScheduler(**mock_configs) + assert scheduler.dit_cache_manager is not None + + def test_schedule_method(self, mock_configs): + """Test scheduling method""" + scheduler = OmniDiffusionScheduler(**mock_configs) + + with patch.object(scheduler, 'dit_cache_manager') as mock_cache: + mock_cache.allocate_cache.return_value = Mock() + result = scheduler.schedule() + # Verify scheduling logic + assert result is not None +``` + +### 3.4 Model Runner Tests + +#### test_ar_model_runner.py +```python +import pytest +import torch +from unittest.mock import Mock, patch +from vllm_omni.model_executor.ar_model_runner import OmniARModelRunner + +class TestOmniARModelRunner: + @pytest.fixture + def mock_runner(self): + """Create mock model runner""" + with patch('vllm_omni.model_executor.ar_model_runner.GPUModelRunner.__init__'): + runner = OmniARModelRunner(Mock(), Mock(), Mock()) + return runner + + def test_execute_model_ar(self, mock_runner): + """Test AR model execution""" + mock_scheduler_output = Mock() + mock_scheduler_output.req_ids = ["req1"] + mock_scheduler_output.req_id_to_index = {"req1": 0} + + with patch.object(mock_runner, 'model') as mock_model: + mock_model.forward.return_value = (torch.randn(1, 10, 768), None) + + result = mock_runner.execute_model(mock_scheduler_output) + + assert result.req_ids == ["req1"] + assert len(result.pooler_output) > 0 # Hidden states +``` + +#### test_dit_model_runner.py +```python +import pytest +import torch +from unittest.mock import Mock, patch +from vllm_omni.model_executor.dit_model_runner import OmniDiffusionModelRunner + +class TestOmniDiffusionModelRunner: + @pytest.fixture + def mock_runner(self): + """Create mock DiT model runner""" + with patch('vllm_omni.model_executor.dit_model_runner.GPUModelRunner.__init__'): + runner = OmniDiffusionModelRunner(Mock(), Mock(), Mock()) + return runner + + def test_execute_model_dit(self, mock_runner): + """Test DiT model execution""" + mock_scheduler_output = Mock() + mock_scheduler_output.req_ids = ["req1"] + mock_scheduler_output.req_id_to_index = {"req1": 0} + + with patch.object(mock_runner, 'model') as mock_model: + mock_model.forward.return_value = torch.randn(1, 3, 512, 512) + + result = mock_runner.execute_model(mock_scheduler_output) + + assert result.req_ids == ["req1"] + assert len(result.pooler_output) > 0 # DiT output tensors +``` + +## 4. Integration Tests + +### 4.1 vLLM Integration Tests + +#### test_vllm_integration.py +```python +import pytest +from vllm_omni.core.omni_llm import OmniLLM +from vllm_omni.config.stage_config import OmniStageConfig + +class TestVLLMIntegration: + @pytest.mark.integration + def test_vllm_engine_creation(self): + """Test creation of vLLM engines""" + stage_configs = [ + OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="microsoft/DialoGPT-small", # Small test model + input_modalities=["text"], + output_modalities=["text"] + ) + ] + + omni_llm = OmniLLM(stage_configs) + assert len(omni_llm.engine_list) == 1 + assert omni_llm.engine_list[0] is not None + + @pytest.mark.integration + def test_stage_processing_integration(self): + """Test integration between stages""" + # Test with real vLLM components + pass +``` + +### 4.2 CLI Integration Tests + +#### test_cli_integration.py +```python +import pytest +import subprocess +import sys +from vllm_omni.entrypoints.cli.main import main + +class TestCLIIntegration: + def test_omni_flag_detection(self): + """Test --omni flag detection""" + original_argv = sys.argv + try: + sys.argv = ["vllm", "serve", "test-model", "--omni"] + # Test that omni command is triggered + with patch('vllm_omni.entrypoints.omni.OmniServeCommand') as mock_serve: + main() + mock_serve.assert_called_once() + finally: + sys.argv = original_argv + + def test_forward_to_vllm(self): + """Test forwarding to vLLM when --omni not present""" + original_argv = sys.argv + try: + sys.argv = ["vllm", "serve", "test-model"] + with patch('vllm.entrypoints.cli.main.main') as mock_vllm_main: + main() + mock_vllm_main.assert_called_once() + finally: + sys.argv = original_argv +``` + +## 5. End-to-End Tests + +### 5.1 Full Pipeline Tests + +#### test_full_pipeline.py +```python +import pytest +from vllm_omni.core.omni_llm import OmniLLM +from vllm_omni.config.stage_config import OmniStageConfig, DiTConfig + +class TestFullPipeline: + @pytest.mark.e2e + @pytest.mark.slow + def test_ar_to_dit_pipeline(self): + """Test complete AR to DiT pipeline""" + stage_configs = [ + OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="microsoft/DialoGPT-small", + input_modalities=["text"], + output_modalities=["text"] + ), + OmniStageConfig( + stage_id=1, + engine_type="DiT", + model_path="stabilityai/stable-diffusion-2-1", + input_modalities=["text"], + output_modalities=["image"], + dit_config=DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=10 # Reduced for testing + ) + ) + ] + + omni_llm = OmniLLM(stage_configs) + + stage_args = [ + {"prompt": "A beautiful landscape"}, + {"prompt": "A beautiful landscape"} + ] + + result = omni_llm.generate(stage_args) + assert result is not None + assert len(result) > 0 +``` + +### 5.2 Multimodal Generation Tests + +#### test_multimodal_generation.py +```python +import pytest +from vllm_omni.core.omni_llm import OmniLLM + +class TestMultimodalGeneration: + @pytest.mark.e2e + def test_text_to_image_generation(self): + """Test text to image generation""" + # Test implementation + pass + + @pytest.mark.e2e + def test_image_to_text_generation(self): + """Test image to text generation""" + # Test implementation + pass + + @pytest.mark.e2e + def test_text_and_image_generation(self): + """Test combined text and image generation""" + # Test implementation + pass +``` + +## 6. Performance Tests + +### 6.1 Memory Usage Tests + +#### test_memory_usage.py +```python +import pytest +import psutil +import torch +from vllm_omni.core.omni_llm import OmniLLM + +class TestMemoryUsage: + @pytest.mark.benchmark + def test_memory_usage_ar_stage(self): + """Test memory usage for AR stage""" + process = psutil.Process() + initial_memory = process.memory_info().rss + + # Run AR stage + # ... test implementation + + final_memory = process.memory_info().rss + memory_increase = final_memory - initial_memory + + # Assert memory usage is reasonable + assert memory_increase < 1024 * 1024 * 1024 # Less than 1GB + + @pytest.mark.benchmark + def test_memory_usage_dit_stage(self): + """Test memory usage for DiT stage""" + # Test implementation + pass + + @pytest.mark.benchmark + def test_cache_memory_management(self): + """Test DiT cache memory management""" + # Test implementation + pass +``` + +### 6.2 Latency Tests + +#### test_latency.py +```python +import pytest +import time +from vllm_omni.core.omni_llm import OmniLLM + +class TestLatency: + @pytest.mark.benchmark + def test_ar_stage_latency(self): + """Test AR stage latency""" + start_time = time.time() + + # Run AR stage + # ... test implementation + + end_time = time.time() + latency = end_time - start_time + + # Assert latency is reasonable + assert latency < 5.0 # Less than 5 seconds + + @pytest.mark.benchmark + def test_dit_stage_latency(self): + """Test DiT stage latency""" + # Test implementation + pass + + @pytest.mark.benchmark + def test_end_to_end_latency(self): + """Test end-to-end pipeline latency""" + # Test implementation + pass +``` + +## 7. Test Configuration + +### 7.1 pytest.ini +```ini +[tool:pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + --strict-markers + --strict-config + --cov=vllm_omni + --cov-report=term-missing + --cov-report=html + --cov-report=xml + --tb=short +markers = + unit: Unit tests + integration: Integration tests + e2e: End-to-end tests + benchmark: Performance benchmark tests + slow: Slow running tests +``` + +### 7.2 conftest.py +```python +import pytest +import torch +from unittest.mock import Mock +from vllm_omni.config.stage_config import OmniStageConfig, DiTConfig + +@pytest.fixture(scope="session") +def device(): + """Get available device for testing""" + return "cuda" if torch.cuda.is_available() else "cpu" + +@pytest.fixture +def sample_stage_configs(): + """Sample stage configurations for testing""" + return [ + OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-ar-model", + input_modalities=["text"], + output_modalities=["text"] + ), + OmniStageConfig( + stage_id=1, + engine_type="DiT", + model_path="test-dit-model", + input_modalities=["text"], + output_modalities=["image"], + dit_config=DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=10 + ) + ) + ] + +@pytest.fixture +def mock_vllm_config(): + """Mock vLLM configuration""" + config = Mock() + config.model = "test-model" + config.tensor_parallel_size = 1 + config.pipeline_parallel_size = 1 + return config +``` + +## 8. Test Execution + +### 8.1 Running Tests +```bash +# Run all tests +pytest + +# Run specific test categories +pytest -m unit +pytest -m integration +pytest -m e2e +pytest -m benchmark + +# Run with coverage +pytest --cov=vllm_omni --cov-report=html + +# Run specific test file +pytest tests/unit/test_omni_llm.py + +# Run with verbose output +pytest -v + +# Run in parallel +pytest -n auto +``` + +### 8.2 Continuous Integration +```yaml +# .github/workflows/test.yml +name: Tests +on: [push, pull_request] +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8, 3.9, 3.10, 3.11] + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install -e ".[dev]" + - name: Run tests + run: | + pytest --cov=vllm_omni --cov-report=xml + - name: Upload coverage + uses: codecov/codecov-action@v3 +``` + +## 9. Test Data Management + +### 9.1 Fixtures Directory +``` +tests/fixtures/ +β”œβ”€β”€ sample_models/ +β”‚ β”œβ”€β”€ ar_model/ +β”‚ └── dit_model/ +β”œβ”€β”€ test_images/ +β”‚ β”œβ”€β”€ landscape.jpg +β”‚ └── portrait.png +β”œβ”€β”€ test_configs/ +β”‚ β”œβ”€β”€ ar_stage_config.json +β”‚ └── dit_stage_config.json +└── expected_outputs/ + β”œβ”€β”€ ar_output.json + └── dit_output.json +``` + +### 9.2 Test Data Loading +```python +import json +import os +from pathlib import Path + +def load_test_config(config_name: str) -> dict: + """Load test configuration from fixtures""" + config_path = Path(__file__).parent / "fixtures" / "test_configs" / f"{config_name}.json" + with open(config_path) as f: + return json.load(f) + +def load_test_image(image_name: str) -> str: + """Load test image path from fixtures""" + return str(Path(__file__).parent / "fixtures" / "test_images" / image_name) +``` + +## 10. Test Quality Metrics + +### 10.1 Coverage Targets +- **Overall Coverage**: > 90% +- **Unit Tests**: > 95% +- **Integration Tests**: > 80% +- **Critical Paths**: 100% + +### 10.2 Performance Targets +- **Unit Test Execution**: < 30 seconds +- **Integration Test Execution**: < 5 minutes +- **E2E Test Execution**: < 15 minutes +- **Memory Usage**: < 8GB for full test suite + +### 10.3 Quality Gates +- All unit tests must pass +- Integration tests must pass +- Coverage threshold must be met +- No critical security vulnerabilities +- Performance regression tests must pass diff --git a/pyproject.toml b/pyproject.toml index ca6beb47748..9b54c16d499 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,11 @@ Documentation = "https://vllm-omni.readthedocs.io" "Bug Tracker" = "https://github.com/hsliuustc0106/vllm-omni/issues" [project.scripts] -vllm-omni = "vllm_omni.cli:main" +vllm = "vllm_omni.entrypoints.cli.main:main" +vllm-omni = "vllm_omni.entrypoints.cli.main:main" + +[project.entry-points."vllm.plugins"] +omni = "vllm_omni.plugin:OmniPlugin" [tool.setuptools.packages.find] where = ["."] diff --git a/requirements.txt b/requirements.txt index 45339bd420a..65c3d9971a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Core dependencies -vllm>=0.9.0 -torch>=2.0.0 +vllm>=0.10.2 +torch>=2.7 transformers>=4.30.0 numpy>=1.21.0 pillow>=8.0.0 @@ -16,6 +16,11 @@ gradio>=3.0.0 asyncio-mqtt>=0.11.0 ray>=2.0.0 -# Optional: DiT acceleration modules -# xDiT # Uncomment when available -# Cache-DiT # Uncomment when available +# Development dependencies +pytest>=7.0.0 +pytest-asyncio>=0.21.0 +pytest-cov>=4.0.0 +black>=23.0.0 +isort>=5.12.0 +flake8>=6.0.0 +mypy>=1.0.0 diff --git a/tests/conftest.py b/tests/conftest.py index 8d40a3e1383..cdec3720174 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,118 +1,82 @@ """ -Pytest configuration and fixtures for vLLM-omni tests. +Shared fixtures and configuration for vLLM-omni tests. """ import pytest -import asyncio -from typing import AsyncGenerator, Generator -from unittest.mock import Mock, AsyncMock - import torch -import numpy as np - -from vllm_omni.configs import load_config +from unittest.mock import Mock +from vllm_omni.config import OmniStageConfig, DiTConfig, DiTCacheConfig, create_ar_stage_config, create_dit_stage_config @pytest.fixture(scope="session") -def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: - """Create an instance of the default event loop for the test session.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - -@pytest.fixture -async def mock_async_llm() -> AsyncMock: - """Mock AsyncLLM for testing.""" - mock = AsyncMock() - mock.process.return_value = {"output": "test_output", "hidden_states": None} - return mock - - -@pytest.fixture -async def mock_omni_engine() -> AsyncMock: - """Mock OmniEngine for testing.""" - mock = AsyncMock() - mock.execute.return_value = {"result": "test_result"} - return mock - - -@pytest.fixture -def sample_config() -> dict: - """Sample configuration for testing.""" - return { - "general": { - "name": "vllm-omni-test", - "version": "0.1.0", - "debug": True, - "log_level": "DEBUG" - }, - "model": { - "device": "cpu", - "dtype": "float32", - "max_model_len": 1024 - }, - "engines": { - "ar_engine": {"enabled": True, "max_batch_size": 32}, - "diffusion_engine": {"enabled": True, "max_batch_size": 8} - } - } +def device(): + """Get available device for testing.""" + return "cuda" if torch.cuda.is_available() else "cpu" @pytest.fixture -def sample_text_input() -> str: - """Sample text input for testing.""" - return "Hello, world! This is a test input for vLLM-omni." +def sample_ar_stage_config(): + """Sample AR stage configuration for testing.""" + return create_ar_stage_config( + stage_id=0, + model_path="test-ar-model", + input_modalities=["text"], + output_modalities=["text"] + ) @pytest.fixture -def sample_image_input() -> np.ndarray: - """Sample image input for testing.""" - return np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) +def sample_dit_stage_config(): + """Sample DiT stage configuration for testing.""" + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=10, + guidance_scale=7.5 + ) + + return create_dit_stage_config( + stage_id=1, + model_path="test-dit-model", + input_modalities=["text"], + output_modalities=["image"], + dit_config=dit_config + ) @pytest.fixture -def sample_audio_input() -> np.ndarray: - """Sample audio input for testing.""" - return np.random.randn(16000).astype(np.float32) # 1 second at 16kHz +def sample_stage_configs(sample_ar_stage_config, sample_dit_stage_config): + """Sample stage configurations for testing.""" + return [sample_ar_stage_config, sample_dit_stage_config] @pytest.fixture -def sample_multimodal_input() -> dict: - """Sample multimodal input for testing.""" - return { - "text": "Describe this image", - "image": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), - "audio": np.random.randn(16000).astype(np.float32) - } +def mock_vllm_config(): + """Mock vLLM configuration.""" + config = Mock() + config.model = "test-model" + config.tensor_parallel_size = 1 + config.pipeline_parallel_size = 1 + return config @pytest.fixture -def mock_torch_device() -> str: - """Mock torch device for testing.""" - if torch.cuda.is_available(): - return "cuda" - elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - return "mps" - else: - return "cpu" - - -# Pytest markers -pytest_plugins = [] - - -def pytest_configure(config): - """Configure pytest with custom markers.""" - config.addinivalue_line( - "markers", "unit: mark test as a unit test" - ) - config.addinivalue_line( - "markers", "integration: mark test as an integration test" - ) - config.addinivalue_line( - "markers", "benchmark: mark test as a benchmark test" - ) - config.addinivalue_line( - "markers", "slow: mark test as slow running" - ) +def mock_dit_cache_config(): + """Mock DiT cache configuration.""" + from vllm_omni.config import DiTCacheTensor + + cache_tensors = [ + DiTCacheTensor( + name="test_tensor", + shape=[1, 512, 512], + dtype="float32", + persistent=True + ) + ] + + return DiTCacheConfig( + cache_tensors=cache_tensors, + max_cache_size=1024 * 1024 * 1024, # 1GB + cache_strategy="fifo", + enable_optimization=True + ) \ No newline at end of file diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 00000000000..1d8229156e3 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,179 @@ +""" +Unit tests for configuration modules. +""" + +import pytest +from vllm_omni.config import ( + OmniStageConfig, + DiTConfig, + DiTCacheConfig, + DiTCacheTensor, + create_ar_stage_config, + create_dit_stage_config, +) + + +class TestOmniStageConfig: + def test_ar_stage_config_creation(self): + """Test AR stage configuration creation.""" + config = OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-model", + input_modalities=["text"], + output_modalities=["text"] + ) + assert config.engine_type == "AR" + assert config.stage_id == 0 + assert config.model_path == "test-model" + assert config.input_modalities == ["text"] + assert config.output_modalities == ["text"] + + def test_dit_stage_config_creation(self): + """Test DiT stage configuration creation.""" + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50 + ) + config = OmniStageConfig( + stage_id=1, + engine_type="DiT", + model_path="test-dit-model", + input_modalities=["text"], + output_modalities=["image"], + dit_config=dit_config + ) + assert config.engine_type == "DiT" + assert config.dit_config is not None + assert config.dit_config.model_type == "dit" + + def test_invalid_engine_type(self): + """Test validation of engine type.""" + with pytest.raises(ValueError): + OmniStageConfig( + stage_id=0, + engine_type="INVALID", + model_path="test-model", + input_modalities=["text"], + output_modalities=["text"] + ) + + def test_empty_modalities(self): + """Test validation of empty modalities.""" + with pytest.raises(ValueError): + OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-model", + input_modalities=[], + output_modalities=["text"] + ) + + with pytest.raises(ValueError): + OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-model", + input_modalities=["text"], + output_modalities=[] + ) + + def test_dit_requires_config(self): + """Test that DiT engine requires dit_config.""" + with pytest.raises(ValueError): + OmniStageConfig( + stage_id=1, + engine_type="DiT", + model_path="test-dit-model", + input_modalities=["text"], + output_modalities=["image"] + ) + + +class TestDiTConfig: + def test_dit_config_defaults(self): + """Test DiT configuration with defaults.""" + config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50 + ) + assert config.use_diffusers is False + assert config.diffusers_pipeline is None + assert config.guidance_scale == 7.5 + assert config.height == 512 + assert config.width == 512 + + def test_diffusers_config(self): + """Test DiT configuration with diffusers.""" + config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50, + use_diffusers=True, + diffusers_pipeline="stable-diffusion" + ) + assert config.use_diffusers is True + assert config.diffusers_pipeline == "stable-diffusion" + + +class TestDiTCacheConfig: + def test_cache_config_creation(self): + """Test cache configuration creation.""" + cache_tensors = [ + DiTCacheTensor( + name="test_tensor", + shape=[1, 512, 512], + dtype="float32", + persistent=True + ) + ] + + config = DiTCacheConfig( + cache_tensors=cache_tensors, + max_cache_size=1024 * 1024 * 1024, + cache_strategy="fifo" + ) + + assert len(config.cache_tensors) == 1 + assert config.max_cache_size == 1024 * 1024 * 1024 + assert config.cache_strategy == "fifo" + assert config.enable_optimization is True + + +class TestHelperFunctions: + def test_create_ar_stage_config(self): + """Test create_ar_stage_config helper function.""" + config = create_ar_stage_config( + stage_id=0, + model_path="test-model" + ) + + assert config.stage_id == 0 + assert config.engine_type == "AR" + assert config.model_path == "test-model" + assert config.input_modalities == ["text"] + assert config.output_modalities == ["text"] + + def test_create_dit_stage_config(self): + """Test create_dit_stage_config helper function.""" + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50 + ) + + config = create_dit_stage_config( + stage_id=1, + model_path="test-dit-model", + dit_config=dit_config + ) + + assert config.stage_id == 1 + assert config.engine_type == "DiT" + assert config.model_path == "test-dit-model" + assert config.input_modalities == ["text"] + assert config.output_modalities == ["image"] + assert config.dit_config is not None + diff --git a/vllm_omni/__init__.py b/vllm_omni/__init__.py index 408545dfd9e..2cdd71b1fb4 100644 --- a/vllm_omni/__init__.py +++ b/vllm_omni/__init__.py @@ -14,9 +14,14 @@ __email__ = "hsliuustc@gmail.com" # Main entry points -from .async_llm import AsyncLLM -from .omni_engine import OmniEngine -from .diffusion_engine import DiffusionEngine +from .core.omni_llm import OmniLLM, AsyncOmniLLM +from .config import ( + OmniStageConfig, + DiTConfig, + DiTCacheConfig, + create_ar_stage_config, + create_dit_stage_config, +) __all__ = [ # Version info @@ -25,9 +30,15 @@ "__email__", # Main components - "AsyncLLM", - "OmniEngine", - "DiffusionEngine", + "OmniLLM", + "AsyncOmniLLM", + + # Configuration + "OmniStageConfig", + "DiTConfig", + "DiTCacheConfig", + "create_ar_stage_config", + "create_dit_stage_config", # All other components are available through their respective modules # processors.*, schedulers.*, executors.*, etc. diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py index 5174e4a0459..e48a874a7ce 100644 --- a/vllm_omni/config/__init__.py +++ b/vllm_omni/config/__init__.py @@ -1,22 +1,21 @@ """ -Configuration management for vLLM-omni. +Configuration module for vLLM-omni. """ -from dataclasses import dataclass, field -from pathlib import Path -from typing import Dict, Any, Optional -from vllm_omni.dit_cache_interface import DiTCacheConfig +from .stage_config import ( + OmniStageConfig, + DiTConfig, + DiTCacheConfig, + DiTCacheTensor, + create_ar_stage_config, + create_dit_stage_config, +) -from vllm.config import VllmConfig - - -@dataclass -class OmniConfig: - """ - The configuration for vLLM-omni. - """ - - """vllm config""" - vllm_config: VllmConfig = field(default_factory=VllmConfig) - """DiT cache config""" - dit_cache_config: DiTCacheConfig = field(default_factory=DiTCacheConfig) # DiT cache config \ No newline at end of file +__all__ = [ + "OmniStageConfig", + "DiTConfig", + "DiTCacheConfig", + "DiTCacheTensor", + "create_ar_stage_config", + "create_dit_stage_config", +] \ No newline at end of file diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py new file mode 100644 index 00000000000..e04cd3b66ef --- /dev/null +++ b/vllm_omni/config/stage_config.py @@ -0,0 +1,145 @@ +""" +Stage configuration for vLLM-omni multi-stage processing. +""" + +from dataclasses import dataclass +from typing import List, Optional, Type, Literal, Any +from vllm.config import VllmConfig +from vllm.executor.executor_base import ExecutorBase as Executor + + +@dataclass +class DiTConfig: + """Configuration for DiT (Diffusion Transformer) stages.""" + model_type: str + scheduler_type: str + num_inference_steps: int + guidance_scale: float = 7.5 + use_diffusers: bool = False + diffusers_pipeline: Optional[str] = None + height: int = 512 + width: int = 512 + batch_size: int = 1 + + +@dataclass +class DiTCacheTensor: + """Configuration for DiT cache tensors.""" + name: str + shape: List[int] + dtype: str = "float32" + persistent: bool = True + + +@dataclass +class DiTCacheConfig: + """Configuration for DiT caching system.""" + cache_tensors: List[DiTCacheTensor] + max_cache_size: int = 1024 * 1024 * 1024 # 1GB + cache_strategy: str = "fifo" # fifo, lru, lfu + enable_optimization: bool = True + cache_compression: bool = False + + +@dataclass +class OmniStageConfig: + """Configuration for a processing stage in vLLM-omni.""" + stage_id: int + engine_type: Literal["AR", "DiT"] + model_path: str + input_modalities: List[str] + output_modalities: List[str] + vllm_config: Optional[VllmConfig] = None + executor_class: Type[Executor] = None # Will be set based on engine_type + dit_config: Optional[DiTConfig] = None + cache_config: Optional[DiTCacheConfig] = None + stage_output: Optional[Any] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if self.engine_type not in ["AR", "DiT"]: + raise ValueError(f"Invalid engine_type: {self.engine_type}. Must be 'AR' or 'DiT'") + + if self.engine_type == "DiT" and self.dit_config is None: + raise ValueError("DiT engine requires dit_config") + + if not self.input_modalities: + raise ValueError("input_modalities cannot be empty") + + if not self.output_modalities: + raise ValueError("output_modalities cannot be empty") + + def get_executor_class(self) -> Type[Executor]: + """Get the appropriate executor class for this stage.""" + if self.executor_class is not None: + return self.executor_class + + if self.engine_type == "AR": + from vllm.executor.uniproc_executor import UniProcExecutor + return UniProcExecutor + elif self.engine_type == "DiT": + if self.dit_config and self.dit_config.use_diffusers: + from vllm_omni.executor.diffusers_executor import DiffusersPipelineExecutor + return DiffusersPipelineExecutor + else: + from vllm.executor.uniproc_executor import UniProcExecutor + return UniProcExecutor + + raise ValueError(f"No executor class available for engine_type: {self.engine_type}") + + +def create_ar_stage_config( + stage_id: int, + model_path: str, + input_modalities: List[str] = None, + output_modalities: List[str] = None, + vllm_config: Optional[VllmConfig] = None +) -> OmniStageConfig: + """Create a configuration for an AR (Autoregressive) stage.""" + if input_modalities is None: + input_modalities = ["text"] + if output_modalities is None: + output_modalities = ["text"] + + return OmniStageConfig( + stage_id=stage_id, + engine_type="AR", + model_path=model_path, + input_modalities=input_modalities, + output_modalities=output_modalities, + vllm_config=vllm_config + ) + + +def create_dit_stage_config( + stage_id: int, + model_path: str, + input_modalities: List[str] = None, + output_modalities: List[str] = None, + dit_config: Optional[DiTConfig] = None, + cache_config: Optional[DiTCacheConfig] = None, + vllm_config: Optional[VllmConfig] = None +) -> OmniStageConfig: + """Create a configuration for a DiT (Diffusion Transformer) stage.""" + if input_modalities is None: + input_modalities = ["text"] + if output_modalities is None: + output_modalities = ["image"] + + if dit_config is None: + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50 + ) + + return OmniStageConfig( + stage_id=stage_id, + engine_type="DiT", + model_path=model_path, + input_modalities=input_modalities, + output_modalities=output_modalities, + vllm_config=vllm_config, + dit_config=dit_config, + cache_config=cache_config + ) diff --git a/vllm_omni/core/dit_cache_manager.py b/vllm_omni/core/dit_cache_manager.py index e69de29bb2d..d2c800e76f5 100644 --- a/vllm_omni/core/dit_cache_manager.py +++ b/vllm_omni/core/dit_cache_manager.py @@ -0,0 +1,178 @@ +""" +DiT Cache Manager for vLLM-omni. + +This module provides caching functionality for DiT (Diffusion Transformer) models +to optimize inference performance and memory usage. +""" + +import time +import torch +from typing import Dict, List, Optional, Any +from dataclasses import dataclass +from ..config import DiTCacheConfig, DiTCacheTensor + + +class DiTCacheManager: + """Manages DiT-specific caching for optimized inference.""" + + def __init__(self, config: DiTCacheConfig): + self.config = config + self.cache_tensors: Dict[str, DiTCacheTensor] = {} + self.cache_groups: List[DiTCacheTensor] = config.cache_tensors + self.max_cache_size = config.max_cache_size + self.current_cache_size = 0 + self.cache_hits = 0 + self.cache_misses = 0 + self.cache_timeout = 3600 # 1 hour default timeout + + def allocate_cache(self, request_id: str, size: int) -> torch.Tensor: + """Allocate cache for a specific request.""" + # Check if we have enough space + if self.current_cache_size + size > self.max_cache_size: + self._evict_cache() + + # Create tensor + tensor = torch.zeros(size, dtype=torch.float32) + + # Store in cache + cache_tensor = DiTCacheTensor( + name=request_id, + tensor=tensor, + timestamp=time.time(), + access_count=0, + size_bytes=size * 4 # Assuming float32 + ) + + self.cache_tensors[request_id] = cache_tensor + self.current_cache_size += cache_tensor.size_bytes + + return tensor + + def get_cache(self, request_id: str) -> Optional[torch.Tensor]: + """Get cached tensor for a request.""" + if request_id in self.cache_tensors: + cache_tensor = self.cache_tensors[request_id] + cache_tensor.access_count += 1 + self.cache_hits += 1 + return cache_tensor.tensor + else: + self.cache_misses += 1 + return None + + def store_cache(self, request_id: str, tensor: torch.Tensor) -> None: + """Store a tensor in the cache.""" + if request_id in self.cache_tensors: + # Update existing cache + old_tensor = self.cache_tensors[request_id] + self.current_cache_size -= old_tensor.size_bytes + + new_size_bytes = tensor.numel() * 4 # Assuming float32 + self.current_cache_size += new_size_bytes + + self.cache_tensors[request_id] = DiTCacheTensor( + name=request_id, + tensor=tensor, + timestamp=time.time(), + access_count=old_tensor.access_count, + size_bytes=new_size_bytes + ) + else: + # Create new cache entry + new_size_bytes = tensor.numel() * 4 + if self.current_cache_size + new_size_bytes > self.max_cache_size: + self._evict_cache() + + self.cache_tensors[request_id] = DiTCacheTensor( + name=request_id, + tensor=tensor, + timestamp=time.time(), + access_count=0, + size_bytes=new_size_bytes + ) + self.current_cache_size += new_size_bytes + + def release_cache(self, request_id: str) -> None: + """Release cache for a request.""" + if request_id in self.cache_tensors: + cache_tensor = self.cache_tensors[request_id] + self.current_cache_size -= cache_tensor.size_bytes + del self.cache_tensors[request_id] + + def clear_expired_cache(self) -> None: + """Clear expired cache entries.""" + current_time = time.time() + expired_keys = [] + + for request_id, cache_tensor in self.cache_tensors.items(): + if current_time - cache_tensor.timestamp > self.cache_timeout: + expired_keys.append(request_id) + + for request_id in expired_keys: + self.release_cache(request_id) + + def _evict_cache(self) -> None: + """Evict cache entries based on strategy.""" + if self.config.cache_strategy == "fifo": + self._evict_fifo() + elif self.config.cache_strategy == "lru": + self._evict_lru() + elif self.config.cache_strategy == "lfu": + self._evict_lfu() + else: + self._evict_fifo() # Default to FIFO + + def _evict_fifo(self) -> None: + """Evict cache entries using FIFO strategy.""" + if not self.cache_tensors: + return + + # Find oldest entry + oldest_key = min(self.cache_tensors.keys(), + key=lambda k: self.cache_tensors[k].timestamp) + self.release_cache(oldest_key) + + def _evict_lru(self) -> None: + """Evict cache entries using LRU strategy.""" + if not self.cache_tensors: + return + + # Find least recently used entry + lru_key = min(self.cache_tensors.keys(), + key=lambda k: self.cache_tensors[k].access_count) + self.release_cache(lru_key) + + def _evict_lfu(self) -> None: + """Evict cache entries using LFU strategy.""" + if not self.cache_tensors: + return + + # Find least frequently used entry + lfu_key = min(self.cache_tensors.keys(), + key=lambda k: self.cache_tensors[k].access_count) + self.release_cache(lfu_key) + + def get_statistics(self) -> Dict[str, Any]: + """Get cache statistics.""" + total_requests = self.cache_hits + self.cache_misses + hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0 + + return { + "cache_size": self.current_cache_size, + "max_cache_size": self.max_cache_size, + "cache_utilization": self.current_cache_size / self.max_cache_size, + "cache_hits": self.cache_hits, + "cache_misses": self.cache_misses, + "hit_rate": hit_rate, + "num_cached_tensors": len(self.cache_tensors) + } + + def clear_all_cache(self) -> None: + """Clear all cache entries.""" + self.cache_tensors.clear() + self.current_cache_size = 0 + self.cache_hits = 0 + self.cache_misses = 0 + + def set_cache_timeout(self, timeout: float) -> None: + """Set cache timeout in seconds.""" + self.cache_timeout = timeout \ No newline at end of file diff --git a/vllm_omni/core/omni_llm.py b/vllm_omni/core/omni_llm.py new file mode 100644 index 00000000000..9b94ccf8a79 --- /dev/null +++ b/vllm_omni/core/omni_llm.py @@ -0,0 +1,332 @@ +""" +Core OmniLLM and AsyncOmniLLM classes for multi-stage processing. +""" + +import asyncio +from typing import List, Dict, Any, Optional, Union, Callable +from vllm.entrypoints.llm import LLM +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.outputs import RequestOutput +from vllm.outputs import LoRARequest + +from ..config import OmniStageConfig +from .stage_manager import StageManager +from ..engine.output_processor import MultimodalOutputProcessor + + +class OmniLLM(LLM): + """Extended LLM supporting multiple engines and stage-based processing.""" + + def __init__( + self, + stage_configs: List[OmniStageConfig], + log_stats: bool = False, + **kwargs + ): + # Use the first stage's model as the default model for LLM + default_model = stage_configs[0].model_path if stage_configs else "test-model" + super().__init__(model=default_model, **kwargs) + self.stage_configs = stage_configs + self.log_stats = log_stats + self.stage_manager = StageManager(stage_configs, log_stats) + self.output_processor = MultimodalOutputProcessor() + self._initialize_stage_engines() + + def _initialize_stage_engines(self) -> None: + """Initialize LLMEngine instances for each stage.""" + self.stage_manager.initialize_engines() + + def generate( + self, + stage_args_list: List[Dict[str, Any]], + use_tqdm: Union[bool, Callable[..., Any]] = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + priority: Optional[List[int]] = None, + **kwargs + ) -> List[RequestOutput]: + """Main generation interface - orchestrates multi-stage processing.""" + + if len(stage_args_list) != len(self.stage_configs): + raise ValueError( + f"Number of stage arguments ({len(stage_args_list)}) must match " + f"number of stage configs ({len(self.stage_configs)})" + ) + + # Process through each stage sequentially + current_output = None + + for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): + stage_engine = self.stage_manager.get_engine(i) + + # Prepare input for this stage + processed_input = self._process_stage_inputs( + stage_config, stage_args, current_output + ) + + # Execute stage + stage_output = self._execute_stage( + stage_engine, processed_input, lora_request, priority + ) + + # Update for next stage + current_output = stage_output + stage_config.stage_output = stage_output + + # Process final output + final_output = self.output_processor.process_output(current_output) + return final_output + + def _process_stage_inputs( + self, + stage_config: OmniStageConfig, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Prepare input for specific stage.""" + if stage_config.engine_type == "AR": + return self._process_ar_inputs(stage_args, previous_output) + elif stage_config.engine_type == "DiT": + return self._process_dit_inputs(stage_args, previous_output) + else: + raise NotImplementedError(f"Unknown engine type: {stage_config.engine_type}") + + def _process_ar_inputs( + self, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Process inputs for AR stage.""" + # For AR stages, we typically use text prompts + processed_input = { + "prompt": stage_args.get("prompt", ""), + "max_tokens": stage_args.get("max_tokens", 100), + "temperature": stage_args.get("temperature", 0.7), + } + + # If we have previous output (e.g., from a previous AR stage), + # we might want to use it as context + if previous_output is not None: + # Extract text from previous output if available + if hasattr(previous_output, 'outputs') and previous_output.outputs: + last_output = previous_output.outputs[-1] + if hasattr(last_output, 'text'): + processed_input["prompt"] = last_output.text + " " + processed_input["prompt"] + + return processed_input + + def _process_dit_inputs( + self, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Process inputs for DiT stage.""" + processed_input = { + "prompt": stage_args.get("prompt", ""), + "height": stage_args.get("height", 512), + "width": stage_args.get("width", 512), + "num_inference_steps": stage_args.get("num_inference_steps", 50), + "guidance_scale": stage_args.get("guidance_scale", 7.5), + } + + # Handle image inputs if present + if "image" in stage_args: + # For now, we'll pass the image path directly + # In a full implementation, this would involve VAE encoding + processed_input["image"] = stage_args["image"] + + # If we have previous output from an AR stage, we might want to use it + if previous_output is not None: + # Extract text from previous AR output + if hasattr(previous_output, 'outputs') and previous_output.outputs: + last_output = previous_output.outputs[-1] + if hasattr(last_output, 'text'): + processed_input["prompt"] = last_output.text + + return processed_input + + def _execute_stage( + self, + stage_engine: LLMEngine, + processed_input: Dict[str, Any], + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + priority: Optional[List[int]] = None + ) -> Any: + """Execute a single stage.""" + # This is a simplified implementation + # In practice, this would involve proper request management + # and integration with vLLM's engine system + + # For now, we'll return a mock output + # In the full implementation, this would call stage_engine.generate() + from vllm.outputs import RequestOutput, CompletionOutput + + mock_output = CompletionOutput( + index=0, + text=processed_input.get("prompt", ""), + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="length" + ) + + return RequestOutput( + request_id="mock_request", + prompt=processed_input.get("prompt", ""), + prompt_token_ids=[], + outputs=[mock_output], + finished=True + ) + + +class AsyncOmniLLM(AsyncLLM): + """Extended AsyncLLM supporting multiple engines and stage-based processing.""" + + def __init__( + self, + stage_configs: List[OmniStageConfig], + log_stats: bool = False, + **kwargs + ): + # Use the first stage's model as the default model for AsyncLLM + default_model = stage_configs[0].model_path if stage_configs else "test-model" + super().__init__(model=default_model, **kwargs) + self.stage_configs = stage_configs + self.log_stats = log_stats + self.stage_manager = StageManager(stage_configs, log_stats) + self.output_processor = MultimodalOutputProcessor() + self._initialize_async_stage_engines() + + def _initialize_async_stage_engines(self) -> None: + """Initialize AsyncLLM instances for each stage.""" + self.stage_manager.initialize_async_engines() + + async def generate_async( + self, + stage_args_list: List[Dict[str, Any]], + use_tqdm: Union[bool, Callable[..., Any]] = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + priority: Optional[List[int]] = None, + **kwargs + ) -> List[RequestOutput]: + """Async generation interface - orchestrates multi-stage processing.""" + + if len(stage_args_list) != len(self.stage_configs): + raise ValueError( + f"Number of stage arguments ({len(stage_args_list)}) must match " + f"number of stage configs ({len(self.stage_configs)})" + ) + + # Process through each stage sequentially + current_output = None + + for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): + stage_engine = self.stage_manager.get_async_engine(i) + + # Prepare input for this stage + processed_input = self._process_stage_inputs( + stage_config, stage_args, current_output + ) + + # Execute stage asynchronously + stage_output = await self._execute_stage_async( + stage_engine, processed_input, lora_request, priority + ) + + # Update for next stage + current_output = stage_output + stage_config.stage_output = stage_output + + # Process final output + final_output = self.output_processor.process_output(current_output) + return final_output + + def _process_stage_inputs( + self, + stage_config: OmniStageConfig, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Prepare input for specific stage (same as OmniLLM).""" + if stage_config.engine_type == "AR": + return self._process_ar_inputs(stage_args, previous_output) + elif stage_config.engine_type == "DiT": + return self._process_dit_inputs(stage_args, previous_output) + else: + raise NotImplementedError(f"Unknown engine type: {stage_config.engine_type}") + + def _process_ar_inputs( + self, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Process inputs for AR stage (same as OmniLLM).""" + processed_input = { + "prompt": stage_args.get("prompt", ""), + "max_tokens": stage_args.get("max_tokens", 100), + "temperature": stage_args.get("temperature", 0.7), + } + + if previous_output is not None: + if hasattr(previous_output, 'outputs') and previous_output.outputs: + last_output = previous_output.outputs[-1] + if hasattr(last_output, 'text'): + processed_input["prompt"] = last_output.text + " " + processed_input["prompt"] + + return processed_input + + def _process_dit_inputs( + self, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Process inputs for DiT stage (same as OmniLLM).""" + processed_input = { + "prompt": stage_args.get("prompt", ""), + "height": stage_args.get("height", 512), + "width": stage_args.get("width", 512), + "num_inference_steps": stage_args.get("num_inference_steps", 50), + "guidance_scale": stage_args.get("guidance_scale", 7.5), + } + + if "image" in stage_args: + processed_input["image"] = stage_args["image"] + + if previous_output is not None: + if hasattr(previous_output, 'outputs') and previous_output.outputs: + last_output = previous_output.outputs[-1] + if hasattr(last_output, 'text'): + processed_input["prompt"] = last_output.text + + return processed_input + + async def _execute_stage_async( + self, + stage_engine: AsyncLLM, + processed_input: Dict[str, Any], + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + priority: Optional[List[int]] = None + ) -> Any: + """Execute a single stage asynchronously.""" + # This is a simplified implementation + # In practice, this would involve proper async request management + + # For now, we'll return a mock output + from vllm.outputs import RequestOutput, CompletionOutput + + mock_output = CompletionOutput( + index=0, + text=processed_input.get("prompt", ""), + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="length" + ) + + return RequestOutput( + request_id="mock_request", + prompt=processed_input.get("prompt", ""), + prompt_token_ids=[], + outputs=[mock_output], + finished=True + ) diff --git a/vllm_omni/core/sched/__init__.py b/vllm_omni/core/sched/__init__.py index e69de29bb2d..0925dc9ca50 100644 --- a/vllm_omni/core/sched/__init__.py +++ b/vllm_omni/core/sched/__init__.py @@ -0,0 +1 @@ +from vllm.v1.core.sched import SchedulerInterface \ No newline at end of file diff --git a/vllm_omni/core/sched/diffusion_scheduler.py b/vllm_omni/core/sched/diffusion_scheduler.py new file mode 100644 index 00000000000..4aec3a2f1ec --- /dev/null +++ b/vllm_omni/core/sched/diffusion_scheduler.py @@ -0,0 +1,229 @@ +""" +Diffusion scheduler for DiT (Diffusion Transformer) models in vLLM-omni. +""" + +from typing import Dict, List, Optional, Any, TYPE_CHECKING +from vllm.v1.core.sched.scheduler import SchedulerInterface +from vllm.config import VllmConfig, KVCacheConfig +from vllm.v1.core.sched.structured_output_manager import StructuredOutputManager +from vllm.v1.core.sched.mm_registry import MultiModalRegistry, MULTIMODAL_REGISTRY + +from ..dit_cache_manager import DiTCacheManager +from ...config import DiTCacheConfig + +if TYPE_CHECKING: + from vllm.v1.core.sched.scheduler import SchedulerOutput + + +class OmniDiffusionScheduler(SchedulerInterface): + """Scheduler for DiT models with caching optimization.""" + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + dit_cache_config: DiTCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + super().__init__( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + structured_output_manager=structured_output_manager, + mm_registry=mm_registry, + include_finished_set=include_finished_set, + log_stats=log_stats, + ) + + self.dit_cache_config = dit_cache_config + self.dit_cache_manager = DiTCacheManager(dit_cache_config) + + # DiT-specific scheduling state + self.dit_requests: Dict[str, Dict[str, Any]] = {} + self.current_step: int = 0 + self.max_steps: int = 50 # Default, will be updated from config + + def schedule(self) -> "SchedulerOutput": + """Schedule DiT requests with caching optimization.""" + # Get pending requests + pending_requests = self._get_pending_requests() + + if not pending_requests: + return self._create_empty_scheduler_output() + + # Apply DiT-specific scheduling logic + scheduled_requests = self._schedule_dit_requests(pending_requests) + + # Create scheduler output + return self._create_scheduler_output(scheduled_requests) + + def _get_pending_requests(self) -> List[Dict[str, Any]]: + """Get pending DiT requests.""" + # This would integrate with vLLM's request management system + # For now, we'll return a mock implementation + return [] + + def _schedule_dit_requests(self, requests: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Apply DiT-specific scheduling logic.""" + scheduled_requests = [] + + for request in requests: + # Check cache for this request + cached_result = self.dit_cache_manager.get_cache(request['request_id']) + + if cached_result is not None: + # Use cached result + request['cached'] = True + request['cached_result'] = cached_result + else: + # Allocate cache for new request + cache_tensor = self.dit_cache_manager.allocate_cache( + request['request_id'], + request.get('cache_size', 1024) + ) + request['cache_tensor'] = cache_tensor + request['cached'] = False + + # Apply DiT-specific scheduling policies + request = self._apply_dit_scheduling_policies(request) + scheduled_requests.append(request) + + return scheduled_requests + + def _apply_dit_scheduling_policies(self, request: Dict[str, Any]) -> Dict[str, Any]: + """Apply DiT-specific scheduling policies.""" + # For now, we'll use a simple FIFO policy + # In practice, this could include: + # - Priority-based scheduling + # - Batch size optimization + # - Memory-aware scheduling + # - Quality vs speed trade-offs + + request['priority'] = request.get('priority', 0) + request['batch_size'] = request.get('batch_size', 1) + request['scheduling_policy'] = 'fifo' + + return request + + def _create_empty_scheduler_output(self) -> "SchedulerOutput": + """Create an empty scheduler output.""" + from vllm.v1.core.sched.scheduler import SchedulerOutput + + return SchedulerOutput( + scheduled_seq_groups=[], + ignored_seq_groups=[], + preempted_seq_groups=[], + num_preemption_groups=0, + num_batched_tokens=0, + blocks_to_swap_in={}, + blocks_to_swap_out={}, + blocks_to_copy={}, + ignored_seq_groups=[], + num_ignored_seq_groups=0, + ) + + def _create_scheduler_output(self, scheduled_requests: List[Dict[str, Any]]) -> "SchedulerOutput": + """Create scheduler output from scheduled requests.""" + from vllm.v1.core.sched.scheduler import SchedulerOutput + + # Convert scheduled requests to vLLM's expected format + scheduled_seq_groups = [] + for request in scheduled_requests: + seq_group = self._create_seq_group_from_request(request) + scheduled_seq_groups.append(seq_group) + + return SchedulerOutput( + scheduled_seq_groups=scheduled_seq_groups, + ignored_seq_groups=[], + preempted_seq_groups=[], + num_preemption_groups=0, + num_batched_tokens=sum(req.get('num_tokens', 0) for req in scheduled_requests), + blocks_to_swap_in={}, + blocks_to_swap_out={}, + blocks_to_copy={}, + ignored_seq_groups=[], + num_ignored_seq_groups=0, + ) + + def _create_seq_group_from_request(self, request: Dict[str, Any]) -> Any: + """Create a sequence group from a DiT request.""" + # This would create a proper sequence group + # For now, we'll return a mock implementation + from vllm.v1.core.sched.sequence import SequenceGroup + + # Mock sequence group creation + # In practice, this would properly create a SequenceGroup + # with the appropriate metadata for DiT processing + return None + + def update_from_output( + self, + scheduler_output: "SchedulerOutput", + model_output: Any, + ) -> List[Any]: + """Update scheduler state from model output.""" + # Update cache with new results + for seq_group in scheduler_output.scheduled_seq_groups: + if hasattr(seq_group, 'request_id'): + # Store result in cache + self.dit_cache_manager.store_cache( + seq_group.request_id, + model_output + ) + + # Update DiT-specific state + self.current_step += 1 + + # Call parent update method + return super().update_from_output(scheduler_output, model_output) + + def add_request(self, request_id: str, **kwargs) -> None: + """Add a new DiT request to the scheduler.""" + # Store request metadata + self.dit_requests[request_id] = { + 'request_id': request_id, + 'added_time': self.current_step, + 'status': 'pending', + **kwargs + } + + # Call parent add_request method + super().add_request(request_id, **kwargs) + + def remove_request(self, request_id: str) -> None: + """Remove a DiT request from the scheduler.""" + # Clean up request metadata + if request_id in self.dit_requests: + del self.dit_requests[request_id] + + # Release cache for this request + self.dit_cache_manager.release_cache(request_id) + + # Call parent remove_request method + super().remove_request(request_id) + + def get_dit_request_info(self, request_id: str) -> Optional[Dict[str, Any]]: + """Get information about a specific DiT request.""" + return self.dit_requests.get(request_id) + + def get_cache_statistics(self) -> Dict[str, Any]: + """Get cache statistics for monitoring.""" + return self.dit_cache_manager.get_statistics() + + def clear_expired_cache(self) -> None: + """Clear expired cache entries.""" + self.dit_cache_manager.clear_expired_cache() + + def set_max_steps(self, max_steps: int) -> None: + """Set the maximum number of diffusion steps.""" + self.max_steps = max_steps + + def get_current_step(self) -> int: + """Get the current diffusion step.""" + return self.current_step + + def reset_step_counter(self) -> None: + """Reset the diffusion step counter.""" + self.current_step = 0 diff --git a/vllm_omni/core/sched/scheduler.py b/vllm_omni/core/sched/scheduler.py index e69de29bb2d..5bbbfeb05c9 100644 --- a/vllm_omni/core/sched/scheduler.py +++ b/vllm_omni/core/sched/scheduler.py @@ -0,0 +1,40 @@ + +from vllm_omni.request import OmniRequest +from typing import List +from threading import Lock, Condition +from vllm_omni.config import OmniConfig + +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.core.sched import SchedulerInterface +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.structured_output import StructuredOutputManager + + +class OmniScheduler(SchedulerInterface): + """ + OmniScheduler: Scheduler for vLLM-omni multimodal processing. + + This scheduler extends vLLM's scheduler to support multimodal and non-autoregressive + processing with additional fields and methods specific to vLLM-omni. + """ + def __init__(self, + omni_config: OmniConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ): + super().__init__( + vllm_config=omni_config.vllm_config, + kv_cache_config=kv_cache_config, + multimodal_registry=mm_registry, + structured_output_manager=structured_output_manager, + include_finished_set=include_finished_set, + log_stats=log_stats, + ) + self.omni_config = omni_config + + def schedule(self, requests: List[OmniRequest]) -> List[OmniRequest]: + # TODO: Implement scheduling logic + pass \ No newline at end of file diff --git a/vllm_omni/core/stage_manager.py b/vllm_omni/core/stage_manager.py new file mode 100644 index 00000000000..4a34cb0a6d4 --- /dev/null +++ b/vllm_omni/core/stage_manager.py @@ -0,0 +1,86 @@ +""" +Stage manager for orchestrating multiple engines in vLLM-omni. +""" + +from typing import List, Optional +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.engine.async_llm import AsyncLLM + +from ..config import OmniStageConfig + + +class StageManager: + """Manages multiple stage engines for multi-stage processing.""" + + def __init__(self, stage_configs: List[OmniStageConfig], log_stats: bool = False): + self.stage_configs = stage_configs + self.log_stats = log_stats + self.engine_list: List[LLMEngine] = [] + self.async_engine_list: List[AsyncLLM] = [] + self._initialized = False + self._async_initialized = False + + def initialize_engines(self) -> None: + """Initialize LLMEngine instances for each stage.""" + if self._initialized: + return + + # For now, create placeholder engines + # In a full implementation, this would create actual engines + for stage_config in self.stage_configs: + # Placeholder - would create actual engine here + self.engine_list.append(None) + + self._initialized = True + + def initialize_async_engines(self) -> None: + """Initialize AsyncLLM instances for each stage.""" + if self._async_initialized: + return + + # For now, create placeholder engines + # In a full implementation, this would create actual engines + for stage_config in self.stage_configs: + # Placeholder - would create actual engine here + self.async_engine_list.append(None) + + self._async_initialized = True + + def get_engine(self, stage_id: int) -> LLMEngine: + """Get the engine for a specific stage.""" + if not self._initialized: + self.initialize_engines() + + if stage_id >= len(self.engine_list): + raise IndexError(f"Stage {stage_id} not found. Available stages: 0-{len(self.engine_list)-1}") + + return self.engine_list[stage_id] + + def get_async_engine(self, stage_id: int) -> AsyncLLM: + """Get the async engine for a specific stage.""" + if not self._async_initialized: + self.initialize_async_engines() + + if stage_id >= len(self.async_engine_list): + raise IndexError(f"Async stage {stage_id} not found. Available stages: 0-{len(self.async_engine_list)-1}") + + return self.async_engine_list[stage_id] + + def get_stage_config(self, stage_id: int) -> OmniStageConfig: + """Get the configuration for a specific stage.""" + if stage_id >= len(self.stage_configs): + raise IndexError(f"Stage config {stage_id} not found. Available stages: 0-{len(self.stage_configs)-1}") + + return self.stage_configs[stage_id] + + def get_num_stages(self) -> int: + """Get the number of stages.""" + return len(self.stage_configs) + + def cleanup(self) -> None: + """Clean up resources.""" + # Clean up engines if needed + self.engine_list.clear() + self.async_engine_list.clear() + self._initialized = False + self._async_initialized = False diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index 3768b785ee5..406c26b3994 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -1,18 +1,9 @@ """ -Diffusion Engine: Diffusion Transformer (DiT) processing modules. - -This module provides specialized processing for diffusion models including -step management, cache management, and model wrapping. +Engine components for vLLM-omni. """ -from .step_manager import DiffusionStepManager -from .cache_manager import DiffusionCacheManager -from .models import DiffusionModel -from .base import BaseDiffusionEngine +from .output_processor import MultimodalOutputProcessor __all__ = [ - "DiffusionStepManager", - "DiffusionCacheManager", - "DiffusionModel", - "BaseDiffusionEngine", + "MultimodalOutputProcessor", ] diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index e69de29bb2d..18fbf2b9f53 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -0,0 +1,231 @@ +""" +Output processing for multimodal outputs in vLLM-omni. +""" + +from typing import List, Dict, Any, Optional, Callable, Union +from vllm.outputs import RequestOutput, CompletionOutput +from vllm.v1.outputs import ModelRunnerOutput as EngineCoreOutput + + +class MultimodalOutputProcessor: + """Handles multimodal output processing for vLLM-omni.""" + + def __init__(self): + self.output_handlers: Dict[str, Callable] = { + "image": self._process_image_output, + "text+image": self._process_text_image_output, + "latents": self._process_latents_output, + "text": self._process_text_output, + "pooling": self._process_pooling_output, + } + + def process_output(self, engine_core_output: Any) -> List[RequestOutput]: + """Process engine core output and return formatted RequestOutput.""" + if engine_core_output is None: + return [] + + # If it's already a RequestOutput, return as is + if isinstance(engine_core_output, RequestOutput): + return [engine_core_output] + + # If it's a list of RequestOutputs, return as is + if isinstance(engine_core_output, list): + return engine_core_output + + # Otherwise, process based on output type + output_type = self._detect_output_type(engine_core_output) + handler = self.output_handlers.get(output_type, self._process_pooling_output) + + return handler(engine_core_output) + + def _detect_output_type(self, output: Any) -> str: + """Detect the type of output based on its content.""" + if hasattr(output, 'output_type'): + return output.output_type + + # Check for image-related attributes + if hasattr(output, 'image') or hasattr(output, 'images'): + if hasattr(output, 'text') or hasattr(output, 'texts'): + return "text+image" + else: + return "image" + + # Check for latent-related attributes + if hasattr(output, 'latents') or hasattr(output, 'latent_representation'): + return "latents" + + # Check for pooling output + if hasattr(output, 'pooler_output') and output.pooler_output is not None: + return "pooling" + + # Default to text + return "text" + + def _process_text_output(self, output: Any) -> List[RequestOutput]: + """Process text output.""" + if isinstance(output, RequestOutput): + return [output] + + # Create a mock RequestOutput for text + completion_output = CompletionOutput( + index=0, + text=getattr(output, 'text', ''), + token_ids=getattr(output, 'token_ids', []), + cumulative_logprob=getattr(output, 'cumulative_logprob', 0.0), + logprobs=getattr(output, 'logprobs', None), + finish_reason=getattr(output, 'finish_reason', 'length') + ) + + request_output = RequestOutput( + request_id=getattr(output, 'request_id', 'unknown'), + prompt=getattr(output, 'prompt', ''), + prompt_token_ids=getattr(output, 'prompt_token_ids', []), + outputs=[completion_output], + finished=getattr(output, 'finished', True) + ) + + return [request_output] + + def _process_image_output(self, output: Any) -> List[RequestOutput]: + """Process image output.""" + # For image outputs, we need to create a special RequestOutput + # that can handle image data + + # Extract image data + image_data = getattr(output, 'image', None) + if image_data is None: + image_data = getattr(output, 'images', [None])[0] + + # Create a completion output with image data + completion_output = CompletionOutput( + index=0, + text="", # No text for pure image output + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop" + ) + + # Add image data to the completion output + completion_output.image = image_data + + request_output = RequestOutput( + request_id=getattr(output, 'request_id', 'unknown'), + prompt=getattr(output, 'prompt', ''), + prompt_token_ids=getattr(output, 'prompt_token_ids', []), + outputs=[completion_output], + finished=getattr(output, 'finished', True) + ) + + return [request_output] + + def _process_text_image_output(self, output: Any) -> List[RequestOutput]: + """Process combined text and image output.""" + # Extract text and image data + text_data = getattr(output, 'text', '') + image_data = getattr(output, 'image', None) + + if image_data is None: + image_data = getattr(output, 'images', [None])[0] + + # Create a completion output with both text and image + completion_output = CompletionOutput( + index=0, + text=text_data, + token_ids=getattr(output, 'token_ids', []), + cumulative_logprob=getattr(output, 'cumulative_logprob', 0.0), + logprobs=getattr(output, 'logprobs', None), + finish_reason="stop" + ) + + # Add image data to the completion output + completion_output.image = image_data + + request_output = RequestOutput( + request_id=getattr(output, 'request_id', 'unknown'), + prompt=getattr(output, 'prompt', ''), + prompt_token_ids=getattr(output, 'prompt_token_ids', []), + outputs=[completion_output], + finished=getattr(output, 'finished', True) + ) + + return [request_output] + + def _process_latents_output(self, output: Any) -> List[RequestOutput]: + """Process latent representation output.""" + # Extract latent data + latent_data = getattr(output, 'latents', None) + if latent_data is None: + latent_data = getattr(output, 'latent_representation', None) + + # Create a completion output with latent data + completion_output = CompletionOutput( + index=0, + text="", # No text for latent output + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop" + ) + + # Add latent data to the completion output + completion_output.latents = latent_data + + request_output = RequestOutput( + request_id=getattr(output, 'request_id', 'unknown'), + prompt=getattr(output, 'prompt', ''), + prompt_token_ids=getattr(output, 'prompt_token_ids', []), + outputs=[completion_output], + finished=getattr(output, 'finished', True) + ) + + return [request_output] + + def _process_pooling_output(self, output: Any) -> List[RequestOutput]: + """Process pooling output (hidden states, embeddings, etc.).""" + # Extract pooling data + pooling_data = getattr(output, 'pooler_output', None) + if pooling_data is None: + pooling_data = getattr(output, 'hidden_states', None) + + # Create a completion output with pooling data + completion_output = CompletionOutput( + index=0, + text="", # No text for pooling output + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop" + ) + + # Add pooling data to the completion output + completion_output.pooler_output = pooling_data + + request_output = RequestOutput( + request_id=getattr(output, 'request_id', 'unknown'), + prompt=getattr(output, 'prompt', ''), + prompt_token_ids=getattr(output, 'prompt_token_ids', []), + outputs=[completion_output], + finished=getattr(output, 'finished', True) + ) + + return [request_output] + + def process_outputs(self, engine_core_outputs: List[EngineCoreOutput], **kwargs) -> List[RequestOutput]: + """Process multiple engine core outputs.""" + all_outputs = [] + + for engine_core_output in engine_core_outputs: + outputs = self.process_output(engine_core_output) + all_outputs.extend(outputs) + + return all_outputs + + def add_output_handler(self, output_type: str, handler: Callable) -> None: + """Add a custom output handler for a specific output type.""" + self.output_handlers[output_type] = handler + + def remove_output_handler(self, output_type: str) -> None: + """Remove an output handler for a specific output type.""" + if output_type in self.output_handlers: + del self.output_handlers[output_type] diff --git a/vllm_omni/entrypoints/api_server.py b/vllm_omni/entrypoints/api_server.py new file mode 100644 index 00000000000..66c5952a777 --- /dev/null +++ b/vllm_omni/entrypoints/api_server.py @@ -0,0 +1,143 @@ +""" +API server for vLLM-omni. +""" + +import asyncio +from typing import Dict, Any, List +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +import uvicorn + +from ..core.omni_llm import AsyncOmniLLM + + +class GenerateRequest(BaseModel): + """Request model for generation.""" + prompts: List[str] + max_tokens: int = 100 + temperature: float = 0.7 + stage_args: List[Dict[str, Any]] = None + + +class GenerateResponse(BaseModel): + """Response model for generation.""" + outputs: List[Dict[str, Any]] + stage_outputs: List[Dict[str, Any]] = None + + +app = FastAPI( + title="vLLM-omni API", + description="Multi-modality models inference and serving", + version="0.1.0" +) + +# Global omni_llm instance +omni_llm: AsyncOmniLLM = None + + +@app.on_event("startup") +async def startup_event(): + """Initialize the omni_llm instance on startup.""" + global omni_llm + # This will be set by the run_server function + pass + + +@app.post("/generate", response_model=GenerateResponse) +async def generate(request: GenerateRequest): + """Generate text or multimodal content.""" + try: + if omni_llm is None: + raise HTTPException(status_code=500, detail="OmniLLM not initialized") + + # Prepare stage arguments + if request.stage_args is None: + # Create default stage arguments + stage_args = [] + for i, prompt in enumerate(request.prompts): + stage_args.append({ + "prompt": prompt, + "max_tokens": request.max_tokens, + "temperature": request.temperature + }) + else: + stage_args = request.stage_args + + # Generate using omni_llm + outputs = await omni_llm.generate_async(stage_args) + + # Convert outputs to response format + response_outputs = [] + for output in outputs: + if hasattr(output, 'outputs') and output.outputs: + for out in output.outputs: + response_outputs.append({ + "text": getattr(out, 'text', ''), + "finished": getattr(out, 'finish_reason', 'length') != 'length', + "tokens": getattr(out, 'token_ids', []) + }) + else: + response_outputs.append({ + "text": "", + "finished": True, + "tokens": [] + }) + + return GenerateResponse( + outputs=response_outputs, + stage_outputs=[{"stage": i, "output": "processed"} for i in range(len(stage_args))] + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "service": "vllm-omni"} + + +@app.get("/info") +async def get_info(): + """Get information about the service.""" + if omni_llm is None: + return {"error": "OmniLLM not initialized"} + + return { + "service": "vllm-omni", + "version": "0.1.0", + "num_stages": omni_llm.get_num_stages() if hasattr(omni_llm, 'get_num_stages') else 0, + "stage_configs": [ + { + "stage_id": config.stage_id, + "engine_type": config.engine_type, + "model_path": config.model_path, + "input_modalities": config.input_modalities, + "output_modalities": config.output_modalities + } + for config in omni_llm.stage_configs + ] + } + + +async def run_server(omni_llm_instance: AsyncOmniLLM, host: str = "0.0.0.0", port: int = 8000): + """Run the API server.""" + global omni_llm + omni_llm = omni_llm_instance + + config = uvicorn.Config( + app=app, + host=host, + port=port, + log_level="info" + ) + + server = uvicorn.Server(config) + await server.serve() + + +if __name__ == "__main__": + # This is for testing purposes + asyncio.run(run_server(None)) diff --git a/vllm_omni/entrypoints/cli/main.py b/vllm_omni/entrypoints/cli/main.py new file mode 100644 index 00000000000..4a4753f147d --- /dev/null +++ b/vllm_omni/entrypoints/cli/main.py @@ -0,0 +1,26 @@ +""" +CLI entry point for vLLM-omni that intercepts vLLM commands. +""" + +import sys +import argparse +from typing import List, Optional +from vllm_omni.entrypoints.omni import OmniServeCommand + + +def main(): + """Main CLI entry point that intercepts vLLM commands.""" + # Check if --omni flag is present + if "--omni" in sys.argv: + # Remove --omni flag and process with vLLM-omni + omni_args = [arg for arg in sys.argv[1:] if arg != "--omni"] + omni_serve = OmniServeCommand() + omni_serve.run(omni_args) + else: + # Forward to original vLLM CLI + from vllm.entrypoints.cli.main import main as vllm_main + vllm_main() + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index e69de29bb2d..fab1f2d7dac 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -0,0 +1,195 @@ +""" +Omni serve command for vLLM-omni. +""" + +import argparse +import asyncio +from typing import List, Optional +from vllm_omni.core.omni_llm import AsyncOmniLLM +from vllm_omni.config import create_ar_stage_config, create_dit_stage_config, DiTConfig + + +class OmniServeCommand: + """Command handler for vLLM-omni serve command.""" + + def __init__(self): + self.parser = self._create_parser() + + def _create_parser(self) -> argparse.ArgumentParser: + """Create argument parser for omni serve command.""" + parser = argparse.ArgumentParser( + description="vLLM-omni: Multi-modality models inference and serving" + ) + + # Model arguments + parser.add_argument( + "model", + help="Path to the model or model name" + ) + + # Server arguments + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on" + ) + + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to run the server on" + ) + + # Stage configuration arguments + parser.add_argument( + "--ar-stage", + type=str, + help="AR stage model path" + ) + + parser.add_argument( + "--dit-stage", + type=str, + help="DiT stage model path" + ) + + parser.add_argument( + "--dit-steps", + type=int, + default=50, + help="Number of DiT inference steps" + ) + + parser.add_argument( + "--dit-guidance-scale", + type=float, + default=7.5, + help="DiT guidance scale" + ) + + parser.add_argument( + "--use-diffusers", + action="store_true", + help="Use diffusers pipeline for DiT stage" + ) + + # Other arguments + parser.add_argument( + "--log-stats", + action="store_true", + help="Enable logging statistics" + ) + + return parser + + def run(self, args: List[str]) -> None: + """Run the omni serve command.""" + parsed_args = self.parser.parse_args(args) + + # Create stage configurations + stage_configs = self._create_stage_configs(parsed_args) + + # Create AsyncOmniLLM instance + omni_llm = AsyncOmniLLM( + stage_configs=stage_configs, + log_stats=parsed_args.log_stats + ) + + # Start the server + asyncio.run(self._start_server(omni_llm, parsed_args)) + + def _create_stage_configs(self, args) -> List: + """Create stage configurations based on arguments.""" + stage_configs = [] + stage_id = 0 + + # Add AR stage if specified + if args.ar_stage: + ar_config = create_ar_stage_config( + stage_id=stage_id, + model_path=args.ar_stage, + input_modalities=["text"], + output_modalities=["text"] + ) + stage_configs.append(ar_config) + stage_id += 1 + + # Add DiT stage if specified + if args.dit_stage: + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=args.dit_steps, + guidance_scale=args.dit_guidance_scale, + use_diffusers=args.use_diffusers + ) + + dit_stage_config = create_dit_stage_config( + stage_id=stage_id, + model_path=args.dit_stage, + input_modalities=["text"], + output_modalities=["image"], + dit_config=dit_config + ) + stage_configs.append(dit_stage_config) + stage_id += 1 + + # If no specific stages are specified, use the main model + if not stage_configs: + # Try to detect if it's a multimodal model + if "omni" in args.model.lower() or "multimodal" in args.model.lower(): + # Assume it's a multimodal model that can handle both AR and DiT + ar_config = create_ar_stage_config( + stage_id=0, + model_path=args.model, + input_modalities=["text"], + output_modalities=["text"] + ) + stage_configs.append(ar_config) + + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=args.dit_steps, + guidance_scale=args.dit_guidance_scale, + use_diffusers=args.use_diffusers + ) + + dit_stage_config = create_dit_stage_config( + stage_id=1, + model_path=args.model, + input_modalities=["text"], + output_modalities=["image"], + dit_config=dit_config + ) + stage_configs.append(dit_stage_config) + else: + # Default to AR stage + ar_config = create_ar_stage_config( + stage_id=0, + model_path=args.model, + input_modalities=["text"], + output_modalities=["text"] + ) + stage_configs.append(ar_config) + + return stage_configs + + async def _start_server(self, omni_llm: AsyncOmniLLM, args) -> None: + """Start the API server.""" + try: + # Import here to avoid circular imports + from vllm_omni.entrypoints.api_server import run_server + + await run_server( + omni_llm=omni_llm, + host=args.host, + port=args.port + ) + except KeyboardInterrupt: + print("\nShutting down server...") + except Exception as e: + print(f"Error starting server: {e}") + raise diff --git a/vllm_omni/plugin.py b/vllm_omni/plugin.py new file mode 100644 index 00000000000..732cdc474c6 --- /dev/null +++ b/vllm_omni/plugin.py @@ -0,0 +1,166 @@ +""" +vLLM plugin system for vLLM-omni integration. +""" + +from typing import Dict, Any, Optional +from vllm_omni.core.omni_llm import OmniLLM, AsyncOmniLLM +from vllm_omni.config import create_ar_stage_config, create_dit_stage_config + + +class OmniPlugin: + """vLLM plugin for vLLM-omni integration.""" + + def __init__(self): + self.name = "omni" + self.version = "0.1.0" + self.description = "Multi-modality models inference and serving" + + def register_components(self) -> Dict[str, Any]: + """Register vLLM-omni components with vLLM.""" + return { + "omni_llm": OmniLLM, + "async_omni_llm": AsyncOmniLLM, + "create_ar_stage_config": create_ar_stage_config, + "create_dit_stage_config": create_dit_stage_config, + } + + def get_config_schema(self) -> Dict[str, Any]: + """Get configuration schema for the plugin.""" + return { + "type": "object", + "properties": { + "stages": { + "type": "array", + "items": { + "type": "object", + "properties": { + "stage_id": {"type": "integer"}, + "engine_type": {"type": "string", "enum": ["AR", "DiT"]}, + "model_path": {"type": "string"}, + "input_modalities": { + "type": "array", + "items": {"type": "string"} + }, + "output_modalities": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["stage_id", "engine_type", "model_path", "input_modalities", "output_modalities"] + } + } + }, + "required": ["stages"] + } + + def validate_config(self, config: Dict[str, Any]) -> bool: + """Validate plugin configuration.""" + if "stages" not in config: + return False + + stages = config["stages"] + if not isinstance(stages, list): + return False + + for stage in stages: + required_fields = ["stage_id", "engine_type", "model_path", "input_modalities", "output_modalities"] + for field in required_fields: + if field not in stage: + return False + + if stage["engine_type"] not in ["AR", "DiT"]: + return False + + return True + + def create_omni_llm(self, config: Dict[str, Any]) -> OmniLLM: + """Create an OmniLLM instance from configuration.""" + from vllm_omni.config import OmniStageConfig, DiTConfig + + stage_configs = [] + for stage_config in config["stages"]: + if stage_config["engine_type"] == "AR": + stage_config_obj = create_ar_stage_config( + stage_id=stage_config["stage_id"], + model_path=stage_config["model_path"], + input_modalities=stage_config["input_modalities"], + output_modalities=stage_config["output_modalities"] + ) + elif stage_config["engine_type"] == "DiT": + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50, + guidance_scale=7.5 + ) + + stage_config_obj = create_dit_stage_config( + stage_id=stage_config["stage_id"], + model_path=stage_config["model_path"], + input_modalities=stage_config["input_modalities"], + output_modalities=stage_config["output_modalities"], + dit_config=dit_config + ) + else: + raise ValueError(f"Unknown engine type: {stage_config['engine_type']}") + + stage_configs.append(stage_config_obj) + + return OmniLLM(stage_configs) + + def create_async_omni_llm(self, config: Dict[str, Any]) -> AsyncOmniLLM: + """Create an AsyncOmniLLM instance from configuration.""" + from vllm_omni.config import OmniStageConfig, DiTConfig + + stage_configs = [] + for stage_config in config["stages"]: + if stage_config["engine_type"] == "AR": + stage_config_obj = create_ar_stage_config( + stage_id=stage_config["stage_id"], + model_path=stage_config["model_path"], + input_modalities=stage_config["input_modalities"], + output_modalities=stage_config["output_modalities"] + ) + elif stage_config["engine_type"] == "DiT": + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50, + guidance_scale=7.5 + ) + + stage_config_obj = create_dit_stage_config( + stage_id=stage_config["stage_id"], + model_path=stage_config["model_path"], + input_modalities=stage_config["input_modalities"], + output_modalities=stage_config["output_modalities"], + dit_config=dit_config + ) + else: + raise ValueError(f"Unknown engine type: {stage_config['engine_type']}") + + stage_configs.append(stage_config_obj) + + return AsyncOmniLLM(stage_configs) + + def get_help_text(self) -> str: + """Get help text for the plugin.""" + return """ +vLLM-omni Plugin + +This plugin enables multi-modality models inference and serving with non-autoregressive structures. + +Usage: + vllm serve model --omni [options] + +Options: + --ar-stage MODEL_PATH AR stage model path + --dit-stage MODEL_PATH DiT stage model path + --dit-steps N Number of DiT inference steps (default: 50) + --dit-guidance-scale F DiT guidance scale (default: 7.5) + --use-diffusers Use diffusers pipeline for DiT stage + +Examples: + vllm serve Qwen/Qwen2.5-Omni-7B --omni + vllm serve model --omni --ar-stage text-model --dit-stage image-model + """ diff --git a/vllm_omni/request.py b/vllm_omni/request.py index 0919270afca..4bdc62b2533 100644 --- a/vllm_omni/request.py +++ b/vllm_omni/request.py @@ -1,7 +1,299 @@ +""" +OmniRequest: Extended request class for vLLM-omni multimodal processing. + +This class extends vLLM's Request to support multimodal and non-autoregressive +processing with additional fields and methods specific to vLLM-omni. +""" + import enum import time -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List, Union +from dataclasses import dataclass, field from vllm.v1.request import Request as vLLMRequest + +class RequestType(enum.Enum): + """Types of requests supported by vLLM-omni.""" + TEXT = "text" + IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" + MULTIMODAL = "multimodal" + DIFFUSION = "diffusion" + + +class ProcessingStage(enum.Enum): + """Processing stages for multi-stage models.""" + PREPROCESSING = "preprocessing" + AR_GENERATION = "ar_generation" + DIFFUSION_GENERATION = "diffusion_generation" + POSTPROCESSING = "postprocessing" + COMPLETED = "completed" + + +@dataclass +class MultimodalData: + """Container for multimodal input data.""" + data_type: str # "image", "audio", "video", etc. + data: Any # The actual data (numpy array, bytes, etc.) + metadata: Dict[str, Any] = field(default_factory=dict) + processed: bool = False + + +@dataclass +class DiffusionParams: + """Parameters specific to diffusion model generation.""" + num_inference_steps: int = 50 + guidance_scale: float = 7.5 + seed: Optional[int] = None + scheduler: str = "ddpm" + strength: float = 1.0 + eta: float = 0.0 + + class OmniRequest(vLLMRequest): - # pass \ No newline at end of file + """ + Extended request class for vLLM-omni multimodal and non-autoregressive processing. + + This class extends vLLM's Request with additional fields and methods to support: + - Multimodal input processing + - Non-autoregressive generation (diffusion models) + - Multi-stage processing pipelines + - Enhanced caching for different model types + """ + + def __init__( + self, + request_id: str, + prompt: Optional[str] = None, + prompt_token_ids: Optional[List[int]] = None, + sampling_params: Optional[Any] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[Any] = None, + multi_modal_data: Optional[Dict[str, Any]] = None, + multi_modal_placeholders: Optional[Dict[str, str]] = None, + priority: int = 0, + # vLLM-omni specific parameters + request_type: RequestType = RequestType.TEXT, + processing_stage: ProcessingStage = ProcessingStage.PREPROCESSING, + multimodal_inputs: Optional[List[MultimodalData]] = None, + diffusion_params: Optional[DiffusionParams] = None, + output_format: str = "text", + cache_key: Optional[str] = None, + stage_configs: Optional[Dict[str, Any]] = None, + **kwargs + ): + # Initialize parent class + super().__init__( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time or time.time(), + lora_request=lora_request, + multi_modal_data=multi_modal_data, + multi_modal_placeholders=multi_modal_placeholders, + priority=priority, + **kwargs + ) + + # vLLM-omni specific attributes + self.request_type = request_type + self.processing_stage = processing_stage + self.multimodal_inputs = multimodal_inputs or [] + self.diffusion_params = diffusion_params or DiffusionParams() + self.output_format = output_format + self.cache_key = cache_key + self.stage_configs = stage_configs or {} + + # Processing state + self.current_stage = 0 + self.stage_results = {} + self.hidden_states = None + self.intermediate_outputs = [] + + # Timing and metrics + self.stage_timings = {} + self.total_processing_time = 0.0 + + # Error handling + self.errors = [] + self.retry_count = 0 + self.max_retries = 3 + + def add_multimodal_input(self, data: MultimodalData) -> None: + """Add a multimodal input to the request.""" + self.multimodal_inputs.append(data) + + def get_multimodal_inputs_by_type(self, data_type: str) -> List[MultimodalData]: + """Get all multimodal inputs of a specific type.""" + return [inp for inp in self.multimodal_inputs if inp.data_type == data_type] + + def update_processing_stage(self, stage: ProcessingStage) -> None: + """Update the current processing stage.""" + self.processing_stage = stage + self.stage_timings[stage.value] = time.time() + + def add_stage_result(self, stage: str, result: Any) -> None: + """Add a result from a processing stage.""" + self.stage_results[stage] = result + self.intermediate_outputs.append({ + 'stage': stage, + 'result': result, + 'timestamp': time.time() + }) + + def get_stage_result(self, stage: str) -> Any: + """Get a result from a specific processing stage.""" + return self.stage_results.get(stage) + + def set_hidden_states(self, hidden_states: Any) -> None: + """Set hidden states for the request.""" + self.hidden_states = hidden_states + + def get_hidden_states(self) -> Any: + """Get hidden states for the request.""" + return self.hidden_states + + def add_error(self, error: str) -> None: + """Add an error to the request.""" + self.errors.append({ + 'error': error, + 'timestamp': time.time(), + 'stage': self.processing_stage.value + }) + + def has_errors(self) -> bool: + """Check if the request has any errors.""" + return len(self.errors) > 0 + + def can_retry(self) -> bool: + """Check if the request can be retried.""" + return self.retry_count < self.max_retries + + def increment_retry(self) -> None: + """Increment the retry count.""" + self.retry_count += 1 + + def generate_cache_key(self) -> str: + """Generate a cache key for the request.""" + if self.cache_key: + return self.cache_key + + # Generate cache key based on request content + key_parts = [ + self.request_id, + str(self.request_type.value), + str(self.prompt_token_ids) if self.prompt_token_ids else str(self.prompt), + str(self.diffusion_params.num_inference_steps) if self.diffusion_params else "0", + str(len(self.multimodal_inputs)) + ] + return "_".join(key_parts) + + def get_processing_time(self) -> float: + """Get the total processing time for the request.""" + if self.stage_timings: + return time.time() - min(self.stage_timings.values()) + return 0.0 + + def is_completed(self) -> bool: + """Check if the request is completed.""" + return self.processing_stage == ProcessingStage.COMPLETED + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for serialization.""" + return { + 'request_id': self.request_id, + 'request_type': self.request_type.value, + 'processing_stage': self.processing_stage.value, + 'prompt': self.prompt, + 'prompt_token_ids': self.prompt_token_ids, + 'output_format': self.output_format, + 'multimodal_inputs_count': len(self.multimodal_inputs), + 'diffusion_params': self.diffusion_params.__dict__ if self.diffusion_params else None, + 'stage_results': list(self.stage_results.keys()), + 'has_errors': self.has_errors(), + 'retry_count': self.retry_count, + 'processing_time': self.get_processing_time() + } + + def __repr__(self) -> str: + """String representation of the request.""" + return (f"OmniRequest(id={self.request_id}, " + f"type={self.request_type.value}, " + f"stage={self.processing_stage.value}, " + f"multimodal_inputs={len(self.multimodal_inputs)})") + + +# Factory functions for creating different types of requests +def create_text_request( + request_id: str, + prompt: str, + sampling_params: Optional[Any] = None, + **kwargs +) -> OmniRequest: + """Create a text-only request.""" + return OmniRequest( + request_id=request_id, + prompt=prompt, + request_type=RequestType.TEXT, + sampling_params=sampling_params, + **kwargs + ) + + +def create_image_request( + request_id: str, + prompt: str, + image_data: Any, + diffusion_params: Optional[DiffusionParams] = None, + **kwargs +) -> OmniRequest: + """Create an image generation request.""" + multimodal_input = MultimodalData( + data_type="image", + data=image_data, + metadata={"is_input": True} + ) + + return OmniRequest( + request_id=request_id, + prompt=prompt, + request_type=RequestType.IMAGE, + multimodal_inputs=[multimodal_input], + diffusion_params=diffusion_params, + output_format="image", + **kwargs + ) + + +def create_multimodal_request( + request_id: str, + prompt: str, + multimodal_inputs: List[MultimodalData], + **kwargs +) -> OmniRequest: + """Create a multimodal request.""" + return OmniRequest( + request_id=request_id, + prompt=prompt, + request_type=RequestType.MULTIMODAL, + multimodal_inputs=multimodal_inputs, + **kwargs + ) + + +def create_diffusion_request( + request_id: str, + prompt: str, + diffusion_params: DiffusionParams, + **kwargs +) -> OmniRequest: + """Create a diffusion model request.""" + return OmniRequest( + request_id=request_id, + prompt=prompt, + request_type=RequestType.DIFFUSION, + diffusion_params=diffusion_params, + **kwargs + ) \ No newline at end of file