diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py index e48a874a7c..a688faf1b9 100644 --- a/vllm_omni/config/__init__.py +++ b/vllm_omni/config/__init__.py @@ -1,6 +1,14 @@ """ Configuration module for vLLM-omni. """ +from typing import Optional +from pydantic.dataclasses import dataclass +from pydantic import ConfigDict + +from vllm.config import ModelConfig +from vllm.config import config + +import vllm_omni.model_executor.models as me_models from .stage_config import ( OmniStageConfig, @@ -11,7 +19,28 @@ create_dit_stage_config, ) + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class OmniModelConfig(ModelConfig): + """Configuration for Omni models, extending the base ModelConfig.""" + + stage_id: int = 0 + model_stage: str = "thinker" + model_arch: str = "Qwen2_5OmniForConditionalGeneration" + engine_output_type: Optional[str] = None + + @property + def registry(self): + return me_models.OmniModelRegistry + + @property + def architectures(self) -> list[str]: + return [self.model_arch] + + __all__ = [ + "OmniModelConfig", "OmniStageConfig", "DiTConfig", "DiTCacheConfig", diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py new file mode 100644 index 0000000000..41a74ca476 --- /dev/null +++ b/vllm_omni/engine/arg_utils.py @@ -0,0 +1,51 @@ +from typing import Optional +from dataclasses import dataclass +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + +from vllm_omni.config import OmniModelConfig + + +@dataclass +class OmniEngineArgs(EngineArgs): + stage_id: int = 0 + model_stage: str = "thinker" + model_arch: str = "Qwen2_5OmniForConditionalGeneration" + engine_output_type: Optional[str] = None + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Shared CLI arguments for vLLM engine.""" + parser.add_argument( + "--engine-output-type", + type=str, + default=EngineArgs.engine_output_type, + help=( + "Declare EngineCoreOutput.output_type (e.g., 'text', 'image', " + "'text+image', 'latent'). This will be written into " + "model_config.engine_output_type for schedulers to use." + ), + ) + parser.add_argument("--model-stage", type=str, default=OmniEngineArgs.model_stage, + help="Declare model stage (e.g., 'thinker', 'talker', 'token2wav'). This will be written into model_config.model_stage for schedulers to use.") + return parser + + def create_model_config(self) -> OmniModelConfig: + # First, get the base ModelConfig from the parent class + base_config = super().create_model_config() + + # Create OmniModelConfig by copying all base config attributes + # and adding the new omni-specific fields + config_dict = base_config.__dict__.copy() + + # Add the new omni-specific fields + config_dict['stage_id'] = self.stage_id + config_dict['model_stage'] = self.model_stage + config_dict['model_arch'] = self.model_arch + config_dict['engine_output_type'] = self.engine_output_type + + # Create and return the OmniModelConfig instance + omni_config = OmniModelConfig(**config_dict) + omni_config.hf_config.architectures = omni_config.architectures + + return omni_config \ No newline at end of file diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 8223c6f20e..cdfe5d53c6 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -1,576 +1,157 @@ -""" -Core OmniLLM and AsyncOmniLLM classes for multi-stage processing. -""" +from typing import Union, Sequence, Optional, Any +import cloudpickle +from pydantic import ValidationError -from typing import List, Dict, Any, Optional, Union, Callable +from vllm.inputs import PromptType +from vllm.sampling_params import SamplingParams from vllm.entrypoints.llm import LLM -from vllm.v1.engine.async_llm import AsyncLLM + from vllm.v1.engine.llm_engine import LLMEngine -from vllm.outputs import RequestOutput, LoRARequest +from vllm.engine.arg_utils import HfOverrides +from vllm.usage.usage_lib import UsageContext +from vllm.config import CompilationConfig, is_init_field +from vllm.utils import Counter +from vllm.logger import init_logger +import vllm.envs as envs -from ..config import OmniStageConfig -from .stage_manager import StageManager -from ..engine.output_processor import MultimodalOutputProcessor -from ..engine.diffusion_engine import DiffusersPipelineEngine +from vllm_omni.entrypoints.utils import load_stage_configs_from_model +from vllm_omni.entrypoints.omni_stage import OmniStage +from vllm_omni.engine.arg_utils import OmniEngineArgs +from vllm_omni.engine.output_processor import MultimodalOutputProcessor +from vllm_omni.engine.processor import OmniProcessor +from vllm_omni.outputs import OmniRequestOutput +logger = init_logger(__name__) -class OmniLLM(LLM): - """Extended LLM supporting multiple engines and stage-based processing.""" - - def __init__( - self, - stage_configs: List[OmniStageConfig], - log_stats: bool = False, - **kwargs - ): - # Use the first stage's model as the default model for LLM - default_model = stage_configs[0].model_path if stage_configs else "test-model" - # Track whether we should launch engines in multiprocess mode. - self.multiprocess_mode = kwargs.pop("multiprocess_mode", False) - - # Fix configuration validation issues - # Ensure max_num_batched_tokens is at least as large as max_model_len - if 'max_model_len' in kwargs and 'max_num_batched_tokens' in kwargs: - if kwargs['max_num_batched_tokens'] < kwargs['max_model_len']: - kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] - elif 'max_model_len' in kwargs: - # If max_model_len is set but max_num_batched_tokens is not, set it to max_model_len - kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] +class OmniLLM: + def __init__(self, model: str, stage_configs = None, log_stats: bool = False, **kwargs): + if stage_configs is None: + self.initialize_stage_configs(model) else: - # Set reasonable defaults to avoid validation errors - kwargs['max_model_len'] = 2048 - kwargs['max_num_batched_tokens'] = 2048 - - super().__init__(model=default_model, **kwargs) - self.stage_configs = stage_configs - self.log_stats = log_stats - self.stage_manager = StageManager(stage_configs, log_stats) - self.output_processor = MultimodalOutputProcessor() - self._initialize_stage_engines() - - def _initialize_stage_engines(self) -> None: - """Initialize LLMEngine instances for each stage.""" - self.stage_manager.initialize_engines() - - def generate( - self, - stage_args_list: Optional[List[Dict[str, Any]]] = None, - use_tqdm: Union[bool, Callable[..., Any]] = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - priority: Optional[List[int]] = None, - *, - prompt: Optional[str] = None, - stage_overrides: Optional[Dict[int, Dict[str, Any]]] = None, - ) -> List[RequestOutput]: - """Main generation interface - orchestrates multi-stage processing.""" - if stage_args_list is None: - if prompt is None: - raise ValueError( - "prompt must be provided when stage_args_list is not supplied" - ) - stage_args_list = self._build_stage_args_from_config( - prompt, stage_overrides or {} - ) - - if len(stage_args_list) != len(self.stage_configs): - raise ValueError( - f"Number of stage arguments ({len(stage_args_list)}) must match " - f"number of stage configs ({len(self.stage_configs)})" - ) - - # Process through each stage sequentially - current_output = None + self.stage_configs = stage_configs - for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): - stage_engine = self.stage_manager.get_engine(i) - - # Prepare input for this stage - processed_input = self._process_stage_inputs( - stage_config, stage_args or {}, current_output - ) - - # Execute stage - stage_output = self._execute_stage( - stage_engine, processed_input, lora_request, priority, stage_config - ) - - # Update for next stage - current_output = stage_output - stage_config.stage_output = stage_output + self.stage_list = [] + self.initialize_stages(model) - # Process final output - final_output = self.output_processor.process_output(current_output) - return final_output + def initialize_stage_configs(self, model: str): + self.stage_configs = load_stage_configs_from_model(model) - def _build_stage_args_from_config( - self, - prompt: str, - stage_overrides: Dict[int, Dict[str, Any]], - ) -> List[Dict[str, Any]]: - """Derive per-stage argument dictionaries from configuration defaults.""" - stage_args: List[Dict[str, Any]] = [] + def initialize_stages(self, model: str): for stage_config in self.stage_configs: - combined: Dict[str, Any] = dict(stage_config.default_stage_args or {}) - override = stage_overrides.get(stage_config.stage_id) - if override: - combined.update(override) - if stage_config.engine_type == "AR": - combined["prompt"] = prompt - stage_args.append(combined) - return stage_args - - def _process_stage_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Prepare input for specific stage.""" - if stage_config.engine_type == "AR": - return self._process_ar_inputs(stage_config, stage_args, previous_output) - elif stage_config.engine_type == "DiT": - return self._process_dit_inputs(stage_config, stage_args, previous_output) - else: - raise NotImplementedError(f"Unknown engine type: {stage_config.engine_type}") - - def _process_ar_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Process inputs for AR stage.""" - combined = dict(stage_config.default_stage_args or {}) - combined.update(stage_args) - combined.setdefault("prompt", "") - combined.setdefault("max_tokens", 100) - combined.setdefault("temperature", 0.7) - - # If we have previous output (e.g., from a previous AR stage), - # we might want to use it as context - if previous_output is not None: - # Extract text from previous output if available - if hasattr(previous_output, 'outputs') and previous_output.outputs: - last_output = previous_output.outputs[-1] - if hasattr(last_output, 'text'): - combined["prompt"] = last_output.text + " " + combined["prompt"] - - return combined - - def _process_dit_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Process inputs for DiT stage.""" - combined = dict(stage_config.default_stage_args or {}) - combined.update(stage_args) - - dit = stage_config.dit_config - if dit is not None: - combined.setdefault("height", getattr(dit, "height", 512)) - combined.setdefault("width", getattr(dit, "width", 512)) - combined.setdefault( - "num_inference_steps", getattr(dit, "num_inference_steps", 50) - ) - combined.setdefault( - "guidance_scale", getattr(dit, "guidance_scale", 7.5) - ) - else: - combined.setdefault("height", 512) - combined.setdefault("width", 512) - combined.setdefault("num_inference_steps", 50) - combined.setdefault("guidance_scale", 7.5) - - # Handle image inputs if present - if "image" in stage_args: - # For now, we'll pass the image path directly - # In a full implementation, this would involve VAE encoding - combined["image"] = stage_args["image"] - - # If we have previous output from an AR stage, we might want to use it - if previous_output is not None: - # Extract text from previous AR output - if hasattr(previous_output, 'outputs') and previous_output.outputs: - last_output = previous_output.outputs[-1] - if hasattr(last_output, 'text'): - combined["prompt"] = last_output.text - - combined.setdefault("prompt", stage_args.get("prompt", "")) - - return combined + stage = OmniStage(stage_config) + stage_llm = OmniStageLLM(model=model, **stage_config.engine_args) + stage.set_engine(stage_llm) + self.stage_list.append(stage) - def _execute_stage( - self, - stage_engine: Optional[Union[LLMEngine, DiffusersPipelineEngine]], - processed_input: Dict[str, Any], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - priority: Optional[List[int]] = None, - stage_config: Optional[OmniStageConfig] = None, - ) -> Any: - """Execute a single stage.""" - # DiT via diffusers backend - if stage_config and stage_config.engine_type == "DiT": - dit = stage_config.dit_config - if dit and getattr(dit, "use_diffusers", False): - # Lazy-init executor per stage - if not hasattr(self, "_dit_engines"): - self._dit_engines = {} - exec_inst = self._dit_engines.get(stage_config.stage_id) - if exec_inst is None: - exec_inst = DiffusersPipelineEngine( - dit_config=dit, - model_path=stage_config.model_path, - log_stats=self.log_stats, - multiprocess_mode=self.multiprocess_mode, - ) - - self._dit_engines[stage_config.stage_id] = exec_inst - - return exec_inst.generate( - prompt=processed_input.get("prompt", ""), - height=processed_input.get("height", getattr(dit, "height", 512)), - width=processed_input.get("width", getattr(dit, "width", 512)), - num_inference_steps=processed_input.get( - "num_inference_steps", getattr(dit, "num_inference_steps", 30) - ), - guidance_scale=processed_input.get( - "guidance_scale", getattr(dit, "guidance_scale", 5.0) - ), - negative_prompt=processed_input.get("negative_prompt"), - seed=processed_input.get("seed"), - image=processed_input.get("image"), - ) - - # Use the parent LLM's generate method for AR text generation - prompt = processed_input.get("prompt", "") - max_tokens = processed_input.get("max_tokens", 100) - temperature = processed_input.get("temperature", 0.7) - - # Generate using the base LLM class - from vllm.sampling_params import SamplingParams - - sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=temperature, - top_p=processed_input.get("top_p", 1.0), - frequency_penalty=processed_input.get("frequency_penalty", 0.0), - presence_penalty=processed_input.get("presence_penalty", 0.0), - stop=processed_input.get("stop", None) - ) - - # Use the parent class's generate method - outputs = super().generate([prompt], sampling_params) - - # Return the first output (we're processing one prompt at a time) - if outputs: - return outputs[0] - else: - # Fallback to mock output if generation fails - from vllm.outputs import RequestOutput, CompletionOutput - - mock_output = CompletionOutput( - index=0, - text="Generation failed", - token_ids=[], - cumulative_logprob=0.0, - logprobs=None, - finish_reason="error" - ) - - return RequestOutput( - request_id="fallback_request", - prompt=prompt, - prompt_token_ids=[], - prompt_logprobs=None, - outputs=[mock_output], - finished=True - ) - - -class AsyncOmniLLM(LLM): - """Extended LLM class supporting multiple engines and stage-based processing.""" - - def __init__( + def generate( self, - stage_configs: List[OmniStageConfig], - log_stats: bool = False, - **kwargs - ): - # Use the first stage's model for the base LLM - if stage_configs and stage_configs[0].model_path: - model = stage_configs[0].model_path - else: - model = "Qwen/Qwen3-0.6B" - - # Fix configuration validation issues - # Ensure max_num_batched_tokens is at least as large as max_model_len - if 'max_model_len' in kwargs and 'max_num_batched_tokens' in kwargs: - if kwargs['max_num_batched_tokens'] < kwargs['max_model_len']: - kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] - elif 'max_model_len' in kwargs: - # If max_model_len is set but max_num_batched_tokens is not, set it to max_model_len - kwargs['max_num_batched_tokens'] = kwargs['max_model_len'] - else: - # Set reasonable defaults to avoid validation errors - kwargs['max_model_len'] = 2048 - kwargs['max_num_batched_tokens'] = 2048 - - super().__init__(model=model, **kwargs) - self.stage_configs = stage_configs - self.log_stats = log_stats - self.stage_manager = StageManager(stage_configs, log_stats) - self.output_processor = MultimodalOutputProcessor() + prompts: Union[PromptType, Sequence[PromptType]], + sampling_params_list: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + ) -> list[OmniRequestOutput]: + """Generate text outputs for the given prompts.""" + final_outputs: list[OmniRequestOutput] = [] + if len(sampling_params_list) != len(self.stage_list): + raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + for stage_id, stage in enumerate(self.stage_list): + if stage_id > 0: + engine_inputs = stage.process_engine_inputs(self.stage_list, prompts) + else: + engine_inputs = prompts + engine_outputs = self._run_generation(stage, sampling_params_list[stage_id], engine_inputs) + stage.set_engine_outputs(engine_outputs) + if hasattr(stage, 'final_output') and stage.final_output: + final_outputs.append(OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs)) + return final_outputs - def _initialize_async_stage_engines(self) -> None: - """Initialize AsyncLLM instances for each stage.""" - self.stage_manager.initialize_async_engines() - - async def generate_async( - self, - stage_args_list: Optional[List[Dict[str, Any]]] = None, - use_tqdm: Union[bool, Callable[..., Any]] = True, - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - priority: Optional[List[int]] = None, - *, - prompt: Optional[str] = None, - stage_overrides: Optional[Dict[int, Dict[str, Any]]] = None, - ) -> List[RequestOutput]: - """Async generation interface - orchestrates multi-stage processing.""" - - if stage_args_list is None: - if prompt is None: + def _run_generation(self, stage: OmniStage, sampling_params: SamplingParams, prompts: Union[PromptType, Sequence[PromptType]]): + engine_outputs = [] + for ro in stage.engine.generate(prompts, sampling_params): + engine_outputs.append(ro) + return engine_outputs + + +class OmniStageLLM(LLM): + def __init__(self, + model: str, + compilation_config: Optional[Union[int, dict[str, Any], + CompilationConfig]] = None, + hf_overrides: Optional[HfOverrides] = None, + **kwargs): + """LLM constructor.""" + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + if "worker_cls" in kwargs: + worker_cls = kwargs["worker_cls"] + # if the worker_cls is not qualified string name, + # we serialize it using cloudpickle to avoid pickling issues + if isinstance(worker_cls, type): + kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) + + if "kv_transfer_config" in kwargs and isinstance( + kwargs["kv_transfer_config"], dict): + from vllm.config import KVTransferConfig + raw_config_dict = kwargs["kv_transfer_config"] + try: + kwargs["kv_transfer_config"] = KVTransferConfig( + **raw_config_dict) + except ValidationError as e: + logger.error( + "Failed to convert 'kv_transfer_config' dict to " + "KVTransferConfig object. Dict: %s. Error: %s", + raw_config_dict, e) + # Consider re-raising a more specific vLLM error or ValueError + # to provide better context to the user. raise ValueError( - "prompt must be provided when stage_args_list is not supplied" - ) - stage_args_list = self._build_stage_args_from_config( - prompt, stage_overrides or {} - ) - - if len(stage_args_list) != len(self.stage_configs): - raise ValueError( - f"Number of stage arguments ({len(stage_args_list)}) must match " - f"number of stage configs ({len(self.stage_configs)})" - ) - - # Process through each stage sequentially - current_output = None - - for i, (stage_config, stage_args) in enumerate(zip(self.stage_configs, stage_args_list)): - stage_engine = self.stage_manager.get_async_engine(i) - - # Prepare input for this stage - processed_input = self._process_stage_inputs( - stage_config, stage_args or {}, current_output - ) - - # Execute stage asynchronously - stage_output = await self._execute_stage_async( - stage_engine, processed_input, lora_request, priority, stage_config - ) - - # Update for next stage - current_output = stage_output - stage_config.stage_output = stage_output - - # Process final output - final_output = self.output_processor.process_output(current_output) - return final_output - - def _build_stage_args_from_config( - self, - prompt: str, - stage_overrides: Dict[int, Dict[str, Any]], - ) -> List[Dict[str, Any]]: - stage_args: List[Dict[str, Any]] = [] - for stage_config in self.stage_configs: - combined: Dict[str, Any] = dict(stage_config.default_stage_args or {}) - override = stage_overrides.get(stage_config.stage_id) - if override: - combined.update(override) - if stage_config.engine_type == "AR": - combined["prompt"] = prompt - stage_args.append(combined) - return stage_args - - def _process_stage_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Prepare input for specific stage (same as OmniLLM).""" - if stage_config.engine_type == "AR": - return self._process_ar_inputs(stage_config, stage_args, previous_output) - elif stage_config.engine_type == "DiT": - return self._process_dit_inputs(stage_config, stage_args, previous_output) + f"Invalid 'kv_transfer_config' provided: {e}") from e + + if hf_overrides is None: + hf_overrides = {} + + if compilation_config is not None: + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig( + level=compilation_config) + elif isinstance(compilation_config, dict): + predicate = lambda x: is_init_field(CompilationConfig, x[0]) + compilation_config_instance = CompilationConfig( + **dict(filter(predicate, compilation_config.items()))) + else: + compilation_config_instance = compilation_config else: - raise NotImplementedError(f"Unknown engine type: {stage_config.engine_type}") - - def _process_ar_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Process inputs for AR stage (same as OmniLLM).""" - combined = dict(stage_config.default_stage_args or {}) - combined.update(stage_args) - combined.setdefault("prompt", "") - combined.setdefault("max_tokens", 100) - combined.setdefault("temperature", 0.7) + compilation_config_instance = CompilationConfig() - if previous_output is not None: - if hasattr(previous_output, 'outputs') and previous_output.outputs: - last_output = previous_output.outputs[-1] - if hasattr(last_output, 'text'): - combined["prompt"] = last_output.text + " " + combined["prompt"] - - return combined - - def _process_dit_inputs( - self, - stage_config: OmniStageConfig, - stage_args: Dict[str, Any], - previous_output: Optional[Any] - ) -> Dict[str, Any]: - """Process inputs for DiT stage (same as OmniLLM).""" - combined = dict(stage_config.default_stage_args or {}) - combined.update(stage_args) + engine_args = OmniEngineArgs( + model=model, + hf_overrides=hf_overrides, + compilation_config=compilation_config_instance, + **kwargs, + ) - dit = stage_config.dit_config - if dit is not None: - combined.setdefault("height", getattr(dit, "height", 512)) - combined.setdefault("width", getattr(dit, "width", 512)) - combined.setdefault( - "num_inference_steps", getattr(dit, "num_inference_steps", 50) - ) - combined.setdefault( - "guidance_scale", getattr(dit, "guidance_scale", 7.5) - ) + # Create the Engine (autoselects V0 vs V1) + self.llm_engine = LLMEngine.from_engine_args( + engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + self.llm_engine.output_processor = MultimodalOutputProcessor(tokenizer=self.llm_engine.tokenizer, + log_stats=self.llm_engine.log_stats) + self.llm_engine.processor = OmniProcessor(vllm_config=self.llm_engine.vllm_config, + tokenizer=self.llm_engine.tokenizer) + self.engine_class = type(self.llm_engine) + + self.request_counter = Counter() + self.default_sampling_params: Union[dict[str, Any], None] = None + + if envs.VLLM_USE_V1: + supported_tasks = self.llm_engine \ + .get_supported_tasks() # type: ignore else: - combined.setdefault("height", 512) - combined.setdefault("width", 512) - combined.setdefault("num_inference_steps", 50) - combined.setdefault("guidance_scale", 7.5) - - if "image" in stage_args: - combined["image"] = stage_args["image"] - - if previous_output is not None: - if hasattr(previous_output, 'outputs') and previous_output.outputs: - last_output = previous_output.outputs[-1] - if hasattr(last_output, 'text'): - combined["prompt"] = last_output.text - - combined.setdefault("prompt", stage_args.get("prompt", "")) - - return combined - - async def _execute_stage_async( - self, - stage_engine: AsyncLLM, - processed_input: Dict[str, Any], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - priority: Optional[List[int]] = None, - stage_config: Optional[OmniStageConfig] = None, - ) -> Any: - """Execute a single stage asynchronously.""" - # DiT via diffusers backend (sync call inside async for MVP) - if stage_config and stage_config.engine_type == "DiT": - dit = stage_config.dit_config - if dit and getattr(dit, "use_diffusers", False): - if not hasattr(self, "_dit_engines"): - self._dit_engines = {} - exec_inst = self._dit_engines.get(stage_config.stage_id) - if exec_inst is None: - from vllm_omni.engine.diffusion_engine import ( - DiffusersPipelineEngine, - ) + supported_tasks = self.llm_engine.model_config.supported_tasks - pipeline_name = getattr(dit, "diffusers_pipeline", None) - device_cfg = getattr(dit, "device_config", None) - model_cfg = getattr(dit, "model_config", None) - if isinstance(device_cfg, dict): - device = device_cfg.get("device") - dtype = device_cfg.get("dtype") - else: - device = getattr(device_cfg, "device", None) - dtype = getattr(device_cfg, "dtype", None) + logger.info("Supported_tasks: %s", supported_tasks) - if dtype is None: - if isinstance(model_cfg, dict): - dtype = model_cfg.get("dtype") - else: - dtype = getattr(model_cfg, "dtype", None) - - exec_inst = DiffusersPipelineEngine( - model_path=stage_config.model_path, - pipeline_name=pipeline_name, - device=device, - dtype=dtype, - ) - self._dit_engines[stage_config.stage_id] = exec_inst - - return exec_inst.generate( - prompt=processed_input.get("prompt", ""), - height=processed_input.get("height", getattr(dit, "height", 512)), - width=processed_input.get("width", getattr(dit, "width", 512)), - num_inference_steps=processed_input.get( - "num_inference_steps", getattr(dit, "num_inference_steps", 30) - ), - guidance_scale=processed_input.get( - "guidance_scale", getattr(dit, "guidance_scale", 5.0) - ), - negative_prompt=processed_input.get("negative_prompt"), - seed=processed_input.get("seed"), - image=processed_input.get("image"), - ) - - # Use the parent LLM's generate method for AR text generation - prompt = processed_input.get("prompt", "") - max_tokens = processed_input.get("max_tokens", 100) - temperature = processed_input.get("temperature", 0.7) - - # Generate using the base LLM class - from vllm.sampling_params import SamplingParams - - sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=temperature, - top_p=processed_input.get("top_p", 1.0), - frequency_penalty=processed_input.get("frequency_penalty", 0.0), - presence_penalty=processed_input.get("presence_penalty", 0.0), - stop=processed_input.get("stop", None) - ) - - # Use the parent class's generate method - outputs = super().generate([prompt], sampling_params) - - # Return the first output (we're processing one prompt at a time) - if outputs: - return outputs[0] - else: - # Fallback to mock output if generation fails - from vllm.outputs import RequestOutput, CompletionOutput - - mock_output = CompletionOutput( - index=0, - text="Generation failed", - token_ids=[], - cumulative_logprob=0.0, - logprobs=None, - finish_reason="error" - ) - - return RequestOutput( - request_id="fallback_request", - prompt=prompt, - prompt_token_ids=[], - prompt_logprobs=None, - outputs=[mock_output], - finished=True - ) + self.supported_tasks = supported_tasks \ No newline at end of file diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py new file mode 100644 index 0000000000..ef2249ac31 --- /dev/null +++ b/vllm_omni/entrypoints/omni_stage.py @@ -0,0 +1,80 @@ +""" +Stage manager for orchestrating multiple engines in vLLM-omni. +""" + +import importlib +from typing import List, Union +from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.inputs import TextPrompt +from vllm.v1.engine import EngineCoreOutput + +from vllm_omni.inputs.data import OmniTokensPrompt + + +class OmniStage: + def __init__(self, stage_config): + self.stage_config = stage_config + self.engine = None + self.async_engine = None + self.stage_id = stage_config.stage_id + self.engine_args = stage_config.engine_args + self.model_stage = stage_config.engine_args.model_stage + if hasattr(stage_config, 'engine_input_source'): + self.engine_input_source = stage_config.engine_input_source + else: + self.engine_input_source = [] + self.engine_output_type = stage_config.engine_args.engine_output_type + self.engine_outputs = None + if hasattr(stage_config, 'custom_process_input_func'): + # Import the module specified in the config (already a full module path) + module_path, func_name = stage_config.custom_process_input_func.rsplit('.', 1) + module = importlib.import_module(module_path) + self.custom_process_input_func = getattr(module, func_name) + else: + self.custom_process_input_func = None + + if hasattr(stage_config, 'final_output'): + self.final_output = stage_config.final_output + else: + self.final_output = False + + if hasattr(stage_config, 'final_output_type'): + self.final_output_type = stage_config.final_output_type + else: + self.final_output_type = None + + def set_engine(self, engine: LLMEngine) -> None: + """Initialize the engine for the stage.""" + self.engine = engine + + def set_async_engine(self, async_engine: AsyncLLM) -> None: + """Initialize the async engine for the stage.""" + self.async_engine = async_engine + + def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None: + """Set the engine output for the stage.""" + self.engine_outputs = engine_outputs + + def process_engine_inputs(self, stage_list, prompt: Union[OmniTokensPrompt, TextPrompt] = None) -> List[Union[OmniTokensPrompt, TextPrompt]]: + """Process the engine input for the stage.""" + if self.custom_process_input_func is None: + engine_inputs = [] + if len(self.engine_input_source) == 0: + raise ValueError("engine_input_source is empty") + source_stage_id = self.engine_input_source[0] + source_outputs = stage_list[source_stage_id].engine_outputs + multi_modal_data = {source_output.request_id: + prompt.get('multi_modal_data', None) for source_output, prompt in zip(source_outputs, prompt)} + + for source_output in source_outputs: + engine_input = OmniTokensPrompt( + prompt_token_ids = source_output.outputs[0].token_ids, + multi_modal_data=multi_modal_data[source_output.request_id] if multi_modal_data else None, + ) + engine_inputs.append(engine_input) + return engine_inputs + + else: + engine_input_source = self.engine_input_source + return self.custom_process_input_func(stage_list, engine_input_source, prompt) \ No newline at end of file diff --git a/vllm_omni/entrypoints/stage_manager.py b/vllm_omni/entrypoints/stage_manager.py deleted file mode 100644 index 22a03d51c2..0000000000 --- a/vllm_omni/entrypoints/stage_manager.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -Stage manager for orchestrating multiple engines in vLLM-omni. -""" - -from typing import List, Optional, Union -from vllm.v1.engine.llm_engine import LLMEngine -from vllm.v1.engine.async_llm import AsyncLLM -from ..engine.diffusion_engine import DiffusionEngine - -from ..config import OmniStageConfig - - -class StageManager: - """Manages multiple stage engines for multi-stage processing.""" - - def __init__(self, stage_configs: List[OmniStageConfig], log_stats: bool = False): - self.stage_configs = stage_configs - self.log_stats = log_stats - self.engine_list: List[Union[LLMEngine, DiffusionEngine]] = [] - self.async_engine_list: List[AsyncLLM] = [] - self._initialized = False - self._async_initialized = False - - def initialize_engines(self) -> None: - """Initialize LLMEngine instances for each stage.""" - if self._initialized: - return - - # For now, create placeholder engines - # In a full implementation, this would create actual engines - for stage_config in self.stage_configs: - # Placeholder - would create actual engine here - self.engine_list.append(None) - - self._initialized = True - - def initialize_async_engines(self) -> None: - """Initialize AsyncLLM instances for each stage.""" - if self._async_initialized: - return - - # For now, create placeholder engines - # In a full implementation, this would create actual engines - for stage_config in self.stage_configs: - # Placeholder - would create actual engine here - self.async_engine_list.append(None) - - self._async_initialized = True - - def get_engine(self, stage_id: int) -> LLMEngine: - """Get the engine for a specific stage.""" - if not self._initialized: - self.initialize_engines() - - if stage_id >= len(self.engine_list): - raise IndexError(f"Stage {stage_id} not found. Available stages: 0-{len(self.engine_list)-1}") - - return self.engine_list[stage_id] - - def get_async_engine(self, stage_id: int) -> AsyncLLM: - """Get the async engine for a specific stage.""" - if not self._async_initialized: - self.initialize_async_engines() - - if stage_id >= len(self.async_engine_list): - raise IndexError(f"Async stage {stage_id} not found. Available stages: 0-{len(self.async_engine_list)-1}") - - return self.async_engine_list[stage_id] - - def get_stage_config(self, stage_id: int) -> OmniStageConfig: - """Get the configuration for a specific stage.""" - if stage_id >= len(self.stage_configs): - raise IndexError(f"Stage config {stage_id} not found. Available stages: 0-{len(self.stage_configs)-1}") - - return self.stage_configs[stage_id] - - def get_num_stages(self) -> int: - """Get the number of stages.""" - return len(self.stage_configs) - - def cleanup(self) -> None: - """Clean up resources.""" - # Clean up engines if needed - self.engine_list.clear() - self.async_engine_list.clear() - self._initialized = False - self._async_initialized = False diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py new file mode 100644 index 0000000000..42a98f1648 --- /dev/null +++ b/vllm_omni/entrypoints/utils.py @@ -0,0 +1,25 @@ +import os +from pathlib import Path +from omegaconf import OmegaConf +from vllm.transformers_utils.config import get_config + +# Get the project root directory (2 levels up from this file) +PROJECT_ROOT = Path(__file__).parent.parent.parent + + +def load_stage_configs_from_model(model: str): + """Load stage configs from model.""" + hf_config = get_config(model, trust_remote_code=True) + model_type = hf_config.model_type + stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" + stage_config_path = PROJECT_ROOT / stage_config_file + if not os.path.exists(stage_config_path): + raise FileNotFoundError(f"Stage config file {stage_config_path} not found") + stage_configs = load_stage_configs_from_yaml(config_path=str(stage_config_path)) + return stage_configs + + +def load_stage_configs_from_yaml(config_path: str): + """Load stage configs from yaml file.""" + config_data = OmegaConf.load(config_path) + return config_data.stage_args \ No newline at end of file diff --git a/vllm_omni/inputs/__init__.py b/vllm_omni/inputs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py new file mode 100644 index 0000000000..b3a4339e47 --- /dev/null +++ b/vllm_omni/inputs/data.py @@ -0,0 +1,12 @@ +from vllm.inputs.data import TokensPrompt +from typing import Any, NotRequired +import torch + + +class OmniTokensPrompt(TokensPrompt): + prompt_embeds: NotRequired[torch.Tensor] + """The embeddings of the prompt.""" + + # New: optional additional information dictionary + # Values may be torch.Tensor or list + additional_information: NotRequired[dict[str, Any]] \ No newline at end of file diff --git a/vllm_omni/model_executor/stage_configs/__init__.py b/vllm_omni/model_executor/stage_configs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml new file mode 100644 index 0000000000..49a1222ccb --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml @@ -0,0 +1,40 @@ +# stage config for running qwen2.5-omni with architecture of OmniLLM. +stage_args: + - stage_id: 0 + engine_args: + model_stage: thinker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.scheduler.OmniScheduler + gpu_memory_utilization: 0.32 + enforce_eager: true # need to discuss + trust_remote_code: true + engine_output_type: latent # change the param name,such as pooling_output + final_output: true + final_output_type: text + - stage_id: 1 + engine_args: + model_stage: talker + model_arch: Qwen2_5OmniForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.scheduler.OmniScheduler + gpu_memory_utilization: 0.32 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker + - stage_id: 2 + engine_args: + model_stage: code2wav + model_arch: Qwen2_5OmniForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_diffusion_worker.GPUDiffusionWorker + scheduler_cls: vllm_omni.core.sched.diffusion_scheduler.DiffusionScheduler + gpu_memory_utilization: 0.3 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + engine_output_type: audio + engine_input_source: [1] + final_output: true + final_output_type: audio \ No newline at end of file diff --git a/vllm_omni/model_executor/stage_input_processors/__init__.py b/vllm_omni/model_executor/stage_input_processors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py new file mode 100644 index 0000000000..474e2c3445 --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py @@ -0,0 +1,40 @@ +from typing import Union +import torch + +from vllm.inputs import TextPrompt + +from vllm_omni.inputs.data import OmniTokensPrompt + +def thinker2talker(stage_list, engine_input_source, prompt: Union[OmniTokensPrompt, TextPrompt] = None): + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + source_stage_id = engine_input_source[0] + if source_stage_id >= len(stage_list): + raise IndexError(f"Invalid stage_id: {source_stage_id}") + if stage_list[source_stage_id].engine_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + thinker_outputs = stage_list[source_stage_id].engine_outputs + talker_inputs = [] + multi_modal_data = {thinker_output.request_id: + prompt.get('multi_modal_data', None) for thinker_output, prompt in zip(thinker_outputs, prompt)} + + for i, thinker_output in enumerate(thinker_outputs): + output = thinker_output.outputs[0] + prompt_token_ids = thinker_output.prompt_token_ids + thinker_output_ids = output.token_ids + prompt_token_ids_len = len(prompt_token_ids) + thinker_hidden_states = output.multimodal_output["latent"].clone().detach().cuda() + talker_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0] * (len(prompt_token_ids) + 2), # add 2 for codec pad and start token + additional_information={ + "thinker_result": thinker_hidden_states[prompt_token_ids_len:].to(torch.float32), + "prompt_embeds": thinker_hidden_states[:prompt_token_ids_len].to(torch.float32), + "prompt_token_ids": prompt_token_ids, + "thinker_output_token_ids": thinker_output_ids, + }, + multi_modal_data=multi_modal_data[thinker_output.request_id] if multi_modal_data is not None else None, + mm_processor_kwargs=None, + ) + ) + return talker_inputs \ No newline at end of file diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py new file mode 100644 index 0000000000..2425fec481 --- /dev/null +++ b/vllm_omni/outputs.py @@ -0,0 +1,9 @@ +from vllm.outputs import RequestOutput +from dataclasses import dataclass + + +@dataclass +class OmniRequestOutput(RequestOutput): + stage_id: int + final_output_type: str + request_output: RequestOutput \ No newline at end of file