diff --git a/.gitignore b/.gitignore index d42c877939..5941d027c7 100644 --- a/.gitignore +++ b/.gitignore @@ -152,6 +152,18 @@ dmypy.json # Cython debug symbols cython_debug/ +# cursor +.cursor/ + +# docker +docker/ + +# scripts +scripts/ + +# tests +tests/ + # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be added to the global gitignore or merged into this project gitignore. For a PyCharm @@ -226,7 +238,6 @@ Dockerfile.dev # Kubernetes k8s/ -*.yaml *.yml !configs/*.yaml !configs/*.yml \ No newline at end of file diff --git a/README.md b/README.md index 75b1bcf23b..66c8725ee5 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,146 @@ # vLLM-omni: Multi-modal Extension for vLLM vLLM-omni is designed to extend vLLM capabilities to support multi-modality model inference and serving, particularly focusing on non-autoregressive structures and non-textual outputs. + +## 🎯 Overview + +Traditional vLLM systems are limited to text-based, autoregressive generation. vLLM-omni addresses this limitation by enabling support for: + +- **Multi-modal Models**: Text, image, video, audio, and sensor data processing +- **Non-autoregressive Architectures**: Diffusion Transformers (DiT) and other parallel generation models +- **Heterogeneous Outputs**: Beyond traditional text generation to structured, binary, and streaming outputs + +## 🏗️ Architecture + +vLLM-omni is built on a modular architecture that extends vLLM's core functionality: + + +## 🚀 Key Features + +### Multi-Engine Support + +- **Autoregressive Engine**: Traditional text generation with enhanced KV-caching +- **Diffusion Engine**: Support for DiT models and iterative generation +- **Hybrid Engine**: Combined AR+DiT processing pipelines + +### Modality Processing + +- **Text**: Advanced tokenization and embedding generation +- **Image**: Vision encoder integration (CLIP, etc.) +- **Audio**: Speech processing and audio embedding +- **Video**: Frame-by-frame and temporal processing +- **Sensor**: IoT and sensor data interpretation + +### Output Formats + +- **Structured Data**: JSON, XML, and custom formats +- **Binary Outputs**: Images, audio, and video generation +- **Streaming**: Real-time progressive generation +- **Multipart**: Combined multi-modal responses + +## 📋 Supported Models + +### AR + Diffusion Transformer (DiT) Models +- Qwen-Image (Image generation and editing) +- Qwen-omni (Thinker-Talker-Codec structure) +- Custom DiT and hiybrid architectures + +## 🛠️ Installation + +### Quick Start + +#### Option 1: Docker (Recommended for macOS) + +```bash +# Clone the repository +git clone https://github.com/hsliuustc0106/vllm-omni.git +cd vllm-omni + +# Run the automated Docker setup +./scripts/docker-setup-macos.sh +``` + +#### Option 2: Local Installation + +```bash +# Clone the repository +git clone https://github.com/hsliuustc0106/vllm-omni.git +cd vllm-omni + +# Run the installation script +./install.sh +``` + +### Prerequisites + +- Python 3.11+ (recommended) +- Conda or Miniconda +- Git +- CUDA 11.8+ (for GPU acceleration) or CPU-only installation + +### Installation Methods + +#### Method 1: Automated Installation (Recommended) +```bash +# Using shell script +./install.sh + +# Or using Python script +python install.py +``` + +#### Method 2: Manual Installation +```bash +# Create conda environment +conda create -n vllm_omni python=3.11 -y +conda activate vllm_omni + +# Install PyTorch (CPU or GPU) +pip install torch>=2.7 --index-url https://download.pytorch.org/whl/cpu # CPU +# pip install torch>=2.7 --index-url https://download.pytorch.org/whl/cu121 # GPU + +# Install dependencies +pip install -r requirements.txt +pip install "vllm>=0.10.2" + +# Install vLLM-omni +pip install -e . +``` + +### Verify Installation + +```bash +# Test the installation +python test_installation.py + +# Test basic functionality +python -c "import vllm_omni; print('Ready!')" + +# Test CLI +vllm --help +``` + +For detailed installation instructions, see [INSTALL.md](INSTALL.md). + +## 📥 Model Download + +Models are automatically downloaded when first used, or you can pre-download them: + +```bash +# Check downloaded models +python scripts/download_models.py --check-cache + +# Download all default models +python scripts/download_models.py --all + +# Download specific models +python scripts/download_models.py --ar-models Qwen/Qwen3-0.6B +python scripts/download_models.py --dit-models stabilityai/stable-diffusion-2-1 +``` + +**Model Storage Location:** +- Default: `~/.cache/huggingface/hub/` +- AR models: 100MB - 1GB each +- DiT models: 2GB - 7GB each + +For detailed model management, see [MODEL_DOWNLOAD_GUIDE.md](docs/MODEL_DOWNLOAD_GUIDE.md). diff --git a/docs/README.md b/docs/README.md deleted file mode 100644 index d52c6c0998..0000000000 --- a/docs/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# vLLM-omni Documentation - -Welcome to the vLLM-omni documentation! This documentation provides comprehensive information about using and developing with vLLM-omni. - -# TODO (add examples) \ No newline at end of file diff --git a/docs/api/API_DESIGN_TEMPLATE.md b/docs/api/API_DESIGN_TEMPLATE.md new file mode 100644 index 0000000000..eefaf9334c --- /dev/null +++ b/docs/api/API_DESIGN_TEMPLATE.md @@ -0,0 +1,336 @@ +# API Design Template for vLLM-omni Modules + +This template provides a standardized structure for designing APIs for core, engine, executor, and worker modules in vLLM-omni. Use this template to ensure consistency and completeness across all modules. + +## Template Structure + +### 1. Module Overview +- **Purpose**: What this module does +- **Responsibilities**: Key responsibilities of the module +- **Dependencies**: What other modules this depends on +- **Integration Points**: How it integrates with other modules + +### 2. Core Classes/Interfaces +- **Base Classes**: Abstract base classes and interfaces +- **Implementation Classes**: Concrete implementations +- **Data Structures**: Key data structures and models + +### 3. Public API Methods +- **Initialization**: Constructor and setup methods +- **Core Operations**: Main functionality methods +- **Configuration**: Configuration and parameter methods +- **Lifecycle Management**: Start, stop, cleanup methods +- **Monitoring**: Status, metrics, and debugging methods + +### 4. Configuration +- **Configuration Classes**: Dataclasses or config objects +- **Required Parameters**: Must-have configuration +- **Optional Parameters**: Optional configuration with defaults +- **Validation**: Parameter validation rules + +### 5. Error Handling +- **Custom Exceptions**: Module-specific exceptions +- **Error Codes**: Standardized error codes +- **Recovery Strategies**: How to handle and recover from errors + +### 6. Examples +- **Basic Usage**: Simple usage examples +- **Advanced Usage**: Complex scenarios +- **Integration Examples**: How to use with other modules + +--- + +## Example: Core Module Template + +### 1. Module Overview + +**Purpose**: The core module provides fundamental scheduling and caching functionality for vLLM-omni. + +**Responsibilities**: +- Request scheduling and prioritization +- DiT cache management +- Resource allocation and coordination +- Inter-module communication + +**Dependencies**: +- `vllm_omni.request` - Request handling +- `vllm_omni.config` - Configuration management +- `vllm_omni.utils` - Utility functions + +**Integration Points**: +- Receives requests from entrypoints +- Coordinates with engine modules +- Manages worker allocation +- Provides status to monitoring systems + +### 2. Core Classes/Interfaces + +```python +from abc import ABC, abstractmethod +from typing import List, Optional, Dict, Any +from dataclasses import dataclass +from enum import Enum + +class DiTCacheManager: + """Manages DiT cache for diffusion models.""" + + def __init__(self, config: DiTCacheConfig): + self.config = config + self.cache = {} + self.cache_stats = {} + + async def get_cache(self, cache_key: str) -> Optional[Any]: + """Retrieve cached data.""" + pass + + async def set_cache(self, cache_key: str, data: Any) -> None: + """Store data in cache.""" + pass + + async def invalidate_cache(self, cache_key: str) -> None: + """Invalidate cached data.""" + pass +``` + +### 3. Public API Methods + +#### Initialization +```python +class OmniScheduler(): + def __init__(self, config: SchedulerConfig): + """Initialize the core scheduler.""" + self.config = config + self.scheduler = self._create_scheduler() + self.cache_manager = DiTCacheManager(config.dit_cache_config) + self._running = False + + def _create_scheduler(self) -> BaseScheduler: + """Factory method to create appropriate scheduler.""" + if self.config.scheduler_type == SchedulerType.FIFO: + return FIFOScheduler(self.config) + elif self.config.scheduler_type == SchedulerType.PRIORITY: + return PriorityScheduler(self.config) + # ... other scheduler types +``` + +#### Core Operations +```python + async def schedule(self, request: OmniRequest) -> bool: + """ + Schedule a request for processing. + + Args: + request: The request to schedule + + Returns: + bool: True if successfully scheduled, False otherwise + + Raises: + SchedulerError: If scheduling fails + QueueFullError: If queue is at capacity + """ + pass +``` + +#### Configuration +```python + def update_config(self, new_config: SchedulerConfig) -> None: + """Update scheduler configuration.""" + pass + + def get_config(self) -> SchedulerConfig: + """Get current configuration.""" + pass +``` + +#### Lifecycle Management +```python + async def start(self) -> None: + """Start the scheduler.""" + self._running = True + # Start background tasks + + async def stop(self) -> None: + """Stop the scheduler gracefully.""" + self._running = False + # Cleanup resources + + async def shutdown(self) -> None: + """Force shutdown the scheduler.""" + # Immediate cleanup +``` + +#### Monitoring +```python + def get_status(self) -> Dict[str, Any]: + """Get current scheduler status.""" + return { + "running": self._running, + "queue_size": self.scheduler.queue_size(), + "processed_requests": self.scheduler.processed_count(), + "cache_hit_rate": self.cache_manager.get_hit_rate() + } + + def get_metrics(self) -> Dict[str, float]: + """Get performance metrics.""" + pass +``` + +### 4. Configuration + +```python +@dataclass +class OmniConfig: + """Configuration for the core module.""" + + # Scheduler configuration + scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) + + # Cache configuration + dit_cache: DiTCacheConfig = field(default_factory=DiTCacheConfig) + + # Resource limits + max_memory_gb: float = 16.0 + max_gpu_utilization: float = 0.8 + + # Timeouts + request_timeout: int = 300 + worker_timeout: int = 60 + + def validate(self) -> None: + """Validate configuration parameters.""" + if self.max_memory_gb <= 0: + raise ValueError("max_memory_gb must be positive") + if not 0 < self.max_gpu_utilization <= 1: + raise ValueError("max_gpu_utilization must be between 0 and 1") +``` + +### 5. Error Handling + +```python +class CoreModuleError(Exception): + """Base exception for core module errors.""" + pass + +class SchedulerError(CoreModuleError): + """Scheduler-related errors.""" + pass + +class QueueFullError(SchedulerError): + """Queue is at capacity.""" + pass + +class CacheError(CoreModuleError): + """Cache-related errors.""" + pass + +class ResourceError(CoreModuleError): + """Resource allocation errors.""" + pass +``` + +### 6. Examples + +#### Basic Usage +```python +from vllm_omni.core import CoreScheduler, CoreModuleConfig +from vllm_omni.request import create_text_request + +# Create configuration +config = CoreModuleConfig( + scheduler=SchedulerConfig(scheduler_type=SchedulerType.FIFO), + max_memory_gb=8.0 +) + +# Initialize scheduler +scheduler = CoreScheduler(config) +await scheduler.start() + +# Create and schedule a request +request = create_text_request( + request_id="req_001", + prompt="Hello, world!", + sampling_params=sampling_params +) + +success = await scheduler.schedule_request(request) +if success: + result = await scheduler.process_request(request) + print(f"Result: {result}") +``` + +#### Advanced Usage +```python +# Custom scheduler with priority handling +config = CoreModuleConfig( + scheduler=SchedulerConfig( + scheduler_type=SchedulerType.PRIORITY, + priority_weights={"high": 2.0, "normal": 1.0, "low": 0.5} + ) +) + +scheduler = CoreScheduler(config) +await scheduler.start() + +# Monitor scheduler status +status = scheduler.get_status() +print(f"Queue size: {status['queue_size']}") +print(f"Cache hit rate: {status['cache_hit_rate']:.2%}") +``` + +--- + +## Module-Specific Guidelines + +### Core Module +- Focus on scheduling, caching, and resource management +- Provide clear interfaces for other modules +- Handle concurrency and thread safety +- Implement comprehensive monitoring + +### Engine Module +- Handle model loading and inference +- Support both AR and diffusion models +- Provide unified interface for different model types +- Implement efficient memory management + +### Executor Module +- Coordinate between different engines +- Handle request routing and load balancing +- Manage execution pipelines +- Provide error recovery mechanisms + +### Worker Module +- Handle actual model execution +- Manage GPU resources +- Implement batching strategies +- Provide performance optimization + +--- + +## Checklist for API Design + +- [ ] **Clear Purpose**: Module purpose is well-defined +- [ ] **Complete Interface**: All public methods are documented +- [ ] **Error Handling**: Comprehensive error handling strategy +- [ ] **Configuration**: Flexible configuration system +- [ ] **Examples**: Basic and advanced usage examples +- [ ] **Type Hints**: All methods have proper type hints +- [ ] **Documentation**: Comprehensive docstrings +- [ ] **Testing**: Unit tests for all public methods +- [ ] **Integration**: Clear integration points with other modules +- [ ] **Performance**: Performance considerations documented + +--- + +## Submission Guidelines + +1. Create a new file: `docs/api/[module_name]_api.md` +2. Follow the template structure exactly +3. Include all required sections +4. Provide working code examples +5. Ensure all methods have proper docstrings +6. Include error handling strategies +7. Submit for review before implementation + +This template ensures consistency and completeness across all vLLM-omni modules. diff --git a/docs/api/README.md b/docs/api/README.md new file mode 100644 index 0000000000..681bbba81d --- /dev/null +++ b/docs/api/README.md @@ -0,0 +1,235 @@ +# vLLM-omni API Documentation + +This directory contains comprehensive API design documentation for all core modules in vLLM-omni. These templates provide a standardized structure for designing and implementing the core, engine, executor, and worker modules. + +## 📋 API Design Templates + +### 1. [API Design Template](API_DESIGN_TEMPLATE.md) +**Master template** for creating API documentation for any module in vLLM-omni. Use this as a starting point for new modules. + +### 2. [Core Module API](core_module_api.md) +**Core module** provides fundamental scheduling, caching, and resource management functionality. + +**Key Components:** +- Request scheduling and prioritization +- DiT cache management for diffusion models +- Resource allocation and coordination +- Inter-module communication + +### 3. [Engine Module API](engine_module_api.md) +**Engine module** handles model loading, inference execution, and output processing. + +**Key Components:** +- Model loading and initialization +- Inference execution for AR and diffusion models +- Input preprocessing for various modalities +- Output postprocessing and formatting + +### 4. [Executor Module API](executor_module_api.md) +**Executor module** coordinates and manages request execution across different engines and workers. + +**Key Components:** +- Request routing and load balancing +- Execution pipeline coordination +- Worker management and task distribution +- Error handling and recovery + +### 5. [Worker Module API](worker_module_api.md) +**Worker module** provides the actual execution environment for model inference. + +**Key Components:** +- Model execution and inference +- GPU resource management +- Request batching and processing +- Performance optimization + +## 🎯 How to Use These Templates + +### For New Module Development + +1. **Start with the Master Template** + - Copy `API_DESIGN_TEMPLATE.md` + - Rename to `[module_name]_api.md` + - Follow the structure exactly + +2. **Fill in Module-Specific Details** + - Update the module overview + - Define core classes and interfaces + - Specify public API methods + - Add configuration options + - Define error handling strategy + - Provide usage examples + +3. **Review and Validate** + - Ensure all sections are complete + - Verify code examples work + - Check for consistency with other modules + - Validate configuration options + +### For Existing Module Updates + +1. **Update API Documentation** + - Modify existing API files + - Add new methods and classes + - Update examples and configuration + - Maintain backward compatibility notes + +2. **Version Control** + - Track changes in git + - Use clear commit messages + - Tag major API changes + +## 📚 Template Structure + +Each API template follows this standardized structure: + +### 1. Module Overview +- **Purpose**: What the module does +- **Responsibilities**: Key responsibilities +- **Dependencies**: Required modules +- **Integration Points**: How it connects with other modules + +### 2. Core Classes/Interfaces +- **Base Classes**: Abstract base classes +- **Implementation Classes**: Concrete implementations +- **Data Structures**: Key data models + +### 3. Public API Methods +- **Initialization**: Constructor and setup +- **Core Operations**: Main functionality +- **Configuration**: Settings management +- **Lifecycle Management**: Start/stop/cleanup +- **Monitoring**: Status and metrics + +### 4. Configuration +- **Configuration Classes**: Dataclasses for settings +- **Required Parameters**: Must-have settings +- **Optional Parameters**: Optional settings with defaults +- **Validation**: Parameter validation rules + +### 5. Error Handling +- **Custom Exceptions**: Module-specific errors +- **Error Codes**: Standardized error codes +- **Recovery Strategies**: Error recovery approaches + +### 6. Examples +- **Basic Usage**: Simple usage examples +- **Advanced Usage**: Complex scenarios +- **Integration Examples**: Multi-module usage + +## 🔧 Implementation Guidelines + +### Code Standards +- Use type hints for all methods +- Include comprehensive docstrings +- Follow PEP 8 style guidelines +- Use async/await for I/O operations + +### Error Handling +- Define specific exception types +- Provide meaningful error messages +- Include error recovery strategies +- Log errors appropriately + +### Configuration +- Use dataclasses for configuration +- Provide sensible defaults +- Include validation methods +- Support environment variables + +### Testing +- Write unit tests for all public methods +- Include integration tests +- Test error conditions +- Validate configuration options + +## 🚀 Getting Started + +### For Developers + +1. **Choose a Module** + - Pick the module you want to work on + - Read the corresponding API documentation + - Understand the responsibilities and interfaces + +2. **Set Up Development Environment** + ```bash + # Clone the repository + git clone https://github.com/hsliuustc0106/vllm-omni.git + cd vllm-omni + + # Install dependencies + pip install -r requirements-dev.txt + + # Set up pre-commit hooks + pre-commit install + ``` + +3. **Start Implementation** + - Create the module directory + - Implement base classes first + - Add concrete implementations + - Write tests as you go + +### For Contributors + +1. **Review API Documentation** + - Read through all module APIs + - Understand the overall architecture + - Identify areas for improvement + +2. **Propose Changes** + - Create issues for API changes + - Discuss in pull requests + - Update documentation accordingly + +3. **Submit Contributions** + - Follow the coding standards + - Include tests for new features + - Update documentation + - Submit pull requests + +## 📖 Additional Resources + +### Architecture Overview +- [vLLM-omni Architecture Design](../architecture/vLLM-omni%20arch%20design%20doc.md) +- [Component Categorization](../../ARCHITECTURE_CATEGORIZATION.md) + +### Development Guides +- [Development Setup](../../README.md#development) +- [Testing Guidelines](../../tests/README.md) +- [Contributing Guidelines](../../CONTRIBUTING.md) + +### Examples +- [Basic Examples](../../examples/basic/) +- [Advanced Examples](../../examples/advanced/) +- [Multimodal Examples](../../examples/multimodal/) + +## 🤝 Contributing + +We welcome contributions to improve the API documentation and implementation. Please: + +1. **Follow the Template Structure** + - Use the master template as a guide + - Maintain consistency across modules + - Include all required sections + +2. **Provide Working Examples** + - Test all code examples + - Include both basic and advanced usage + - Show integration patterns + +3. **Keep Documentation Updated** + - Update docs when APIs change + - Version control documentation changes + - Maintain backward compatibility notes + +## 📝 Notes + +- All API documentation should be kept up-to-date with implementation +- Code examples should be tested and working +- Configuration options should be validated +- Error handling should be comprehensive +- Performance considerations should be documented + +For questions or suggestions about the API documentation, please open an issue or start a discussion in the repository. diff --git a/docs/architecture/vLLM-omni arch design doc.md b/docs/architecture/high level arch design.md similarity index 99% rename from docs/architecture/vLLM-omni arch design doc.md rename to docs/architecture/high level arch design.md index 0306606c0c..1545ca877a 100644 --- a/docs/architecture/vLLM-omni arch design doc.md +++ b/docs/architecture/high level arch design.md @@ -136,4 +136,4 @@ Hongsheng Liu: [hsliuustc@gmail.com](mailto:hsliuustc@gmail.com) [image10]: -[image11]: \ No newline at end of file +[image11]: diff --git a/docs/architecture/implementation_architecture.md b/docs/architecture/implementation_architecture.md new file mode 100644 index 0000000000..a83d96d1e9 --- /dev/null +++ b/docs/architecture/implementation_architecture.md @@ -0,0 +1,353 @@ +# vLLM-omni Implementation Architecture + +## 1. Package Structure + +``` +vllm_omni/ +├── __init__.py # Main package exports +├── config/ # Configuration management +│ ├── __init__.py +│ ├── stage_config.py # OmniStageConfig implementation +│ ├── dit_config.py # DiT-specific configurations +│ └── cache_config.py # Caching configurations +├── core/ # Core processing components +│ ├── __init__.py +│ ├── stage_manager.py # Multi-stage orchestration +│ ├── dit_cache_manager.py # DiT caching system +│ └── sched/ # Schedulers +│ ├── __init__.py +│ ├── scheduler.py # Base scheduler interface + +├── engine/ # Engine components +│ ├── __init__.py +│ ├── processor.py # Output processing +│ └── output_processor.py # Multimodal output handling +├── executor/ # Executor implementations +│ ├── __init__.py +│ ├── base_executor.py # Base executor interface +│ └── diffusers_executor.py # Diffusers pipeline executor +├── model_executor/ # Model execution components +│ ├── __init__.py +├── worker/ # Worker implementations +│ ├── __init__.py +│ ├── gpu_diffusion_model_runner.py # Loads and runs diffusers pipelines +│ └── gpu_diffusion_worker.py # Worker facade around the diffusion runner +├── entrypoints/ # Entry points and CLI +│ ├── __init__.py +│ ├── omni.py # OmniServeCommand +│ ├── omni_llm.py # OmniLLM and AsyncOmniLLM +│ └── cli/ # CLI integration +│ ├── __init__.py +│ └── main.py # CLI main entry point +├── request.py # Request handling +├── dit_cache_interface.py # DiT cache interface +└── utils/ # Utility functions + ├── __init__.py + ├── multimodal.py # Multimodal utilities + └── vae.py # VAE utilities for image processing +``` + +## 2. Core Module Dependencies + +### 2.1 vLLM Integration Points + +```python +# Key vLLM imports and extensions +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.core.sched.scheduler import SchedulerInterface +from vllm.v1.executor.multiproc_executor import MultiprocExecutor +from vllm.v1.worker.gpu_worker import Worker as GPUWorker +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.outputs import ModelRunnerOutput +``` + +### 2.2 Internal Dependencies + +```python +# Internal module dependencies +vllm_omni.config → vllm_omni.core +vllm_omni.core → vllm_omni.engine, vllm_omni.executor, vllm_omni.worker +vllm_omni.engine → vllm_omni.model_executor +vllm_omni.entrypoints → vllm_omni.core +``` + +## 3. Configuration System + +### 3.1 Stage Configuration + +```python +# vllm_omni/config/stage_config.py +@dataclass +class OmniStageConfig: + stage_id: int + engine_type: Literal["AR", "DiT"] + model_path: str + input_modalities: List[str] + output_modalities: List[str] + vllm_config: Optional[VllmConfig] = None + dit_config: Optional[DiTConfig] = None + executor_class: Optional[type[Executor]] = None + stage_output: Optional[Any] = None +``` + +### 3.2 DiT Configuration + +```python +# vllm_omni/config/stage_config.py +@dataclass +class DiTConfig: + model_config: Optional[ModelConfig] = None + scheduler_config: Optional[Any] = None + device_config: Optional[DeviceConfig] = None + load_config: Optional[LoadConfig] = None + compilation_config: Optional[CompilationConfig] = None + dit_cache_config: Optional[DiTCacheConfig] = None + num_inference_steps: int + guidance_scale: float = 7.5 + use_diffusers: bool = False + diffusers_pipeline: Optional[str] = None + height: int = 512 + width: int = 512 +``` + +### 3.3 Cache Configuration + +```python +# vllm_omni/config/cache_config.py +@dataclass +class DiTCacheConfig: + cache_tensors: List[DiTCacheTensor] + max_cache_size: int + cache_strategy: str = "fifo" + enable_optimization: bool = True +``` + +## 4. Core Implementation Details + +### 4.1 OmniLLM Implementation + +```python +# vllm_omni/core/omni_llm.py +class OmniLLM(LLM): + def __init__(self, stage_configs: List[OmniStageConfig]): + super().__init__() + self.stage_configs = stage_configs + self.engine_list: List[LLMEngine] = [] + self.output_processor = MultimodalOutputProcessor() + self._initialize_stage_engines() + + def _initialize_stage_engines(self) -> None: + """Initialize LLMEngine instances for each stage""" + for stage_config in self.stage_configs: + if stage_config.engine_type == "AR": + engine = self._create_ar_engine(stage_config) + elif stage_config.engine_type == "DiT": + engine = self._create_dit_engine(stage_config) + self.engine_list.append(engine) + + def generate(self, stage_args_list: List[Dict], **kwargs) -> List[RequestOutput]: + """Main generation interface - orchestrates multi-stage processing""" + current_output = None + + for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): + stage_engine = self.engine_list[i] + + # Prepare input for this stage + processed_input = self._process_stage_inputs(stage_config, stage_args, current_output) + + # Execute stage + stage_output = self._execute_stage(stage_engine, processed_input) + + # Update for next stage + current_output = stage_output + stage_config.stage_output = stage_output + + # Process final output + return self.output_processor.process_output(current_output) +``` + +### 4.2 AsyncOmniLLM Implementation + +```python +# vllm_omni/core/omni_llm.py +class AsyncOmniLLM(AsyncLLM): + def __init__(self, stage_configs: List[OmniStageConfig]): + super().__init__() + self.stage_configs = stage_configs + self.async_engine_list: List[AsyncLLM] = [] + self.output_processor = MultimodalOutputProcessor() + self._initialize_async_stage_engines() + + async def generate_async(self, stage_args_list: List[Dict], **kwargs) -> List[RequestOutput]: + """Async generation interface""" + # Similar to OmniLLM but with async/await patterns + pass +`` + +### 4.3 Model Runners + +```python +# vllm_omni/model_executor/dit_model_runner.py +class OmniDiffusionModelRunner(GPUModelRunner): + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + # DiT model execution logic + return ModelRunnerOutput( + req_ids=[...], + req_id_to_index={...}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[tensor, ...], # DiT output tensors + kv_connector_output=None, + num_nans_in_logits=None, + ) + +# vllm_omni/model_executor/ar_model_runner.py +class OmniARModelRunner(GPUModelRunner): + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + # AR model execution with hidden state output + return ModelRunnerOutput( + req_ids=[...], + req_id_to_index={...}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[hidden_states, ...], # Hidden states + kv_connector_output=None, + num_nans_in_logits=None, + ) +``` + +### 4.5 Output Processing + +```python +# vllm_omni/engine/output_processor.py +class MultimodalOutputProcessor(OutputProcessor): + def __init__(self): + self.output_handlers: Dict[str, Callable] = { + "image": self._process_image_output, + "text+image": self._process_text_image_output, + "latents": self._process_latents_output, + "text": self._process_text_output, + } + + def process_outputs(self, engine_core_outputs: List[EngineCoreOutput], ...): + """Process multimodal outputs based on type""" + for engine_core_output in engine_core_outputs: + output_type = self._detect_output_type(engine_core_output) + handler = self.output_handlers.get(output_type, self._process_pooling_output) + handler(engine_core_output) +``` + +## 5. CLI Integration + +### 5.1 Entry Point Override + +```python +# vllm_omni/entrypoints/cli/main.py +def main(): + """Main CLI entry point that intercepts vLLM commands""" + if "--omni" in sys.argv: + omni_args = [arg for arg in sys.argv[1:] if arg != "--omni"] + omni_serve = OmniServeCommand() + omni_serve.run(omni_args) + else: + from vllm.entrypoints.cli.main import main as vllm_main + vllm_main() +``` + +### 5.2 Package Configuration + +```toml +# pyproject.toml updates +[project.scripts] +vllm = "vllm_omni.entrypoints.cli.main:main" +vllm-omni = "vllm_omni.entrypoints.cli.main:main" + +[project.entry-points."vllm.plugins"] +omni = "vllm_omni.plugin:OmniPlugin" +``` + +## 6. Testing Strategy + +### 6.1 Unit Tests +- Individual component testing +- Mock vLLM dependencies +- Configuration validation + +### 6.2 Integration Tests +- End-to-end pipeline testing +- vLLM compatibility testing +- Multi-stage processing validation + +### 6.3 Performance Tests +- Benchmarking against native vLLM +- Memory usage profiling +- Latency measurements + +## 7. Installation and Setup + +### 7.1 Package Installation +```bash +pip install vllm>=0.10.2 +pip install vllm-omni +``` + +### 7.2 Development Setup +```bash +git clone https://github.com/hsliuustc0106/vllm-omni +cd vllm-omni +pip install -e ".[dev]" +``` + +### 7.3 Usage +```bash +vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8000 +``` + +## 8. Memory Management + +### 8.1 DiT Cache Management +- Tensor caching for intermediate results +- Memory pooling for efficient allocation +- Cache eviction strategies + +### 8.2 Multi-Stage Memory +- Inter-stage data passing optimization +- Memory sharing between stages +- Garbage collection optimization + +## 9. Error Handling + +### 9.1 Stage Failure Handling +- Graceful degradation on stage failures +- Error propagation and reporting +- Recovery mechanisms + +### 9.2 vLLM Compatibility +- Version compatibility checks +- API change detection +- Fallback mechanisms + +## 10. Monitoring and Logging + +### 10.1 Performance Metrics +- Stage execution times +- Memory usage per stage +- Throughput measurements + +### 10.2 Debug Information +- Request tracing across stages +- Cache hit/miss ratios +- Error logging and reporting diff --git a/examples/README.md b/examples/README.md deleted file mode 100644 index e79470c893..0000000000 --- a/examples/README.md +++ /dev/null @@ -1,62 +0,0 @@ -# vLLM-omni Examples - -This directory contains examples demonstrating how to use vLLM-omni for various tasks. - -## Basic Examples - -- [Text Generation](basic/text_generation.py) - Basic text generation using AR models -- [Image Generation](basic/image_generation.py) - Image generation using diffusion models -- [Multimodal Processing](basic/multimodal_processing.py) - Processing text and images together - -## Advanced Examples - -- [Custom Model Integration](advanced/custom_model.py) - Integrating custom models -- [Batch Processing](advanced/batch_processing.py) - Efficient batch processing -- [Streaming Output](advanced/streaming.py) - Real-time streaming output - -## Multimodal Examples - -- [Image-to-Text](multimodal/image_to_text.py) - Image captioning and description -- [Text-to-Image](multimodal/text_to_image.py) - Text-to-image generation -- [Audio Processing](multimodal/audio_processing.py) - Audio generation and processing -- [Video Generation](multimodal/video_generation.py) - Video generation workflows - -## API Examples - -- [REST API](api/rest_api.py) - Using the REST API interface -- [Gradio Interface](api/gradio_interface.py) - Creating Gradio web interfaces -- [ComfyUI Integration](api/comfyui_integration.py) - ComfyUI workflow integration - -## Configuration Examples - -- [Custom Configuration](config/custom_config.py) - Custom configuration setup -- [Multi-GPU Setup](config/multi_gpu.py) - Multi-GPU configuration -- [Distributed Processing](config/distributed.py) - Distributed processing setup - -## Getting Started - -1. Install vLLM-omni: - ```bash - pip install vllm-omni - ``` - -2. Run a basic example: - ```bash - python examples/basic/text_generation.py - ``` - -3. Explore the examples in each subdirectory for more advanced usage. - -## Requirements - -Most examples require additional dependencies. Install them with: - -```bash -pip install -r requirements.txt -``` - -For development examples, install the development dependencies: - -```bash -pip install -r requirements-dev.txt -``` diff --git a/examples/basic/api_client.py b/examples/basic/api_client.py new file mode 100644 index 0000000000..61e45a81be --- /dev/null +++ b/examples/basic/api_client.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +""" +Simple API client example for vLLM-omni server. + +This example shows how to interact with the vLLM-omni API server +using HTTP requests. +""" + +import requests +import json +import time + + +def test_health(host="localhost", port=8000): + """Test if the server is healthy.""" + url = f"http://{host}:{port}/health" + try: + response = requests.get(url, timeout=5) + if response.status_code == 200: + print("✅ Server is healthy") + return True + else: + print(f"❌ Server health check failed: {response.status_code}") + return False + except requests.exceptions.RequestException as e: + print(f"❌ Cannot connect to server: {e}") + return False + + +def generate_text(prompt, host="localhost", port=8000, **kwargs): + """Generate text using the API server.""" + url = f"http://{host}:{port}/generate" + + # Default parameters + params = { + "prompts": [prompt], + "max_tokens": kwargs.get("max_tokens", 100), + "temperature": kwargs.get("temperature", 0.7), + "top_p": kwargs.get("top_p", 1.0), + "frequency_penalty": kwargs.get("frequency_penalty", 0.0), + "presence_penalty": kwargs.get("presence_penalty", 0.0), + } + + # Add optional parameters + if "stop" in kwargs: + params["stop"] = kwargs["stop"] + + try: + response = requests.post(url, json=params, timeout=30) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"❌ API request failed: {e}") + return None + + +def main(): + """Main example function.""" + print("vLLM-omni API Client Example") + print("=" * 40) + + # Test server health + if not test_health(): + print("Please start the server first:") + print(" vllm serve Qwen/Qwen3-0.6B --omni --port 8000") + return + + # Example 1: Simple generation + print("\n1. Simple text generation:") + result = generate_text("Hello, how are you?", max_tokens=50) + if result: + text = result["outputs"][0]["text"] + print(f" Input: Hello, how are you?") + print(f" Output: {text}") + + # Example 2: Creative writing + print("\n2. Creative writing:") + result = generate_text( + "Write a short story about a robot learning to paint", + max_tokens=150, + temperature=0.9 + ) + if result: + text = result["outputs"][0]["text"] + print(f" Input: Write a short story about a robot learning to paint") + print(f" Output: {text}") + + # Example 3: Question answering + print("\n3. Question answering:") + result = generate_text( + "What is the capital of France?", + max_tokens=30, + temperature=0.3 + ) + if result: + text = result["outputs"][0]["text"] + print(f" Input: What is the capital of France?") + print(f" Output: {text}") + + # Example 4: Multiple prompts + print("\n4. Multiple prompts:") + prompts = [ + "Tell me a joke", + "What's 2+2?", + "Describe the color blue" + ] + + for prompt in prompts: + result = generate_text(prompt, max_tokens=30) + if result: + text = result["outputs"][0]["text"] + print(f" Q: {prompt}") + print(f" A: {text}") + + # Example 5: Show full response structure + print("\n5. Full response structure:") + result = generate_text("Hello world", max_tokens=20) + if result: + print(" Full API response:") + print(json.dumps(result, indent=2)) + + print("\n" + "=" * 40) + print("API client examples completed!") + + +if __name__ == "__main__": + main() + diff --git a/examples/basic/text_generation.py b/examples/basic/text_generation.py new file mode 100644 index 0000000000..a6a9b5179b --- /dev/null +++ b/examples/basic/text_generation.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +""" +Simple vLLM-omni usage example. + +This example demonstrates how to use vLLM-omni for basic text generation +with a single AR (Autoregressive) stage. +""" + +import asyncio +from vllm_omni.entrypoints.omni_llm import OmniLLM, AsyncOmniLLM +from vllm_omni.config import create_ar_stage_config + + +def sync_example(): + """Synchronous usage example.""" + print("=== Synchronous vLLM-omni Example ===") + + # Create a simple AR stage configuration + stage_config = create_ar_stage_config( + stage_id=0, + model_path="Qwen/Qwen3-0.6B", + input_modalities=["text"], + output_modalities=["text"] + ) + + # Initialize OmniLLM + omni_llm = OmniLLM([stage_config]) + + # Prepare stage arguments + stage_args = [{ + "prompt": "Hello, how are you today?", + "max_tokens": 50, + "temperature": 0.7 + }] + + # Generate text + print(f"Input: {stage_args[0]['prompt']}") + outputs = omni_llm.generate(stage_args) + + # Display results + for i, output in enumerate(outputs): + print(f"Output {i}:") + if hasattr(output, 'outputs') and output.outputs: + for completion in output.outputs: + print(f" Text: {completion.text}") + print(f" Finished: {completion.finish_reason != 'length'}") + print(f" Tokens: {len(completion.token_ids)} tokens") + + +async def async_example(): + """Asynchronous usage example.""" + print("\n=== Asynchronous vLLM-omni Example ===") + + # Create a simple AR stage configuration + stage_config = create_ar_stage_config( + stage_id=0, + model_path="Qwen/Qwen3-0.6B", + input_modalities=["text"], + output_modalities=["text"] + ) + + # Initialize AsyncOmniLLM + omni_llm = AsyncOmniLLM([stage_config]) + + # Prepare stage arguments + stage_args = [{ + "prompt": "What is artificial intelligence?", + "max_tokens": 100, + "temperature": 0.8 + }] + + # Generate text asynchronously + print(f"Input: {stage_args[0]['prompt']}") + outputs = await omni_llm.generate_async(stage_args) + + # Display results + for i, output in enumerate(outputs): + print(f"Output {i}:") + if hasattr(output, 'outputs') and output.outputs: + for completion in output.outputs: + print(f" Text: {completion.text}") + print(f" Finished: {completion.finish_reason != 'length'}") + print(f" Tokens: {len(completion.token_ids)} tokens") + + +def multi_prompt_example(): + """Example with multiple prompts.""" + print("\n=== Multi-Prompt Example ===") + + # Create stage configuration + stage_config = create_ar_stage_config( + stage_id=0, + model_path="Qwen/Qwen3-0.6B", + input_modalities=["text"], + output_modalities=["text"] + ) + + # Initialize OmniLLM + omni_llm = OmniLLM([stage_config]) + + # Multiple prompts + prompts = [ + "Tell me a joke", + "What's the weather like?", + "Explain quantum computing" + ] + + for prompt in prompts: + print(f"\nInput: {prompt}") + stage_args = [{ + "prompt": prompt, + "max_tokens": 30, + "temperature": 0.9 + }] + + outputs = omni_llm.generate(stage_args) + + if outputs and hasattr(outputs[0], 'outputs') and outputs[0].outputs: + completion = outputs[0].outputs[0] + print(f"Output: {completion.text}") + + +if __name__ == "__main__": + print("vLLM-omni Simple Usage Examples") + print("=" * 40) + + # Run synchronous example + sync_example() + + # Run asynchronous example + asyncio.run(async_example()) + + # Run multi-prompt example + multi_prompt_example() + + print("\n" + "=" * 40) + print("Examples completed!") diff --git a/examples/omni/README.md b/examples/omni/README.md new file mode 100644 index 0000000000..69274991ba --- /dev/null +++ b/examples/omni/README.md @@ -0,0 +1,31 @@ +# Omni Examples + +Examples showcasing multi-stage AR + DiT pipelines powered by vLLM-omni. + +## AR ➜ DiT with Diffusers Backend + +```bash +# Optional: customise models +export AR_MODEL=Qwen/Qwen3-0.6B +export DIT_MODEL=./models/stable-diffusion-2-1/ + +# Run with defaults (20 steps @ 512x512) +python examples/omni/ar_dit_diffusers.py + +# Faster run +python examples/omni/ar_dit_diffusers.py --steps 14 --height 384 --width 384 --guidance 4.5 +``` + +Arguments: +- `--ar-model`: Override AR stage model (default `Qwen/Qwen3-0.6B`). +- `--dit-model`: Override DiT stage model (default env `DIT_MODEL`, else diffusers repo). +- `--steps`: Number of diffusion steps (default 20). +- `--guidance`: Guidance scale (default 5.0). +- `--height` / `--width`: Output resolution (default 512). +- `--seed`: Optional RNG seed. +- `--prompt`, `--temperature`, `--max-tokens`: Control AR stage generation. +- `--output`: Destination path for the generated image file. + +Generated images are saved next to the script as `omni_dit_output.png` by default. + +See the script for additional tweaks (scheduler, prompt chaining, etc.). diff --git a/examples/omni/ar_dit_diffusers.py b/examples/omni/ar_dit_diffusers.py new file mode 100644 index 0000000000..5685b5dc3c --- /dev/null +++ b/examples/omni/ar_dit_diffusers.py @@ -0,0 +1,175 @@ +"""Run the AR → DiT (diffusers) pipeline using YAML configuration defaults.""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path +from typing import Dict, List + +import yaml + +from vllm_omni import ( + OmniLLM, + DiTConfig, + create_ar_stage_config, + create_dit_stage_config, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run a two-stage AR → DiT pipeline with config-derived defaults." + ) + parser.add_argument( + "--config", + type=Path, + default=Path(__file__).resolve().parent / "configs" / "ar_dit_local.yaml", + help="Path to a YAML configuration describing the pipeline stages.", + ) + parser.add_argument( + "--prompt", + default="A scenic watercolor painting of a lighthouse at sunset", + help="Prompt passed to the AR stage.", + ) + parser.add_argument( + "--negative-prompt", + default=None, + help="Optional negative prompt forwarded to the diffusion stage.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Optional seed for deterministic diffusion sampling.", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("./omni_dit_output.png"), + help="Destination path for the generated image.", + ) + return parser.parse_args() + + +def _normalize_modalities(values: List[str] | None, default: List[str]) -> List[str]: + if not values: + return default + return list(values) + + +def _load_stage_configs_from_yaml(config_path: Path): + with config_path.open("r", encoding="utf-8") as f: + config = yaml.safe_load(f) or {} + + stages = config.get("stages", []) + if not stages: + raise ValueError(f"No stages defined in config: {config_path}") + + stage_configs = [] + for stage in stages: + stage_id = stage["stage_id"] + engine_type = stage["engine_type"] + default_stage_args = stage.get("default_stage_args") + + if engine_type == "AR": + stage_configs.append( + create_ar_stage_config( + stage_id=stage_id, + model_path=stage["model_path"], + input_modalities=_normalize_modalities( + stage.get("input_modalities"), ["text"] + ), + output_modalities=_normalize_modalities( + stage.get("output_modalities"), ["text"] + ), + default_stage_args=default_stage_args, + ) + ) + elif engine_type == "DiT": + dit_cfg_dict: Dict = dict(stage.get("dit_config", {})) + if "num_inference_steps" not in dit_cfg_dict: + raise ValueError( + "DiT stage requires 'num_inference_steps' in dit_config" + ) + dit_cfg = DiTConfig(**dit_cfg_dict) + + stage_configs.append( + create_dit_stage_config( + stage_id=stage_id, + model_path=stage["model_path"], + input_modalities=_normalize_modalities( + stage.get("input_modalities"), ["text"] + ), + output_modalities=_normalize_modalities( + stage.get("output_modalities"), ["image"] + ), + dit_config=dit_cfg, + default_stage_args=default_stage_args, + ) + ) + else: + raise ValueError(f"Unsupported engine_type '{engine_type}' in config") + + return stage_configs + + +def _apply_env_overrides(stage_configs): + ar_override = os.environ.get("AR_MODEL") + dit_override = os.environ.get("DIT_MODEL") + + for stage_config in stage_configs: + if ar_override and stage_config.engine_type == "AR": + stage_config.model_path = ar_override + if dit_override and stage_config.engine_type == "DiT": + stage_config.model_path = dit_override + + +def main(): + args = parse_args() + + stage_configs = _load_stage_configs_from_yaml(args.config) + _apply_env_overrides(stage_configs) + + omni = OmniLLM(stage_configs) + + stage_overrides: Dict[int, Dict[str, object]] = {} + if args.seed is not None or args.negative_prompt: + for stage_config in stage_configs: + if stage_config.engine_type != "DiT": + continue + override = stage_overrides.setdefault(stage_config.stage_id, {}) + if args.seed is not None: + override["seed"] = args.seed + if args.negative_prompt: + override["negative_prompt"] = args.negative_prompt + + outputs = omni.generate( + prompt=args.prompt, + stage_overrides=stage_overrides if stage_overrides else None, + ) + + image = None + for request_output in outputs: + for completion in getattr(request_output, "outputs", []) or []: + candidate = getattr(completion, "image", None) + if candidate is not None: + image = candidate + break + if image is not None: + break + + if image is None: + print("No image found in outputs.") + return + + out_path = args.output.resolve() + try: + image.save(out_path) + print(f"Saved image to {out_path}") + except Exception: + print("Generated output is not a PIL.Image; skipping save.") + + +if __name__ == "__main__": + main() diff --git a/examples/omni/configs/ar_dit_local.yaml b/examples/omni/configs/ar_dit_local.yaml new file mode 100644 index 0000000000..5bae87e3c1 --- /dev/null +++ b/examples/omni/configs/ar_dit_local.yaml @@ -0,0 +1,28 @@ +# Example config for running OmniLLM with local AR + DiT models. +stages: + - stage_id: 0 + engine_type: AR + model_path: Qwen/Qwen3-0.6B + input_modalities: [text] + output_modalities: [text] + default_stage_args: + max_tokens: 16 + temperature: 0.7 + - stage_id: 1 + engine_type: DiT + model_path: stabilityai/stable-diffusion-2-1 + input_modalities: [text] + output_modalities: [image] + default_stage_args: + height: 512 + width: 512 + dit_config: + num_inference_steps: 20 + guidance_scale: 5.0 + use_diffusers: true + diffusers_pipeline: auto + height: 512 + width: 512 + device_config: + device: mps + dtype: float16 diff --git a/pyproject.toml b/pyproject.toml index ca6beb4774..4e3382ccfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "numpy>=1.21.0", "pillow>=8.0.0", "opencv-python>=4.5.0", + "PyYAML>=6.0", "gradio>=3.0.0", "fastapi>=0.100.0", "uvicorn>=0.20.0", @@ -71,7 +72,8 @@ Documentation = "https://vllm-omni.readthedocs.io" "Bug Tracker" = "https://github.com/hsliuustc0106/vllm-omni/issues" [project.scripts] -vllm-omni = "vllm_omni.cli:main" +vllm = "vllm_omni.entrypoints.cli.main:main" +vllm-omni = "vllm_omni.entrypoints.cli.main:main" [tool.setuptools.packages.find] where = ["."] diff --git a/requirements.txt b/requirements.txt index 45339bd420..50606db8b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ # Core dependencies -vllm>=0.9.0 -torch>=2.0.0 +vllm>=0.10.2 +torch>=2.7 transformers>=4.30.0 numpy>=1.21.0 pillow>=8.0.0 opencv-python>=4.5.0 +PyYAML>=6.0 # API and serving fastapi>=0.100.0 @@ -16,6 +17,11 @@ gradio>=3.0.0 asyncio-mqtt>=0.11.0 ray>=2.0.0 -# Optional: DiT acceleration modules -# xDiT # Uncomment when available -# Cache-DiT # Uncomment when available +# Development dependencies +pytest>=7.0.0 +pytest-asyncio>=0.21.0 +pytest-cov>=4.0.0 +black>=23.0.0 +isort>=5.12.0 +flake8>=6.0.0 +mypy>=1.0.0 diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000000..6f4e90bb5e --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,109 @@ +# vLLM-omni Testing Scripts + +This directory contains automated testing scripts for vLLM-omni serving functionality. + +## Scripts Overview + +### 1. `test_serving.sh` - Comprehensive Test Suite +Full-featured testing script that performs complete validation of the vLLM-omni serving functionality. + +**Features:** +- Model existence validation +- Environment setup verification +- Import functionality testing +- Server startup and health checks +- Text generation testing +- Performance benchmarking +- API client integration testing +- Comprehensive logging and error handling + +**Usage:** +```bash +# Use default model and port +./scripts/test_serving.sh + +# Use specific model and port +./scripts/test_serving.sh ./models/Qwen3-0.6B 8000 + +# Use HuggingFace model +./scripts/test_serving.sh Qwen/Qwen3-0.6B 8001 + +# Show help +./scripts/test_serving.sh --help +``` + +**Test Coverage:** +- ✅ Model loading and server startup +- ✅ Health and info endpoints +- ✅ Text generation functionality +- ✅ Performance metrics +- ✅ API client integration +- ✅ All imports working correctly + +### 2. `quick_test.sh` - Fast Validation +Lightweight script for quick validation after making changes. + +**Features:** +- Fast import testing +- Basic server startup +- Health endpoint validation +- Simple text generation test +- Retry mechanism for reliability + +**Usage:** +```bash +# Use default port (8000) +./scripts/quick_test.sh + +# Use specific port +./scripts/quick_test.sh 8001 +``` + +**Test Coverage:** +- ✅ Import functionality +- ✅ Server startup +- ✅ Health endpoint +- ✅ Basic text generation + +## Prerequisites + +1. **Conda Environment**: Ensure `vllm_omni` environment is activated +2. **Model Available**: Qwen3-0.6B model should be available in `./models/Qwen3-0.6B/` +3. **Dependencies**: All vLLM-omni dependencies should be installed + +## Environment Setup + +```bash +# Activate conda environment +conda activate vllm_omni + +# Verify installation +python -c "import vllm_omni; print('vLLM-omni ready')" +``` + +## Model Setup + +The scripts expect the Qwen3-0.6B model to be available. You can: + +1. **Use local model**: Place model in `./models/Qwen3-0.6B/` +2. **Use HuggingFace model**: Pass `Qwen/Qwen3-0.6B` as model path +3. **Download model**: Use the `download_models.py` script + +```bash +# Download model using the provided script +python scripts/download_models.py +``` + +## Output and Logging + +### Test Results +- **Success**: Green `[SUCCESS]` messages +- **Info**: Blue `[INFO]` messages +- **Warnings**: Yellow `[WARNING]` messages +- **Errors**: Red `[ERROR]` messages + +### Log Files +- `server.log`: Server startup and runtime logs +- `api_client_test.log`: API client test results +- `simple_usage_test.log`: Simple usage test results + diff --git a/scripts/test_serving.sh b/scripts/test_serving.sh new file mode 100755 index 0000000000..d3a8d6cda5 --- /dev/null +++ b/scripts/test_serving.sh @@ -0,0 +1,363 @@ +#!/bin/bash + +# vLLM-omni Serving Functionality Test Script +# This script tests the complete serving functionality of vLLM-omni +# Usage: ./scripts/test_serving.sh [model_path] [port] + +set -e # Exit on any error + +# Configuration +MODEL_PATH=${1:-"./models/Qwen3-0.6B"} +PORT=${2:-8000} +HOST="localhost" +TIMEOUT=30 +SERVER_STARTUP_WAIT=15 + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Logging functions +log_info() { + echo -e "${BLUE}[INFO]${NC} $1" +} + +log_success() { + echo -e "${GREEN}[SUCCESS]${NC} $1" +} + +log_warning() { + echo -e "${YELLOW}[WARNING]${NC} $1" +} + +log_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +# Cleanup function +cleanup() { + log_info "Cleaning up..." + if [ ! -z "$SERVER_PID" ]; then + log_info "Stopping server (PID: $SERVER_PID)..." + kill $SERVER_PID 2>/dev/null || true + wait $SERVER_PID 2>/dev/null || true + fi + # Kill any remaining vllm processes + pkill -f "vllm serve" 2>/dev/null || true + log_info "Cleanup completed" +} + +# Set up trap for cleanup on exit +trap cleanup EXIT + +# Check if model exists +check_model() { + log_info "Checking if model exists at: $MODEL_PATH" + if [ ! -d "$MODEL_PATH" ]; then + log_error "Model directory not found: $MODEL_PATH" + log_info "Available models:" + ls -la models/ 2>/dev/null || log_warning "No models directory found" + exit 1 + fi + log_success "Model found: $MODEL_PATH" +} + +# Check if conda environment is activated +check_environment() { + log_info "Checking conda environment..." + if [[ "$CONDA_DEFAULT_ENV" != "vllm_omni" ]]; then + log_warning "vllm_omni conda environment not activated" + log_info "Attempting to activate vllm_omni environment..." + source $(conda info --base)/etc/profile.d/conda.sh + conda activate vllm_omni + if [[ "$CONDA_DEFAULT_ENV" != "vllm_omni" ]]; then + log_error "Failed to activate vllm_omni environment" + exit 1 + fi + fi + log_success "Environment check passed: $CONDA_DEFAULT_ENV" +} + +# Test import functionality +test_imports() { + log_info "Testing imports..." + python -c " +import vllm_omni +from vllm_omni.entrypoints.omni_llm import OmniLLM, AsyncOmniLLM +print('✅ All imports successful') +" || { + log_error "Import test failed" + exit 1 + } + log_success "Import test passed" +} + +# Start the server +start_server() { + log_info "Starting vLLM-omni server on port $PORT..." + log_info "Command: vllm serve $MODEL_PATH --omni --port $PORT" + + # Start server in background + vllm serve "$MODEL_PATH" --omni --port "$PORT" > server.log 2>&1 & + SERVER_PID=$! + + log_info "Server started with PID: $SERVER_PID" + log_info "Waiting $SERVER_STARTUP_WAIT seconds for server to initialize..." + sleep $SERVER_STARTUP_WAIT + + # Check if server is still running + if ! kill -0 $SERVER_PID 2>/dev/null; then + log_error "Server failed to start" + log_info "Server logs:" + cat server.log + exit 1 + fi + + log_success "Server appears to be running" +} + +# Test health endpoint +test_health() { + log_info "Testing health endpoint..." + for attempt in {1..10}; do + local response=$(curl -s -w "%{http_code}" http://$HOST:$PORT/health 2>/dev/null || echo "000") + local http_code="${response: -3}" + local body="${response%???}" + + if [ "$http_code" = "200" ]; then + log_success "Health check passed: $body" + return 0 + fi + + log_warning "Health check attempt $attempt failed (HTTP $http_code). Retrying in 3s..." + sleep 3 + done + + log_error "Health check failed after multiple attempts" + return 1 +} + +# Test info endpoint +test_info() { + log_info "Testing info endpoint..." + local response=$(curl -s -w "%{http_code}" http://$HOST:$PORT/info 2>/dev/null || echo "000") + local http_code="${response: -3}" + local body="${response%???}" + + if [ "$http_code" = "200" ]; then + log_success "Info endpoint working" + echo "$body" | python -m json.tool 2>/dev/null || log_warning "Info response not valid JSON" + else + log_error "Info endpoint failed (HTTP $http_code): $body" + return 1 + fi +} + +# Test text generation +test_generation() { + log_info "Testing text generation..." + + # Test 1: Simple generation + local response=$(curl -s -X POST http://$HOST:$PORT/generate \ + -H "Content-Type: application/json" \ + -d '{"prompts": ["Test the server functionality"], "max_tokens": 20, "temperature": 0.7}' \ + 2>/dev/null || echo "{}") + + if echo "$response" | python -c "import sys, json; data=json.load(sys.stdin); exit(0 if 'outputs' in data and len(data['outputs']) > 0 else 1)" 2>/dev/null; then + log_success "Text generation test passed" + echo "$response" | python -c "import sys, json; data=json.load(sys.stdin); print('Generated text:', data['outputs'][0]['text'][:100] + '...' if len(data['outputs'][0]['text']) > 100 else data['outputs'][0]['text'])" 2>/dev/null + else + log_error "Text generation test failed" + echo "Response: $response" + return 1 + fi +} + +# Test API client example +test_api_client() { + log_info "Testing API client example..." + if [ -f "examples/basic/api_client.py" ]; then + python examples/basic/api_client.py > api_client_test.log 2>&1 || { + log_error "API client test failed" + log_info "API client logs:" + cat api_client_test.log + return 1 + } + log_success "API client test passed" + else + log_warning "API client example not found, skipping" + fi +} + +# Test AR → DiT pipeline example +test_ar_dit_pipeline() { + log_info "Testing AR → DiT diffusers pipeline example..." + + if ! python - 2>/dev/null <<'PY' +import importlib +for pkg in ("torch", "diffusers"): + if importlib.util.find_spec(pkg) is None: + raise ImportError(pkg) +PY + then + log_warning "Required packages for diffusers pipeline not found, skipping AR → DiT test" + return 0 + fi + + if [ ! -f "examples/omni/ar_dit_diffusers.py" ]; then + log_warning "AR → DiT example not found, skipping" + return 0 + fi + + local output_path="logs/ar_dit_pipeline.png" + local log_path="logs/ar_dit_pipeline.log" + rm -f "$output_path" + + local cmd=(python examples/omni/ar_dit_diffusers.py + --prompt "Smoke test prompt of a lighthouse" + --seed 0 + --output "$output_path") + + local env_prefix=() + if [ -n "$AR_MODEL" ]; then + env_prefix+=("AR_MODEL=$AR_MODEL") + fi + if [ -n "$DIT_MODEL" ]; then + env_prefix+=("DIT_MODEL=$DIT_MODEL") + fi + + if ! env "${env_prefix[@]}" "${cmd[@]}" >"$log_path" 2>&1; then + log_error "AR → DiT pipeline test failed" + log_info "AR → DiT logs:" + cat "$log_path" + return 1 + fi + + if [ ! -f "$output_path" ]; then + log_error "AR → DiT pipeline did not produce an image" + log_info "AR → DiT logs:" + cat "$log_path" + return 1 + fi + + log_success "AR → DiT pipeline test passed (output: $output_path)" +} + +# Test simple usage example +test_simple_usage() { + log_info "Testing simple usage example..." + if [ -f "examples/basic/simple_usage.py" ]; then + # Run with timeout to avoid hanging + timeout 60 python examples/basic/simple_usage.py > simple_usage_test.log 2>&1 || { + local exit_code=$? + if [ $exit_code -eq 124 ]; then + log_warning "Simple usage test timed out (60s), but this might be normal for model loading" + else + log_error "Simple usage test failed (exit code: $exit_code)" + log_info "Simple usage logs:" + cat simple_usage_test.log + return 1 + fi + } + log_success "Simple usage test passed" + else + log_warning "Simple usage example not found, skipping" + fi +} + +# Performance test +test_performance() { + log_info "Running performance test..." + + local start_time=$(date +%s) + local response=$(curl -s -X POST http://$HOST:$PORT/generate \ + -H "Content-Type: application/json" \ + -d '{"prompts": ["Performance test prompt"], "max_tokens": 50, "temperature": 0.7}' \ + 2>/dev/null) + local end_time=$(date +%s) + local duration=$((end_time - start_time)) + + if echo "$response" | python -c "import sys, json; data=json.load(sys.stdin); exit(0 if 'outputs' in data else 1)" 2>/dev/null; then + log_success "Performance test passed (${duration}s response time)" + else + log_error "Performance test failed" + return 1 + fi +} + +# Main test function +run_tests() { + log_info "Starting vLLM-omni serving functionality tests..." + log_info "Model: $MODEL_PATH" + log_info "Port: $PORT" + log_info "Host: $HOST" + echo "==========================================" + + # Run all tests + check_model + check_environment + test_imports + start_server + + # Wait a bit more for server to be fully ready + sleep 5 + + test_health + test_info + test_generation + test_performance + test_api_client + test_ar_dit_pipeline + + # Note: Simple usage test is commented out as it starts its own server + # test_simple_usage + + log_success "All tests completed successfully!" + echo "==========================================" + log_info "Test Summary:" + log_info "✅ Model loading and server startup" + log_info "✅ Health and info endpoints" + log_info "✅ Text generation functionality" + log_info "✅ Performance metrics" + log_info "✅ API client integration" + log_info "✅ AR → DiT diffusers pipeline" + log_info "✅ All imports working correctly" +} + +# Show usage +show_usage() { + echo "Usage: $0 [model_path] [port]" + echo "" + echo "Arguments:" + echo " model_path Path to the model directory (default: ./models/Qwen3-0.6B)" + echo " port Port to run the server on (default: 8000)" + echo "" + echo "Examples:" + echo " $0 # Use default model and port" + echo " $0 ./models/Qwen3-0.6B 8001 # Use specific model and port" + echo " $0 Qwen/Qwen3-0.6B 8000 # Use HuggingFace model" +} + +# Main execution +main() { + # Check for help flag + if [[ "$1" == "-h" || "$1" == "--help" ]]; then + show_usage + exit 0 + fi + + # Change to script directory + cd "$(dirname "$0")/.." + + # Create logs directory if it doesn't exist + mkdir -p logs + + # Run tests + run_tests +} + +# Run main function with all arguments +main "$@" diff --git a/tests/conftest.py b/tests/conftest.py index 8d40a3e138..cdec372017 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,118 +1,82 @@ """ -Pytest configuration and fixtures for vLLM-omni tests. +Shared fixtures and configuration for vLLM-omni tests. """ import pytest -import asyncio -from typing import AsyncGenerator, Generator -from unittest.mock import Mock, AsyncMock - import torch -import numpy as np - -from vllm_omni.configs import load_config +from unittest.mock import Mock +from vllm_omni.config import OmniStageConfig, DiTConfig, DiTCacheConfig, create_ar_stage_config, create_dit_stage_config @pytest.fixture(scope="session") -def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: - """Create an instance of the default event loop for the test session.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - -@pytest.fixture -async def mock_async_llm() -> AsyncMock: - """Mock AsyncLLM for testing.""" - mock = AsyncMock() - mock.process.return_value = {"output": "test_output", "hidden_states": None} - return mock - - -@pytest.fixture -async def mock_omni_engine() -> AsyncMock: - """Mock OmniEngine for testing.""" - mock = AsyncMock() - mock.execute.return_value = {"result": "test_result"} - return mock - - -@pytest.fixture -def sample_config() -> dict: - """Sample configuration for testing.""" - return { - "general": { - "name": "vllm-omni-test", - "version": "0.1.0", - "debug": True, - "log_level": "DEBUG" - }, - "model": { - "device": "cpu", - "dtype": "float32", - "max_model_len": 1024 - }, - "engines": { - "ar_engine": {"enabled": True, "max_batch_size": 32}, - "diffusion_engine": {"enabled": True, "max_batch_size": 8} - } - } +def device(): + """Get available device for testing.""" + return "cuda" if torch.cuda.is_available() else "cpu" @pytest.fixture -def sample_text_input() -> str: - """Sample text input for testing.""" - return "Hello, world! This is a test input for vLLM-omni." +def sample_ar_stage_config(): + """Sample AR stage configuration for testing.""" + return create_ar_stage_config( + stage_id=0, + model_path="test-ar-model", + input_modalities=["text"], + output_modalities=["text"] + ) @pytest.fixture -def sample_image_input() -> np.ndarray: - """Sample image input for testing.""" - return np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) +def sample_dit_stage_config(): + """Sample DiT stage configuration for testing.""" + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=10, + guidance_scale=7.5 + ) + + return create_dit_stage_config( + stage_id=1, + model_path="test-dit-model", + input_modalities=["text"], + output_modalities=["image"], + dit_config=dit_config + ) @pytest.fixture -def sample_audio_input() -> np.ndarray: - """Sample audio input for testing.""" - return np.random.randn(16000).astype(np.float32) # 1 second at 16kHz +def sample_stage_configs(sample_ar_stage_config, sample_dit_stage_config): + """Sample stage configurations for testing.""" + return [sample_ar_stage_config, sample_dit_stage_config] @pytest.fixture -def sample_multimodal_input() -> dict: - """Sample multimodal input for testing.""" - return { - "text": "Describe this image", - "image": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), - "audio": np.random.randn(16000).astype(np.float32) - } +def mock_vllm_config(): + """Mock vLLM configuration.""" + config = Mock() + config.model = "test-model" + config.tensor_parallel_size = 1 + config.pipeline_parallel_size = 1 + return config @pytest.fixture -def mock_torch_device() -> str: - """Mock torch device for testing.""" - if torch.cuda.is_available(): - return "cuda" - elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - return "mps" - else: - return "cpu" - - -# Pytest markers -pytest_plugins = [] - - -def pytest_configure(config): - """Configure pytest with custom markers.""" - config.addinivalue_line( - "markers", "unit: mark test as a unit test" - ) - config.addinivalue_line( - "markers", "integration: mark test as an integration test" - ) - config.addinivalue_line( - "markers", "benchmark: mark test as a benchmark test" - ) - config.addinivalue_line( - "markers", "slow: mark test as slow running" - ) +def mock_dit_cache_config(): + """Mock DiT cache configuration.""" + from vllm_omni.config import DiTCacheTensor + + cache_tensors = [ + DiTCacheTensor( + name="test_tensor", + shape=[1, 512, 512], + dtype="float32", + persistent=True + ) + ] + + return DiTCacheConfig( + cache_tensors=cache_tensors, + max_cache_size=1024 * 1024 * 1024, # 1GB + cache_strategy="fifo", + enable_optimization=True + ) \ No newline at end of file diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000000..1d8229156e --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,179 @@ +""" +Unit tests for configuration modules. +""" + +import pytest +from vllm_omni.config import ( + OmniStageConfig, + DiTConfig, + DiTCacheConfig, + DiTCacheTensor, + create_ar_stage_config, + create_dit_stage_config, +) + + +class TestOmniStageConfig: + def test_ar_stage_config_creation(self): + """Test AR stage configuration creation.""" + config = OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-model", + input_modalities=["text"], + output_modalities=["text"] + ) + assert config.engine_type == "AR" + assert config.stage_id == 0 + assert config.model_path == "test-model" + assert config.input_modalities == ["text"] + assert config.output_modalities == ["text"] + + def test_dit_stage_config_creation(self): + """Test DiT stage configuration creation.""" + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50 + ) + config = OmniStageConfig( + stage_id=1, + engine_type="DiT", + model_path="test-dit-model", + input_modalities=["text"], + output_modalities=["image"], + dit_config=dit_config + ) + assert config.engine_type == "DiT" + assert config.dit_config is not None + assert config.dit_config.model_type == "dit" + + def test_invalid_engine_type(self): + """Test validation of engine type.""" + with pytest.raises(ValueError): + OmniStageConfig( + stage_id=0, + engine_type="INVALID", + model_path="test-model", + input_modalities=["text"], + output_modalities=["text"] + ) + + def test_empty_modalities(self): + """Test validation of empty modalities.""" + with pytest.raises(ValueError): + OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-model", + input_modalities=[], + output_modalities=["text"] + ) + + with pytest.raises(ValueError): + OmniStageConfig( + stage_id=0, + engine_type="AR", + model_path="test-model", + input_modalities=["text"], + output_modalities=[] + ) + + def test_dit_requires_config(self): + """Test that DiT engine requires dit_config.""" + with pytest.raises(ValueError): + OmniStageConfig( + stage_id=1, + engine_type="DiT", + model_path="test-dit-model", + input_modalities=["text"], + output_modalities=["image"] + ) + + +class TestDiTConfig: + def test_dit_config_defaults(self): + """Test DiT configuration with defaults.""" + config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50 + ) + assert config.use_diffusers is False + assert config.diffusers_pipeline is None + assert config.guidance_scale == 7.5 + assert config.height == 512 + assert config.width == 512 + + def test_diffusers_config(self): + """Test DiT configuration with diffusers.""" + config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50, + use_diffusers=True, + diffusers_pipeline="stable-diffusion" + ) + assert config.use_diffusers is True + assert config.diffusers_pipeline == "stable-diffusion" + + +class TestDiTCacheConfig: + def test_cache_config_creation(self): + """Test cache configuration creation.""" + cache_tensors = [ + DiTCacheTensor( + name="test_tensor", + shape=[1, 512, 512], + dtype="float32", + persistent=True + ) + ] + + config = DiTCacheConfig( + cache_tensors=cache_tensors, + max_cache_size=1024 * 1024 * 1024, + cache_strategy="fifo" + ) + + assert len(config.cache_tensors) == 1 + assert config.max_cache_size == 1024 * 1024 * 1024 + assert config.cache_strategy == "fifo" + assert config.enable_optimization is True + + +class TestHelperFunctions: + def test_create_ar_stage_config(self): + """Test create_ar_stage_config helper function.""" + config = create_ar_stage_config( + stage_id=0, + model_path="test-model" + ) + + assert config.stage_id == 0 + assert config.engine_type == "AR" + assert config.model_path == "test-model" + assert config.input_modalities == ["text"] + assert config.output_modalities == ["text"] + + def test_create_dit_stage_config(self): + """Test create_dit_stage_config helper function.""" + dit_config = DiTConfig( + model_type="dit", + scheduler_type="ddpm", + num_inference_steps=50 + ) + + config = create_dit_stage_config( + stage_id=1, + model_path="test-dit-model", + dit_config=dit_config + ) + + assert config.stage_id == 1 + assert config.engine_type == "DiT" + assert config.model_path == "test-dit-model" + assert config.input_modalities == ["text"] + assert config.output_modalities == ["image"] + assert config.dit_config is not None + diff --git a/vllm_omni/__init__.py b/vllm_omni/__init__.py index 408545dfd9..fb544b0faa 100644 --- a/vllm_omni/__init__.py +++ b/vllm_omni/__init__.py @@ -14,9 +14,14 @@ __email__ = "hsliuustc@gmail.com" # Main entry points -from .async_llm import AsyncLLM -from .omni_engine import OmniEngine -from .diffusion_engine import DiffusionEngine +from .entrypoints.omni_llm import OmniLLM, AsyncOmniLLM +from .config import ( + OmniStageConfig, + DiTConfig, + DiTCacheConfig, + create_ar_stage_config, + create_dit_stage_config, +) __all__ = [ # Version info @@ -25,9 +30,15 @@ "__email__", # Main components - "AsyncLLM", - "OmniEngine", - "DiffusionEngine", + "OmniLLM", + "AsyncOmniLLM", + + # Configuration + "OmniStageConfig", + "DiTConfig", + "DiTCacheConfig", + "create_ar_stage_config", + "create_dit_stage_config", # All other components are available through their respective modules # processors.*, schedulers.*, executors.*, etc. diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py index 5174e4a045..e48a874a7c 100644 --- a/vllm_omni/config/__init__.py +++ b/vllm_omni/config/__init__.py @@ -1,22 +1,21 @@ """ -Configuration management for vLLM-omni. +Configuration module for vLLM-omni. """ -from dataclasses import dataclass, field -from pathlib import Path -from typing import Dict, Any, Optional -from vllm_omni.dit_cache_interface import DiTCacheConfig +from .stage_config import ( + OmniStageConfig, + DiTConfig, + DiTCacheConfig, + DiTCacheTensor, + create_ar_stage_config, + create_dit_stage_config, +) -from vllm.config import VllmConfig - - -@dataclass -class OmniConfig: - """ - The configuration for vLLM-omni. - """ - - """vllm config""" - vllm_config: VllmConfig = field(default_factory=VllmConfig) - """DiT cache config""" - dit_cache_config: DiTCacheConfig = field(default_factory=DiTCacheConfig) # DiT cache config \ No newline at end of file +__all__ = [ + "OmniStageConfig", + "DiTConfig", + "DiTCacheConfig", + "DiTCacheTensor", + "create_ar_stage_config", + "create_dit_stage_config", +] \ No newline at end of file diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py new file mode 100644 index 0000000000..8cca98c63e --- /dev/null +++ b/vllm_omni/config/stage_config.py @@ -0,0 +1,166 @@ +""" +Stage configuration for vLLM-omni multi-stage processing. +""" + +from dataclasses import dataclass +from typing import List, Optional, Type, Literal, Any, Dict +from vllm.config import ( + VllmConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + CompilationConfig, +) +from vllm.executor.executor_base import ExecutorBase as Executor + +@dataclass +class DiTCacheTensor: + """Configuration for DiT cache tensors.""" + name: str + shape: List[int] + dtype: str = "float32" + persistent: bool = True + + +@dataclass +class DiTCacheConfig: + """Configuration for DiT caching system.""" + cache_tensors: List[DiTCacheTensor] + max_cache_size: int = 1024 * 1024 * 1024 # 1GB + cache_strategy: str = "fifo" # fifo, lru, lfu + enable_optimization: bool = True + cache_compression: bool = False + + +@dataclass +class DiTConfig: + """Configuration for DiT (Diffusion Transformer) stages.""" + + num_inference_steps: int + """Number of diffusion inference steps.""" + + model_config: Optional[ModelConfig] = None + """Model configuration.""" + scheduler_config: Optional[Any] = None + """Scheduler configuration.""" + device_config: Optional[DeviceConfig] = None + """Device configuration.""" + load_config: Optional[LoadConfig] = None + """Model loading configuration.""" + compilation_config: Optional[CompilationConfig] = None + """Compilation configuration.""" + dit_cache_config: Optional[DiTCacheConfig] = None + """DiT cache configuration.""" + guidance_scale: float = 7.5 + use_diffusers: bool = False + diffusers_pipeline: Optional[str] = None + height: int = 512 + width: int = 512 + batch_size: int = 1 + + +@dataclass +class OmniStageConfig: + """Configuration for a processing stage in vLLM-omni.""" + + stage_id: int + engine_type: Literal["AR", "DiT"] + model_path: str + input_modalities: List[str] + output_modalities: List[str] + vllm_config: Optional[VllmConfig] = None + dit_config: Optional[DiTConfig] = None + executor_class: Type[Executor] = None # Will be set based on engine_type + default_stage_args: Optional[Dict[str, Any]] = None + stage_output: Optional[Any] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if self.engine_type not in ["AR", "DiT"]: + raise ValueError(f"Invalid engine_type: {self.engine_type}. Must be 'AR' or 'DiT'") + + if self.engine_type == "DiT" and self.dit_config is None: + raise ValueError("DiT engine requires dit_config") + + if not self.input_modalities: + raise ValueError("input_modalities cannot be empty") + + if not self.output_modalities: + raise ValueError("output_modalities cannot be empty") + + def get_executor_class(self) -> Type[Executor]: + """Get the appropriate executor class for this stage.""" + if self.executor_class is not None: + return self.executor_class + + if self.engine_type == "AR": + from vllm.executor.uniproc_executor import UniProcExecutor + return UniProcExecutor + elif self.engine_type == "DiT": + if self.dit_config and self.dit_config.use_diffusers: + from vllm_omni.executor.diffusers_executor import DiffusersPipelineExecutor + return DiffusersPipelineExecutor + else: + from vllm.executor.uniproc_executor import UniProcExecutor + return UniProcExecutor + + raise ValueError(f"No executor class available for engine_type: {self.engine_type}") + + +def create_ar_stage_config( + stage_id: int, + model_path: str, + input_modalities: List[str] = None, + output_modalities: List[str] = None, + vllm_config: Optional[VllmConfig] = None, + executor_class: Optional[Type[Executor]] = None, + default_stage_args: Optional[Dict[str, Any]] = None, +) -> OmniStageConfig: + """Create a configuration for an AR (Autoregressive) stage.""" + if input_modalities is None: + input_modalities = ["text"] + if output_modalities is None: + output_modalities = ["text"] + + # For now, we'll create minimal configs + # vllm_config and executor_class will be handled by the LLM class + return OmniStageConfig( + stage_id=stage_id, + engine_type="AR", + model_path=model_path, + input_modalities=input_modalities, + output_modalities=output_modalities, + vllm_config=vllm_config, + executor_class=executor_class, + default_stage_args=default_stage_args, + ) + + +def create_dit_stage_config( + stage_id: int, + model_path: str, + input_modalities: Optional[List[str]] = None, + output_modalities: Optional[List[str]] = None, + dit_config: Optional[DiTConfig] = None, + executor_class: Optional[Type[Executor]] = None, + default_stage_args: Optional[Dict[str, Any]] = None, +) -> OmniStageConfig: + """Create a configuration for a DiT (Diffusion Transformer) stage.""" + if input_modalities is None: + input_modalities = ["text"] + if output_modalities is None: + output_modalities = ["image"] + + if dit_config is None: + dit_config = DiTConfig(num_inference_steps=50) + + return OmniStageConfig( + stage_id=stage_id, + engine_type="DiT", + model_path=model_path, + input_modalities=input_modalities, + output_modalities=output_modalities, + dit_config=dit_config, + executor_class=executor_class, + default_stage_args=default_stage_args, + ) diff --git a/vllm_omni/core/dit_cache_manager.py b/vllm_omni/core/dit_cache_manager.py index e69de29bb2..d2c800e76f 100644 --- a/vllm_omni/core/dit_cache_manager.py +++ b/vllm_omni/core/dit_cache_manager.py @@ -0,0 +1,178 @@ +""" +DiT Cache Manager for vLLM-omni. + +This module provides caching functionality for DiT (Diffusion Transformer) models +to optimize inference performance and memory usage. +""" + +import time +import torch +from typing import Dict, List, Optional, Any +from dataclasses import dataclass +from ..config import DiTCacheConfig, DiTCacheTensor + + +class DiTCacheManager: + """Manages DiT-specific caching for optimized inference.""" + + def __init__(self, config: DiTCacheConfig): + self.config = config + self.cache_tensors: Dict[str, DiTCacheTensor] = {} + self.cache_groups: List[DiTCacheTensor] = config.cache_tensors + self.max_cache_size = config.max_cache_size + self.current_cache_size = 0 + self.cache_hits = 0 + self.cache_misses = 0 + self.cache_timeout = 3600 # 1 hour default timeout + + def allocate_cache(self, request_id: str, size: int) -> torch.Tensor: + """Allocate cache for a specific request.""" + # Check if we have enough space + if self.current_cache_size + size > self.max_cache_size: + self._evict_cache() + + # Create tensor + tensor = torch.zeros(size, dtype=torch.float32) + + # Store in cache + cache_tensor = DiTCacheTensor( + name=request_id, + tensor=tensor, + timestamp=time.time(), + access_count=0, + size_bytes=size * 4 # Assuming float32 + ) + + self.cache_tensors[request_id] = cache_tensor + self.current_cache_size += cache_tensor.size_bytes + + return tensor + + def get_cache(self, request_id: str) -> Optional[torch.Tensor]: + """Get cached tensor for a request.""" + if request_id in self.cache_tensors: + cache_tensor = self.cache_tensors[request_id] + cache_tensor.access_count += 1 + self.cache_hits += 1 + return cache_tensor.tensor + else: + self.cache_misses += 1 + return None + + def store_cache(self, request_id: str, tensor: torch.Tensor) -> None: + """Store a tensor in the cache.""" + if request_id in self.cache_tensors: + # Update existing cache + old_tensor = self.cache_tensors[request_id] + self.current_cache_size -= old_tensor.size_bytes + + new_size_bytes = tensor.numel() * 4 # Assuming float32 + self.current_cache_size += new_size_bytes + + self.cache_tensors[request_id] = DiTCacheTensor( + name=request_id, + tensor=tensor, + timestamp=time.time(), + access_count=old_tensor.access_count, + size_bytes=new_size_bytes + ) + else: + # Create new cache entry + new_size_bytes = tensor.numel() * 4 + if self.current_cache_size + new_size_bytes > self.max_cache_size: + self._evict_cache() + + self.cache_tensors[request_id] = DiTCacheTensor( + name=request_id, + tensor=tensor, + timestamp=time.time(), + access_count=0, + size_bytes=new_size_bytes + ) + self.current_cache_size += new_size_bytes + + def release_cache(self, request_id: str) -> None: + """Release cache for a request.""" + if request_id in self.cache_tensors: + cache_tensor = self.cache_tensors[request_id] + self.current_cache_size -= cache_tensor.size_bytes + del self.cache_tensors[request_id] + + def clear_expired_cache(self) -> None: + """Clear expired cache entries.""" + current_time = time.time() + expired_keys = [] + + for request_id, cache_tensor in self.cache_tensors.items(): + if current_time - cache_tensor.timestamp > self.cache_timeout: + expired_keys.append(request_id) + + for request_id in expired_keys: + self.release_cache(request_id) + + def _evict_cache(self) -> None: + """Evict cache entries based on strategy.""" + if self.config.cache_strategy == "fifo": + self._evict_fifo() + elif self.config.cache_strategy == "lru": + self._evict_lru() + elif self.config.cache_strategy == "lfu": + self._evict_lfu() + else: + self._evict_fifo() # Default to FIFO + + def _evict_fifo(self) -> None: + """Evict cache entries using FIFO strategy.""" + if not self.cache_tensors: + return + + # Find oldest entry + oldest_key = min(self.cache_tensors.keys(), + key=lambda k: self.cache_tensors[k].timestamp) + self.release_cache(oldest_key) + + def _evict_lru(self) -> None: + """Evict cache entries using LRU strategy.""" + if not self.cache_tensors: + return + + # Find least recently used entry + lru_key = min(self.cache_tensors.keys(), + key=lambda k: self.cache_tensors[k].access_count) + self.release_cache(lru_key) + + def _evict_lfu(self) -> None: + """Evict cache entries using LFU strategy.""" + if not self.cache_tensors: + return + + # Find least frequently used entry + lfu_key = min(self.cache_tensors.keys(), + key=lambda k: self.cache_tensors[k].access_count) + self.release_cache(lfu_key) + + def get_statistics(self) -> Dict[str, Any]: + """Get cache statistics.""" + total_requests = self.cache_hits + self.cache_misses + hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0 + + return { + "cache_size": self.current_cache_size, + "max_cache_size": self.max_cache_size, + "cache_utilization": self.current_cache_size / self.max_cache_size, + "cache_hits": self.cache_hits, + "cache_misses": self.cache_misses, + "hit_rate": hit_rate, + "num_cached_tensors": len(self.cache_tensors) + } + + def clear_all_cache(self) -> None: + """Clear all cache entries.""" + self.cache_tensors.clear() + self.current_cache_size = 0 + self.cache_hits = 0 + self.cache_misses = 0 + + def set_cache_timeout(self, timeout: float) -> None: + """Set cache timeout in seconds.""" + self.cache_timeout = timeout \ No newline at end of file diff --git a/vllm_omni/core/sched/scheduler.py b/vllm_omni/core/sched/scheduler.py index e69de29bb2..5bbbfeb05c 100644 --- a/vllm_omni/core/sched/scheduler.py +++ b/vllm_omni/core/sched/scheduler.py @@ -0,0 +1,40 @@ + +from vllm_omni.request import OmniRequest +from typing import List +from threading import Lock, Condition +from vllm_omni.config import OmniConfig + +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.core.sched import SchedulerInterface +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.structured_output import StructuredOutputManager + + +class OmniScheduler(SchedulerInterface): + """ + OmniScheduler: Scheduler for vLLM-omni multimodal processing. + + This scheduler extends vLLM's scheduler to support multimodal and non-autoregressive + processing with additional fields and methods specific to vLLM-omni. + """ + def __init__(self, + omni_config: OmniConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ): + super().__init__( + vllm_config=omni_config.vllm_config, + kv_cache_config=kv_cache_config, + multimodal_registry=mm_registry, + structured_output_manager=structured_output_manager, + include_finished_set=include_finished_set, + log_stats=log_stats, + ) + self.omni_config = omni_config + + def schedule(self, requests: List[OmniRequest]) -> List[OmniRequest]: + # TODO: Implement scheduling logic + pass \ No newline at end of file diff --git a/vllm_omni/dit_cache_interface.py b/vllm_omni/dit_cache_interface.py deleted file mode 100644 index 260acac5f0..0000000000 --- a/vllm_omni/dit_cache_interface.py +++ /dev/null @@ -1,24 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class DiTCacheTensor: - """ - A class for specifying how the workers should initialize the DiT cache. - """ - size: int # the size of the cache tensor in bytes - - -@dataclass -class DiTCacheConfig: - """ - The DiT cache configuration of a model. - """ - """How should model runner initialize the KV cache tensors for each layer""" - dit_cache_tensors: list[DiTCacheTensor] - """ - The DiT cache groups of the model. - For models with only one type of DiT, there is only one group that - contains all layers. - """ - kv_cache_groups: list[DiTCacheTensor] \ No newline at end of file diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index 3768b785ee..406c26b399 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -1,18 +1,9 @@ """ -Diffusion Engine: Diffusion Transformer (DiT) processing modules. - -This module provides specialized processing for diffusion models including -step management, cache management, and model wrapping. +Engine components for vLLM-omni. """ -from .step_manager import DiffusionStepManager -from .cache_manager import DiffusionCacheManager -from .models import DiffusionModel -from .base import BaseDiffusionEngine +from .output_processor import MultimodalOutputProcessor __all__ = [ - "DiffusionStepManager", - "DiffusionCacheManager", - "DiffusionModel", - "BaseDiffusionEngine", + "MultimodalOutputProcessor", ] diff --git a/vllm_omni/engine/diffusion_engine.py b/vllm_omni/engine/diffusion_engine.py new file mode 100644 index 0000000000..ef8a5b51d0 --- /dev/null +++ b/vllm_omni/engine/diffusion_engine.py @@ -0,0 +1,133 @@ +"""Diffusion engine adapters for diffusers-backed DiT stages.""" + +from __future__ import annotations + +from typing import Any, Optional + +from ..config import DiTConfig +from ..worker.gpu_diffusion_worker import ( + DiffusionGPUWorker, + DiffusionRunnerOutput, +) + + +class DiffusionEngine: + """Base diffusion engine storing shared DiT configuration.""" + + def __init__( + self, + dit_config: DiTConfig, + model_path: Optional[str] = None, + log_stats: bool = False, + multiprocess_mode: bool = False, + ) -> None: + self.dit_config = dit_config + + model_cfg = getattr(dit_config, "model_config", None) + config_model_path = None + if model_cfg is not None: + if isinstance(model_cfg, dict): + config_model_path = ( + model_cfg.get("model") + or model_cfg.get("model_path") + ) + else: + config_model_path = ( + getattr(model_cfg, "model", None) + or getattr(model_cfg, "model_path", None) + ) + + self.model_path = model_path or config_model_path + self.model_config = dit_config.model_config + self.cache_config = dit_config.dit_cache_config + + self.height = dit_config.height + self.width = dit_config.width + self.num_inference_steps = dit_config.num_inference_steps + self.guidance_scale = dit_config.guidance_scale + + self.log_stats = log_stats + self.multiprocess_mode = multiprocess_mode + + def generate(self, *args, **kwargs): # pragma: no cover - placeholder + raise NotImplementedError("This method should be implemented in subclasses.") + + +class DiffusersPipelineEngine(DiffusionEngine): + """Adapter that invokes a diffusers pipeline via DiffusionGPUWorker.""" + + def __init__( + self, + dit_config: DiTConfig, + model_path: Optional[str] = None, + log_stats: bool = False, + multiprocess_mode: bool = False, + ) -> None: + super().__init__(dit_config, model_path, log_stats, multiprocess_mode) + + resolved_model_path = self.model_path + if resolved_model_path is None: + raise ValueError("Model path must be provided for DiffusersPipelineEngine.") + + device_cfg = getattr(dit_config, "device_config", None) + model_cfg = getattr(dit_config, "model_config", None) + + device = None + dtype = None + + if device_cfg: + if isinstance(device_cfg, dict): + device = device_cfg.get("device") + dtype = device_cfg.get("dtype") + else: + device = getattr(device_cfg, "device", None) + dtype = getattr(device_cfg, "dtype", None) + + if dtype is None and model_cfg: + if isinstance(model_cfg, dict): + dtype = model_cfg.get("dtype") + else: + dtype = getattr(model_cfg, "dtype", None) + + self.worker = DiffusionGPUWorker( + resolved_model_path, + pipeline_name=dit_config.diffusers_pipeline, + device=device, + dtype=dtype, + ) + + def generate( + self, + prompt: str, + *, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + negative_prompt: Optional[str] = None, + seed: Optional[int] = None, + image: Optional[Any] = None, + ) -> DiffusionRunnerOutput: + height = int(height) if height is not None else self.height + width = int(width) if width is not None else self.width + num_steps = ( + int(num_inference_steps) + if num_inference_steps is not None + else self.num_inference_steps + ) + guidance = ( + float(guidance_scale) + if guidance_scale is not None + else self.guidance_scale + ) + + return self.worker.generate( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_steps, + guidance_scale=guidance, + negative_prompt=negative_prompt, + seed=seed, + image=image, + ) diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index e69de29bb2..8f5fcb3571 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -0,0 +1,215 @@ +""" +Output processing for multimodal outputs in vLLM-omni. +""" + +from typing import List, Dict, Any, Optional, Callable, Union +from vllm.outputs import RequestOutput, CompletionOutput +from vllm.v1.outputs import ModelRunnerOutput as EngineCoreOutput + + +class MultimodalOutputProcessor: + """Handles multimodal output processing for vLLM-omni.""" + + def __init__(self): + self.output_handlers: Dict[str, Callable] = { + "image": self._process_image_output, + "text+image": self._process_text_image_output, + "latents": self._process_latents_output, + "text": self._process_text_output, + "pooling": self._process_pooling_output, + } + + def process_output(self, engine_core_output: Any) -> List[RequestOutput]: + """Process engine core output and return formatted RequestOutput.""" + if engine_core_output is None: + return [] + + # If it's already a RequestOutput, return as is + if isinstance(engine_core_output, RequestOutput): + return [engine_core_output] + + # If it's a list of RequestOutputs, return as is + if isinstance(engine_core_output, list): + return engine_core_output + + # Otherwise, process based on output type + output_type = self._detect_output_type(engine_core_output) + handler = self.output_handlers.get(output_type, self._process_pooling_output) + + return handler(engine_core_output) + + def _build_request_output( + self, + source: Any, + completion_outputs: List[CompletionOutput], + ) -> RequestOutput: + """Helper to construct RequestOutput with safe defaults.""" + + return RequestOutput( + request_id=getattr(source, "request_id", "unknown"), + prompt=getattr(source, "prompt", ""), + prompt_token_ids=getattr(source, "prompt_token_ids", []), + prompt_logprobs=getattr(source, "prompt_logprobs", None), + outputs=completion_outputs, + finished=getattr(source, "finished", True), + ) + + def _detect_output_type(self, output: Any) -> str: + """Detect the type of output based on its content.""" + if hasattr(output, 'output_type'): + return output.output_type + + # Check for image-related attributes + if hasattr(output, 'image') or hasattr(output, 'images'): + if hasattr(output, 'text') or hasattr(output, 'texts'): + return "text+image" + else: + return "image" + + # Check for latent-related attributes + if hasattr(output, 'latents') or hasattr(output, 'latent_representation'): + return "latents" + + # Check for pooling output + if hasattr(output, 'pooler_output') and output.pooler_output is not None: + return "pooling" + + # Default to text + return "text" + + def _process_text_output(self, output: Any) -> List[RequestOutput]: + """Process text output.""" + if isinstance(output, RequestOutput): + return [output] + + # Create a mock RequestOutput for text + completion_output = CompletionOutput( + index=0, + text=getattr(output, 'text', ''), + token_ids=getattr(output, 'token_ids', []), + cumulative_logprob=getattr(output, 'cumulative_logprob', 0.0), + logprobs=getattr(output, 'logprobs', None), + finish_reason=getattr(output, 'finish_reason', 'length') + ) + + return [self._build_request_output(output, [completion_output])] + + def _process_image_output(self, output: Any) -> List[RequestOutput]: + """Process image output.""" + # For image outputs, we need to create a special RequestOutput + # that can handle image data + + # Extract image data + image_data = getattr(output, 'image', None) + if image_data is None: + image_data = getattr(output, 'images', [None])[0] + + # Create a completion output with image data + completion_output = CompletionOutput( + index=0, + text="", # No text for pure image output + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop" + ) + + # Add image data to the completion output + completion_output.image = image_data + + return [ + self._build_request_output(output, [completion_output]) + ] + + def _process_text_image_output(self, output: Any) -> List[RequestOutput]: + """Process combined text and image output.""" + # Extract text and image data + text_data = getattr(output, 'text', '') + image_data = getattr(output, 'image', None) + + if image_data is None: + image_data = getattr(output, 'images', [None])[0] + + # Create a completion output with both text and image + completion_output = CompletionOutput( + index=0, + text=text_data, + token_ids=getattr(output, 'token_ids', []), + cumulative_logprob=getattr(output, 'cumulative_logprob', 0.0), + logprobs=getattr(output, 'logprobs', None), + finish_reason="stop" + ) + + # Add image data to the completion output + completion_output.image = image_data + + return [ + self._build_request_output(output, [completion_output]) + ] + + def _process_latents_output(self, output: Any) -> List[RequestOutput]: + """Process latent representation output.""" + # Extract latent data + latent_data = getattr(output, 'latents', None) + if latent_data is None: + latent_data = getattr(output, 'latent_representation', None) + + # Create a completion output with latent data + completion_output = CompletionOutput( + index=0, + text="", # No text for latent output + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop" + ) + + # Add latent data to the completion output + completion_output.latents = latent_data + + return [ + self._build_request_output(output, [completion_output]) + ] + + def _process_pooling_output(self, output: Any) -> List[RequestOutput]: + """Process pooling output (hidden states, embeddings, etc.).""" + # Extract pooling data + pooling_data = getattr(output, 'pooler_output', None) + if pooling_data is None: + pooling_data = getattr(output, 'hidden_states', None) + + # Create a completion output with pooling data + completion_output = CompletionOutput( + index=0, + text="", # No text for pooling output + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop" + ) + + # Add pooling data to the completion output + completion_output.pooler_output = pooling_data + + return [ + self._build_request_output(output, [completion_output]) + ] + + def process_outputs(self, engine_core_outputs: List[EngineCoreOutput], **kwargs) -> List[RequestOutput]: + """Process multiple engine core outputs.""" + all_outputs = [] + + for engine_core_output in engine_core_outputs: + outputs = self.process_output(engine_core_output) + all_outputs.extend(outputs) + + return all_outputs + + def add_output_handler(self, output_type: str, handler: Callable) -> None: + """Add a custom output handler for a specific output type.""" + self.output_handlers[output_type] = handler + + def remove_output_handler(self, output_type: str) -> None: + """Remove an output handler for a specific output type.""" + if output_type in self.output_handlers: + del self.output_handlers[output_type] diff --git a/vllm_omni/entrypoints/api_server.py b/vllm_omni/entrypoints/api_server.py new file mode 100644 index 0000000000..af14fbf08d --- /dev/null +++ b/vllm_omni/entrypoints/api_server.py @@ -0,0 +1,146 @@ +""" +API server for vLLM-omni. +""" + +import asyncio +from typing import Dict, Any, List +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +import uvicorn + +from .omni_llm import AsyncOmniLLM + + +class GenerateRequest(BaseModel): + """Request model for generation.""" + prompts: List[str] + max_tokens: int = 100 + temperature: float = 0.7 + stage_args: List[Dict[str, Any]] = None + + +class GenerateResponse(BaseModel): + """Response model for generation.""" + outputs: List[Dict[str, Any]] + stage_outputs: List[Dict[str, Any]] = None + + +app = FastAPI( + title="vLLM-omni API", + description="Multi-modality models inference and serving", + version="0.1.0" +) + +# Global omni_llm instance +omni_llm: AsyncOmniLLM = None + + +@app.on_event("startup") +async def startup_event(): + """Initialize the omni_llm instance on startup.""" + global omni_llm + # This will be set by the run_server function + pass + + +@app.post("/generate", response_model=GenerateResponse) +async def generate(request: GenerateRequest): + """Generate text or multimodal content.""" + try: + if omni_llm is None: + raise HTTPException(status_code=500, detail="OmniLLM not initialized") + + # Prepare stage arguments + if request.stage_args is None: + # Create default stage arguments - one per stage config + # For now, we'll process all prompts in the first stage + stage_args = [{ + "prompt": " ".join(request.prompts) if request.prompts else "", + "max_tokens": request.max_tokens, + "temperature": request.temperature + }] + else: + stage_args = request.stage_args + + # Generate using omni_llm + try: + outputs = await omni_llm.generate_async(stage_args) + except Exception as e: + # Fallback to synchronous generation + outputs = omni_llm.generate(stage_args) + + # Convert outputs to response format + response_outputs = [] + for output in outputs: + if hasattr(output, 'outputs') and output.outputs: + for out in output.outputs: + response_outputs.append({ + "text": getattr(out, 'text', ''), + "finished": getattr(out, 'finish_reason', 'length') != 'length', + "tokens": getattr(out, 'token_ids', []) + }) + else: + response_outputs.append({ + "text": "", + "finished": True, + "tokens": [] + }) + + return GenerateResponse( + outputs=response_outputs, + stage_outputs=[{"stage": i, "output": "processed"} for i in range(len(stage_args))] + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "service": "vllm-omni"} + + +@app.get("/info") +async def get_info(): + """Get information about the service.""" + if omni_llm is None: + return {"error": "OmniLLM not initialized"} + + return { + "service": "vllm-omni", + "version": "0.1.0", + "num_stages": omni_llm.get_num_stages() if hasattr(omni_llm, 'get_num_stages') else 0, + "stage_configs": [ + { + "stage_id": config.stage_id, + "engine_type": config.engine_type, + "model_path": config.model_path, + "input_modalities": config.input_modalities, + "output_modalities": config.output_modalities + } + for config in omni_llm.stage_configs + ] + } + + +async def run_server(omni_llm_instance: AsyncOmniLLM, host: str = "0.0.0.0", port: int = 8000): + """Run the API server.""" + global omni_llm + omni_llm = omni_llm_instance + + config = uvicorn.Config( + app=app, + host=host, + port=port, + log_level="info" + ) + + server = uvicorn.Server(config) + await server.serve() + + +if __name__ == "__main__": + # This is for testing purposes + asyncio.run(run_server(None)) diff --git a/vllm_omni/entrypoints/cli/__init__.py b/vllm_omni/entrypoints/cli/__init__.py new file mode 100644 index 0000000000..2df4defa76 --- /dev/null +++ b/vllm_omni/entrypoints/cli/__init__.py @@ -0,0 +1,5 @@ +"""CLI helpers for vLLM-omni entrypoints.""" + +from .serve import OmniServeCommand + +__all__ = ["OmniServeCommand"] diff --git a/vllm_omni/entrypoints/cli/main.py b/vllm_omni/entrypoints/cli/main.py new file mode 100644 index 0000000000..830e181a99 --- /dev/null +++ b/vllm_omni/entrypoints/cli/main.py @@ -0,0 +1,26 @@ +""" +CLI entry point for vLLM-omni that intercepts vLLM commands. +""" + +import sys +import argparse +from typing import List, Optional +from vllm_omni.entrypoints.cli.serve import OmniServeCommand + + +def main(): + """Main CLI entry point that intercepts vLLM commands.""" + # Check if --omni flag is present + if "--omni" in sys.argv: + # Remove --omni flag and "serve" command, process with vLLM-omni + omni_args = [arg for arg in sys.argv[1:] if arg not in ["--omni", "serve"]] + omni_serve = OmniServeCommand() + omni_serve.run(omni_args) + else: + # Forward to original vLLM CLI + from vllm.entrypoints.cli.main import main as vllm_main + vllm_main() + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py new file mode 100644 index 0000000000..1fdc0f79aa --- /dev/null +++ b/vllm_omni/entrypoints/cli/serve.py @@ -0,0 +1,208 @@ +""" +Omni serve command for vLLM-omni. +""" + +import argparse +import asyncio +from typing import List, Optional + +from ..omni_llm import AsyncOmniLLM +from vllm_omni.config import ( + DiTConfig, + create_ar_stage_config, + create_dit_stage_config, +) + + +class OmniServeCommand: + """Command handler for vLLM-omni serve command.""" + + def __init__(self): + self.parser = self._create_parser() + + def _create_parser(self) -> argparse.ArgumentParser: + """Create argument parser for omni serve command.""" + parser = argparse.ArgumentParser( + description="vLLM-omni: Multi-modality models inference and serving" + ) + + # Model arguments - make it optional with default + parser.add_argument( + "model", + nargs="?", + default="Qwen/Qwen3-0.6B", + help="Path to the model or model name (default: Qwen/Qwen3-0.6B)" + ) + + # Server arguments + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on" + ) + + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to run the server on" + ) + + # Stage configuration arguments + parser.add_argument( + "--ar-stage", + type=str, + help="AR stage model path" + ) + + parser.add_argument( + "--dit-stage", + type=str, + help="DiT stage model path" + ) + + parser.add_argument( + "--dit-steps", + type=int, + default=50, + help="Number of DiT inference steps" + ) + + parser.add_argument( + "--dit-guidance-scale", + type=float, + default=7.5, + help="DiT guidance scale" + ) + + parser.add_argument( + "--use-diffusers", + action="store_true", + help="Use diffusers pipeline for DiT stage" + ) + + # Other arguments + parser.add_argument( + "--log-stats", + action="store_true", + help="Enable logging statistics" + ) + + return parser + + def run(self, args: List[str]) -> None: + """Run the omni serve command.""" + parsed_args = self.parser.parse_args(args) + + # Create stage configurations + stage_configs = self._create_stage_configs(parsed_args) + + # Create AsyncOmniLLM instance + omni_llm = AsyncOmniLLM( + stage_configs=stage_configs, + log_stats=parsed_args.log_stats + ) + + # Start the server + asyncio.run(self._start_server(omni_llm, parsed_args)) + + def _create_stage_configs(self, args) -> List: + """Create stage configurations based on arguments.""" + stage_configs = [] + stage_id = 0 + + # Add AR stage - use main model if ar_stage not specified + ar_model = args.ar_stage if args.ar_stage else args.model + ar_config = create_ar_stage_config( + stage_id=stage_id, + model_path=ar_model, + input_modalities=["text"], + output_modalities=["text"] + ) + stage_configs.append(ar_config) + stage_id += 1 + + # Add DiT stage if specified + if args.dit_stage: + dit_model = args.dit_stage + elif args.use_diffusers: + # Use a default DiT model if diffusers is enabled + dit_model = "stabilityai/stable-diffusion-2-1" + else: + dit_model = None + + if dit_model: + dit_config = DiTConfig( + num_inference_steps=args.dit_steps, + guidance_scale=args.dit_guidance_scale, + use_diffusers=args.use_diffusers, + diffusers_pipeline="auto" if args.use_diffusers else None, + ) + + dit_stage_config = create_dit_stage_config( + stage_id=stage_id, + model_path=dit_model, + input_modalities=["text"], + output_modalities=["image"], + dit_config=dit_config + ) + stage_configs.append(dit_stage_config) + stage_id += 1 + + # If no specific stages are specified, use the main model + if not stage_configs: + # Try to detect if it's a multimodal model + if "omni" in args.model.lower() or "multimodal" in args.model.lower(): + # Assume it's a multimodal model that can handle both AR and DiT + ar_config = create_ar_stage_config( + stage_id=0, + model_path=args.model, + input_modalities=["text"], + output_modalities=["text"] + ) + stage_configs.append(ar_config) + + dit_config = DiTConfig( + num_inference_steps=args.dit_steps, + guidance_scale=args.dit_guidance_scale, + use_diffusers=args.use_diffusers, + diffusers_pipeline="auto" if args.use_diffusers else None, + ) + + dit_stage_config = create_dit_stage_config( + stage_id=1, + model_path=args.model, + input_modalities=["text"], + output_modalities=["image"], + dit_config=dit_config + ) + stage_configs.append(dit_stage_config) + else: + # Default to AR stage + ar_config = create_ar_stage_config( + stage_id=0, + model_path=args.model, + input_modalities=["text"], + output_modalities=["text"] + ) + stage_configs.append(ar_config) + + return stage_configs + + async def _start_server(self, omni_llm: AsyncOmniLLM, args) -> None: + """Start the API server.""" + try: + # Import here to avoid circular imports + from vllm_omni.entrypoints.api_server import run_server + + await run_server( + omni_llm_instance=omni_llm, + host=args.host, + port=args.port + ) + except KeyboardInterrupt: + print("\nShutting down server...") + except Exception as e: + print(f"Error starting server: {e}") + raise diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py new file mode 100644 index 0000000000..8223c6f20e --- /dev/null +++ b/vllm_omni/entrypoints/omni_llm.py @@ -0,0 +1,576 @@ +""" +Core OmniLLM and AsyncOmniLLM classes for multi-stage processing. +""" + +from typing import List, Dict, Any, Optional, Union, Callable +from vllm.entrypoints.llm import LLM +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.outputs import RequestOutput, LoRARequest + +from ..config import OmniStageConfig +from .stage_manager import StageManager +from ..engine.output_processor import MultimodalOutputProcessor +from ..engine.diffusion_engine import DiffusersPipelineEngine + + + +class OmniLLM(LLM): + """Extended LLM supporting multiple engines and stage-based processing.""" + + def __init__( + self, + stage_configs: List[OmniStageConfig], + log_stats: bool = False, + **kwargs + ): + # Use the first stage's model as the default model for LLM + default_model = stage_configs[0].model_path if stage_configs else "test-model" + + # Track whether we should launch engines in multiprocess mode. + self.multiprocess_mode = kwargs.pop("multiprocess_mode", False) + + # Fix configuration validation issues + # Ensure max_num_batched_tokens is at least as large as max_model_len + if 'max_model_len' in kwargs and 'max_num_batched_tokens' in kwargs: + if kwargs['max_num_batched_tokens'] < kwargs['max_model_len']: + kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] + elif 'max_model_len' in kwargs: + # If max_model_len is set but max_num_batched_tokens is not, set it to max_model_len + kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] + else: + # Set reasonable defaults to avoid validation errors + kwargs['max_model_len'] = 2048 + kwargs['max_num_batched_tokens'] = 2048 + + super().__init__(model=default_model, **kwargs) + self.stage_configs = stage_configs + self.log_stats = log_stats + self.stage_manager = StageManager(stage_configs, log_stats) + self.output_processor = MultimodalOutputProcessor() + self._initialize_stage_engines() + + def _initialize_stage_engines(self) -> None: + """Initialize LLMEngine instances for each stage.""" + self.stage_manager.initialize_engines() + + def generate( + self, + stage_args_list: Optional[List[Dict[str, Any]]] = None, + use_tqdm: Union[bool, Callable[..., Any]] = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + priority: Optional[List[int]] = None, + *, + prompt: Optional[str] = None, + stage_overrides: Optional[Dict[int, Dict[str, Any]]] = None, + ) -> List[RequestOutput]: + """Main generation interface - orchestrates multi-stage processing.""" + if stage_args_list is None: + if prompt is None: + raise ValueError( + "prompt must be provided when stage_args_list is not supplied" + ) + stage_args_list = self._build_stage_args_from_config( + prompt, stage_overrides or {} + ) + + if len(stage_args_list) != len(self.stage_configs): + raise ValueError( + f"Number of stage arguments ({len(stage_args_list)}) must match " + f"number of stage configs ({len(self.stage_configs)})" + ) + + # Process through each stage sequentially + current_output = None + + for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): + stage_engine = self.stage_manager.get_engine(i) + + # Prepare input for this stage + processed_input = self._process_stage_inputs( + stage_config, stage_args or {}, current_output + ) + + # Execute stage + stage_output = self._execute_stage( + stage_engine, processed_input, lora_request, priority, stage_config + ) + + # Update for next stage + current_output = stage_output + stage_config.stage_output = stage_output + + # Process final output + final_output = self.output_processor.process_output(current_output) + return final_output + + def _build_stage_args_from_config( + self, + prompt: str, + stage_overrides: Dict[int, Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """Derive per-stage argument dictionaries from configuration defaults.""" + stage_args: List[Dict[str, Any]] = [] + for stage_config in self.stage_configs: + combined: Dict[str, Any] = dict(stage_config.default_stage_args or {}) + override = stage_overrides.get(stage_config.stage_id) + if override: + combined.update(override) + if stage_config.engine_type == "AR": + combined["prompt"] = prompt + stage_args.append(combined) + return stage_args + + def _process_stage_inputs( + self, + stage_config: OmniStageConfig, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Prepare input for specific stage.""" + if stage_config.engine_type == "AR": + return self._process_ar_inputs(stage_config, stage_args, previous_output) + elif stage_config.engine_type == "DiT": + return self._process_dit_inputs(stage_config, stage_args, previous_output) + else: + raise NotImplementedError(f"Unknown engine type: {stage_config.engine_type}") + + def _process_ar_inputs( + self, + stage_config: OmniStageConfig, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Process inputs for AR stage.""" + combined = dict(stage_config.default_stage_args or {}) + combined.update(stage_args) + combined.setdefault("prompt", "") + combined.setdefault("max_tokens", 100) + combined.setdefault("temperature", 0.7) + + # If we have previous output (e.g., from a previous AR stage), + # we might want to use it as context + if previous_output is not None: + # Extract text from previous output if available + if hasattr(previous_output, 'outputs') and previous_output.outputs: + last_output = previous_output.outputs[-1] + if hasattr(last_output, 'text'): + combined["prompt"] = last_output.text + " " + combined["prompt"] + + return combined + + def _process_dit_inputs( + self, + stage_config: OmniStageConfig, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Process inputs for DiT stage.""" + combined = dict(stage_config.default_stage_args or {}) + combined.update(stage_args) + + dit = stage_config.dit_config + if dit is not None: + combined.setdefault("height", getattr(dit, "height", 512)) + combined.setdefault("width", getattr(dit, "width", 512)) + combined.setdefault( + "num_inference_steps", getattr(dit, "num_inference_steps", 50) + ) + combined.setdefault( + "guidance_scale", getattr(dit, "guidance_scale", 7.5) + ) + else: + combined.setdefault("height", 512) + combined.setdefault("width", 512) + combined.setdefault("num_inference_steps", 50) + combined.setdefault("guidance_scale", 7.5) + + # Handle image inputs if present + if "image" in stage_args: + # For now, we'll pass the image path directly + # In a full implementation, this would involve VAE encoding + combined["image"] = stage_args["image"] + + # If we have previous output from an AR stage, we might want to use it + if previous_output is not None: + # Extract text from previous AR output + if hasattr(previous_output, 'outputs') and previous_output.outputs: + last_output = previous_output.outputs[-1] + if hasattr(last_output, 'text'): + combined["prompt"] = last_output.text + + combined.setdefault("prompt", stage_args.get("prompt", "")) + + return combined + + def _execute_stage( + self, + stage_engine: Optional[Union[LLMEngine, DiffusersPipelineEngine]], + processed_input: Dict[str, Any], + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + priority: Optional[List[int]] = None, + stage_config: Optional[OmniStageConfig] = None, + ) -> Any: + """Execute a single stage.""" + # DiT via diffusers backend + if stage_config and stage_config.engine_type == "DiT": + dit = stage_config.dit_config + if dit and getattr(dit, "use_diffusers", False): + # Lazy-init executor per stage + if not hasattr(self, "_dit_engines"): + self._dit_engines = {} + exec_inst = self._dit_engines.get(stage_config.stage_id) + if exec_inst is None: + exec_inst = DiffusersPipelineEngine( + dit_config=dit, + model_path=stage_config.model_path, + log_stats=self.log_stats, + multiprocess_mode=self.multiprocess_mode, + ) + + self._dit_engines[stage_config.stage_id] = exec_inst + + return exec_inst.generate( + prompt=processed_input.get("prompt", ""), + height=processed_input.get("height", getattr(dit, "height", 512)), + width=processed_input.get("width", getattr(dit, "width", 512)), + num_inference_steps=processed_input.get( + "num_inference_steps", getattr(dit, "num_inference_steps", 30) + ), + guidance_scale=processed_input.get( + "guidance_scale", getattr(dit, "guidance_scale", 5.0) + ), + negative_prompt=processed_input.get("negative_prompt"), + seed=processed_input.get("seed"), + image=processed_input.get("image"), + ) + + # Use the parent LLM's generate method for AR text generation + prompt = processed_input.get("prompt", "") + max_tokens = processed_input.get("max_tokens", 100) + temperature = processed_input.get("temperature", 0.7) + + # Generate using the base LLM class + from vllm.sampling_params import SamplingParams + + sampling_params = SamplingParams( + max_tokens=max_tokens, + temperature=temperature, + top_p=processed_input.get("top_p", 1.0), + frequency_penalty=processed_input.get("frequency_penalty", 0.0), + presence_penalty=processed_input.get("presence_penalty", 0.0), + stop=processed_input.get("stop", None) + ) + + # Use the parent class's generate method + outputs = super().generate([prompt], sampling_params) + + # Return the first output (we're processing one prompt at a time) + if outputs: + return outputs[0] + else: + # Fallback to mock output if generation fails + from vllm.outputs import RequestOutput, CompletionOutput + + mock_output = CompletionOutput( + index=0, + text="Generation failed", + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="error" + ) + + return RequestOutput( + request_id="fallback_request", + prompt=prompt, + prompt_token_ids=[], + prompt_logprobs=None, + outputs=[mock_output], + finished=True + ) + + +class AsyncOmniLLM(LLM): + """Extended LLM class supporting multiple engines and stage-based processing.""" + + def __init__( + self, + stage_configs: List[OmniStageConfig], + log_stats: bool = False, + **kwargs + ): + # Use the first stage's model for the base LLM + if stage_configs and stage_configs[0].model_path: + model = stage_configs[0].model_path + else: + model = "Qwen/Qwen3-0.6B" + + # Fix configuration validation issues + # Ensure max_num_batched_tokens is at least as large as max_model_len + if 'max_model_len' in kwargs and 'max_num_batched_tokens' in kwargs: + if kwargs['max_num_batched_tokens'] < kwargs['max_model_len']: + kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] + elif 'max_model_len' in kwargs: + # If max_model_len is set but max_num_batched_tokens is not, set it to max_model_len + kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] + else: + # Set reasonable defaults to avoid validation errors + kwargs['max_model_len'] = 2048 + kwargs['max_num_batched_tokens'] = 2048 + + super().__init__(model=model, **kwargs) + self.stage_configs = stage_configs + self.log_stats = log_stats + self.stage_manager = StageManager(stage_configs, log_stats) + self.output_processor = MultimodalOutputProcessor() + + def _initialize_async_stage_engines(self) -> None: + """Initialize AsyncLLM instances for each stage.""" + self.stage_manager.initialize_async_engines() + + async def generate_async( + self, + stage_args_list: Optional[List[Dict[str, Any]]] = None, + use_tqdm: Union[bool, Callable[..., Any]] = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + priority: Optional[List[int]] = None, + *, + prompt: Optional[str] = None, + stage_overrides: Optional[Dict[int, Dict[str, Any]]] = None, + ) -> List[RequestOutput]: + """Async generation interface - orchestrates multi-stage processing.""" + + if stage_args_list is None: + if prompt is None: + raise ValueError( + "prompt must be provided when stage_args_list is not supplied" + ) + stage_args_list = self._build_stage_args_from_config( + prompt, stage_overrides or {} + ) + + if len(stage_args_list) != len(self.stage_configs): + raise ValueError( + f"Number of stage arguments ({len(stage_args_list)}) must match " + f"number of stage configs ({len(self.stage_configs)})" + ) + + # Process through each stage sequentially + current_output = None + + for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): + stage_engine = self.stage_manager.get_async_engine(i) + + # Prepare input for this stage + processed_input = self._process_stage_inputs( + stage_config, stage_args or {}, current_output + ) + + # Execute stage asynchronously + stage_output = await self._execute_stage_async( + stage_engine, processed_input, lora_request, priority, stage_config + ) + + # Update for next stage + current_output = stage_output + stage_config.stage_output = stage_output + + # Process final output + final_output = self.output_processor.process_output(current_output) + return final_output + + def _build_stage_args_from_config( + self, + prompt: str, + stage_overrides: Dict[int, Dict[str, Any]], + ) -> List[Dict[str, Any]]: + stage_args: List[Dict[str, Any]] = [] + for stage_config in self.stage_configs: + combined: Dict[str, Any] = dict(stage_config.default_stage_args or {}) + override = stage_overrides.get(stage_config.stage_id) + if override: + combined.update(override) + if stage_config.engine_type == "AR": + combined["prompt"] = prompt + stage_args.append(combined) + return stage_args + + def _process_stage_inputs( + self, + stage_config: OmniStageConfig, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Prepare input for specific stage (same as OmniLLM).""" + if stage_config.engine_type == "AR": + return self._process_ar_inputs(stage_config, stage_args, previous_output) + elif stage_config.engine_type == "DiT": + return self._process_dit_inputs(stage_config, stage_args, previous_output) + else: + raise NotImplementedError(f"Unknown engine type: {stage_config.engine_type}") + + def _process_ar_inputs( + self, + stage_config: OmniStageConfig, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Process inputs for AR stage (same as OmniLLM).""" + combined = dict(stage_config.default_stage_args or {}) + combined.update(stage_args) + combined.setdefault("prompt", "") + combined.setdefault("max_tokens", 100) + combined.setdefault("temperature", 0.7) + + if previous_output is not None: + if hasattr(previous_output, 'outputs') and previous_output.outputs: + last_output = previous_output.outputs[-1] + if hasattr(last_output, 'text'): + combined["prompt"] = last_output.text + " " + combined["prompt"] + + return combined + + def _process_dit_inputs( + self, + stage_config: OmniStageConfig, + stage_args: Dict[str, Any], + previous_output: Optional[Any] + ) -> Dict[str, Any]: + """Process inputs for DiT stage (same as OmniLLM).""" + combined = dict(stage_config.default_stage_args or {}) + combined.update(stage_args) + + dit = stage_config.dit_config + if dit is not None: + combined.setdefault("height", getattr(dit, "height", 512)) + combined.setdefault("width", getattr(dit, "width", 512)) + combined.setdefault( + "num_inference_steps", getattr(dit, "num_inference_steps", 50) + ) + combined.setdefault( + "guidance_scale", getattr(dit, "guidance_scale", 7.5) + ) + else: + combined.setdefault("height", 512) + combined.setdefault("width", 512) + combined.setdefault("num_inference_steps", 50) + combined.setdefault("guidance_scale", 7.5) + + if "image" in stage_args: + combined["image"] = stage_args["image"] + + if previous_output is not None: + if hasattr(previous_output, 'outputs') and previous_output.outputs: + last_output = previous_output.outputs[-1] + if hasattr(last_output, 'text'): + combined["prompt"] = last_output.text + + combined.setdefault("prompt", stage_args.get("prompt", "")) + + return combined + + async def _execute_stage_async( + self, + stage_engine: AsyncLLM, + processed_input: Dict[str, Any], + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + priority: Optional[List[int]] = None, + stage_config: Optional[OmniStageConfig] = None, + ) -> Any: + """Execute a single stage asynchronously.""" + # DiT via diffusers backend (sync call inside async for MVP) + if stage_config and stage_config.engine_type == "DiT": + dit = stage_config.dit_config + if dit and getattr(dit, "use_diffusers", False): + if not hasattr(self, "_dit_engines"): + self._dit_engines = {} + exec_inst = self._dit_engines.get(stage_config.stage_id) + if exec_inst is None: + from vllm_omni.engine.diffusion_engine import ( + DiffusersPipelineEngine, + ) + + pipeline_name = getattr(dit, "diffusers_pipeline", None) + device_cfg = getattr(dit, "device_config", None) + model_cfg = getattr(dit, "model_config", None) + if isinstance(device_cfg, dict): + device = device_cfg.get("device") + dtype = device_cfg.get("dtype") + else: + device = getattr(device_cfg, "device", None) + dtype = getattr(device_cfg, "dtype", None) + + if dtype is None: + if isinstance(model_cfg, dict): + dtype = model_cfg.get("dtype") + else: + dtype = getattr(model_cfg, "dtype", None) + + exec_inst = DiffusersPipelineEngine( + model_path=stage_config.model_path, + pipeline_name=pipeline_name, + device=device, + dtype=dtype, + ) + self._dit_engines[stage_config.stage_id] = exec_inst + + return exec_inst.generate( + prompt=processed_input.get("prompt", ""), + height=processed_input.get("height", getattr(dit, "height", 512)), + width=processed_input.get("width", getattr(dit, "width", 512)), + num_inference_steps=processed_input.get( + "num_inference_steps", getattr(dit, "num_inference_steps", 30) + ), + guidance_scale=processed_input.get( + "guidance_scale", getattr(dit, "guidance_scale", 5.0) + ), + negative_prompt=processed_input.get("negative_prompt"), + seed=processed_input.get("seed"), + image=processed_input.get("image"), + ) + + # Use the parent LLM's generate method for AR text generation + prompt = processed_input.get("prompt", "") + max_tokens = processed_input.get("max_tokens", 100) + temperature = processed_input.get("temperature", 0.7) + + # Generate using the base LLM class + from vllm.sampling_params import SamplingParams + + sampling_params = SamplingParams( + max_tokens=max_tokens, + temperature=temperature, + top_p=processed_input.get("top_p", 1.0), + frequency_penalty=processed_input.get("frequency_penalty", 0.0), + presence_penalty=processed_input.get("presence_penalty", 0.0), + stop=processed_input.get("stop", None) + ) + + # Use the parent class's generate method + outputs = super().generate([prompt], sampling_params) + + # Return the first output (we're processing one prompt at a time) + if outputs: + return outputs[0] + else: + # Fallback to mock output if generation fails + from vllm.outputs import RequestOutput, CompletionOutput + + mock_output = CompletionOutput( + index=0, + text="Generation failed", + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="error" + ) + + return RequestOutput( + request_id="fallback_request", + prompt=prompt, + prompt_token_ids=[], + prompt_logprobs=None, + outputs=[mock_output], + finished=True + ) diff --git a/vllm_omni/entrypoints/stage_manager.py b/vllm_omni/entrypoints/stage_manager.py new file mode 100644 index 0000000000..22a03d51c2 --- /dev/null +++ b/vllm_omni/entrypoints/stage_manager.py @@ -0,0 +1,87 @@ +""" +Stage manager for orchestrating multiple engines in vLLM-omni. +""" + +from typing import List, Optional, Union +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.engine.async_llm import AsyncLLM +from ..engine.diffusion_engine import DiffusionEngine + +from ..config import OmniStageConfig + + +class StageManager: + """Manages multiple stage engines for multi-stage processing.""" + + def __init__(self, stage_configs: List[OmniStageConfig], log_stats: bool = False): + self.stage_configs = stage_configs + self.log_stats = log_stats + self.engine_list: List[Union[LLMEngine, DiffusionEngine]] = [] + self.async_engine_list: List[AsyncLLM] = [] + self._initialized = False + self._async_initialized = False + + def initialize_engines(self) -> None: + """Initialize LLMEngine instances for each stage.""" + if self._initialized: + return + + # For now, create placeholder engines + # In a full implementation, this would create actual engines + for stage_config in self.stage_configs: + # Placeholder - would create actual engine here + self.engine_list.append(None) + + self._initialized = True + + def initialize_async_engines(self) -> None: + """Initialize AsyncLLM instances for each stage.""" + if self._async_initialized: + return + + # For now, create placeholder engines + # In a full implementation, this would create actual engines + for stage_config in self.stage_configs: + # Placeholder - would create actual engine here + self.async_engine_list.append(None) + + self._async_initialized = True + + def get_engine(self, stage_id: int) -> LLMEngine: + """Get the engine for a specific stage.""" + if not self._initialized: + self.initialize_engines() + + if stage_id >= len(self.engine_list): + raise IndexError(f"Stage {stage_id} not found. Available stages: 0-{len(self.engine_list)-1}") + + return self.engine_list[stage_id] + + def get_async_engine(self, stage_id: int) -> AsyncLLM: + """Get the async engine for a specific stage.""" + if not self._async_initialized: + self.initialize_async_engines() + + if stage_id >= len(self.async_engine_list): + raise IndexError(f"Async stage {stage_id} not found. Available stages: 0-{len(self.async_engine_list)-1}") + + return self.async_engine_list[stage_id] + + def get_stage_config(self, stage_id: int) -> OmniStageConfig: + """Get the configuration for a specific stage.""" + if stage_id >= len(self.stage_configs): + raise IndexError(f"Stage config {stage_id} not found. Available stages: 0-{len(self.stage_configs)-1}") + + return self.stage_configs[stage_id] + + def get_num_stages(self) -> int: + """Get the number of stages.""" + return len(self.stage_configs) + + def cleanup(self) -> None: + """Clean up resources.""" + # Clean up engines if needed + self.engine_list.clear() + self.async_engine_list.clear() + self._initialized = False + self._async_initialized = False diff --git a/vllm_omni/executor/__init__.py b/vllm_omni/model_executor/__init__.py similarity index 100% rename from vllm_omni/executor/__init__.py rename to vllm_omni/model_executor/__init__.py diff --git a/vllm_omni/request.py b/vllm_omni/request.py index 0919270afc..4bdc62b253 100644 --- a/vllm_omni/request.py +++ b/vllm_omni/request.py @@ -1,7 +1,299 @@ +""" +OmniRequest: Extended request class for vLLM-omni multimodal processing. + +This class extends vLLM's Request to support multimodal and non-autoregressive +processing with additional fields and methods specific to vLLM-omni. +""" + import enum import time -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List, Union +from dataclasses import dataclass, field from vllm.v1.request import Request as vLLMRequest + +class RequestType(enum.Enum): + """Types of requests supported by vLLM-omni.""" + TEXT = "text" + IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" + MULTIMODAL = "multimodal" + DIFFUSION = "diffusion" + + +class ProcessingStage(enum.Enum): + """Processing stages for multi-stage models.""" + PREPROCESSING = "preprocessing" + AR_GENERATION = "ar_generation" + DIFFUSION_GENERATION = "diffusion_generation" + POSTPROCESSING = "postprocessing" + COMPLETED = "completed" + + +@dataclass +class MultimodalData: + """Container for multimodal input data.""" + data_type: str # "image", "audio", "video", etc. + data: Any # The actual data (numpy array, bytes, etc.) + metadata: Dict[str, Any] = field(default_factory=dict) + processed: bool = False + + +@dataclass +class DiffusionParams: + """Parameters specific to diffusion model generation.""" + num_inference_steps: int = 50 + guidance_scale: float = 7.5 + seed: Optional[int] = None + scheduler: str = "ddpm" + strength: float = 1.0 + eta: float = 0.0 + + class OmniRequest(vLLMRequest): - # pass \ No newline at end of file + """ + Extended request class for vLLM-omni multimodal and non-autoregressive processing. + + This class extends vLLM's Request with additional fields and methods to support: + - Multimodal input processing + - Non-autoregressive generation (diffusion models) + - Multi-stage processing pipelines + - Enhanced caching for different model types + """ + + def __init__( + self, + request_id: str, + prompt: Optional[str] = None, + prompt_token_ids: Optional[List[int]] = None, + sampling_params: Optional[Any] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[Any] = None, + multi_modal_data: Optional[Dict[str, Any]] = None, + multi_modal_placeholders: Optional[Dict[str, str]] = None, + priority: int = 0, + # vLLM-omni specific parameters + request_type: RequestType = RequestType.TEXT, + processing_stage: ProcessingStage = ProcessingStage.PREPROCESSING, + multimodal_inputs: Optional[List[MultimodalData]] = None, + diffusion_params: Optional[DiffusionParams] = None, + output_format: str = "text", + cache_key: Optional[str] = None, + stage_configs: Optional[Dict[str, Any]] = None, + **kwargs + ): + # Initialize parent class + super().__init__( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time or time.time(), + lora_request=lora_request, + multi_modal_data=multi_modal_data, + multi_modal_placeholders=multi_modal_placeholders, + priority=priority, + **kwargs + ) + + # vLLM-omni specific attributes + self.request_type = request_type + self.processing_stage = processing_stage + self.multimodal_inputs = multimodal_inputs or [] + self.diffusion_params = diffusion_params or DiffusionParams() + self.output_format = output_format + self.cache_key = cache_key + self.stage_configs = stage_configs or {} + + # Processing state + self.current_stage = 0 + self.stage_results = {} + self.hidden_states = None + self.intermediate_outputs = [] + + # Timing and metrics + self.stage_timings = {} + self.total_processing_time = 0.0 + + # Error handling + self.errors = [] + self.retry_count = 0 + self.max_retries = 3 + + def add_multimodal_input(self, data: MultimodalData) -> None: + """Add a multimodal input to the request.""" + self.multimodal_inputs.append(data) + + def get_multimodal_inputs_by_type(self, data_type: str) -> List[MultimodalData]: + """Get all multimodal inputs of a specific type.""" + return [inp for inp in self.multimodal_inputs if inp.data_type == data_type] + + def update_processing_stage(self, stage: ProcessingStage) -> None: + """Update the current processing stage.""" + self.processing_stage = stage + self.stage_timings[stage.value] = time.time() + + def add_stage_result(self, stage: str, result: Any) -> None: + """Add a result from a processing stage.""" + self.stage_results[stage] = result + self.intermediate_outputs.append({ + 'stage': stage, + 'result': result, + 'timestamp': time.time() + }) + + def get_stage_result(self, stage: str) -> Any: + """Get a result from a specific processing stage.""" + return self.stage_results.get(stage) + + def set_hidden_states(self, hidden_states: Any) -> None: + """Set hidden states for the request.""" + self.hidden_states = hidden_states + + def get_hidden_states(self) -> Any: + """Get hidden states for the request.""" + return self.hidden_states + + def add_error(self, error: str) -> None: + """Add an error to the request.""" + self.errors.append({ + 'error': error, + 'timestamp': time.time(), + 'stage': self.processing_stage.value + }) + + def has_errors(self) -> bool: + """Check if the request has any errors.""" + return len(self.errors) > 0 + + def can_retry(self) -> bool: + """Check if the request can be retried.""" + return self.retry_count < self.max_retries + + def increment_retry(self) -> None: + """Increment the retry count.""" + self.retry_count += 1 + + def generate_cache_key(self) -> str: + """Generate a cache key for the request.""" + if self.cache_key: + return self.cache_key + + # Generate cache key based on request content + key_parts = [ + self.request_id, + str(self.request_type.value), + str(self.prompt_token_ids) if self.prompt_token_ids else str(self.prompt), + str(self.diffusion_params.num_inference_steps) if self.diffusion_params else "0", + str(len(self.multimodal_inputs)) + ] + return "_".join(key_parts) + + def get_processing_time(self) -> float: + """Get the total processing time for the request.""" + if self.stage_timings: + return time.time() - min(self.stage_timings.values()) + return 0.0 + + def is_completed(self) -> bool: + """Check if the request is completed.""" + return self.processing_stage == ProcessingStage.COMPLETED + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for serialization.""" + return { + 'request_id': self.request_id, + 'request_type': self.request_type.value, + 'processing_stage': self.processing_stage.value, + 'prompt': self.prompt, + 'prompt_token_ids': self.prompt_token_ids, + 'output_format': self.output_format, + 'multimodal_inputs_count': len(self.multimodal_inputs), + 'diffusion_params': self.diffusion_params.__dict__ if self.diffusion_params else None, + 'stage_results': list(self.stage_results.keys()), + 'has_errors': self.has_errors(), + 'retry_count': self.retry_count, + 'processing_time': self.get_processing_time() + } + + def __repr__(self) -> str: + """String representation of the request.""" + return (f"OmniRequest(id={self.request_id}, " + f"type={self.request_type.value}, " + f"stage={self.processing_stage.value}, " + f"multimodal_inputs={len(self.multimodal_inputs)})") + + +# Factory functions for creating different types of requests +def create_text_request( + request_id: str, + prompt: str, + sampling_params: Optional[Any] = None, + **kwargs +) -> OmniRequest: + """Create a text-only request.""" + return OmniRequest( + request_id=request_id, + prompt=prompt, + request_type=RequestType.TEXT, + sampling_params=sampling_params, + **kwargs + ) + + +def create_image_request( + request_id: str, + prompt: str, + image_data: Any, + diffusion_params: Optional[DiffusionParams] = None, + **kwargs +) -> OmniRequest: + """Create an image generation request.""" + multimodal_input = MultimodalData( + data_type="image", + data=image_data, + metadata={"is_input": True} + ) + + return OmniRequest( + request_id=request_id, + prompt=prompt, + request_type=RequestType.IMAGE, + multimodal_inputs=[multimodal_input], + diffusion_params=diffusion_params, + output_format="image", + **kwargs + ) + + +def create_multimodal_request( + request_id: str, + prompt: str, + multimodal_inputs: List[MultimodalData], + **kwargs +) -> OmniRequest: + """Create a multimodal request.""" + return OmniRequest( + request_id=request_id, + prompt=prompt, + request_type=RequestType.MULTIMODAL, + multimodal_inputs=multimodal_inputs, + **kwargs + ) + + +def create_diffusion_request( + request_id: str, + prompt: str, + diffusion_params: DiffusionParams, + **kwargs +) -> OmniRequest: + """Create a diffusion model request.""" + return OmniRequest( + request_id=request_id, + prompt=prompt, + request_type=RequestType.DIFFUSION, + diffusion_params=diffusion_params, + **kwargs + ) \ No newline at end of file diff --git a/vllm_omni/worker/gpu_diffusion_model_runner.py b/vllm_omni/worker/gpu_diffusion_model_runner.py new file mode 100644 index 0000000000..e464922945 --- /dev/null +++ b/vllm_omni/worker/gpu_diffusion_model_runner.py @@ -0,0 +1,133 @@ +"""Model runner that wraps a diffusers text-to-image pipeline.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional + +@dataclass +class DiffusionRunnerOutput: + """Structured output returned by the diffusion runner.""" + + prompt: str + images: list + request_id: str = "diffusion_request" + finished: bool = True + output_type: str = "image" + + +class DiffusionModelRunner: + """Lightweight runner that loads and executes diffusers pipelines.""" + + def __init__( + self, + model_path: str, + *, + pipeline_name: Optional[str] = None, + device: Optional[str] = None, + dtype: Optional[str] = None, + ) -> None: + self.model_path = model_path + self.pipeline_name = pipeline_name or "auto" + self.device, self.torch_dtype = self._resolve_device_and_dtype(device, dtype) + self._pipeline = self._load_pipeline() + + def _resolve_device_and_dtype( + self, + device: Optional[str], + dtype: Optional[str], + ): + import torch + + resolved_device = device + if resolved_device is None: + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + resolved_device = "mps" + elif torch.cuda.is_available(): + resolved_device = "cuda" + else: + resolved_device = "cpu" + + if dtype is not None: + resolved_dtype = getattr(torch, dtype, None) + else: + if resolved_device in {"cuda", "mps"}: + resolved_dtype = torch.float16 + else: + resolved_dtype = torch.float32 + + return resolved_device, resolved_dtype + + def _load_pipeline(self): + import diffusers + + if self.pipeline_name != "auto": + PipelineCls = getattr(diffusers, self.pipeline_name) + else: + PipelineCls = diffusers.AutoPipelineForText2Image + + pipeline = PipelineCls.from_pretrained( + self.model_path, + torch_dtype=self.torch_dtype, + ) + + try: + pipeline = pipeline.to(self.device) + except Exception: + pipeline = pipeline.to("cpu") + + self._apply_optimisations(pipeline) + return pipeline + + @staticmethod + def _apply_optimisations(pipeline) -> None: + # Enable attention slicing / VAE tiling where available to reduce memory. + try: + if hasattr(pipeline, "enable_attention_slicing"): + pipeline.enable_attention_slicing() + except Exception: + pass + try: + vae = getattr(pipeline, "vae", None) + if vae and hasattr(vae, "enable_tiling"): + vae.enable_tiling() + except Exception: + pass + + def generate( + self, + prompt: str, + *, + height: int = 512, + width: int = 512, + num_inference_steps: int = 30, + guidance_scale: float = 5.0, + negative_prompt: Optional[str] = None, + seed: Optional[int] = None, + image: Optional[Any] = None, + ) -> DiffusionRunnerOutput: + import torch + + generator = None + if seed is not None: + try: + generator = torch.Generator(device=self.device).manual_seed(int(seed)) + except Exception: + # Some backends (e.g., CPU without full support) may not accept the device. + generator = torch.Generator().manual_seed(int(seed)) + + output = self._pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=int(num_inference_steps), + guidance_scale=float(guidance_scale), + height=int(height), + width=int(width), + generator=generator, + image=image, + ) + + return DiffusionRunnerOutput( + prompt=prompt, + images=getattr(output, "images", []) or [], + ) diff --git a/vllm_omni/worker/gpu_diffusion_worker.py b/vllm_omni/worker/gpu_diffusion_worker.py new file mode 100644 index 0000000000..74f2474c4c --- /dev/null +++ b/vllm_omni/worker/gpu_diffusion_worker.py @@ -0,0 +1,56 @@ +"""Worker abstraction for executing diffusion model runners.""" + +from __future__ import annotations + +from typing import Any, Optional + +from vllm.v1.worker.worker_base import WorkerBase + +from .gpu_diffusion_model_runner import DiffusionModelRunner, DiffusionRunnerOutput + + +class DiffusionGPUWorker(WorkerBase): + """Minimal worker wrapping DiffusionModelRunner. + + The worker interface mirrors the shape expected by vLLM executors: it owns + a model runner instance and exposes a `generate` method that can be called + by an executor implementation. + """ + + def __init__( + self, + model_path: str, + *, + pipeline_name: Optional[str] = None, + device: Optional[str] = None, + dtype: Optional[str] = None, + ) -> None: + self.model_runner = DiffusionModelRunner( + model_path, + pipeline_name=pipeline_name, + device=device, + dtype=dtype, + ) + + def generate( + self, + prompt: str, + *, + height: int = 512, + width: int = 512, + num_inference_steps: int = 30, + guidance_scale: float = 5.0, + negative_prompt: Optional[str] = None, + seed: Optional[int] = None, + image: Optional[Any] = None, + ) -> DiffusionRunnerOutput: + return self.model_runner.generate( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + seed=seed, + image=image, + )