From a920102d79d9cfe0a59c8b356d5708c5d3851913 Mon Sep 17 00:00:00 2001 From: "Huang, Zeyu" <11222265+fhfuih@users.noreply.github.com> Date: Wed, 7 Jan 2026 09:05:08 +0000 Subject: [PATCH] [Frontend][Model] Support batch request with refined OmniDiffusionRequest, new OmniTextPropmt & OmniDiffusionSamplingParams Signed-off-by: Huang, Zeyu <11222265+fhfuih@users.noreply.github.com> --- docs/getting_started/quickstart.md | 40 + .../diffusion/cache_dit_acceleration.md | 3 +- .../diffusion/parallelism_acceleration.md | 42 +- docs/user_guide/diffusion/teacache.md | 16 +- docs/user_guide/diffusion_acceleration.md | 48 +- .../offline_inference/text_to_image.md | 37 +- examples/offline_inference/bagel/end2end.py | 48 +- .../image_to_image/image_edit.py | 32 +- .../image_to_video/image_to_video.py | 25 +- .../lora_inference/lora_inference.py | 18 +- .../text_to_audio/text_to_audio.py | 23 +- .../offline_inference/text_to_image/README.md | 37 +- .../text_to_image/gradio_demo.py | 15 +- .../text_to_image/text_to_image.py | 23 +- .../text_to_video/text_to_video.py | 23 +- .../omni_connectors/test_kv_flow.py | 21 +- tests/e2e/offline_inference/conftest.py | 21 +- tests/e2e/offline_inference/test_cache_dit.py | 16 +- .../test_diffusion_cpu_offload.py | 13 +- .../offline_inference/test_diffusion_lora.py | 33 +- .../e2e/offline_inference/test_ovis_image.py | 39 +- .../test_sequence_parallel.py | 30 +- .../test_stable_audio_model.py | 23 +- tests/e2e/offline_inference/test_t2i_model.py | 15 +- tests/e2e/offline_inference/test_t2v_model.py | 18 +- tests/e2e/offline_inference/test_teacache.py | 16 +- .../test_zimage_tensor_parallel.py | 16 +- tests/e2e/online_serving/test_async_omni.py | 19 +- .../openai_api/test_image_server.py | 46 +- tests/entrypoints/test_omni_diffusion.py | 1103 +++++++++++++++++ tests/entrypoints/test_omni_llm.py | 22 +- vllm_omni/diffusion/diffusion_engine.py | 100 +- vllm_omni/diffusion/executor/abstract.py | 4 +- .../diffusion/executor/multiproc_executor.py | 6 +- .../diffusion/models/bagel/pipeline_bagel.py | 36 +- .../flux2_klein/pipeline_flux2_klein.py | 69 +- .../models/glm_image/pipeline_glm_image.py | 94 +- .../longcat_image/pipeline_longcat_image.py | 60 +- .../pipeline_longcat_image_edit.py | 105 +- .../models/ovis_image/pipeline_ovis_image.py | 35 +- .../models/qwen_image/pipeline_qwen_image.py | 37 +- .../qwen_image/pipeline_qwen_image_edit.py | 96 +- .../pipeline_qwen_image_edit_plus.py | 104 +- .../qwen_image/pipeline_qwen_image_layered.py | 114 +- .../diffusion/models/sd3/pipeline_sd3.py | 33 +- .../stable_audio/pipeline_stable_audio.py | 30 +- .../models/wan2_2/pipeline_wan2_2.py | 144 ++- .../models/wan2_2/pipeline_wan2_2_i2v.py | 148 ++- .../models/wan2_2/pipeline_wan2_2_ti2v.py | 144 ++- .../models/z_image/pipeline_z_image.py | 38 +- vllm_omni/diffusion/request.py | 195 +-- vllm_omni/diffusion/scheduler.py | 4 +- .../worker/gpu_diffusion_model_runner.py | 40 +- .../diffusion/worker/gpu_diffusion_worker.py | 36 +- vllm_omni/entrypoints/async_omni.py | 78 +- vllm_omni/entrypoints/async_omni_diffusion.py | 122 +- vllm_omni/entrypoints/omni.py | 132 +- vllm_omni/entrypoints/omni_diffusion.py | 65 +- vllm_omni/entrypoints/omni_stage.py | 201 ++- vllm_omni/entrypoints/openai/api_server.py | 58 +- vllm_omni/entrypoints/openai/serving_chat.py | 59 +- vllm_omni/inputs/data.py | 177 ++- .../model_executor/stage_configs/bagel.yaml | 6 - vllm_omni/outputs.py | 6 +- 64 files changed, 3019 insertions(+), 1438 deletions(-) create mode 100644 tests/entrypoints/test_omni_diffusion.py diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index b64bf142321..d4087621ad5 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -46,6 +46,46 @@ if __name__ == "__main__": images[0].save("coffee.png") ``` +You can pass a list of prompts and wait for them to process altogether, shown below. + +!!! info + + However, it is not currently recommended to do so + because not all models support batch inference, + and batch requesting mostly does not provide significant performance improvement (despite the impression that it does). + This feature is primarily for the sake of interface compatibility with vLLM and to allow for future improvements. + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + # stage_configs_path="./stage-config.yaml", # See below + ) + prompts = [ + "a cup of coffee on a table", + "a toy dinosaur on a sandy beach", + "a fox waking up in bed and yawning", + ] + omni_outputs = omni.generate(prompts) + for i_prompt, prompt_output in enumerate(omni_outputs): + this_request_output = prompt_output.request_output[0] + this_images = this_request_output.images + for i_image, image in enumerate(this_images): + image.save(f"p{i_prompt}-img{i_image}.jpg") + print("saved to", f"p{i_prompt}-img{i_image}.jpg") + # saved to p0-img0.jpg + # saved to p1-img0.jpg + # saved to p2-img0.jpg +``` + +!!! info + + For diffusion pipelines, the stage config field `stage_args.[].runtime.max_batch_size` is 1 by default, and the input + list is sliced into single-item requests before feeding into the diffusion pipeline. For models that do internally support + batched inputs, you can [modify this configuration](../configuration/stage_configs.md) to let the model accept a longer batch of prompts. + For more usages, please refer to [offline inference](../user_guide/examples/offline_inference/qwen2_5_omni.md) ## Online Serving with OpenAI-Completions API diff --git a/docs/user_guide/diffusion/cache_dit_acceleration.md b/docs/user_guide/diffusion/cache_dit_acceleration.md index 57f6a94e65f..fd1bd522a20 100644 --- a/docs/user_guide/diffusion/cache_dit_acceleration.md +++ b/docs/user_guide/diffusion/cache_dit_acceleration.md @@ -18,6 +18,7 @@ Enable cache-dit acceleration by simply setting `cache_backend="cache_dit"`. Cac ```python from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams # Simplest way: just enable cache-dit with default parameters omni = Omni( @@ -27,7 +28,7 @@ omni = Omni( images = omni.generate( "a beautiful landscape", - num_inference_steps=50, + OmniDiffusionSamplingParams(num_inference_steps=50), ) ``` diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index 09e4651ae23..45e53294e5a 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -67,10 +67,12 @@ omni = Omni( ) outputs = omni.generate( - prompt="a cat reading a book", - num_inference_steps=9, - width=512, - height=512, + "a cat reading a book", + OmniDiffusionSamplingParams( + num_inference_steps=9, + width=512, + height=512, + ), ) ``` @@ -83,6 +85,7 @@ outputs = omni.generate( An example of offline inference script using [Ulysses-SP](https://arxiv.org/pdf/2309.14509) is shown below: ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.diffusion.data import DiffusionParallelConfig ulysses_degree = 2 @@ -91,7 +94,10 @@ omni = Omni( parallel_config=DiffusionParallelConfig(ulysses_degree=2) ) -outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) ``` See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example. @@ -133,6 +139,7 @@ Ring-Attention ([arxiv paper](https://arxiv.org/abs/2310.01889)) splits the inpu An example of offline inference script using Ring-Attention is shown below: ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.diffusion.data import DiffusionParallelConfig ring_degree = 2 @@ -141,7 +148,10 @@ omni = Omni( parallel_config=DiffusionParallelConfig(ring_degree=2) ) -outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) ``` See `examples/offline_inference/text_to_image/text_to_image.py` for a complete working example. @@ -183,6 +193,7 @@ You can combine both Ulysses-SP and Ring-Attention for larger scale parallelism. ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.diffusion.data import DiffusionParallelConfig # Hybrid: 2 Ulysses × 2 Ring = 4 GPUs total @@ -191,7 +202,10 @@ omni = Omni( parallel_config=DiffusionParallelConfig(ulysses_degree=2, ring_degree=2) ) -outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) ``` ##### Online Serving @@ -374,11 +388,15 @@ omni = Omni( ) outputs = omni.generate( - prompt="turn this cat to a dog", - negative_prompt="low quality, blurry", - true_cfg_scale=4.0, - pil_image=input_image, - num_inference_steps=50, + { + "prompt": "turn this cat to a dog", + "negative_prompt": "low quality, blurry", + }, + OmniDiffusionSamplingParams( + true_cfg_scale=4.0, + pil_image=input_image, + num_inference_steps=50, + ), ) ``` diff --git a/docs/user_guide/diffusion/teacache.md b/docs/user_guide/diffusion/teacache.md index 3fe9614057c..40dafeb88ad 100644 --- a/docs/user_guide/diffusion/teacache.md +++ b/docs/user_guide/diffusion/teacache.md @@ -8,6 +8,7 @@ Enable TeaCache by setting `cache_backend` to `"tea_cache"`: ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams # Simple configuration - model_type is automatically extracted from pipeline.__class__.__name__ omni = Omni( @@ -17,7 +18,12 @@ omni = Omni( "rel_l1_thresh": 0.2 # Optional, defaults to 0.2 } ) -outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams( + num_inference_steps=50, + ), +) ``` ### Using Environment Variable @@ -68,13 +74,19 @@ Controls the balance between speed and quality. Lower values prioritize quality, ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams omni = Omni( model="Qwen/Qwen-Image", cache_backend="tea_cache", cache_config={"rel_l1_thresh": 0.2} ) -outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams( + num_inference_steps=50, + ), +) ``` ## Performance Tuning diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 42202d8d7ed..8d31747d21c 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -96,6 +96,7 @@ To measure the parallelism methods, we run benchmarks with **Qwen/Qwen-Image** m ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams omni = Omni( model="Qwen/Qwen-Image", @@ -103,13 +104,19 @@ omni = Omni( cache_config={"rel_l1_thresh": 0.2} # Optional, defaults to 0.2 ) -outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams( + num_inference_steps=50, + ), +) ``` ### Using Cache-DiT ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams omni = Omni( model="Qwen/Qwen-Image", @@ -123,7 +130,12 @@ omni = Omni( } ) -outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams( + num_inference_steps=50, + ), +) ``` ### Using Ulysses-SP @@ -131,6 +143,7 @@ outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_st Run text-to-image: ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.diffusion.data import DiffusionParallelConfig ulysses_degree = 2 @@ -139,13 +152,17 @@ omni = Omni( parallel_config=DiffusionParallelConfig(ulysses_degree=ulysses_degree) ) -outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) ``` Run image-to-image: ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.diffusion.data import DiffusionParallelConfig ulysses_degree = 2 @@ -154,8 +171,13 @@ omni = Omni( parallel_config=DiffusionParallelConfig(ulysses_degree=ulysses_degree) ) -outputs = omni.generate(prompt="turn this cat to a dog", - pil_image=input_image, num_inference_steps=50) +outputs = omni.generate( + { + "prompt": "turn this cat to a dog", + "multi_modal_data": {"image": input_image} + }, + OmniDiffusionSamplingParams(num_inference_steps=50), +) ``` ### Using Ring-Attention @@ -163,6 +185,7 @@ outputs = omni.generate(prompt="turn this cat to a dog", Run text-to-image: ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.diffusion.data import DiffusionParallelConfig ring_degree = 2 @@ -171,7 +194,10 @@ omni = Omni( parallel_config=DiffusionParallelConfig(ring_degree=2) ) -outputs = omni.generate(prompt="A cat sitting on a windowsill", num_inference_steps=50, width=2048, height=2048) +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50, width=2048, height=2048), +) ``` ### Using CFG-Parallel @@ -182,6 +208,7 @@ CFG-Parallel splits the CFG positive/negative branches across GPUs. Use it when ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.diffusion.data import DiffusionParallelConfig cfg_parallel_size = 2 @@ -190,8 +217,13 @@ omni = Omni( parallel_config=DiffusionParallelConfig(cfg_parallel_size=cfg_parallel_size) ) -outputs = omni.generate(prompt="turn this cat to a dog", - pil_image=input_image, num_inference_steps=50, true_cfg_scale=4.0) +outputs = omni.generate( + { + "prompt": "turn this cat to a dog", + "multi_modal_data": {"image": input_image} + }, + OmniDiffusionSamplingParams(num_inference_steps=50, true_cfg_scale=4.0), +) ``` ## Documentation diff --git a/docs/user_guide/examples/offline_inference/text_to_image.md b/docs/user_guide/examples/offline_inference/text_to_image.md index 5c6f167e1d4..2dd0f7b7ccb 100644 --- a/docs/user_guide/examples/offline_inference/text_to_image.md +++ b/docs/user_guide/examples/offline_inference/text_to_image.md @@ -23,7 +23,7 @@ if __name__ == "__main__": images[0].save("coffee.png") ``` -Or put more than one prompt in a request, processing them sequentially. +Or put more than one prompt in a request. ```python from vllm_omni.entrypoints.omni import Omni @@ -40,6 +40,41 @@ if __name__ == "__main__": image = output.request_output[0].images[0].save(f"{i}.jpg") ``` +!!! info + + However, it is not currently recommended to do so + because not all models support batch inference, + and batch requesting mostly does not provide significant performance improvement (despite the impression that it does). + This feature is primarily for the sake of interface compatibility with vLLM and to allow for future improvements. + +!!! info + + For diffusion pipelines, the stage config field `stage_args.[].runtime.max_batch_size` is 1 by default, and the input + list is sliced into single-item requests before feeding into the diffusion pipeline. For models that do internally support + batched inputs, you can [modify this configuration](../../../configuration/stage_configs.md) to let the model accept a longer batch of prompts. + +Apart from string prompt, vLLM-Omni also supports dictionary prompts in the same style as vLLM. +This is useful for models that support negative prompts. + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="Qwen/Qwen-Image") + outputs = omni.generate([ + { + "prompt": "a cup of coffee on a table", + "negative_prompt": "low resolution" + }, + { + "prompt": "a toy dinosaur on a sandy beach", + "negative_prompt": "cinematic, realistic" + } + ]) + for i, output in enumerate(outputs): + image = output.request_output[0].images[0].save(f"{i}.jpg") +``` + ## Local CLI Usage ```bash diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index 93000662dcd..397fd333ec0 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -1,5 +1,8 @@ import argparse import os +from typing import cast + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType def parse_args(): @@ -50,21 +53,24 @@ def parse_args(): def main(): args = parse_args() model_name = args.model + prompts: list[OmniPromptType] = [] try: # Preferred: load from txt file (one prompt per line) if getattr(args, "txt_prompts", None) and args.prompt_type == "text": with open(args.txt_prompts, encoding="utf-8") as f: lines = [ln.strip() for ln in f.readlines()] - args.prompts = [ln for ln in lines if ln != ""] - print(f"[Info] Loaded {len(args.prompts)} prompts from {args.txt_prompts}") + prompts = [ln for ln in lines if ln != ""] + print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}") + else: + prompts = args.prompts except Exception as e: print(f"[Error] Failed to load prompts: {e}") raise - if args.prompts is None: + if not prompts: # Default prompt for text2img test if none provided - args.prompts = ["<|im_start|>A cute cat<|im_end|>"] - print(f"[Info] No prompts provided, using default: {args.prompts}") + prompts = ["<|im_start|>A cute cat<|im_end|>"] + print(f"[Info] No prompts provided, using default: {prompts}") omni_outputs = [] from PIL import Image @@ -77,21 +83,27 @@ def main(): print("[Info] Running in img2img mode (Stage 1 only)") client = OmniDiffusion(model=model_name) - generate_kwargs = { - "prompt": args.prompts, - "seed": 52, - "need_kv_receive": False, - "num_inference_steps": args.steps, - } - if args.image_path: if os.path.exists(args.image_path): loaded_image = Image.open(args.image_path).convert("RGB") - generate_kwargs["pil_image"] = loaded_image + prompts = [ + { + "prompt": cast(str, p), + "multi_modal_data": {"image": loaded_image}, + } + for p in prompts + ] else: print(f"[Warning] Image path {args.image_path} does not exist.") - result = client.generate(**generate_kwargs) + result = client.generate( + prompts, + OmniDiffusionSamplingParams( + seed=52, + need_kv_receive=False, + num_inference_steps=args.steps, + ), + ) # Ensure result is a list for iteration if not isinstance(result, list): @@ -100,8 +112,6 @@ def main(): omni_outputs = result else: - import copy - from vllm_omni.entrypoints.omni import Omni omni_kwargs = {} @@ -144,11 +154,11 @@ def main(): prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]} formatted_prompts.append(prompt_dict) - params_list = copy.deepcopy(omni.default_sampling_params_list) + params_list = omni.default_sampling_params_list if args.modality == "text2img": - params_list[0]["max_tokens"] = 1 + params_list[0].max_tokens = 1 # type: ignore # The first stage is a SamplingParam (vllm) if len(params_list) > 1: - params_list[1]["num_inference_steps"] = args.steps + params_list[1].num_inference_steps = args.steps # type: ignore # The second stage is an OmniDiffusionSamplingParam omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list)) diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 405a0cafa14..f2065e6b2b6 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -79,6 +79,7 @@ from vllm_omni.diffusion.data import DiffusionParallelConfig, logger from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -374,25 +375,28 @@ def main(): print(f"{'=' * 60}\n") generation_start = time.perf_counter() - # Generate edited image - generate_kwargs = { - "prompt": args.prompt, - "pil_image": input_image, - "negative_prompt": args.negative_prompt, - "generator": generator, - "true_cfg_scale": args.cfg_scale, - "guidance_scale": args.guidance_scale, - "num_inference_steps": args.num_inference_steps, - "num_outputs_per_prompt": args.num_outputs_per_prompt, - "layers": args.layers, - "resolution": args.resolution, - } if profiler_enabled: print("[Profiler] Starting profiling...") omni.start_profile() - outputs = omni.generate(**generate_kwargs) + # Generate edited image + outputs = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "multi_modal_data": {"image": input_image}, + }, + OmniDiffusionSamplingParams( + generator=generator, + true_cfg_scale=args.cfg_scale, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + num_outputs_per_prompt=args.num_outputs_per_prompt, + layers=args.layers, + resolution=args.resolution, + ), + ) generation_end = time.perf_counter() generation_time = generation_end - generation_start diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 2104ec0fd5c..06c13fbeed3 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -27,6 +27,7 @@ import torch from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -119,16 +120,20 @@ def main(): # omni.generate() returns Generator[OmniRequestOutput, None, None] frames = omni.generate( - args.prompt, - negative_prompt=args.negative_prompt, - pil_image=image, - height=height, - width=width, - generator=generator, - guidance_scale=args.guidance_scale, - guidance_scale_2=args.guidance_scale_high, - num_inference_steps=args.num_inference_steps, - num_frames=args.num_frames, + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "multi_modal_data": {"image": image}, + }, + OmniDiffusionSamplingParams( + height=height, + width=width, + generator=generator, + guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_high, + num_inference_steps=args.num_inference_steps, + num_frames=args.num_frames, + ), ) # Extract video frames from OmniRequestOutput diff --git a/examples/offline_inference/lora_inference/lora_inference.py b/examples/offline_inference/lora_inference/lora_inference.py index 17e9d6196dd..5e4299edb84 100644 --- a/examples/offline_inference/lora_inference/lora_inference.py +++ b/examples/offline_inference/lora_inference/lora_inference.py @@ -5,6 +5,7 @@ from pathlib import Path from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.lora.request import LoRARequest from vllm_omni.lora.utils import stable_lora_int_id @@ -95,18 +96,17 @@ def main(): ) print(f"Activating pre-loaded LoRA: id={lora_request_id}, scale={args.lora_scale}") - gen_kwargs = { - "prompt": args.prompt, - "height": args.height, - "width": args.width, - "num_inference_steps": args.num_inference_steps, - } + sampling_params = OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + ) if lora_request: - gen_kwargs["lora_request"] = lora_request - gen_kwargs["lora_scale"] = args.lora_scale + sampling_params.lora_request = lora_request + sampling_params.lora_scale = args.lora_scale - outputs = omni.generate(**gen_kwargs) + outputs = omni.generate(args.prompt, sampling_params) if not outputs or len(outputs) == 0: raise ValueError("No output generated from omni.generate()") diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py index ca1a455d145..1124dbdb4dd 100644 --- a/examples/offline_inference/text_to_audio/text_to_audio.py +++ b/examples/offline_inference/text_to_audio/text_to_audio.py @@ -21,6 +21,7 @@ import torch from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.utils.platform_utils import detect_device_type @@ -143,16 +144,20 @@ def main(): # Generate audio outputs = omni.generate( - args.prompt, - negative_prompt=args.negative_prompt, - generator=generator, - guidance_scale=args.guidance_scale, - num_inference_steps=args.num_inference_steps, - num_outputs_per_prompt=args.num_waveforms, - extra={ - "audio_start_in_s": args.audio_start, - "audio_end_in_s": audio_end_in_s, + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, }, + OmniDiffusionSamplingParams( + generator=generator, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + num_outputs_per_prompt=args.num_waveforms, + extra_args={ + "audio_start_in_s": args.audio_start, + "audio_end_in_s": audio_end_in_s, + }, + ), ) generation_end = time.perf_counter() diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index 0f4351ac228..f4de19db891 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -20,7 +20,7 @@ if __name__ == "__main__": images[0].save("coffee.png") ``` -Or put more than one prompt in a request, processing them sequentially. +Or put more than one prompt in a request. ```python from vllm_omni.entrypoints.omni import Omni @@ -37,6 +37,41 @@ if __name__ == "__main__": image = output.request_output[0].images[0].save(f"{i}.jpg") ``` +!!! info + + However, it is not currently recommended to do so + because not all models support batch inference, + and batch requesting mostly does not provide significant performance improvement (despite the impression that it does). + This feature is primarily for the sake of interface compatibility with vLLM and to allow for future improvements. + +!!! info + + For diffusion pipelines, the stage config field `stage_args.[].runtime.max_batch_size` is 1 by default, and the input + list is sliced into single-item requests before feeding into the diffusion pipeline. For models that do internally support + batched inputs, you can [modify this configuration](../../../configuration/stage_configs.md) to let the model accept a longer batch of prompts. + +Apart from string prompt, vLLM-Omni also supports dictionary prompts in the same style as vLLM. +This is useful for models that support negative prompts. + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="Qwen/Qwen-Image") + outputs = omni.generate([ + { + "prompt": "a cup of coffee on a table", + "negative_prompt": "low resolution" + }, + { + "prompt": "a toy dinosaur on a sandy beach", + "negative_prompt": "cinematic, realistic" + } + ]) + for i, output in enumerate(outputs): + image = output.request_output[0].images[0].save(f"{i}.jpg") +``` + ## Local CLI Usage ```bash diff --git a/examples/offline_inference/text_to_image/gradio_demo.py b/examples/offline_inference/text_to_image/gradio_demo.py index 8d4ff2b2093..30446f50eb4 100644 --- a/examples/offline_inference/text_to_image/gradio_demo.py +++ b/examples/offline_inference/text_to_image/gradio_demo.py @@ -5,6 +5,7 @@ import torch from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -101,12 +102,14 @@ def run_inference( generator = torch.Generator(device=device).manual_seed(seed) outputs = omni.generate( prompt.strip(), - height=height, - width=width, - generator=generator, - true_cfg_scale=float(cfg_scale_value), - num_inference_steps=num_steps, - num_outputs_per_prompt=num_images, + OmniDiffusionSamplingParams( + height=height, + width=width, + generator=generator, + true_cfg_scale=float(cfg_scale_value), + num_inference_steps=num_steps, + num_outputs_per_prompt=num_images, + ), ) images_outputs = [] for output in outputs: diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 428c2437e6f..7fc6ec832f9 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -10,6 +10,7 @@ from vllm_omni.diffusion.data import DiffusionParallelConfig, logger from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -189,15 +190,19 @@ def main(): generation_start = time.perf_counter() outputs = omni.generate( - args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - generator=generator, - true_cfg_scale=args.cfg_scale, - guidance_scale=args.guidance_scale, - num_inference_steps=args.num_inference_steps, - num_outputs_per_prompt=args.num_images_per_prompt, + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + }, + OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + generator=generator, + true_cfg_scale=args.cfg_scale, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + num_outputs_per_prompt=args.num_images_per_prompt, + ), ) generation_end = time.perf_counter() generation_time = generation_end - generation_start diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index c655fff1468..054db95b3e9 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -11,6 +11,7 @@ from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -108,15 +109,19 @@ def main(): generation_start = time.perf_counter() frames = omni.generate( - args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - generator=generator, - guidance_scale=args.guidance_scale, - guidance_scale_2=args.guidance_scale_high, - num_inference_steps=args.num_inference_steps, - num_frames=args.num_frames, + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + }, + OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + generator=generator, + guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_high, + num_inference_steps=args.num_inference_steps, + num_frames=args.num_frames, + ), ) generation_end = time.perf_counter() generation_time = generation_end - generation_start diff --git a/tests/distributed/omni_connectors/test_kv_flow.py b/tests/distributed/omni_connectors/test_kv_flow.py index a8d2f03b222..2bb06d4e00d 100644 --- a/tests/distributed/omni_connectors/test_kv_flow.py +++ b/tests/distributed/omni_connectors/test_kv_flow.py @@ -1,10 +1,12 @@ import unittest +from types import SimpleNamespace from unittest.mock import MagicMock, patch import torch from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner @@ -109,11 +111,12 @@ def test_receiver_injection_logic(self): transfer_data.metadata = {"kv_lens": [self.seq_len], "ropes": [0]} # 2. Setup Request with Injected Data - req = OmniDiffusionRequest(prompt="test") - from types import SimpleNamespace + sp = OmniDiffusionSamplingParams( + past_key_values=SimpleNamespace(**layer_blocks), + kv_metadata=transfer_data.metadata, + ) - req.past_key_values = SimpleNamespace(**layer_blocks) - req.kv_metadata = transfer_data.metadata + req = OmniDiffusionRequest(["test"], sp) # 3. Setup Pipeline pipeline = MockBagelPipeline() @@ -144,7 +147,7 @@ def mock_prepare_prompts(curr_kvlens, curr_rope, **kwargs): current_cache = RealNaiveCache(self.num_layers) # --- Logic from Source Code --- - injected_kv = req.past_key_values + injected_kv = req.sampling_params.past_key_values if isinstance(current_cache, RealNaiveCache) and hasattr(injected_kv, "key_cache"): # Assuming injected_kv is SimpleNamespace or object with list attrs for layer_idx in range(len(injected_kv.key_cache)): @@ -176,9 +179,11 @@ def test_integration(self): data_dict = runner_test_result.to_dict() # 3. Receiver (Request Setup) - req = OmniDiffusionRequest(prompt="integration_test") - req.past_key_values = data_dict["layer_blocks"] - req.kv_metadata = data_dict["metadata"] + sp = OmniDiffusionSamplingParams( + past_key_values=data_dict["layer_blocks"], + kv_metadata=data_dict["metadata"], + ) + req = OmniDiffusionRequest(["integration_test"], sp) # noqa: F841 # 4. Receiver (Injection Simulation) # Use the logic verification again diff --git a/tests/e2e/offline_inference/conftest.py b/tests/e2e/offline_inference/conftest.py index 276fd8844c2..e30c49612fa 100644 --- a/tests/e2e/offline_inference/conftest.py +++ b/tests/e2e/offline_inference/conftest.py @@ -7,11 +7,12 @@ from typing import Any import pytest +from vllm import TextPrompt from vllm.distributed.parallel_state import cleanup_dist_env_and_memory -from vllm.sampling_params import SamplingParams from tests.conftest import clean_gpu_memory from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniSamplingParams from vllm_omni.outputs import OmniRequestOutput PromptAudioInput = list[tuple[Any, int]] | tuple[Any, int] | None @@ -64,7 +65,7 @@ def __init__( **kwargs, ) - def get_default_sampling_params_list(self) -> list[SamplingParams]: + def get_default_sampling_params_list(self) -> list[OmniSamplingParams]: """ Get a list of default sampling parameters for all stages. @@ -82,7 +83,7 @@ def get_omni_inputs( videos: PromptVideoInput = None, mm_processor_kwargs: dict[str, Any] | None = None, modalities: list[str] | None = None, - ) -> list[dict[str, Any]]: + ) -> list[TextPrompt]: """ Construct Omni input format from prompts and multimodal data. @@ -175,7 +176,7 @@ def _normalize_mm_input(mm_input, num_prompts): f"<|im_start|>assistant\n" ) - input_dict: dict[str, Any] = {"prompt": full_prompt} + input_dict: TextPrompt = {"prompt": full_prompt} if multi_modal_data: input_dict["multi_modal_data"] = multi_modal_data if modalities: @@ -189,8 +190,8 @@ def _normalize_mm_input(mm_input, num_prompts): def generate( self, - prompts: list[dict[str, Any]], - sampling_params_list: list[SamplingParams] | None = None, + prompts: list[TextPrompt], + sampling_params_list: list[OmniSamplingParams] | None = None, ) -> list[OmniRequestOutput]: """ Generate outputs for the given prompts. @@ -212,7 +213,7 @@ def generate( def generate_multimodal( self, prompts: list[str] | str, - sampling_params_list: list[SamplingParams] | None = None, + sampling_params_list: list[OmniSamplingParams] | None = None, system_prompt: str | None = None, audios: PromptAudioInput = None, images: PromptImageInput = None, @@ -249,7 +250,7 @@ def generate_multimodal( def generate_audio( self, prompts: list[str] | str, - sampling_params_list: list[SamplingParams] | None = None, + sampling_params_list: list[OmniSamplingParams] | None = None, system_prompt: str | None = None, audios: PromptAudioInput = None, mm_processor_kwargs: dict[str, Any] | None = None, @@ -276,7 +277,7 @@ def generate_audio( def generate_video( self, prompts: list[str] | str, - sampling_params_list: list[SamplingParams] | None = None, + sampling_params_list: list[OmniSamplingParams] | None = None, system_prompt: str | None = None, videos: PromptVideoInput = None, mm_processor_kwargs: dict[str, Any] | None = None, @@ -303,7 +304,7 @@ def generate_video( def generate_image( self, prompts: list[str] | str, - sampling_params_list: list[SamplingParams] | None = None, + sampling_params_list: list[OmniSamplingParams] | None = None, system_prompt: str | None = None, images: PromptImageInput = None, mm_processor_kwargs: dict[str, Any] | None = None, diff --git a/tests/e2e/offline_inference/test_cache_dit.py b/tests/e2e/offline_inference/test_cache_dit.py index 9eaee00c389..18d3988e042 100644 --- a/tests/e2e/offline_inference/test_cache_dit.py +++ b/tests/e2e/offline_inference/test_cache_dit.py @@ -15,6 +15,8 @@ import pytest import torch +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + # ruff: noqa: E402 REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: @@ -55,12 +57,14 @@ def test_cache_dit(model_name: str): outputs = m.generate( "a photo of a cat sitting on a laptop keyboard", - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), - num_outputs_per_prompt=1, # Single output for speed + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, # Single output for speed + ), ) # Extract images from request_output[0]['images'] first_output = outputs[0] diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py index 5b43400e85b..f6606dce62f 100644 --- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py +++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py @@ -6,6 +6,7 @@ from vllm.distributed.parallel_state import cleanup_dist_env_and_memory from tests.utils import GPUMemoryMonitor +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.utils.platform_utils import is_npu, is_rocm # ruff: noqa: E402 @@ -30,11 +31,13 @@ def inference(model_name: str, offload: bool = True): m.generate( "a photo of a cat sitting on a laptop keyboard", - height=height, - width=width, - num_inference_steps=9, - guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=9, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + ), ) peak = monitor.peak_used_mb monitor.stop() diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py index 1761f253c6a..ef734da6148 100644 --- a/tests/e2e/offline_inference/test_diffusion_lora.py +++ b/tests/e2e/offline_inference/test_diffusion_lora.py @@ -7,6 +7,7 @@ import torch from safetensors.torch import save_file +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.utils.platform_utils import is_npu @@ -88,12 +89,14 @@ def _write_zimage_lora(adapter_dir: Path) -> str: outputs = m.generate( prompt, - height=height, - width=width, - num_inference_steps=2, - guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), - num_outputs_per_prompt=1, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=2, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, + ), ) images = _extract_images(outputs) @@ -116,14 +119,16 @@ def _write_zimage_lora(adapter_dir: Path) -> str: ) outputs_lora = m.generate( prompt, - height=height, - width=width, - num_inference_steps=2, - guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), - num_outputs_per_prompt=1, - lora_request=lora_request, - lora_scale=2.0, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=2, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, + lora_request=lora_request, + lora_scale=2.0, + ), ) images_lora = _extract_images(outputs_lora) assert len(images_lora) == 1 diff --git a/tests/e2e/offline_inference/test_ovis_image.py b/tests/e2e/offline_inference/test_ovis_image.py index 33e1cb880a8..f1bc73817d3 100644 --- a/tests/e2e/offline_inference/test_ovis_image.py +++ b/tests/e2e/offline_inference/test_ovis_image.py @@ -25,6 +25,7 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.models.ovis_image.pipeline_ovis_image import OvisImagePipeline from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams @pytest.fixture @@ -162,11 +163,13 @@ def test_basic_generation(ovis_pipeline): """Test the forward pass logic.""" # Setup request req = OmniDiffusionRequest( - prompt="A photo of a cat", - height=256, - width=256, - num_inference_steps=2, - guidance_scale=1.0, + prompts=["A photo of a cat"], + sampling_params=OmniDiffusionSamplingParams( + height=256, + width=256, + num_inference_steps=2, + guidance_scale=1.0, + ), ) output = ovis_pipeline(req) @@ -184,12 +187,18 @@ def test_basic_generation(ovis_pipeline): def test_guidance_scale(ovis_pipeline): """Test that classifier-free guidance path is taken when scale > 1.0.""" req = OmniDiffusionRequest( - prompt="A photo of a cat", - negative_prompt="bad quality", - height=256, - width=256, - num_inference_steps=1, - guidance_scale=2.0, # Trigger CFG + prompts=[ + { + "prompt": "A photo of a cat", + "negative_prompt": "bad quality", + } + ], + sampling_params=OmniDiffusionSamplingParams( + height=256, + width=256, + num_inference_steps=1, + guidance_scale=2.0, # Trigger CFG + ), ) ovis_pipeline(req) @@ -200,9 +209,11 @@ def test_resolution_check(ovis_pipeline): """Test resolution divisible validation logic if present.""" # Pass odd resolution req = OmniDiffusionRequest( - prompt="test", - height=250, # Not divisible by 16 (8*2) - width=250, + prompts=["test"], + sampling_params=OmniDiffusionSamplingParams( + height=250, # Not divisible by 16 (8*2) + width=250, + ), ) # Should warn but proceed (as per code I read earlier) or resize? diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py index 9de9fd98183..367e2e59e2d 100644 --- a/tests/e2e/offline_inference/test_sequence_parallel.py +++ b/tests/e2e/offline_inference/test_sequence_parallel.py @@ -21,6 +21,8 @@ import torch.distributed as dist from PIL import Image +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + # ruff: noqa: E402 REPO_ROOT = Path(__file__).resolve().parents[3] if str(REPO_ROOT) not in sys.path: @@ -81,12 +83,14 @@ def _run_baseline(model_name: str, dtype: torch.dtype, attn_backend: str, height try: outputs = baseline.generate( PROMPT, - height=height, - width=width, - num_inference_steps=4, - guidance_scale=0.0, - generator=torch.Generator(get_device_name()).manual_seed(seed), - num_outputs_per_prompt=1, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=4, + guidance_scale=0.0, + generator=torch.Generator(get_device_name()).manual_seed(seed), + num_outputs_per_prompt=1, + ), ) return outputs[0].request_output[0].images finally: @@ -116,12 +120,14 @@ def _run_sp( try: outputs = sp.generate( PROMPT, - height=height, - width=width, - num_inference_steps=4, - guidance_scale=0.0, - generator=torch.Generator(get_device_name()).manual_seed(seed), - num_outputs_per_prompt=1, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=4, + guidance_scale=0.0, + generator=torch.Generator(get_device_name()).manual_seed(seed), + num_outputs_per_prompt=1, + ), ) return outputs[0].request_output[0].images finally: diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py index 729b501e49c..df2ca5e4283 100644 --- a/tests/e2e/offline_inference/test_stable_audio_model.py +++ b/tests/e2e/offline_inference/test_stable_audio_model.py @@ -5,6 +5,7 @@ import pytest import torch +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput # ruff: noqa: E402 @@ -29,16 +30,20 @@ def test_stable_audio_model(model_name: str): sample_rate = 44100 # Stable Audio uses 44100 Hz outputs = m.generate( - "The sound of a dog barking", - negative_prompt="Low quality.", - num_inference_steps=4, # Minimal steps for speed - guidance_scale=7.0, - generator=torch.Generator("cuda").manual_seed(42), - num_outputs_per_prompt=1, - extra={ - "audio_start_in_s": audio_start_in_s, - "audio_end_in_s": audio_end_in_s, + prompts={ + "prompt": "The sound of a dog barking", + "negative_prompt": "Low quality.", }, + sampling_params_list=OmniDiffusionSamplingParams( + num_inference_steps=4, # Minimal steps for speed + guidance_scale=7.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, + extra_args={ + "audio_start_in_s": audio_start_in_s, + "audio_end_in_s": audio_end_in_s, + }, + ), ) # Extract audio from OmniRequestOutput diff --git a/tests/e2e/offline_inference/test_t2i_model.py b/tests/e2e/offline_inference/test_t2i_model.py index 09fadc93a81..ba6cac34ca0 100644 --- a/tests/e2e/offline_inference/test_t2i_model.py +++ b/tests/e2e/offline_inference/test_t2i_model.py @@ -5,6 +5,7 @@ import pytest import torch +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.utils.platform_utils import is_npu, is_rocm @@ -42,12 +43,14 @@ def test_diffusion_model(model_name: str): width = 256 outputs = m.generate( "a photo of a cat sitting on a laptop keyboard", - height=height, - width=width, - num_inference_steps=2, - guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), - num_outputs_per_prompt=2, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=2, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=2, + ), ) # Extract images from request_output[0]['images'] first_output = outputs[0] diff --git a/tests/e2e/offline_inference/test_t2v_model.py b/tests/e2e/offline_inference/test_t2v_model.py index 0263310e358..a378291acdc 100644 --- a/tests/e2e/offline_inference/test_t2v_model.py +++ b/tests/e2e/offline_inference/test_t2v_model.py @@ -5,6 +5,8 @@ import pytest import torch +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + # ruff: noqa: E402 REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: @@ -33,13 +35,15 @@ def test_video_diffusion_model(model_name: str): width = 640 num_frames = 5 outputs = m.generate( - "A cat sitting on a table", - height=height, - width=width, - num_frames=num_frames, - num_inference_steps=2, - guidance_scale=1.0, - generator=torch.Generator("cuda").manual_seed(42), + prompts="A cat sitting on a table", + sampling_params_list=OmniDiffusionSamplingParams( + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=2, + guidance_scale=1.0, + generator=torch.Generator("cuda").manual_seed(42), + ), ) first_output = outputs[0] assert first_output.final_output_type == "image" diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py index 7ce915b9635..7d626138819 100644 --- a/tests/e2e/offline_inference/test_teacache.py +++ b/tests/e2e/offline_inference/test_teacache.py @@ -15,6 +15,8 @@ import pytest import torch +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + # ruff: noqa: E402 REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: @@ -51,12 +53,14 @@ def test_teacache(model_name: str): outputs = m.generate( "a photo of a cat sitting on a laptop keyboard", - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), - num_outputs_per_prompt=1, # Single output for speed + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, # Single output for speed + ), ) # Extract images from request_output[0]['images'] first_output = outputs[0] diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py index a9fdec4dc0b..ea8f0acfeb0 100644 --- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py @@ -12,6 +12,8 @@ from PIL import Image from vllm.distributed.parallel_state import cleanup_dist_env_and_memory +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + # ruff: noqa: E402 REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: @@ -90,12 +92,14 @@ def _run_zimage_generate( num_requests = 4 # 1 warmup + 3 timed gen = m.generate( [PROMPT] * num_requests, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=0.0, - seed=seed, - num_outputs_per_prompt=1, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=0.0, + seed=seed, + num_outputs_per_prompt=1, + ), py_generator=True, ) diff --git a/tests/e2e/online_serving/test_async_omni.py b/tests/e2e/online_serving/test_async_omni.py index 112b68ff41e..d90727e1bc5 100644 --- a/tests/e2e/online_serving/test_async_omni.py +++ b/tests/e2e/online_serving/test_async_omni.py @@ -5,6 +5,7 @@ from pathlib import Path import pytest +from vllm import SamplingParams from vllm.inputs import PromptType from vllm_omni.entrypoints.async_omni import AsyncOmni, ClientRequestState @@ -25,15 +26,15 @@ async def generate( ) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) - thinker_sampling_params = { - "temperature": 0.4, # Deterministic - "top_p": 0.9, - "top_k": 1, - "max_tokens": max_tokens, - "repetition_penalty": 1.05, - "stop_token_ids": [151645], # Qwen EOS token <|im_end|> - "seed": SEED, - } + thinker_sampling_params = SamplingParams( + temperature=0.4, # Deterministic + top_p=0.9, + top_k=1, + max_tokens=max_tokens, + repetition_penalty=1.05, + stop_token_ids=[151645], # Qwen EOS token <|im_end|> + seed=SEED, + ) sampling_params_list = [ thinker_sampling_params, diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 45f4b33ad0c..6130ca81d77 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -14,11 +14,13 @@ import pytest from fastapi.testclient import TestClient from PIL import Image +from vllm import SamplingParams from vllm_omni.entrypoints.openai.image_api_utils import ( encode_image_base64, parse_size, ) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams # Unit Tests @@ -111,7 +113,7 @@ class FakeAsyncOmni: def __init__(self): self.stage_list = ["llm", "diffusion"] - self.default_sampling_params_list = [{"temperature": 0.1}, {"top_p": 0.9}] + self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()] self.captured_sampling_params_list = None async def generate(self, prompt, request_id, sampling_params_list): @@ -127,7 +129,8 @@ def mock_async_diffusion(): async def generate(**kwargs): # Return n PIL images wrapped in result object - n = kwargs.get("num_outputs_per_prompt", 1) + print("!!!!!!!!!!!!!!!!!!!!! kwargs", kwargs) + n = kwargs["sampling_params_list"][0].num_outputs_per_prompt images = [Image.new("RGB", (64, 64), color="blue") for _ in range(n)] return MockGenerationResult(images) @@ -217,11 +220,11 @@ def test_generate_images_async_omni_sampling_params(async_omni_test_client): captured = engine.captured_sampling_params_list assert captured is not None assert len(captured) == 2 - assert captured[0] == {"temperature": 0.1} - assert captured[1]["num_outputs_per_prompt"] == 2 - assert captured[1]["height"] == 256 - assert captured[1]["width"] == 256 - assert captured[1]["seed"] == 7 + assert captured[0].temperature == 0.1 + assert captured[1].num_outputs_per_prompt == 2 + assert captured[1].height == 256 + assert captured[1].width == 256 + assert captured[1].seed == 7 def test_generate_multiple_images(test_client): @@ -454,30 +457,11 @@ def test_parameters_passed_through(test_client, mock_async_diffusion): # Ensure generate() was called exactly once mock_async_diffusion.generate.assert_awaited_once() - call_kwargs = mock_async_diffusion.generate.call_args[1] - assert call_kwargs["num_inference_steps"] == 100 - assert call_kwargs["guidance_scale"] == 7.5 - assert call_kwargs["true_cfg_scale"] == 3.0 - assert call_kwargs["seed"] == 42 - - -def test_optional_parameters_omitted(test_client, mock_async_diffusion): - """Verify optional parameters not passed when omitted""" - response = test_client.post( - "/v1/images/generations", - json={ - "prompt": "test", - "size": "512x512", - }, - ) - assert response.status_code == 200 - - # Ensure generate() was called exactly once - mock_async_diffusion.generate.assert_awaited_once() - call_kwargs = mock_async_diffusion.generate.call_args[1] - assert "num_inference_steps" not in call_kwargs - assert "guidance_scale" not in call_kwargs - assert "true_cfg_scale" not in call_kwargs + call_kwargs = mock_async_diffusion.generate.call_args[1]["sampling_params_list"][0] + assert call_kwargs.num_inference_steps == 100 + assert call_kwargs.guidance_scale == 7.5 + assert call_kwargs.true_cfg_scale == 3.0 + assert call_kwargs.seed == 42 def test_model_field_omitted_works(test_client): diff --git a/tests/entrypoints/test_omni_diffusion.py b/tests/entrypoints/test_omni_diffusion.py new file mode 100644 index 00000000000..c4884e3abd1 --- /dev/null +++ b/tests/entrypoints/test_omni_diffusion.py @@ -0,0 +1,1103 @@ +import uuid +import warnings +from queue import Empty, Queue +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies. +warnings.filterwarnings( + "ignore", + message=r"builtin type SwigPy.*has no __module__ attribute", + category=DeprecationWarning, +) + + +class _FakeEngineArgs(dict): + """Fake engine args that can be used both as object attributes and as **kwargs.""" + + def __init__(self, args_dict: dict[str, Any]): + super().__init__(args_dict) + # Add required attributes if not present + if "model_stage" not in self: + self["model_stage"] = None + if "engine_output_type" not in self: + self["engine_output_type"] = None + # Also set as attributes for object-style access + for key, value in self.items(): + setattr(self, key, value) + + +class _FakeStageConfig: + """Fake stage config object that mimics the real stage config structure.""" + + def __init__(self, config_dict: dict[str, Any]): + # engine_args needs to work both as object (for OmniStage) and as dict (for **kwargs) + engine_args_dict = config_dict.get("engine_args", {}) + self.engine_args = _FakeEngineArgs(engine_args_dict) + self.final_output = config_dict.get("final_output", False) + self.final_output_type = config_dict.get("final_output_type", None) + self.stage_id = config_dict.get("stage_id", 0) + # Store original dict for reference + self._config_dict = config_dict + + +class _FakeQueue: + """Fake queue using standard library Queue to replace mp.Queue.""" + + def __init__(self, maxsize=0): + self._queue = Queue(maxsize=maxsize) + + def put(self, item): + self._queue.put(item) + + def put_nowait(self, item): + self._queue.put_nowait(item) + + def get(self): + return self._queue.get() + + def get_nowait(self): + return self._queue.get_nowait() + + def empty(self): + return self._queue.empty() + + +class _FakeStage: + """Lightweight Stage stub for multi-process pipeline version with queue support.""" + + def __init__(self, config, stage_init_timeout: int = 300): + # Handle both dict and object configs + if isinstance(config, dict): + config = _FakeStageConfig(config) + self.config = config + self.stage_config = config + self.engine = None + self.engine_outputs = None + # Set attributes that OmniStage expects + self.stage_id = getattr(config, "stage_id", 0) + self.engine_args = config.engine_args + self.model_stage = getattr(config.engine_args, "model_stage", None) + self.stage_type = "diffusion" + # set default sampling params + self.default_sampling_params = OmniDiffusionSamplingParams(num_inference_steps=1) + # Allow configuring final_output and final_output_type + self.final_output = config.final_output if hasattr(config, "final_output") else False + self.final_output_type = getattr(config, "final_output_type", None) + # Configurable processing logic, default returns placeholder + processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"]) + self._processed_input = processed_input + # Queue references (set by attach_queues) + self._in_q = None + self._out_q = None + self._proc = None # Mock process reference + self._stage_init_timeout = max(0, int(stage_init_timeout)) + + def attach_queues(self, in_q, out_q): + """Attach input and output queues.""" + self._in_q = in_q + self._out_q = out_q + + def init_stage_worker( + self, + model: str, + *, + is_async: bool = False, + shm_threshold_bytes: int = 65536, + ctx=None, + batch_timeout: int = 10, + **kwargs, + ): + """Mock init_stage_worker: don't start real process, just send stage_ready message.""" + # Create a mock process object + self._proc = MagicMock() + self._proc.start = MagicMock() + self._proc.join = MagicMock() + self._proc.is_alive = MagicMock(return_value=False) + self._proc.terminate = MagicMock() + # Send stage_ready message to output queue + if self._out_q is not None: + try: + self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id}) + except Exception: + pass + + def stop_stage_worker(self): + """Mock stop_stage_worker: clean up queue references.""" + if self._in_q is not None: + try: + self._in_q.put_nowait(SHUTDOWN_TASK) + except Exception: + pass + + def submit(self, payload: dict[str, Any]): + """Submit task to input queue.""" + if self._in_q is not None: + self._in_q.put(payload) + + def try_collect(self) -> Any: + """Non-blocking collect from output queue.""" + if self._out_q is None: + return None + try: + return self._out_q.get_nowait() + except Empty: + return None + + def set_engine_outputs(self, outputs): + """Set engine outputs for the stage.""" + self.engine_outputs = outputs + + def process_engine_inputs(self, stage_list, prompts): + """Process engine inputs: return preset processed result.""" + return self._processed_input + + +class _FakeEngine: + """Lightweight Engine stub: provides generate iterator output.""" + + def __init__(self, outputs: list[Any]): + self._outputs = outputs + + def generate(self, prompts, sampling_params): + # Record the most recent prompts for outer assertions + self._last_prompts = prompts + # Simplified: return preset list at once, ensuring iterability + yield from self._outputs + + +@pytest.fixture +def fake_stage_config(): + return { + # Don't include 'model' in engine_args since it's passed separately + "engine_args": {}, + "final_output": True, + "final_output_type": "text", + # Second stage will use processed_input to verify the chain + "processed_input": ["processed-by-stage"], + } + + +def _setup_engine_mocks(monkeypatch): + """Helper function to set up common engine mocks.""" + fake_engine = MagicMock() + # Add necessary attributes to fake_engine + fake_engine.tokenizer = MagicMock() + fake_engine.log_stats = False + fake_engine.vllm_config = MagicMock() + fake_engine.vllm_config.model_config = MagicMock() + fake_engine.vllm_config.model_config.io_processor_plugin = None + fake_engine.get_supported_tasks = MagicMock(return_value=[]) + fake_engine.model_config = MagicMock() + fake_engine.model_config.io_processor_plugin = None + # Add registry with resolve_model_cls method + fake_registry = MagicMock() + fake_registry.resolve_model_cls = MagicMock(return_value=(MagicMock(), "test_arch")) + fake_engine.model_config.registry = fake_registry + fake_engine.vllm_config.model_config.registry = fake_registry + + monkeypatch.setattr( + "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args", + lambda **kw: fake_engine, + raising=False, + ) + + # Mock model_config.registry.resolve_model_cls to return a tuple + # Use a real class instead of MagicMock to avoid inspect.getsource issues + class FakeModelClass: + pass + + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils.get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + + monkeypatch.setattr( + "vllm.model_executor.model_loader.utils._get_model_architecture", + lambda model_config: (FakeModelClass, "test_arch"), + raising=False, + ) + + # Mock try_create_mm_pooling_model_cls to return the class as-is + monkeypatch.setattr( + "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls", + lambda model_cls: model_cls, + raising=False, + ) + + # Mock _enable_processor_cache to return False + monkeypatch.setattr( + "vllm.multimodal.cache._enable_processor_cache", + lambda model_config, mm_registry: False, + raising=False, + ) + + # Mock get_io_processor to return None + monkeypatch.setattr( + "vllm.plugins.io_processors.get_io_processor", + lambda vllm_config, io_processor_plugin: None, + raising=False, + ) + + +def _setup_multiprocessing_mocks(monkeypatch): + """Helper function to set up multiprocessing mocks.""" + import multiprocessing as mp + + # Mock Process + fake_process_class = MagicMock() + fake_process_instance = MagicMock() + fake_process_instance.start = MagicMock() + fake_process_instance.join = MagicMock() + fake_process_instance.is_alive = MagicMock(return_value=False) + fake_process_instance.terminate = MagicMock() + fake_process_class.return_value = fake_process_instance + + # Mock get_context to return a context with Queue that returns _FakeQueue + fake_ctx = MagicMock() + fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize) + fake_ctx.Process = fake_process_class + + def _mock_get_context(method): + return fake_ctx + + monkeypatch.setattr(mp, "get_context", _mock_get_context, raising=False) + monkeypatch.setattr(mp, "Process", fake_process_class, raising=False) + + +def _setup_ipc_mocks(monkeypatch): + """Helper function to set up IPC function mocks.""" + + # Mock _encode: simple serialization + def _fake_encode(obj, threshold, obj_key, shm_key): + return {obj_key: obj} + + # Mock _load: extract object from result + def _fake_load(result, obj_key, shm_key): + return result.get(obj_key) + + # Mock _set: calculate serialization size + def _fake_set(obj): + return str(obj).encode() + + monkeypatch.setattr("vllm_omni.entrypoints.omni._encode", _fake_encode, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._load", _fake_load, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._set", _fake_set, raising=False) + + +def _setup_log_mocks(monkeypatch): + """Helper function to set up logging and stats mocks.""" + # Mock OrchestratorMetrics to be a simple class that doesn't require file operations + + class _FakeOrchestratorMetrics: + def __init__(self, num_stages, enable_stats, wall_start_ts): + self.num_stages = num_stages + self.enable_stats = enable_stats + self.stage_first_ts = [None] * num_stages + self.stage_last_ts = [None] * num_stages + self.e2e_done = set() + + def on_stage_metrics(self, stage_id, req_id, metrics): + pass + + def on_finalize_request(self, stage_id, req_id, start_ts): + self.e2e_done.add(req_id) + + def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm): + pass + + def build_and_log_summary(self, final_stage_id): + return "Fake summary" + + monkeypatch.setattr( + "vllm_omni.entrypoints.omni.OrchestratorMetrics", + _FakeOrchestratorMetrics, + raising=False, + ) + + +@pytest.fixture(autouse=True) +def mock_get_config(monkeypatch): + """Auto-mock get_config and related model loading functions to avoid model path validation.""" + # CRITICAL: Mock tokenizer-related imports FIRST, before any module imports + # This prevents ImportError when async_omni is imported (which happens via omni_stage) + import sys + + fake_tokenizer = MagicMock() + fake_tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + fake_tokenizer.decode = MagicMock(return_value="test") + + # Mock init_tokenizer_from_configs (used in async_omni) + def _mock_init_tokenizer_from_configs(model_config=None, **kwargs): + return fake_tokenizer + + # Strategy 1: Mock in the original location (vllm.transformers_utils.tokenizer) + # This works if the module hasn't been imported yet + monkeypatch.setattr( + "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs", + _mock_init_tokenizer_from_configs, + raising=False, + ) + + # Strategy 2: If the module is already in sys.modules, patch it directly + tokenizer_module_path = "vllm.transformers_utils.tokenizer" + if tokenizer_module_path in sys.modules: + tokenizer_module = sys.modules[tokenizer_module_path] + setattr(tokenizer_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + # CRITICAL: Mock length_from_prompt_token_ids_or_embeds BEFORE trying to mock async_omni + + # This is because async_omni imports processor.py, which imports this function at module level + # Mock length_from_prompt_token_ids_or_embeds (used in processor.py) + def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None): + # Return a reasonable default length + if prompt_token_ids is not None: + if isinstance(prompt_token_ids, list): + return len(prompt_token_ids) + elif hasattr(prompt_token_ids, "shape"): + return prompt_token_ids.shape[-1] if len(prompt_token_ids.shape) > 0 else 1 + if prompt_embeds is not None: + if hasattr(prompt_embeds, "shape"): + return prompt_embeds.shape[-2] if len(prompt_embeds.shape) > 1 else 1 + return 10 # Default length + + # Mock in vllm.utils + monkeypatch.setattr( + "vllm.utils.length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + raising=False, + ) + # Also mock in processor module if it's imported + monkeypatch.setattr( + "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds", + _mock_length_from_prompt_token_ids_or_embeds, + raising=False, + ) + # If processor module is already imported, patch it directly + processor_module_path = "vllm_omni.engine.input_processor" + if processor_module_path in sys.modules: + processor_module = sys.modules[processor_module_path] + setattr( + processor_module, "length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds + ) + + # Strategy 3: Now mock async_omni AFTER length_from_prompt_token_ids_or_embeds is mocked + # This prevents ImportError when async_omni imports processor.py + monkeypatch.setattr( + "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", + _mock_init_tokenizer_from_configs, + raising=False, + ) + + # Strategy 4: If async_omni is already imported, patch it directly + async_omni_path = "vllm_omni.entrypoints.async_omni" + if async_omni_path in sys.modules: + async_omni_module = sys.modules[async_omni_path] + setattr(async_omni_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) + + # Now mock get_config and other functions + fake_hf_config = MagicMock() + fake_hf_config.model_type = "qwen2_5_omni" + + def _mock_get_config(model, **kwargs): + return fake_hf_config + + monkeypatch.setattr( + "vllm.transformers_utils.config.get_config", + _mock_get_config, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.get_config", + _mock_get_config, + raising=False, + ) + + # Mock transformers' cached_file to avoid downloading model configs + def _mock_cached_file(path_or_repo_id, *args, **kwargs): + import os + import tempfile + + fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json") + if not os.path.exists(fake_config_file): + with open(fake_config_file, "w") as f: + f.write('{"model_type": "qwen2_5_omni"}') + return fake_config_file + + monkeypatch.setattr( + "transformers.utils.hub.cached_file", + _mock_cached_file, + raising=False, + ) + monkeypatch.setattr( + "transformers.utils.hub.cached_files", + lambda path_or_repo_id, filenames, **kwargs: ( + [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None + ), + raising=False, + ) + + +def test_initialize_stage_configs_called_when_none(monkeypatch, fake_stage_config): + """Test that stage configs are auto-loaded when stage_configs_path is None.""" + + def _fake_loader(model: str, base_engine_args=None): + return [ + _FakeStageConfig(fake_stage_config), + _FakeStageConfig(fake_stage_config), + ] + + # Remove modules from cache BEFORE setting mocks + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + # Set up mocks + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + # Mock load_stage_configs_from_model + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + + # Replace OmniStage + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + # Import the module after mocks are set + import vllm_omni.entrypoints.omni as omni_module + + # Patch the imported function and class in the module + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + # Verify: auto-loaded stage_configs and stage_list have consistent count + assert isinstance(omni.stage_configs, list) + assert len(omni.stage_configs) == 2 + assert len(omni.stage_list) == 2 + # Verify: each Stage is _FakeStage instance + for st in omni.stage_list: + assert isinstance(st, _FakeStage) + # Verify: queues are attached + for st in omni.stage_list: + assert st._in_q is not None + assert st._out_q is not None + # Verify: all stages are ready + assert len(omni._stages_ready) == 2 + + +def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config): + """Test that generate raises ValueError when sampling_params_list length doesn't match.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + with pytest.raises(ValueError): + omni.generate(prompts=["hi"], sampling_params_list=[]) + + +def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config): + """Test multi-stage generation pipeline with queue polling.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg1["processed_input"] = ["processed-for-stage-1"] + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: manually put results into output queues + # Note: We put results before calling generate, which simulates worker processes + # that have already completed. The polling loop will collect them in stage order. + # Stage 0 output (will be collected first) + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "s0"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + # Stage 1 output (will be collected after stage 0 forwards to it) + # Note: In real flow, stage 1 result would appear after stage 0 forwards, + # but for testing we pre-populate it. The polling loop processes stages + # in order, so stage 0 result will be collected first, then forwarded, + # then stage 1 result will be collected. + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1, "text": "s1"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + sampling_params_list = [ + OmniDiffusionSamplingParams(num_inference_steps=1), + OmniDiffusionSamplingParams(num_inference_steps=1, max_sequence_length=10), + ] + prompts = ["hi"] + outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list) + + # Both stages have final_output=True, so should aggregate two OmniRequestOutput + assert len(outputs) == 2 + # Verify stage outputs are set + assert omni.stage_list[0].engine_outputs == [{"stage": 0, "text": "s0"}] + assert omni.stage_list[1].engine_outputs == [{"stage": 1, "text": "s1"}] + # Verify stage 0 input queue received the task + assert not omni.stage_list[0]._in_q.empty() + # Verify stage 1 received forwarded task (process_engine_inputs was called) + assert omni.stage_list[1].process_engine_inputs([], []) is not None + + +def test_generate_pipeline_with_batch_input(monkeypatch, fake_stage_config): + """Test single-stage generation pipeline with multiple inputs in one batch.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg0["final_output"] = False + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: manually put results into output queues + # Note: We put results before calling generate, which simulates worker processes + # that have already completed. The polling loop will collect them in stage order. + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "s0"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "s0"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + outputs = omni.generate( + prompts=[ + { + "prompt": "hi", + "negative_prompt": "hi", + "multi_modal_data": {"image": ["dog.jpg", "cat.jpg"]}, + }, + { + "prompt": "hi", + "negative_prompt": "hi", + "multi_modal_data": {"image": ["dog.jpg", "cat.jpg"]}, + }, + ], + sampling_params_list=[ + OmniDiffusionSamplingParams(num_inference_steps=1), + OmniDiffusionSamplingParams(num_inference_steps=1), + ], + ) + + assert len(outputs) == 2 + + +def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config): + """Test that generate returns empty list when all stages have final_output=False.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg0["final_output"] = False + stage_cfg1["final_output"] = False + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: put results into output queues + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + outputs = omni.generate( + prompts=["p"], + sampling_params_list=[ + OmniDiffusionSamplingParams(num_inference_steps=1), + OmniDiffusionSamplingParams(num_inference_steps=1, max_sequence_length=10), + ], + ) + assert outputs == [] + + +def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_config): + """Test that generate uses default sampling params when sampling_params_list is None.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg0["final_output"] = False + stage_cfg1["final_output"] = False + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: put results into output queues + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + # Use the default sampling params + omni.generate(prompts=["p"], sampling_params_list=None) + + +def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config): + """Test that _wait_for_stages_ready handles timeout correctly.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + + # Create a stage that doesn't send stage_ready message + class _FakeStageNoReady(_FakeStage): + def init_stage_worker(self, *args, **kwargs): + # Don't send stage_ready message + self._proc = MagicMock() + self._proc.start = MagicMock() + self._proc.join = MagicMock() + self._proc.is_alive = MagicMock(return_value=False) + self._proc.terminate = MagicMock() + + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + # Use very short timeout + omni = Omni(model="any", init_timeout=0.01) + # Verify that no stages are ready + assert len(omni._stages_ready) == 0 + + +def test_generate_handles_error_messages(monkeypatch, fake_stage_config): + """Test that generate handles error messages from stages correctly.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_" + expected_request_id = f"0_{test_uuid}" + + # Put error message in output queue + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "error": "test error", + } + ) + # Also put a valid result after error to allow the loop to complete + # (error handling continues the loop, so we need a valid result to finish) + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0, "text": "result"}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + # Generate should handle error gracefully (log but continue) + sampling_params_list = [OmniDiffusionSamplingParams(num_inference_steps=1)] + outputs = omni.generate(prompts=["hi"], sampling_params_list=sampling_params_list) + # Should return final output (error was logged but didn't stop processing) + assert isinstance(outputs, list) + # Since final_output=True, should have one output + assert len(outputs) == 1 + + +def test_close_sends_shutdown_signal(monkeypatch, fake_stage_config): + """Test that close() sends shutdown signal to all input queues.""" + + def _fake_loader(model: str, base_engine_args=None): + return [_FakeStageConfig(fake_stage_config)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + _fake_loader, + raising=False, + ) + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs)) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Call close + omni.close() + + # Verify shutdown signal (None) was sent to input queue + # Use get_nowait to avoid blocking (close() uses put_nowait, so should be safe) + try: + shutdown_signal = omni.stage_list[0]._in_q.get_nowait() + assert shutdown_signal == SHUTDOWN_TASK + except Empty: + # If queue was already empty or only had stage_ready, that's also acceptable + # The important thing is that close() was called without error + pass + + # Verify stop_stage_worker was called (process should be set) + assert omni.stage_list[0]._proc is not None diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index fd2ba26733e..f99c6d8336c 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock import pytest +from vllm import SamplingParams from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK @@ -82,8 +83,9 @@ def __init__(self, config, stage_init_timeout: int = 300): self.stage_id = getattr(config, "stage_id", 0) self.engine_args = config.engine_args self.model_stage = getattr(config.engine_args, "model_stage", None) + self.stage_type = "llm" # set default sampling params - self.default_sampling_params = {"temperature": 1.0} + self.default_sampling_params = SamplingParams(temperature=1.0) # Allow configuring final_output and final_output_type self.final_output = config.final_output if hasattr(config, "final_output") else False self.final_output_type = getattr(config, "final_output_type", None) @@ -637,8 +639,10 @@ def _fake_loader(model: str, base_engine_args=None): } ) - # Use dicts instead of object() for serializable sampling params - sampling_params_list = [{"temperature": 0.7}, {"temperature": 0.8}] + sampling_params_list = [ + SamplingParams(temperature=0.7), + SamplingParams(temperature=0.8), + ] prompts = ["hi"] outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list) @@ -722,8 +726,13 @@ def _fake_loader(model: str, base_engine_args=None): } ) - # Use dicts instead of object() for serializable sampling params - outputs = omni.generate(prompts=["p"], sampling_params_list=[{"temperature": 0.7}, {"temperature": 0.8}]) + outputs = omni.generate( + prompts=["p"], + sampling_params_list=[ + SamplingParams(temperature=0.7), + SamplingParams(temperature=0.8), + ], + ) assert outputs == [] @@ -922,8 +931,7 @@ def _fake_loader(model: str, base_engine_args=None): ) # Generate should handle error gracefully (log but continue) - # Use dict instead of object() for serializable sampling params - sampling_params_list = [{"temperature": 0.7}] + sampling_params_list = [SamplingParams(temperature=0.7)] outputs = omni.generate(prompts=["hi"], sampling_params_list=sampling_params_list) # Should return final output (error was logged but didn't stop processing) assert isinstance(outputs, list) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 2e00f41b97f..b0b9b89db0a 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -17,6 +17,7 @@ get_diffusion_pre_process_func, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -60,36 +61,31 @@ def __init__(self, od_config: OmniDiffusionConfig): self.close() raise e - def step(self, requests: list[OmniDiffusionRequest]): + def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: # Apply pre-processing if available if self.pre_process_func is not None: preprocess_start_time = time.time() - requests = self.pre_process_func(requests) + request = self.pre_process_func(request) preprocess_time = time.time() - preprocess_start_time logger.info(f"Pre-processing completed in {preprocess_time:.4f} seconds") - output = self.add_req_and_wait_for_response(requests) + output = self.add_req_and_wait_for_response(request) if output.error: raise Exception(f"{output.error}") logger.info("Generation completed successfully.") if output.output is None: logger.warning("Output is None, returning empty OmniRequestOutput") - # Return empty output for the first request - if len(requests) > 0: - request = requests[0] - request_id = request.request_id or "" - prompt = request.prompt - if isinstance(prompt, list): - prompt = prompt[0] if prompt else None - return OmniRequestOutput.from_diffusion( - request_id=request_id, + return [ + OmniRequestOutput.from_diffusion( + request_id=request.request_ids[i] if i < len(request.request_ids) else "", images=[], prompt=prompt, metrics={}, latents=None, ) - return None + for i, prompt in enumerate(request.prompts) + ] postprocess_start_time = time.time() outputs = self.post_process_func(output.output) if self.post_process_func is not None else output.output @@ -102,13 +98,10 @@ def step(self, requests: list[OmniDiffusionRequest]): outputs = [outputs] if outputs is not None else [] # Handle single request or multiple requests - if len(requests) == 1: + if len(request.prompts) == 1: # Single request: return single OmniRequestOutput - request = requests[0] - request_id = request.request_id or "" - prompt = request.prompt - if isinstance(prompt, list): - prompt = prompt[0] if prompt else None + prompt = request.prompts[0] + request_id = request.request_ids[0] if request.request_ids else "" metrics = {} if output.trajectory_timesteps is not None: @@ -116,37 +109,38 @@ def step(self, requests: list[OmniDiffusionRequest]): if supports_audio_output(self.od_config.model_class_name): audio_payload = outputs[0] if len(outputs) == 1 else outputs - return OmniRequestOutput.from_diffusion( - request_id=request_id, - images=[], - prompt=prompt, - metrics=metrics, - latents=output.trajectory_latents, - multimodal_output={"audio": audio_payload}, - final_output_type="audio", - ) + return [ + OmniRequestOutput.from_diffusion( + request_id=request_id, + images=[], + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + multimodal_output={"audio": audio_payload}, + final_output_type="audio", + ), + ] else: - return OmniRequestOutput.from_diffusion( - request_id=request_id, - images=outputs, - prompt=prompt, - metrics=metrics, - latents=output.trajectory_latents, - ) + return [ + OmniRequestOutput.from_diffusion( + request_id=request_id, + images=outputs, + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + ), + ] else: # Multiple requests: return list of OmniRequestOutput # Split images based on num_outputs_per_prompt for each request results = [] output_idx = 0 - for request in requests: - request_id = request.request_id or "" - prompt = request.prompt - if isinstance(prompt, list): - prompt = prompt[0] if prompt else None + for i, prompt in enumerate(request.prompts): + request_id = request.request_ids[i] if i < len(request.request_ids) else "" # Get images for this request - num_outputs = request.num_outputs_per_prompt + num_outputs = request.sampling_params.num_outputs_per_prompt request_outputs = outputs[output_idx : output_idx + num_outputs] if output_idx < len(outputs) else [] output_idx += num_outputs @@ -192,8 +186,8 @@ def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine": """ return DiffusionEngine(config) - def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]): - return self.executor.add_req(requests) + def add_req_and_wait_for_response(self, request: OmniDiffusionRequest): + return self.executor.add_req(request) def start_profile(self, trace_filename: str | None = None) -> None: """ @@ -316,8 +310,6 @@ def stop_profile(self) -> dict: def _dummy_run(self): """A dummy run to warm up the model.""" - prompt = "dummy run" - # note that num_inference_steps=1 will cause timestep and temb None in the pipeline num_inference_steps = 1 height = 1024 width = 1024 @@ -327,17 +319,19 @@ def _dummy_run(self): dummy_image = PIL.Image.new("RGB", (width, height), color=(0, 0, 0)) else: dummy_image = None + prompt: OmniTextPrompt = {"prompt": "dummy run", "multi_modal_data": {"image": dummy_image}} req = OmniDiffusionRequest( - prompt=prompt, - height=height, - width=width, - pil_image=dummy_image, - num_inference_steps=num_inference_steps, - num_outputs_per_prompt=1, + prompts=[prompt], + sampling_params=OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=num_inference_steps, + num_outputs_per_prompt=1, + ), ) logger.info("dummy run to warm up the model") - requests = self.pre_process_func([req]) if self.pre_process_func is not None else [req] - self.add_req_and_wait_for_response(requests) + request = self.pre_process_func(req) if self.pre_process_func is not None else req + self.add_req_and_wait_for_response(request) def collective_rpc( self, diff --git a/vllm_omni/diffusion/executor/abstract.py b/vllm_omni/diffusion/executor/abstract.py index 021c210efcb..e41f41d119e 100644 --- a/vllm_omni/diffusion/executor/abstract.py +++ b/vllm_omni/diffusion/executor/abstract.py @@ -3,7 +3,7 @@ from vllm.utils.import_utils import resolve_obj_by_qualname -from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -59,7 +59,7 @@ def _init_executor(self) -> None: pass @abstractmethod - def add_req(self, requests: list[OmniDiffusionRequest]): + def add_req(self, requests: OmniDiffusionRequest) -> DiffusionOutput: """Add requests to the execution queue.""" pass diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index a04a40e86aa..21a94acd35a 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -6,7 +6,7 @@ from vllm.logger import init_logger -from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE +from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, DiffusionOutput from vllm_omni.diffusion.executor.abstract import DiffusionExecutor from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.scheduler import Scheduler @@ -130,8 +130,8 @@ def _launch_workers(self, broadcast_handle): return processes, result_handle - def add_req(self, requests: list[OmniDiffusionRequest]): - return self.scheduler.add_req(requests) + def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: + return self.scheduler.add_req(request) def collective_rpc( self, diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 5d3bcdd11cb..bdb9f1f5c3f 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -272,16 +272,22 @@ def _decode_image_from_latent( @torch.inference_mode() def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: - prompt = req.prompt or "" - if isinstance(prompt, list): - # vllm-omni request supports list; Bagel pipeline currently supports first prompt. - prompt = prompt[0] if prompt else "" + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + # TODO: In online mode, sometimes it receives [{"prompts": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(req.prompts[0], str) else (req.prompts[0].get("prompt") or "") + max_hw = int(self.bagel.max_latent_size * self.bagel.latent_downsample) - if req.height is None and req.width is None: + if req.sampling_params.height is None and req.sampling_params.width is None: height = width = max_hw else: - height = int(req.height) if req.height is not None else max_hw - width = int(req.width) if req.width is not None else max_hw + height = int(req.sampling_params.height) if req.sampling_params.height is not None else max_hw + width = int(req.sampling_params.width) if req.sampling_params.width is not None else max_hw if height > max_hw or width > max_hw: raise ValueError( f"Requested resolution {height}x{width} exceeds Bagel checkpoint limit " @@ -292,7 +298,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Map request params to Bagel gen params (defaults follow Bagel inferencer) gen_params = BagelGenParams( - num_timesteps=int(req.num_inference_steps or 50), + num_timesteps=int(req.sampling_params.num_inference_steps or 50), timestep_shift=3.0, ) @@ -304,7 +310,7 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Add text prompt (prefill) on gen context. # [Omni] Check for injected KV Cache from remote transfer - injected_kv = getattr(req, "past_key_values", None) + injected_kv = req.sampling_params.past_key_values if injected_kv is not None: logger.info("Using injected KV Cache (direct)") gen_context["past_key_values"] = injected_kv @@ -316,9 +322,13 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: gen_context["ropes"] = [seq_len] else: - image_input = getattr(req, "pil_image", None) + image_input = ( + None if isinstance(first_prompt, str) else (first_prompt.get("multi_modal_data") or {}).get("image") + ) if image_input and not isinstance(image_input, list): image_input = [image_input] + if image_input: + image_input = [Image.open(image) if isinstance(image, str) else image for image in image_input] if image_input: # If we have an image, we prefill with it @@ -414,10 +424,10 @@ def vae_transforms(img): gen_context["kv_lens"] = newlens gen_context["ropes"] = new_rope - if req.seed is not None: - torch.manual_seed(req.seed) + if req.sampling_params.seed is not None: + torch.manual_seed(req.sampling_params.seed) if self.device.type == "cuda": - torch.cuda.manual_seed(req.seed) + torch.cuda.manual_seed(req.sampling_params.seed) # Prepare latent query and run flow generation_input = self.bagel.prepare_vae_latent( diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index ba29e681c32..0496a7ec3a3 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -19,7 +19,7 @@ import math import os from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, cast import numpy as np import PIL.Image @@ -743,27 +743,54 @@ def forward( `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") - prompt = req.prompt if req.prompt is not None else prompt - image = req.pil_image if req.pil_image is not None else image - height = req.height or height - width = req.width or width - num_inference_steps = req.num_inference_steps or num_inference_steps - guidance_scale = req.guidance_scale if req.guidance_scale is not None else guidance_scale - generator = req.generator or generator - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) - if req_num_outputs and req_num_outputs > 0: - num_images_per_prompt = req_num_outputs - - if isinstance(req.prompt_embeds, torch.Tensor): - prompt_embeds = req.prompt_embeds - if isinstance(req.negative_prompt_embeds, torch.Tensor): - negative_prompt_embeds = req.negative_prompt_embeds - - if req.max_sequence_length is not None: - max_sequence_length = req.max_sequence_length - if getattr(req, "text_encoder_out_layers", None) is not None: - text_encoder_out_layers = req.text_encoder_out_layers + if ( + raw_image := None + if isinstance(first_prompt, str) + else first_prompt.get("multi_modal_data", {}).get("image") + ) is None: + pass # use image from param list + elif isinstance(raw_image, list): + image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image] + else: + image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else cast(PIL.Image.Image, raw_image) + + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + text_encoder_out_layers = req.sampling_params.extra_args.get("text_encoder_out_layers", text_encoder_out_layers) + + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at list one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError # 1. Check inputs. Raise error if not correct self.check_inputs( diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py index 9a03a934983..d222342b51b 100644 --- a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -17,6 +17,7 @@ import os import re from collections.abc import Iterable +from typing import cast import numpy as np import PIL.Image @@ -48,6 +49,7 @@ GlmImageTransformer2DModel, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -77,16 +79,26 @@ def get_glm_image_pre_process_func(od_config: OmniDiffusionConfig): # GLM-Image uses patch_size=2 for transformer patch_size = 2 - def pre_process_func(requests: list[OmniDiffusionRequest]): + def pre_process_func(request: OmniDiffusionRequest): """Pre-process condition images for Image Edit mode.""" - for req in requests: - images = req.pil_image - if images is None: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: # Text-to-image mode, no preprocessing needed continue - if not isinstance(images, list): - images = [images] + if not isinstance(raw_image, list): + raw_image = [raw_image] + images = [ + PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im) + for im in raw_image + ] preprocessed = [] height, width = None, None @@ -110,14 +122,19 @@ def pre_process_func(requests: list[OmniDiffusionRequest]): height, width = img_h, img_w # Store in request - req.preprocessed_image = preprocessed - req.prompt_image = images # Keep original PIL images - if req.height is None: - req.height = height - if req.width is None: - req.width = width - - return requests + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt, additional_information={}) + elif "additional_information" not in prompt: + prompt["additional_information"] = {} + prompt["additional_information"]["preprocessed_image"] = processed # type: ignore + prompt["additional_information"]["prompt_image"] = images # type: ignore + request.prompts[i] = prompt + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + return request return pre_process_func @@ -821,27 +838,44 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: Returns: DiffusionOutput containing generated image """ - prompt = req.prompt or "" - if isinstance(prompt, list): - prompt = prompt[0] if prompt else "" + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") # Get pre-computed prompt embeddings if provided - prompt_embeds = req.prompt_embeds if isinstance(req.prompt_embeds, torch.Tensor) else None + if isinstance(first_prompt, str): + prompt_embeds = None + else: + prompt_embeds = first_prompt.get("prompt_embeds") + if not isinstance(prompt_embeds, torch.Tensor): + prompt_embeds = None # Get condition images for Image Edit mode # Use pre-processed images from pre_process_func - preprocessed_images = req.preprocessed_image - condition_images = getattr(req, "prompt_image", None) - img_height = req.height - img_width = req.width + preprocessed_images = ( + None + if isinstance(first_prompt, str) + else first_prompt.get("additional_information", {}).get("preprocessed_image") + ) + condition_images = ( + None + if isinstance(first_prompt, str) + else first_prompt.get("additional_information", {}).get("prompt_image") + ) + img_height = req.sampling_params.height + img_width = req.sampling_params.width is_image_edit = preprocessed_images is not None # Use image dimensions as default if available - height = req.height or img_height or self.default_sample_size * self.vae_scale_factor - width = req.width or img_width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.num_inference_steps or 50 - guidance_scale = req.guidance_scale or 1.5 + height = req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or 50 + guidance_scale = req.sampling_params.guidance_scale or 1.5 # 0. Validate inputs self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds) @@ -851,13 +885,13 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Set seed if provided generator = None - if req.seed is not None: - generator = torch.Generator(device=self.device).manual_seed(req.seed) + if req.sampling_params.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) # 1. Get prior tokens - either from external source (multistage) or generate internally # Check if prior_token_ids are provided externally (from AR stage in multistage mode) - external_prior_tokens = req.extra.get("prior_token_ids") if req.extra else None - external_prior_image_ids = req.extra.get("prior_token_image_ids") if req.extra else None + external_prior_tokens = req.sampling_params.extra_args.get("prior_token_ids") + external_prior_image_ids = req.sampling_params.extra_args.get("prior_token_image_ids") if external_prior_tokens is not None: # Multistage mode: use externally provided prior tokens from vLLM AR stage diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index e918afe0b2e..8b616ec45f4 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -451,7 +451,7 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + elif prompt is not None and not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: @@ -470,11 +470,11 @@ def forward( num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float = 4.5, - num_images_per_prompt: int | None = 1, + num_images_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.FloatTensor | None = None, - prompt_embeds: torch.FloatTensor | None = None, - negative_prompt_embeds: torch.FloatTensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, output_type: str | None = "pil", return_dict: bool = True, joint_attention_kwargs: dict[str, Any] | None = None, @@ -482,20 +482,44 @@ def forward( cfg_renorm_min: float | None = 0.0, enable_prompt_rewrite: bool | None = True, ) -> DiffusionOutput: - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt - - height = req.height or height or self.default_sample_size * self.vae_scale_factor - width = req.width or width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.num_inference_steps or num_inference_steps - generator = req.generator or generator - guidance_scale = req.guidance_scale if getattr(req, "guidance_scale", None) is not None else guidance_scale - num_images_per_prompt = getattr(req, "num_outputs_per_prompt", None) or num_images_per_prompt - enable_prompt_rewrite = getattr(req, "enable_prompt_rewrite", None) or enable_prompt_rewrite - enable_cfg_renorm = getattr(req, "enable_cfg_renorm", None) or enable_cfg_renorm - cfg_renorm_min = getattr(req, "cfg_renorm_min", None) or cfg_renorm_min - prompt_embeds = getattr(req, "prompt_embeds", None) or prompt_embeds - negative_prompt_embeds = getattr(req, "negative_prompt_embeds", None) or negative_prompt_embeds + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + generator = req.sampling_params.generator or generator + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt is not None + else num_images_per_prompt + ) + enable_prompt_rewrite = req.sampling_params.extra_args.get("enable_prompt_rewrite", enable_prompt_rewrite) + enable_cfg_renorm = req.sampling_params.extra_args.get("enable_cfg_renorm", enable_cfg_renorm) + cfg_renorm_min = req.sampling_params.extra_args.get("cfg_renorm_min", cfg_renorm_min) + + req_prompt_embeds = [p.get("prompt_embeds") if not isinstance(p, str) else None for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + # If at list one prompt is provided as an embedding, + # Then assume that the user wants to provide embeddings for all prompts, and enter this if block + # If the user in fact provides mixed input format, req_prompt_embeds will have some None's + # And `torch.stack` automatically raises an exception for us + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError + + req_negative_prompt_embeds = [ + p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts + ] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError self.check_inputs( prompt, diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index 607553aee03..f2c3fd648ee 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -6,7 +6,7 @@ import os import re from collections.abc import Iterable -from typing import Any +from typing import Any, cast import numpy as np import PIL.Image @@ -34,6 +34,7 @@ ) from vllm_omni.diffusion.models.longcat_image.pipeline_longcat_image import calculate_shift from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -59,22 +60,38 @@ def get_longcat_image_edit_pre_process_func( latent_channels = vae_config.get("latent_channels", 16) def pre_process_func( - requests: list[OmniDiffusionRequest], + request: OmniDiffusionRequest, ): - """Pre-process requests for QwenImageEditPipeline.""" - for req in requests: - image = req.pil_image + """Pre-process requests for LongCatImageEditPipeline.""" + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} - image_size = image[0].size if isinstance(image, list) else image.size + if raw_image is None or isinstance(raw_image, list): + raise ValueError( + """Received no image or a list of image. Only a single image is supported by this model.""" + """Please correctly set `"multi_modal_data": {"image": , …}`""" + ) + + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor | np.ndarray, raw_image) + + image_size = image.size calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) - height = req.height or calculated_height - width = req.width or calculated_width + height = request.sampling_params.height or calculated_height + width = request.sampling_params.width or calculated_width # Store calculated dimensions in request - req.calculated_height = calculated_height - req.calculated_width = calculated_width - req.height = height - req.width = width + prompt["additional_information"]["calculated_height"] = calculated_height + prompt["additional_information"]["calculated_width"] = calculated_width + request.sampling_params.height = height + request.sampling_params.width = width # Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels): @@ -83,9 +100,10 @@ def pre_process_func( image = image_processor.preprocess(image, calculated_height, calculated_width) # Store preprocessed image and prompt image in request - req.preprocessed_image = image - req.prompt_image = prompt_image - return requests + prompt["additional_information"]["preprocessed_image"] = image + prompt["additional_information"]["prompt_image"] = prompt_image + request.prompts[i] = prompt + return request return pre_process_func @@ -500,42 +518,59 @@ def forward( self, req: OmniDiffusionRequest, image: PIL.Image.Image | torch.Tensor | None = None, - prompt: str | list[str] = None, - negative_prompt: str | list[str] = None, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float = 3.5, num_images_per_prompt: int | None = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.FloatTensor | None = None, - prompt_embeds: torch.FloatTensor | None = None, - negative_prompt_embeds: torch.FloatTensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, output_type: str | None = "pil", return_dict: bool = True, joint_attention_kwargs: dict[str, Any] | None = None, ): - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt - negative_prompt = "" if negative_prompt is None else negative_prompt - guidance_scale = req.guidance_scale if req.guidance_scale is not None else guidance_scale - num_inference_steps = req.num_inference_steps or num_inference_steps - num_images_per_prompt = getattr(req, "num_outputs_per_prompt", None) or num_images_per_prompt - generator = req.generator or generator - prompt_embeds = getattr(req, "prompt_embeds", None) or prompt_embeds - negative_prompt_embeds = getattr(req, "negative_prompt_embeds", None) or negative_prompt_embeds - height = req.height or self.default_sample_size * self.vae_scale_factor - width = req.width or self.default_sample_size * self.vae_scale_factor + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("prompt_embeds") + negative_prompt_embeds = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt_embeds") # type: ignore # Why it is list[torch.Tensor] in OmniTokenInputs or OmniEmbedsPrompt? Doesn't make sense + + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt is not None + else num_images_per_prompt + ) + generator = req.sampling_params.generator or generator + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor if prompt is not None: batch_size = 1 if isinstance(prompt, str) else len(prompt) else: batch_size = prompt_embeds.shape[0] - if hasattr(req, "preprocessed_image"): - prompt_image = req.prompt_image - image = req.preprocessed_image - calculated_height = req.calculated_height if hasattr(req, "calculated_height") else height - calculated_width = req.calculated_width if hasattr(req, "calculated_width") else width + if not isinstance(first_prompt, str) and "preprocessed_image" in ( + additional_information := first_prompt.get("additional_information", {}) + ): + prompt_image = additional_information.get("prompt_image") + image = additional_information.get("preprocessed_image") + calculated_height = additional_information.get("calculated_height", height) + calculated_width = additional_information.get("calculated_width", width) else: image_size = image[0].size if isinstance(image, list) else image.size calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) diff --git a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py index 660338a3565..a5b583ef535 100644 --- a/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py +++ b/vllm_omni/diffusion/models/ovis_image/pipeline_ovis_image.py @@ -520,8 +520,8 @@ def interrupt(self): def forward( self, req: OmniDiffusionRequest, - prompt: str | list[str] = None, - negative_prompt: str | list[str] = None, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, guidance_scale: float = 5.0, height: int | None = None, width: int | None = None, @@ -607,16 +607,27 @@ def forward( [`~pipelines.ovis_image.OvisImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt - height = req.height or self.default_sample_size * self.vae_scale_factor - width = req.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.num_inference_steps or num_inference_steps - guidance_scale = req.guidance_scale if req.guidance_scale is not None else guidance_scale - generator = req.generator or generator - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) - if req_num_outputs and req_num_outputs > 0: - num_images_per_prompt = req_num_outputs + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale + ) + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) # Steps: # 1. Check Inputs diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 2b9d906d3d4..e0a37b8bc8c 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -662,21 +662,28 @@ def forward( callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 512, ) -> DiffusionOutput: - # # TODO: only support single prompt now - # if req.prompt is not None: - # prompt = req.prompt[0] if isinstance(req.prompt, list) else req.prompt - prompt = req.prompt - negative_prompt = req.negative_prompt - height = req.height or self.default_sample_size * self.vae_scale_factor - width = req.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.num_inference_steps or num_inference_steps - generator = req.generator or generator - true_cfg_scale = req.true_cfg_scale or true_cfg_scale - if req.guidance_scale_provided: - guidance_scale = req.guidance_scale - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) - if req_num_outputs and req_num_outputs > 0: - num_images_per_prompt = req_num_outputs + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) # 1. check inputs # 2. encode prompts # 3. prepare latents and timesteps diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index ac74507c791..ba65bf07e36 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -7,7 +7,7 @@ import math import os from collections.abc import Iterable -from typing import Any +from typing import Any, cast import numpy as np import PIL.Image @@ -39,6 +39,7 @@ ) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -64,16 +65,33 @@ def get_qwen_image_edit_pre_process_func( latent_channels = vae_config.get("z_dim", 16) def pre_process_func( - requests: list[OmniDiffusionRequest], + request: OmniDiffusionRequest, ): """Pre-process requests for QwenImageEditPipeline.""" - for req in requests: - image = req.pil_image + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + # Only handles single image + if raw_image is None or isinstance(raw_image, list): + raise ValueError( + """Received no image or a list of image. Only a single image is supported by this model.""" + """Please correctly set `"multi_modal_data": {"image": , …}`""" + ) - image_size = image[0].size if isinstance(image, list) else image.size + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor | np.ndarray, raw_image) + + image_size = image.size calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) - height = req.height or calculated_height - width = req.width or calculated_width + height = request.sampling_params.height or calculated_height + width = request.sampling_params.width or calculated_width # Ensure dimensions are multiples of vae_scale_factor * 2 multiple_of = vae_scale_factor * 2 @@ -81,10 +99,10 @@ def pre_process_func( width = width // multiple_of * multiple_of # Store calculated dimensions in request - req.calculated_height = calculated_height - req.calculated_width = calculated_width - req.height = height - req.width = width + prompt["additional_information"]["calculated_height"] = calculated_height + prompt["additional_information"]["calculated_width"] = calculated_width + request.sampling_params.height = height + request.sampling_params.width = width # Preprocess image if image is not None and not ( @@ -96,10 +114,10 @@ def pre_process_func( image = image.unsqueeze(2) # Store preprocessed image and prompt image in request - req.preprocessed_image = image - req.prompt_image = prompt_image - - return requests + prompt["additional_information"]["preprocessed_image"] = image + prompt["additional_information"]["prompt_image"] = prompt_image + request.prompts[i] = prompt + return request return pre_process_func @@ -717,17 +735,27 @@ def forward( max_sequence_length: int = 512, ) -> DiffusionOutput: """Forward pass for image editing.""" - prompt = req.prompt - negative_prompt = req.negative_prompt + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") # Get preprocessed image from request (pre-processing is done in DiffusionEngine) - if hasattr(req, "preprocessed_image"): - prompt_image = req.prompt_image - image = req.preprocessed_image - calculated_height = req.calculated_height - calculated_width = req.calculated_width - height = req.height - width = req.width + if not isinstance(first_prompt, str) and "preprocessed_image" in ( + additional_information := first_prompt.get("additional_information", {}) + ): + prompt_image = additional_information.get("prompt_image") + image = additional_information.get("preprocessed_image") + calculated_height = additional_information.get("calculated_height") + calculated_width = additional_information.get("calculated_width") + height = req.sampling_params.height + width = req.sampling_params.width else: # fallback to run pre-processing in pipeline (debug only) image_size = image[0].size if isinstance(image, list) else image.size @@ -745,14 +773,18 @@ def forward( image = self.image_processor.preprocess(image, calculated_height, calculated_width) image = image.unsqueeze(2) - num_inference_steps = req.num_inference_steps or num_inference_steps - generator = req.generator or generator - true_cfg_scale = req.true_cfg_scale or true_cfg_scale - if req.guidance_scale_provided: - guidance_scale = req.guidance_scale - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) - if req_num_outputs and req_num_outputs > 0: - num_images_per_prompt = req_num_outputs + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) # 1. check inputs # 2. encode prompts diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 4e2fe35f19c..52f1d64eca3 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -5,7 +5,7 @@ import logging import os from collections.abc import Iterable -from typing import Any +from typing import Any, cast import numpy as np import PIL.Image @@ -42,6 +42,7 @@ ) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -70,24 +71,33 @@ def get_qwen_image_edit_plus_pre_process_func( latent_channels = vae_config.get("z_dim", 16) def pre_process_func( - requests: list[OmniDiffusionRequest], + request: OmniDiffusionRequest, ): """Pre-process requests for QwenImageEditPlusPipeline.""" - for req in requests: - image = req.pil_image + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} # Handle single image or list of images - if image is None: + if raw_image is None: continue - if not isinstance(image, list): - image = [image] + if not isinstance(raw_image, list): + raw_image = [raw_image] + image = [ + PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im) + for im in raw_image + ] # Calculate dimensions based on first image image_size = image[0].size calculated_width, calculated_height = calculate_dimensions(VAE_IMAGE_SIZE, image_size[0] / image_size[1]) - height = req.height or calculated_height - width = req.width or calculated_width + height = request.sampling_params.height or calculated_height + width = request.sampling_params.width or calculated_width # Ensure dimensions are multiples of vae_scale_factor * 2 multiple_of = vae_scale_factor * 2 @@ -95,10 +105,10 @@ def pre_process_func( width = width // multiple_of * multiple_of # Store calculated dimensions in request - req.calculated_height = calculated_height - req.calculated_width = calculated_width - req.height = height - req.width = width + prompt["additional_information"]["calculated_height"] = calculated_height + prompt["additional_information"]["calculated_width"] = calculated_width + request.sampling_params.height = height + request.sampling_params.width = width # Preprocess images into condition_images (for prompt encoding) and vae_images (for VAE encoding) condition_images = [] @@ -124,12 +134,12 @@ def pre_process_func( vae_images.append(image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) # Store preprocessed images in request - req.condition_images = condition_images - req.vae_images = vae_images - req.condition_image_sizes = condition_image_sizes - req.vae_image_sizes = vae_image_sizes - - return requests + prompt["additional_information"]["condition_images"] = condition_images + prompt["additional_information"]["vae_images"] = vae_images + prompt["additional_information"]["condition_image_sizes"] = condition_image_sizes + prompt["additional_information"]["vae_image_sizes"] = vae_image_sizes + request.prompts[i] = prompt + return request return pre_process_func @@ -281,7 +291,7 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor def _get_qwen_prompt_embeds( self, - prompt: str | list[str] = None, + prompt: str | list[str], image: list[torch.Tensor] | torch.Tensor | None = None, dtype: torch.dtype | None = None, ): @@ -655,19 +665,31 @@ def forward( max_sequence_length: int = 512, ) -> DiffusionOutput: """Forward pass for image editing with support for multiple images.""" - prompt = req.prompt - negative_prompt = req.negative_prompt + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") # Get preprocessed images from request (pre-processing is done in DiffusionEngine) - if hasattr(req, "vae_images") and hasattr(req, "condition_images"): - condition_images = req.condition_images - vae_images = req.vae_images - condition_image_sizes = req.condition_image_sizes - vae_image_sizes = req.vae_image_sizes - calculated_height = req.calculated_height - calculated_width = req.calculated_width - height = req.height - width = req.width + if ( + not isinstance(first_prompt, str) + and "vae_images" in (additional_information := first_prompt.get("additional_information", {})) + and "condition_images" in additional_information + ): + condition_images = additional_information.get("condition_images") + vae_images = additional_information.get("vae_images") + condition_image_sizes = additional_information.get("condition_image_sizes") + vae_image_sizes = additional_information.get("vae_image_sizes") + calculated_height = additional_information.get("calculated_height") + calculated_width = additional_information.get("calculated_width") + height = req.sampling_params.height + width = req.sampling_params.width else: # fallback to run pre-processing in pipeline (debug only) if image is None: @@ -701,14 +723,18 @@ def forward( condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) - num_inference_steps = req.num_inference_steps or num_inference_steps - generator = req.generator or generator - true_cfg_scale = req.true_cfg_scale or true_cfg_scale - if req.guidance_scale_provided: - guidance_scale = req.guidance_scale - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) - if req_num_outputs and req_num_outputs > 0: - num_images_per_prompt = req_num_outputs + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) # 1. check inputs # 2. encode prompts diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index ae2d039e339..ae27ba708a4 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -7,7 +7,7 @@ import math import os from collections.abc import Iterable -from typing import Any +from typing import Any, cast import numpy as np import PIL.Image @@ -38,6 +38,7 @@ ) from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs +from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -69,17 +70,35 @@ def get_qwen_image_layered_pre_process_func( image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) def pre_process_func( - requests: list[OmniDiffusionRequest], + request: OmniDiffusionRequest, ): """Pre-process requests for QwenImageLayeredPipeline.""" - for req in requests: - image = req.pil_image + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None or isinstance(raw_image, list): + raise ValueError( + """Received no image or a list of image. Only a single image is supported by this model.""" + """Please correctly set `"multi_modal_data": {"image": , …}`""" + ) + + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor | np.ndarray, raw_image) # 1. calculate dimensions - image_size = image[0].size if isinstance(image, list) else image.size - assert req.resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {req.resolution}" + image_size = image.size + assert request.sampling_params.resolution in [640, 1024], ( + f"resolution must be either 640 or 1024, but got {request.sampling_params.resolution}" + ) calculated_width, calculated_height = calculate_dimensions( - req.resolution * req.resolution, image_size[0] / image_size[1] + request.sampling_params.resolution * request.sampling_params.resolution, image_size[0] / image_size[1] ) height = calculated_height width = calculated_width @@ -89,10 +108,10 @@ def pre_process_func( height = height // multiple_of * multiple_of # Store calculated dimensions in request - req.calculated_height = calculated_height - req.calculated_width = calculated_width - req.height = height - req.width = width + prompt["additional_information"]["calculated_height"] = calculated_height + prompt["additional_information"]["calculated_width"] = calculated_width + request.sampling_params.height = height + request.sampling_params.width = width # 2. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels): @@ -103,10 +122,10 @@ def pre_process_func( # image = image.to(dtype=self.text_encoder.dtype) # do it later # Store preprocessed image and prompt image in request - req.preprocessed_image = image - req.prompt_image = prompt_image - - return requests + prompt["additional_information"]["preprocessed_image"] = image + prompt["additional_information"]["prompt_image"] = prompt_image + request.prompts[i] = prompt + return request return pre_process_func @@ -689,7 +708,7 @@ def forward( req: OmniDiffusionRequest, image: PIL.Image.Image | torch.Tensor | None = None, prompt: str | list[str] | None = None, - negative_prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, true_cfg_scale: float = 4.0, layers: int | None = 4, num_inference_steps: int = 50, @@ -713,29 +732,48 @@ def forward( # 1. Get preprocessed image from request (pre-processing is done in DiffusionEngine) # Override parameters from request if provided - prompt = req.prompt if req.prompt is not None else prompt - layers = req.layers if req.layers is not None else layers - resolution = req.resolution if req.resolution is not None else resolution - cfg_normalize = req.cfg_normalize if req.cfg_normalize is not None else cfg_normalize - use_en_prompt = req.use_en_prompt if req.use_en_prompt is not None else use_en_prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt - num_inference_steps = req.num_inference_steps or num_inference_steps - generator = req.generator or generator - true_cfg_scale = req.true_cfg_scale or true_cfg_scale - if req.guidance_scale_provided: - guidance_scale = req.guidance_scale - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) - if req_num_outputs and req_num_outputs > 0: - num_images_per_prompt = req_num_outputs - - if hasattr(req, "preprocessed_image"): - prompt_image = req.prompt_image - image = req.preprocessed_image + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + if len(req.prompts) > 1: + logger.warning( + """This model only supports a single prompt, not a batched request.""", + """Taking only the first image for now.""", + ) + first_prompt = req.prompts[0] + prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "") + negative_prompt = None if isinstance(first_prompt, str) else first_prompt.get("negative_prompt") + + layers = req.sampling_params.layers if req.sampling_params.layers is not None else layers + resolution = req.sampling_params.resolution if req.sampling_params.resolution is not None else resolution + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + cfg_normalize = ( + req.sampling_params.cfg_normalize if req.sampling_params.cfg_normalize is not None else cfg_normalize + ) + use_en_prompt = ( + req.sampling_params.use_en_prompt if req.sampling_params.use_en_prompt is not None else use_en_prompt + ) + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + + if not isinstance(first_prompt, str) and "preprocessed_image" in ( + additional_information := first_prompt.get("additional_information", {}) + ): + prompt_image = additional_information.get("prompt_image") + image = additional_information.get("preprocessed_image") image = image.to(dtype=self.text_encoder.dtype) # Now we get the type - calculated_height = req.calculated_height - calculated_width = req.calculated_width - height = req.height - width = req.width + calculated_height = additional_information.get("calculated_height") + calculated_width = additional_information.get("calculated_width") + height = req.sampling_params.height + width = req.sampling_params.width else: # fallback to run pre-processing in pipeline (debug only) image_size = image[0].size if isinstance(image, list) else image.size diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py index dd17af8834f..34a0eb6c140 100644 --- a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -562,19 +562,24 @@ def forward( negative_pooled_prompt_embeds: torch.Tensor | None = None, max_sequence_length: int = 256, ) -> DiffusionOutput: - # # TODO: only support single prompt now - # if req.prompt is not None: - # prompt = req.prompt[0] if isinstance(req.prompt, list) else req.prompt - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt - height = req.height or self.default_sample_size * self.vae_scale_factor - width = req.width or self.default_sample_size * self.vae_scale_factor - sigmas = req.sigmas or sigmas - num_inference_steps = req.num_inference_steps or num_inference_steps - generator = req.generator or generator - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) - if req_num_outputs and req_num_outputs > 0: - num_images_per_prompt = req_num_outputs + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + negative_prompt = [ + "" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts + ] or negative_prompt + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + generator = req.sampling_params.generator or generator + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) # 1. check inputs # 2. encode prompts # 3. prepare latents and timesteps @@ -595,7 +600,7 @@ def forward( max_sequence_length=max_sequence_length, ) - self._guidance_scale = req.guidance_scale + self._guidance_scale = req.sampling_params.guidance_scale self._current_timestep = None self._interrupt = False diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index 14516a9bb3b..f605c86988f 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -332,7 +332,7 @@ def prepare_latents( sample_size: int, dtype: torch.dtype, device: torch.device, - generator: torch.Generator | None, + generator: torch.Generator | list[torch.Generator] | None, latents: torch.Tensor | None = None, ) -> torch.Tensor: """Prepare initial latent noise.""" @@ -358,7 +358,7 @@ def forward( num_inference_steps: int = 100, guidance_scale: float = 7.0, num_waveforms_per_prompt: int = 1, - generator: torch.Generator | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, @@ -386,20 +386,26 @@ def forward( DiffusionOutput containing generated audio """ # Extract from request - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt - num_inference_steps = req.num_inference_steps or num_inference_steps - if req.guidance_scale_provided: - guidance_scale = req.guidance_scale + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale if generator is None: - generator = req.generator - if generator is None and req.seed is not None: - generator = torch.Generator(device=self.device).manual_seed(req.seed) + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) # Get audio duration from request extra params or defaults - audio_start_in_s = req.extra.get("audio_start_in_s", audio_start_in_s) - audio_end_in_s = req.extra.get("audio_end_in_s", audio_end_in_s) + audio_start_in_s = req.sampling_params.extra_args.get("audio_start_in_s", audio_start_in_s) + audio_end_in_s = req.sampling_params.extra_args.get("audio_end_in_s", audio_end_in_s) # Calculate audio length downsample_ratio = self.vae.hop_length diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 8ff6271a019..27017f4a226 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -4,8 +4,10 @@ from __future__ import annotations import json +import logging import os from collections.abc import Iterable +from typing import cast import PIL.Image import torch @@ -21,6 +23,9 @@ from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt + +logger = logging.getLogger(__name__) def retrieve_latents( @@ -127,39 +132,54 @@ def get_wan22_pre_process_func( video_processor = VideoProcessor(vae_scale_factor=8) - def pre_process_func(requests: list[OmniDiffusionRequest]) -> list[OmniDiffusionRequest]: - for req in requests: - # Load image if path is provided - if req.image_path is not None and req.pil_image is None: - req.pil_image = PIL.Image.open(req.image_path).convert("RGB") - - if req.pil_image is not None: - image = req.pil_image - - # Calculate dimensions based on aspect ratio if not provided - if req.height is None or req.width is None: - # Default max area for 720P - max_area = 720 * 1280 - aspect_ratio = image.height / image.width - - # Calculate dimensions maintaining aspect ratio - mod_value = 16 # Must be divisible by 16 - height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value - width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value - - if req.height is None: - req.height = height - if req.width is None: - req.width = width - - # Resize image to target dimensions - image = image.resize((req.width, req.height), PIL.Image.Resampling.LANCZOS) - req.pil_image = image - - # Preprocess for VAE - req.preprocessed_image = video_processor.preprocess(image, height=req.height, width=req.width) + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + continue + + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"""Unsupported image format {raw_image.__class__}.""", + """Please correctly set `"multi_modal_data": {"image": , …}`""", + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 720P + max_area = 720 * 1280 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore # Above has ensured that width & height are not None + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore # key existence already checked above - return requests + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=request.sampling_params.height, width=request.sampling_params.width + ) + request.prompts[i] = prompt + return request return pre_process_func @@ -290,20 +310,27 @@ def forward( guidance_scale: float | tuple[float, float] = 4.0, frame_num: int = 81, output_type: str | None = "np", - generator: torch.Generator | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, attention_kwargs: dict | None = None, **kwargs, ) -> DiffusionOutput: - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt + # Get parameters from request or arguments + if len(req.prompts) > 1: + raise ValueError( + """This model only supports a single prompt, not a batched request.""", + """Please pass in a single prompt object or string, or a single-item list.""", + ) + if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list + prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") + negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") if prompt is None and prompt_embeds is None: raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") - height = req.height or height - width = req.width or width - num_frames = req.num_frames if req.num_frames else frame_num + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num # Ensure dimensions are compatible with VAE and patch size # For expand_timesteps mode, we need latent dims to be even (divisible by patch_size) @@ -311,16 +338,16 @@ def forward( mod_value = self.vae_scale_factor_spatial * patch_size[1] # 16*2=32 for TI2V, 8*2=16 for I2V height = (height // mod_value) * mod_value width = (width // mod_value) * mod_value - num_steps = req.num_inference_steps or num_inference_steps + num_steps = req.sampling_params.num_inference_steps or num_inference_steps # Respect per-request guidance_scale when explicitly provided. - if req.guidance_scale_provided: - guidance_scale = req.guidance_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] guidance_high = ( - req.guidance_scale_2 - if req.guidance_scale_2 is not None + req.sampling_params.guidance_scale_2 + if req.sampling_params.guidance_scale_2 is not None else ( guidance_scale[1] if isinstance(guidance_scale, (list, tuple)) and len(guidance_scale) > 1 @@ -352,9 +379,9 @@ def forward( # Seed / generator if generator is None: - generator = req.generator - if generator is None and req.seed is not None: - generator = torch.Generator(device=device).manual_seed(req.seed) + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) # Encode prompts if prompt_embeds is None: @@ -362,8 +389,8 @@ def forward( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, - num_videos_per_prompt=req.num_outputs_per_prompt or 1, - max_sequence_length=req.max_sequence_length or 512, + num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -385,7 +412,22 @@ def forward( boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps # Handle I2V mode when expand_timesteps=True and image is provided - image = req.pil_image + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if raw_image is None: + image = None + elif isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) + latent_condition = None first_frame_mask = None @@ -416,7 +458,7 @@ def forward( dtype=torch.float32, device=device, generator=generator, - latents=req.latents, + latents=req.sampling_params.latents, ) # Encode image condition @@ -458,7 +500,7 @@ def forward( dtype=torch.float32, device=device, generator=generator, - latents=req.latents, + latents=req.sampling_params.latents, ) if attention_kwargs is None: diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 635da11e6c4..0e8158671f4 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -3,8 +3,10 @@ from __future__ import annotations +import logging import os from collections.abc import Iterable +from typing import cast import numpy as np import PIL.Image @@ -26,6 +28,9 @@ retrieve_latents, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt + +logger = logging.getLogger(__name__) def _load_model_index(model: str, local_files_only: bool) -> dict: @@ -77,39 +82,56 @@ def get_wan22_i2v_pre_process_func( video_processor = VideoProcessor(vae_scale_factor=8) - def pre_process_func(requests: list[OmniDiffusionRequest]) -> list[OmniDiffusionRequest]: - for req in requests: - # Load image if path is provided - if req.image_path is not None and req.pil_image is None: - req.pil_image = PIL.Image.open(req.image_path).convert("RGB") - - if req.pil_image is not None: - image = req.pil_image - - # Calculate dimensions based on aspect ratio if not provided - if req.height is None or req.width is None: - # Default max area for 480P - max_area = 480 * 832 - aspect_ratio = image.height / image.width - - # Calculate dimensions maintaining aspect ratio - mod_value = 16 # Must be divisible by 16 - height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value - width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value - - if req.height is None: - req.height = height - if req.width is None: - req.width = width - - # Resize image to target dimensions - image = image.resize((req.width, req.height), PIL.Image.Resampling.LANCZOS) - req.pil_image = image - - # Preprocess for VAE - req.preprocessed_image = video_processor.preprocess(image, height=req.height, width=req.width) + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + raise ValueError( + """No image is provided. This model requires an image to run.""", + """Please correctly set `"multi_modal_data": {"image": , …}`""", + ) + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"""Unsupported image format {raw_image.__class__}.""", + """Please correctly set `"multi_modal_data": {"image": , …}`""", + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 480P + max_area = 480 * 832 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore # Above has ensured that width & height are not None + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore # key existence already checked above - return requests + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=request.sampling_params.height, width=request.sampling_params.width + ) + request.prompts[i] = prompt + return request return pre_process_func @@ -267,7 +289,7 @@ def forward( guidance_scale: float | tuple[float, float] = 5.0, frame_num: int = 81, output_type: str | None = "np", - generator: torch.Generator | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, @@ -276,31 +298,51 @@ def forward( **kwargs, ) -> DiffusionOutput: # Get parameters from request or arguments - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt + if len(req.prompts) > 1: + raise ValueError( + """This model only supports a single prompt, not a batched request.""", + """Please pass in a single prompt object or string, or a single-item list.""", + ) + if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list + prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") + negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") if prompt is None and prompt_embeds is None: - raise ValueError("Prompt or prompt_embeds is required for Wan2.2 I2V generation.") + raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") # Get image from request if image is None: - image = req.pil_image - if image is None: - raise ValueError("Image is required for I2V generation.") + multi_modal_data = ( + req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + ) + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if raw_image is None: + raise ValueError("Image is required for I2V generation.") + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) - height = req.height or height - width = req.width or width - num_frames = req.num_frames if req.num_frames else frame_num - num_steps = req.num_inference_steps or num_inference_steps + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_frames = req.sampling_params.num_frames or frame_num + num_steps = req.sampling_params.num_inference_steps or num_inference_steps # Respect per-request guidance_scale when explicitly provided. - if req.guidance_scale_provided: - guidance_scale = req.guidance_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale # Handle guidance scales guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] guidance_high = ( - req.guidance_scale_2 - if req.guidance_scale_2 is not None + req.sampling_params.guidance_scale_2 + if req.sampling_params.guidance_scale_2 is not None else ( guidance_scale[1] if isinstance(guidance_scale, (list, tuple)) and len(guidance_scale) > 1 @@ -334,9 +376,9 @@ def forward( # Generator setup if generator is None: - generator = req.generator - if generator is None and req.seed is not None: - generator = torch.Generator(device=device).manual_seed(req.seed) + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) # Encode prompts if prompt_embeds is None: @@ -344,8 +386,8 @@ def forward( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, - num_videos_per_prompt=req.num_outputs_per_prompt or 1, - max_sequence_length=req.max_sequence_length or 512, + num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -411,7 +453,7 @@ def forward( dtype=torch.float32, device=device, generator=generator, - latents=req.latents, + latents=req.sampling_params.latents, last_image=last_image_tensor, ) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 6c1e39242f3..b9894063a37 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -16,8 +16,10 @@ from __future__ import annotations +import logging import os from collections.abc import Iterable +from typing import cast import numpy as np import PIL.Image @@ -39,6 +41,9 @@ retrieve_latents, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt + +logger = logging.getLogger(__name__) def get_wan22_ti2v_post_process_func( @@ -67,39 +72,56 @@ def get_wan22_ti2v_pre_process_func( video_processor = VideoProcessor(vae_scale_factor=8) - def pre_process_func(requests: list[OmniDiffusionRequest]) -> list[OmniDiffusionRequest]: - for req in requests: - # Load image if path is provided - if req.image_path is not None and req.pil_image is None: - req.pil_image = PIL.Image.open(req.image_path).convert("RGB") - - if req.pil_image is not None: - image = req.pil_image - - # Calculate dimensions based on aspect ratio if not provided - if req.height is None or req.width is None: - # Default max area for 720P (TI2V-5B default) - max_area = 720 * 1280 - aspect_ratio = image.height / image.width - - # Calculate dimensions maintaining aspect ratio - mod_value = 16 # Must be divisible by 16 - height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value - width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value - - if req.height is None: - req.height = height - if req.width is None: - req.width = width - - # Resize image to target dimensions - image = image.resize((req.width, req.height), PIL.Image.Resampling.LANCZOS) - req.pil_image = image - - # Preprocess for VAE - req.preprocessed_image = video_processor.preprocess(image, height=req.height, width=req.width) + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + raise ValueError( + """No image is provided. This model requires an image to run.""", + """Please correctly set `"multi_modal_data": {"image": , …}`""", + ) + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"""Unsupported image format {raw_image.__class__}.""", + """Please correctly set `"multi_modal_data": {"image": , …}`""", + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 720P (TI2V-5B default) + max_area = 720 * 1280 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore # Above has ensured that width & height are not None + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore # key existence already checked above - return requests + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=request.sampling_params.height, width=request.sampling_params.width + ) + request.prompts[i] = prompt + return request return pre_process_func @@ -206,31 +228,53 @@ def forward( guidance_scale: float = 5.0, frame_num: int = 81, output_type: str | None = "np", - generator: torch.Generator | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, attention_kwargs: dict | None = None, **kwargs, ) -> DiffusionOutput: # Get parameters from request or arguments - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt + if len(req.prompts) > 1: + raise ValueError( + """This model only supports a single prompt, not a batched request.""", + """Please pass in a single prompt object or string, or a single-item list.""", + ) + if len(req.prompts) == 1: # If req.prompt is empty, default to prompt & neg_prompt in param list + prompt = req.prompts[0] if isinstance(req.prompts[0], str) else req.prompts[0].get("prompt") + negative_prompt = None if isinstance(req.prompts[0], str) else req.prompts[0].get("negative_prompt") if prompt is None and prompt_embeds is None: - raise ValueError("Prompt or prompt_embeds is required for Wan2.2 TI2V generation.") + raise ValueError("Prompt or prompt_embeds is required for Wan2.2 generation.") # Get image from request (optional for TI2V) if image is None: - image = req.pil_image + multi_modal_data = ( + req.prompts[0].get("multi_modal_data", {}) if not isinstance(req.prompts[0], str) else None + ) + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(raw_image, list): + if len(raw_image) > 1: + logger.warning( + """Received a list of image. Only a single image is supported by this model.""" + """Taking only the first image for now.""" + ) + raw_image = raw_image[0] + if raw_image is None: + image = None + elif isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) # Default dimensions for TI2V-5B (720P) - height = req.height or height - width = req.width or width - num_frames = req.num_frames if req.num_frames else frame_num - num_steps = req.num_inference_steps or num_inference_steps + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num + num_steps = req.sampling_params.num_inference_steps or num_inference_steps # Respect per-request guidance_scale when explicitly provided. - if req.guidance_scale_provided: - guidance_scale = req.guidance_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale self._guidance_scale = guidance_scale @@ -255,9 +299,9 @@ def forward( # Generator setup if generator is None: - generator = req.generator - if generator is None and req.seed is not None: - generator = torch.Generator(device=device).manual_seed(req.seed) + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=device).manual_seed(req.sampling_params.seed) # Encode prompts if prompt_embeds is None: @@ -265,8 +309,8 @@ def forward( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, - num_videos_per_prompt=req.num_outputs_per_prompt or 1, - max_sequence_length=req.max_sequence_length or 512, + num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -311,7 +355,7 @@ def forward( dtype=torch.float32, device=device, generator=generator, - latents=req.latents, + latents=req.sampling_params.latents, ) else: # T2V mode: prepare random latents @@ -324,7 +368,7 @@ def forward( dtype=torch.float32, device=device, generator=generator, - latents=req.latents, + latents=req.sampling_params.latents, ) latent_condition = None first_frame_mask = torch.ones( diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 7d6fb901d75..7b69521124f 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -317,15 +317,15 @@ def forward( self, req: OmniDiffusionRequest, prompt: str | list[str] | None = None, - height: int | None = None, - width: int | None = None, + height: int = 1024, + width: int = 1024, num_inference_steps: int = 50, sigmas: list[float] | None = None, guidance_scale: float = 5.0, cfg_normalization: bool = False, cfg_truncation: float = 1.0, negative_prompt: str | list[str] | None = None, - num_images_per_prompt: int | None = 1, + num_images_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.FloatTensor | None = None, prompt_embeds: list[torch.FloatTensor] | None = None, @@ -414,16 +414,28 @@ def forward( `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - prompt = req.prompt - negative_prompt = req.negative_prompt - height: int = req.height or 1024 - width: int = req.width or 1024 - num_inference_steps = req.num_inference_steps or 50 - generator = req.generator - guidance_scale = req.guidance_scale if req.guidance_rescale is not None else guidance_scale - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) - if req_num_outputs and req_num_outputs > 0: - num_images_per_prompt = req_num_outputs + # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") + # TODO: May be some data formatting operations on the API side. Hack for now. + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or height + width = req.sampling_params.width or width + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + generator = req.sampling_params.generator + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + guidance_scale = ( + req.sampling_params.guidance_scale if req.sampling_params.guidance_rescale is not None else guidance_scale + ) + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) vae_scale = self.vae_scale_factor * 2 if height % vae_scale != 0: diff --git a/vllm_omni/diffusion/request.py b/vllm_omni/diffusion/request.py index 89c0a79f146..a6005290cdc 100644 --- a/vllm_omni/diffusion/request.py +++ b/vllm_omni/diffusion/request.py @@ -2,14 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pprint -from dataclasses import asdict, dataclass, field -from typing import Any +from dataclasses import dataclass, field -import PIL.Image -import torch - -from vllm_omni.lora.request import LoRARequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType @dataclass @@ -17,9 +12,8 @@ class OmniDiffusionRequest: """ Complete state passed through the pipeline execution. - This dataclass contains all information needed during the diffusion pipeline - execution, allowing methods to update specific components without needing - to manage numerous individual parameters. + This dataclass contains the prompts and sampling parameters for the diffusion pipeline + execution. It also contains a request_id for other components to trace this request and its outputs. """ # TODO(will): double check that args are separate from server_args @@ -27,173 +21,24 @@ class OmniDiffusionRequest: # specific arguments. # data_type: DataType - request_id: str | None = None - - generator: torch.Generator | list[torch.Generator] | None = None - - # Image inputs - image_path: str | None = None - # Image encoder hidden states - image_embeds: list[torch.Tensor] = field(default_factory=list) - pil_image: torch.Tensor | PIL.Image.Image | None = None - pixel_values: torch.Tensor | PIL.Image.Image | None = None - preprocessed_image: torch.Tensor | None = None - - # Text inputs - prompt: str | list[str] | None = None - negative_prompt: str | list[str] | None = None - prompt_path: str | None = None - output_path: str = "outputs/" - # without extension - output_file_name: str | None = None - output_file_ext: str | None = None - # Primary encoder embeddings - prompt_embeds: list[torch.Tensor] | torch.Tensor = field(default_factory=list) - negative_prompt_embeds: list[torch.Tensor] | None = None - prompt_attention_mask: list[torch.Tensor] | None = None - negative_attention_mask: list[torch.Tensor] | None = None - clip_embedding_pos: list[torch.Tensor] | None = None - clip_embedding_neg: list[torch.Tensor] | None = None - - pooled_embeds: list[torch.Tensor] = field(default_factory=list) - neg_pooled_embeds: list[torch.Tensor] = field(default_factory=list) - - # Additional text-related parameters - max_sequence_length: int | None = None - prompt_template: dict[str, Any] | None = None - do_classifier_free_guidance: bool = False - - # Batch info - num_outputs_per_prompt: int = 1 - seed: int | None = None - seeds: list[int] | None = None - - # layered info - layers: int = 4 - - # cfg info - cfg_normalize: bool = False - - # caption language - use_en_prompt: bool = False - - # different bucket in (640, 1024) to determine the condition and output resolution - resolution: int = 640 - - # Tracking if embeddings are already processed - is_prompt_processed: bool = False - - # Latent tensors - latents: torch.Tensor | None = None - raw_latent_shape: torch.Tensor | None = None - noise_pred: torch.Tensor | None = None - image_latent: torch.Tensor | None = None - - # Latent dimensions - height_latents: list[int] | int | None = None - width_latents: list[int] | int | None = None - num_frames: list[int] | int = 1 # Default for image models - num_frames_round_down: bool = False # Whether to round down num_frames if it's not divisible by num_gpus - - # Original dimensions (before VAE scaling) - height: list[int] | int | None = None - width: list[int] | int | None = None - fps: list[int] | int | None = None - height_not_provided: bool = False - width_not_provided: bool = False - - # Timesteps - timesteps: torch.Tensor | None = None - timestep: torch.Tensor | float | int | None = None - step_index: int | None = None - boundary_ratio: float | None = None - - # Scheduler parameters - num_inference_steps: int = 50 - guidance_scale: float = 1.0 - guidance_scale_provided: bool = False - guidance_scale_2: float | None = None - guidance_rescale: float = 0.0 - eta: float = 0.0 - sigmas: list[float] | None = None - - true_cfg_scale: float | None = None # qwen-image specific now + prompts: list[OmniPromptType] # Actually supporting str-based prompts + sampling_params: OmniDiffusionSamplingParams - n_tokens: int | None = None - - # Other parameters that may be needed by specific schedulers - extra_step_kwargs: dict[str, Any] = field(default_factory=dict) - - # [Omni] KV Cache Transfer, for bagel model now - past_key_values: Any | None = None # Injected KV Cache - kv_metadata: dict[str, Any] | None = None # Metadata for KV Cache (e.g., kv_lens, ropes) - need_kv_receive: bool = True # Flag to indicate if this request expects KV transfer - - # Component modules (populated by the pipeline) - modules: dict[str, Any] = field(default_factory=dict) - - return_trajectory_latents: bool = False - return_trajectory_decoded: bool = False - trajectory_timesteps: list[torch.Tensor] | None = None - trajectory_latents: torch.Tensor | None = None - - # Extra parameters that might be needed by specific pipeline implementations - extra: dict[str, Any] = field(default_factory=dict) - - # Misc - save_output: bool = True - return_frames: bool = False - - # LoRA - lora_request: LoRARequest | None = None - lora_scale: float = 1.0 - - # STA parameters - STA_param: list | None = None - is_cfg_negative: bool = False - mask_search_final_result_pos: list[list] | None = None - mask_search_final_result_neg: list[list] | None = None - - # VSA parameters - VSA_sparsity: float = 0.0 - # perf_logger: PerformanceLogger | None = None - - # stage logging - # logging_info: PipelineLoggingInfo = field(default_factory=PipelineLoggingInfo) - - # profile - profile: bool = False - num_profiled_timesteps: int = 8 - - # debugging - debug: bool = False - - # results - output: torch.Tensor | None = None - - @property - def batch_size(self): - # Determine batch size - if isinstance(self.prompt, list): - batch_size = len(self.prompt) - elif self.prompt is not None: - batch_size = 1 - else: - batch_size = self.prompt_embeds[0].shape[0] - - # Adjust batch size for number of videos per prompt - batch_size *= self.num_outputs_per_prompt - return batch_size + request_ids: list[str] = field(default_factory=list) def __post_init__(self): """Initialize dependent fields after dataclass initialization.""" # Set do_classifier_free_guidance based on guidance scale and negative prompt - if self.guidance_scale > 1.0 and self.negative_prompt is not None: - self.do_classifier_free_guidance = True - if self.negative_prompt_embeds is None: - self.negative_prompt_embeds = [] - if self.guidance_scale_2 is None: - self.guidance_scale_2 = self.guidance_scale - - def __str__(self): - return pprint.pformat(asdict(self), indent=2, width=120) + if self.sampling_params.guidance_scale > 1.0 and any( + (not isinstance(p, str) and p.get("negative_prompt")) for p in self.prompts + ): + self.sampling_params.do_classifier_free_guidance = True + if self.sampling_params.guidance_scale_2 is None: + self.sampling_params.guidance_scale_2 = self.sampling_params.guidance_scale + + # The dataclass default value is 0 (false-like), used to detect whether user explicitly provides this value + # After this check is done, reset this value to old default 1 + if self.sampling_params.guidance_scale: + self.sampling_params.guidance_scale_provided = True + else: + self.sampling_params.guidance_scale = 1.0 diff --git a/vllm_omni/diffusion/scheduler.py b/vllm_omni/diffusion/scheduler.py index 21b0d49149b..f104d052266 100644 --- a/vllm_omni/diffusion/scheduler.py +++ b/vllm_omni/diffusion/scheduler.py @@ -41,14 +41,14 @@ def initialize_result_queue(self, handle): def get_broadcast_handle(self): return self.mq.export_handle() - def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: + def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: """Sends a request to the scheduler and waits for the response.""" try: # Prepare RPC request for generation rpc_request = { "type": "rpc", "method": "generate", - "args": (requests,), + "args": (request,), "kwargs": {}, "output_rank": 0, "exec_all_ranks": True, diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py index 54327be8169..bf56df15590 100644 --- a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py @@ -173,12 +173,13 @@ def _init_omni_connector(self) -> None: def _receive_kv_cache_for_request(self, req: OmniDiffusionRequest) -> None: """Receive KV cache for a request via OmniConnector.""" # TODO(wzliu)! must get control info from stage queue instead of hardcode - if not req.request_id: + if not req.request_ids: logger.warning("Request has no ID, cannot receive KV cache") return + request_id = req.request_ids[0] try: - logger.info(f"Attempting to receive KV cache for request {req.request_id}") + logger.info(f"Attempting to receive KV cache for request {request_id}") # TODO: Key used for transfer (must match sender side) # key = f"kv_cache_{req.request_id}" @@ -197,7 +198,7 @@ def _receive_kv_cache_for_request(self, req: OmniDiffusionRequest) -> None: from_stage = stage_id - 1 else: raise ValueError("Invalid stage id") - logger.info(f"Wait for KV cache for request {req.request_id} from stage {from_stage} to {to_stage}...") + logger.info(f"Wait for KV cache for request {request_id} from stage {from_stage} to {to_stage}...") # Check if we should receive KV cache based on config need_recv_cache = omni_kv_config.get("need_recv_cache", False) @@ -207,7 +208,7 @@ def _receive_kv_cache_for_request(self, req: OmniDiffusionRequest) -> None: start_time = time.time() while True: - get_key = f"omni_{from_stage}_to_{to_stage}_kv_cache_{req.request_id}" + get_key = f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}" result = self.connector.get( from_stage=from_stage, to_stage=to_stage, @@ -217,18 +218,18 @@ def _receive_kv_cache_for_request(self, req: OmniDiffusionRequest) -> None: break if time.time() - start_time > timeout: - logger.error(f"Timeout waiting for KV cache for request {req.request_id} after {timeout}s") + logger.error(f"Timeout waiting for KV cache for request {request_id} after {timeout}s") result = None break time.sleep(0.5) else: - logger.info(f"Skip receiving KV cache for {req.request_id} (need_recv_cache=False)") + logger.info(f"Skip receiving KV cache for {request_id} (need_recv_cache=False)") result = None if result: data, size = result - logger.info(f"Successfully received KV cache for {req.request_id}") + logger.info(f"Successfully received KV cache for {request_id}") # Assume data structure matches KVCacheTransferData.to_dict() if isinstance(data, dict) and "layer_blocks" in data: @@ -242,16 +243,16 @@ def _receive_kv_cache_for_request(self, req: OmniDiffusionRequest) -> None: cache_list[i] = tensor.to(self.pipeline.device).contiguous() from types import SimpleNamespace - req.past_key_values = SimpleNamespace(**layer_blocks) + req.sampling_params.past_key_values = SimpleNamespace(**layer_blocks) if "metadata" in data: - req.kv_metadata = data["metadata"] + req.sampling_params.kv_metadata = data["metadata"] else: - logger.warning(f"No KV cache received for {req.request_id} (timeout or empty)") + logger.warning(f"No KV cache received for {request_id} (timeout or empty)") except Exception as e: - logger.error(f"Error receiving KV cache for {req.request_id}: {e}") + logger.error(f"Error receiving KV cache for {request_id}: {e}") import traceback traceback.print_exc() @@ -261,33 +262,30 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return self.pipeline.load_weights(weights) @torch.inference_mode() - def execute_model(self, reqs: list[OmniDiffusionRequest]) -> DiffusionOutput: + def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: """ Execute a forward pass for the given requests. Args: - reqs: List of diffusion requests to process. + req: A diffusion request containing a list of prompts to process. Returns: DiffusionOutput with generated results. """ assert self.pipeline is not None, "Model not loaded. Call load_model() first." - if not reqs or len(reqs) == 0: + if len(req.prompts) == 0: raise ValueError("Cannot execute model with empty request list") - # TODO: dealing with first req for now - req = reqs[0] - # [Omni] KV Cache Receiving Logic - if getattr(req, "need_kv_receive", False) and self.connector is not None: + if req.sampling_params.need_kv_receive and self.connector is not None: self._receive_kv_cache_for_request(req) - if req.generator is None and req.seed is not None: - req.generator = torch.Generator(device=self.device).manual_seed(req.seed) + if req.sampling_params.generator is None and req.sampling_params.seed is not None: + req.sampling_params.generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) # Refresh cache context if needed if self.cache_backend is not None and self.cache_backend.is_enabled(): - self.cache_backend.refresh(self.pipeline, req.num_inference_steps) + self.cache_backend.refresh(self.pipeline, req.sampling_params.num_inference_steps) with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): with record_function("pipeline_forward"): diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_worker.py b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py index cadb6b64dc0..d8b962a58ab 100644 --- a/vllm_omni/diffusion/worker/gpu_diffusion_worker.py +++ b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py @@ -124,9 +124,9 @@ def init_device(self) -> None: ) logger.info(f"Worker {self.rank}: Initialization complete.") - def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: + def generate(self, request: OmniDiffusionRequest) -> DiffusionOutput: """Generate output for the given requests.""" - return self.execute_model(requests, self.od_config) + return self.execute_model(request, self.od_config) @classmethod def start_profile(cls, trace_path_template: str) -> str: @@ -138,37 +138,17 @@ def stop_profile(cls) -> dict | None: """Stop profiling and return the result dictionary.""" return CurrentProfiler.stop() - def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: + def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput: """Execute a forward pass by delegating to the model runner.""" assert self.model_runner is not None, "Model runner not initialized" - if self.lora_manager is not None and reqs: - req = reqs[0] - - if len(reqs) > 1: - # This worker (and the current diffusion model runner) applies - # a single LoRA to the whole batch. Reject inconsistent LoRA - # settings to avoid silently applying the wrong adapter. - def _lora_key(r: OmniDiffusionRequest): - if r.lora_request is None: - return None - lr = r.lora_request - return (lr.lora_name, lr.lora_int_id, lr.lora_path, lr.tensorizer_config_dict) - - key0 = _lora_key(req) - scale0 = req.lora_scale if key0 is not None else None - for other in reqs[1:]: - if _lora_key(other) != key0: - raise ValueError("All requests in a diffusion batch must share the same LoRARequest.") - if key0 is not None and other.lora_scale != scale0: - raise ValueError("All requests in a diffusion batch must share the same lora_scale.") - + if self.lora_manager is not None: try: - self.lora_manager.set_active_adapter(req.lora_request, req.lora_scale) + self.lora_manager.set_active_adapter(req.sampling_params.lora_request, req.sampling_params.lora_scale) except Exception as exc: - if req.lora_request is not None: + if req.sampling_params.lora_request is not None: raise logger.warning("LoRA activation skipped: %s", exc) - return self.model_runner.execute_model(reqs) + return self.model_runner.execute_model(req) def load_weights(self, weights) -> set[str]: """Load weights by delegating to the model runner.""" @@ -360,7 +340,7 @@ def worker_busy_loop(self) -> None: except Exception as e: logger.error(f"Error processing RPC: {e}", exc_info=True) if self.result_mq is not None: - self.return_result({"status": "error", "error": str(e)}) + self.return_result(DiffusionOutput(error=str(e))) elif isinstance(msg, dict) and msg.get("type") == "shutdown": logger.info("Worker %s: Received shutdown message", self.gpu_id) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 3dd687a8b49..a78710fa2fd 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -3,10 +3,10 @@ import asyncio import time import weakref -from collections.abc import AsyncGenerator, Iterable +from collections.abc import AsyncGenerator, Iterable, Sequence from dataclasses import asdict from pprint import pformat -from typing import Any +from typing import Any, cast from vllm.config import VllmConfig from vllm.inputs.preprocess import InputPreprocessor @@ -32,6 +32,7 @@ from vllm_omni.entrypoints.utils import ( get_final_stage_id_for_e2e, ) +from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams, OmniTokensPrompt # Internal imports (our code) from vllm_omni.lora.request import LoRARequest @@ -66,10 +67,8 @@ class AsyncOmni(OmniBase): asynchronous LLM and Diffusion models. Args: - *args: Variable length argument list. - - args[0]: Model name or path to load. + model: Model name or path to load. **kwargs: Arbitrary keyword arguments. - - model: Model name or path to load (if not in args). - stage_configs_path: Optional path to YAML file containing stage configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging @@ -95,7 +94,7 @@ class AsyncOmni(OmniBase): ... print(output) """ - def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: + def __init__(self, model: str, **kwargs: dict[str, Any]) -> None: # Pause/resume control attributes self._pause_cond: asyncio.Condition = asyncio.Condition() self._paused: bool = False @@ -104,7 +103,7 @@ def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: self.request_states: dict[str, ClientRequestState] = {} self.output_handler: asyncio.Task | None = None - super().__init__(*args, **kwargs) + super().__init__(model, **kwargs) # Register weak reference cleanup (called on garbage collection) self._weak_finalizer = weakref.finalize( @@ -233,7 +232,14 @@ def shutdown(self): if hasattr(self, "_weak_finalizer"): self._weak_finalizer() - async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator[OmniRequestOutput, None]: + async def generate( + self, + prompt: OmniPromptType, + request_id: str, + sampling_params_list: Sequence[OmniSamplingParams] | None = None, + *, + output_modalities: list[str] | None = None, + ) -> AsyncGenerator[OmniRequestOutput, None]: """Generate outputs for the given prompt asynchronously. Coordinates multi-stage pipeline through YAML configuration. @@ -243,21 +249,13 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator sampling parameters from the sampling_params_list. Args: - *args: Arguments for generation. - - prompt: Prompt to process. Can be a text string, token IDs, - or multimodal prompt. - - request_id: Unique identifier for this request - - sampling_params_list: List of SamplingParams, one for each stage. - Must have the same length as the number of stages. - If None, uses default sampling params for each stage. - **kwargs: Additional arguments for generation. - - prompt: Prompt to process. Can be a text string, token IDs, - or multimodal prompt. - - request_id: Unique identifier for this request - - sampling_params_list: List of SamplingParams, one for each stage. - Must have the same length as the number of stages. - If None, uses default sampling params for each stage. - - output_modalities: Optional list of output modalities. + prompt: Prompt to process. Can be a text string, token IDs, + or multimodal prompt. + request_id: Unique identifier for this request + sampling_params_list: List of SamplingParams, one for each stage. + Must have the same length as the number of stages. + If None, uses default sampling params for each stage. + output_modalities: Optional list of output modalities. Yields: OmniRequestOutput objects as they are produced by each stage. @@ -276,33 +274,9 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator # Start output handler on the first call to generate() self._run_output_handler() - prompt = args[0] if args else kwargs.get("prompt") - request_id = args[1] if len(args) > 1 else kwargs.get("request_id") - sampling_params_list = args[2] if len(args) > 2 else kwargs.get("sampling_params_list") - output_modalities = kwargs.get("output_modalities", None) # TODO: lora_request, trace_headers, priority are not supported yet - if sampling_params_list is None: - # For Omni LLM, the params are parsed via the yaml file. For the current version, - # diffusion params can parsed via the command line. - omni_params_kwargs = { - k: v for k, v in kwargs.items() if k not in ["prompt", "request_id", "output_modalities"] - } - - per_stage_params: list[Any] = [] - for stage_id, stage in enumerate(self.stage_list): - stage_type = getattr(stage, "stage_type", "llm") - if stage_type == "diffusion": - default_dict = self.default_sampling_params_list[stage_id] - # Merge user-provided kwargs - merged = {**default_dict, **omni_params_kwargs} - # Diffusion only needs to keep diff params, will be used via OmniDiffusionRequest - per_stage_params.append(merged) - else: - # LLM directly constructs SamplingParams, don't use the merged params - per_stage_params.append(self.default_sampling_params_list[stage_id]) - - sampling_params_list = per_stage_params + sampling_params_list = self.default_sampling_params_list if len(sampling_params_list) != len(self.stage_list): raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") @@ -334,11 +308,11 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator stage_queues = {stage_id: asyncio.Queue() for stage_id in range(num_stages)} req_state.stage_queues = stage_queues for i in range(num_stages): - sp: SamplingParams = sampling_params_list[i] - engine_inputs = prompt + sp: SamplingParams = cast(SamplingParams, sampling_params_list[i]) + engine_inputs = cast(OmniTokensPrompt, prompt) if i != 0: - prompt_token_ids = prompt["prompt_token_ids"] - prompt_1 = prompt.copy() + prompt_token_ids = engine_inputs["prompt_token_ids"] + prompt_1 = engine_inputs.copy() prompt_1["prompt_token_ids"] = [0] * compute_talker_prompt_ids_length(prompt_token_ids) prompt_1["multi_modal_data"] = prompt_1["mm_processor_kwargs"] = None engine_inputs = prompt_1 diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 6d34e3445f4..535f04f7d2e 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -12,16 +12,15 @@ import uuid from collections.abc import AsyncGenerator, Iterable from concurrent.futures import ThreadPoolExecutor -from dataclasses import fields from typing import Any -from PIL import Image from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_file_to_dict from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput @@ -109,68 +108,19 @@ def __init__( logger.info("AsyncOmniDiffusion initialized with model: %s", model) - def _prepare_request( - self, - prompt: str, - request_id: str | None = None, - **kwargs: Any, - ) -> OmniDiffusionRequest: - """Prepare a diffusion request from prompt and parameters. - - Args: - prompt: Text prompt for image generation - request_id: Optional unique identifier for the request - **kwargs: Additional generation parameters - - Returns: - OmniDiffusionRequest ready for processing - """ - if request_id is None: - request_id = f"diff-{uuid.uuid4().hex[:16]}" - - field_names = {f.name for f in fields(OmniDiffusionRequest)} - - init_kwargs = { - "prompt": prompt, - "request_id": request_id, - } - - for key, value in kwargs.items(): - if key in field_names: - init_kwargs[key] = value - - if "guidance_scale" in kwargs: - init_kwargs["guidance_scale_provided"] = True - - return OmniDiffusionRequest(**init_kwargs) - async def generate( self, - prompt: str, + prompt: OmniPromptType, + sampling_params: OmniDiffusionSamplingParams, request_id: str | None = None, - num_inference_steps: int = 50, - guidance_scale: float | None = None, - height: int | None = None, - width: int | None = None, - negative_prompt: str | None = None, - num_outputs_per_prompt: int = 1, - seed: int | None = None, - lora_request=None, - **kwargs: Any, + lora_request: LoRARequest | None = None, ) -> OmniRequestOutput: """Generate images asynchronously from a text prompt. Args: prompt: Text prompt describing the desired image + sampling_params: Sampling parameters request_id: Optional unique identifier for tracking the request - num_inference_steps: Number of denoising steps (default: 50) - guidance_scale: Classifier-free guidance scale (optional, uses model defaults if omitted) - height: Optional image height in pixels - width: Optional image width in pixels - negative_prompt: Optional negative prompt for guidance - num_outputs_per_prompt: Number of images to generate (default: 1) - seed: Optional random seed for reproducibility - **kwargs: Additional generation parameters Returns: OmniRequestOutput containing generated images @@ -181,64 +131,38 @@ async def generate( if request_id is None: request_id = f"diff-{uuid.uuid4().hex[:16]}" - # Prepare request - request_kwargs = { - "prompt": prompt, - "request_id": request_id, - "num_inference_steps": num_inference_steps, - "height": height, - "width": width, - "negative_prompt": negative_prompt, - "num_outputs_per_prompt": num_outputs_per_prompt, - "seed": seed, - "lora_request": lora_request, - **kwargs, - } - if guidance_scale is not None: - request_kwargs["guidance_scale"] = guidance_scale - - request = self._prepare_request(**request_kwargs) + if sampling_params.guidance_scale: + sampling_params.guidance_scale_provided = True + + if lora_request is not None: + sampling_params.lora_request = lora_request + + request = OmniDiffusionRequest( + prompts=[prompt], + sampling_params=sampling_params, + request_ids=[request_id], + ) logger.debug("Starting generation for request %s", request_id) # Run engine in thread pool loop = asyncio.get_event_loop() try: + # In async mode, only a single request is submitted at a time result = await loop.run_in_executor( self._executor, self.engine.step, - [request], + request, ) + result = result[0] except Exception as e: logger.error("Generation failed for request %s: %s", request_id, e) raise RuntimeError(f"Diffusion generation failed: {e}") from e - # Check if result is already OmniRequestOutput - if isinstance(result, OmniRequestOutput): - # Update request_id if needed - if not result.request_id: - result.request_id = request_id - return result - - # Process results if not OmniRequestOutput - images: list[Image.Image] = [] - if result is not None: - if isinstance(result, list): - for item in result: - if isinstance(item, Image.Image): - images.append(item) - elif isinstance(result, Image.Image): - images.append(result) - - return OmniRequestOutput.from_diffusion( - request_id=request_id, - images=images, - prompt=prompt, - metrics={ - "num_inference_steps": num_inference_steps, - "guidance_scale": request.guidance_scale, - }, - ) + # Update request_id if needed + if not result.request_id: + result.request_id = request_id + return result async def generate_stream( self, diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 90b344c6fe2..97357dc3b33 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -10,14 +10,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict from pprint import pformat -from typing import Any +from typing import Any, Literal, overload from omegaconf import OmegaConf from tqdm.auto import tqdm -from vllm.inputs import PromptType +from vllm import SamplingParams from vllm.logger import init_logger -from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.distributed.omni_connectors import ( get_stage_connector_config, initialize_orchestrator_connectors, @@ -42,6 +41,7 @@ load_stage_configs_from_yaml, resolve_model_config_path, ) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -82,10 +82,8 @@ class OmniBase: """Base class for serving Omni models. Args: - *args: Variable length argument list. - - args[0]: Model name or path to load. + model: Model name or path to load. **kwargs: Arbitrary keyword arguments. - - model: Model name or path to load (if not in args). - stage_configs_path: Optional path to YAML file containing stage configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging @@ -102,14 +100,9 @@ class OmniBase: - Additional keyword arguments passed to stage engines. """ - def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: - model = args[0] if args else kwargs.get("model", "") - assert model != "", "Null model id detected, please specify a model id." + def __init__(self, model: str, **kwargs: Any) -> None: model = omni_snapshot_download(model) - if args: - args[0] = model - elif kwargs.get("model", "") != "": - kwargs["model"] = model + kwargs["model"] = model # Stage management attributes self.stage_list: list[OmniStage] = [] @@ -502,10 +495,8 @@ class Omni(OmniBase): """Unified entrypoint for both LLM and Diffusion models for better usability. Args: - *args: Variable length argument list. - - args[0]: Model name or path to load. + model: Model name or path to load. **kwargs: Arbitrary keyword arguments. - - model: Model name or path to load (if not in args). - stage_configs_path: Optional path to YAML file containing stage configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging @@ -527,8 +518,8 @@ class Omni(OmniBase): >>> print(outputs) """ - def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: - super().__init__(*args, **kwargs) + def __init__(self, model: str, **kwargs: Any) -> None: + super().__init__(model, **kwargs) # Register weak reference cleanup (called on garbage collection) self._weak_finalizer = weakref.finalize( @@ -539,8 +530,31 @@ def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: self._ray_pg, ) + @overload def generate( - self, *args: Any, **kwargs: dict[str, Any] + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, + *, + py_generator: Literal[True], + ) -> Generator[OmniRequestOutput, None, None]: ... + + @overload + def generate( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, + *, + py_generator: Literal[False] = False, + ) -> list[OmniRequestOutput]: ... + + def generate( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, + *, + py_generator: bool = False, + use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None] | list[OmniRequestOutput]: """Generate outputs for the given prompts. @@ -548,12 +562,10 @@ def generate( Each stage will use OmniLLM or OmniDiffusion based on stage_type. Args: - *args: Variable length argument list. - - args[0]: Input prompts for generation. - - args[1]: Optional list of per-stage parameters. - **kwargs: Arbitrary keyword arguments. - - prompt: Input prompts for generation (if not in args). - - sampling_params_list: Optional list of per-stage parameters (if not in args). + prompts: Input prompt(s) for generation. + sampling_params_list: Optional list of per-stage parameters. + py_generator: Whether the returned result(s) are wrapped in a generator instead of a list. + use_tqdm: Whether to use tqdm progress bar Returns: List of OmniRequestOutput objects, one for each input prompt. @@ -563,40 +575,26 @@ def generate( Raises: ValueError: If sampling_params_list is None or has incorrect length. """ - prompts = args[0] if args else kwargs.get("prompts") - sampling_params_list = args[1] if len(args) > 1 else kwargs.get("sampling_params_list") - py_generator = kwargs.get("py_generator", False) - if prompts is None: - if kwargs.get("prompt") is None: - raise ValueError("prompts is required for generation") - prompts = kwargs.get("prompt") - if sampling_params_list is None: - # For Omni LLM, the params are parsed via the yaml file. For the current version, - # diffusion params can parsed via the command line. - omni_params_kwargs = { - k: v for k, v in kwargs.items() if k not in ["prompt", "request_id", "output_modalities"] - } - - per_stage_params: list[Any] = [] - for stage_id, stage in enumerate(self.stage_list): - stage_type = getattr(stage, "stage_type", "llm") - if stage_type == "diffusion": - default_dict = self.default_sampling_params_list[stage_id] - # Merge user-provided kwargs - merged = {**default_dict, **omni_params_kwargs} - # Diffusion only needs to keep diff params, will be used via OmniDiffusionRequest - per_stage_params.append(merged) + sampling_params_list = self.default_sampling_params_list + elif not isinstance(sampling_params_list, Sequence): + # TODO: After the recent introduction of BAGEL model (one LLM and one Diffusion), + # expect the text_to_image example code to run when only passing one OmniDiffusionSamplingParams + # This behavior may be confusing, and future PR can improve it. + per_stage_params: list[OmniSamplingParams] = [] + for default_stage_sp in self.default_sampling_params_list: + default_sp_type = default_stage_sp.__class__ + if default_sp_type == sampling_params_list.__class__: + per_stage_params.append(sampling_params_list) else: - # LLM directly constructs SamplingParams, don't use the merged params - per_stage_params.append(self.default_sampling_params_list[stage_id]) - + per_stage_params.append(default_stage_sp) sampling_params_list = per_stage_params + try: if py_generator: return self._run_generation_with_generator(prompts, sampling_params_list) else: - outputs = list(self._run_generation(prompts, sampling_params_list)) + outputs = list(self._run_generation(prompts, sampling_params_list, use_tqdm)) return outputs except Exception as e: logger.exception("[Orchestrator] Failed to run generation: %s", e) @@ -606,8 +604,8 @@ def generate( def _run_generation_with_generator( self, - prompts: PromptType | Sequence[PromptType] | OmniDiffusionRequest | Sequence[OmniDiffusionRequest], - sampling_params_list: Any | Sequence[Any] | None, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: Sequence[OmniSamplingParams], ) -> Generator[OmniRequestOutput, None, None]: """Run generation through all stages in the pipeline and return a generator.""" gen = self._run_generation(prompts, sampling_params_list) @@ -622,8 +620,8 @@ def _run_generation_with_generator( def _run_generation( self, - prompts: PromptType | Sequence[PromptType] | OmniDiffusionRequest | Sequence[OmniDiffusionRequest], - sampling_params_list: Any | Sequence[Any] | None = None, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: Sequence[OmniSamplingParams], use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None]: """Run generation through all stages in the pipeline.""" @@ -631,18 +629,20 @@ def _run_generation( if sampling_params_list is None: raise ValueError("sampling_params_list is required for pipelined generation") - # Normalize sampling_params_list to a list - if not isinstance(sampling_params_list, (list, tuple)): - sampling_params_list = [sampling_params_list] - else: - sampling_params_list = list(sampling_params_list) - if len(sampling_params_list) != len(self.stage_list): raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + for i, (stage, sp) in enumerate(zip(self.stage_list, sampling_params_list)): + ExpectedSPType = OmniDiffusionSamplingParams if stage.stage_type == "diffusion" else SamplingParams + if not isinstance(sp, ExpectedSPType): + raise ValueError( + f"Expected sampling parameters with type {ExpectedSPType} in stage {i}, got {sp.__class__}" + ) + # Normalize prompts to a list for per-request iteration - if not isinstance(prompts, (list, tuple)): - request_prompts: list[PromptType] = [prompts] + # str is also Sequence but only test list-like containers here + if isinstance(prompts, str) or not isinstance(prompts, Sequence): + request_prompts: list[OmniPromptType] = [prompts] else: request_prompts = list(prompts) @@ -650,8 +650,8 @@ def _run_generation( num_stages = len(self.stage_list) # Generate globally unique request IDs and map them to original prompts - request_ids: list[str] = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] - request_id_to_prompt: dict[str, PromptType] = {rid: p for rid, p in zip(request_ids, request_prompts)} + request_ids = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] + request_id_to_prompt = {rid: p for rid, p in zip(request_ids, request_prompts)} # Track per-request start time for end-to-end timing _req_start_ts: dict[str, float] = {} diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index f5b0a0a402d..5ad9a91c80d 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -3,7 +3,7 @@ import logging import uuid -from dataclasses import fields +from collections.abc import Sequence from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_file_to_dict @@ -11,6 +11,8 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType +from vllm_omni.outputs import OmniRequestOutput # TODO configure logging properly logging.basicConfig(level=logging.INFO) @@ -18,21 +20,6 @@ logger = init_logger(__name__) -def prepare_requests(prompt: str | list[str], **kwargs): - field_names = {f.name for f in fields(OmniDiffusionRequest)} - - init_kwargs = {"prompt": prompt} - - for key, value in kwargs.items(): - if key in field_names: - init_kwargs[key] = value - - if "guidance_scale" in kwargs: - init_kwargs["guidance_scale_provided"] = True - - return OmniDiffusionRequest(**init_kwargs) - - class OmniDiffusion: """ It is the main class to interact with vLLM-Omni diffusion models. @@ -100,42 +87,24 @@ def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs): def generate( self, - prompt: str | list[str], - **kwargs, - ): - prompts = [] - if isinstance(prompt, str): - prompts.append(prompt) - elif isinstance(prompt, list): - prompts.extend(prompt) + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params: OmniDiffusionSamplingParams, + request_ids: list[str] = [], + ) -> list[OmniRequestOutput]: + if isinstance(prompts, (str, dict)): + prompts = [prompts] else: - raise ValueError("Prompt must be a string or a list of strings") - - requests: list[OmniDiffusionRequest] = [] + prompts = list(prompts) # Check if request_id is provided in kwargs - request_id = kwargs.get("request_id") - request_ids = kwargs.pop("request_ids", None) - - for i, p in enumerate(prompts): - req_kwargs = kwargs.copy() - if request_ids and isinstance(request_ids, list) and i < len(request_ids): - req_kwargs["request_id"] = request_ids[i] - elif request_id is None: - # Generate default ID consistent with OmniLLM: "{i}_{uuid}" - req_kwargs["request_id"] = f"{i}_{uuid.uuid4()}" - - requests.append( - prepare_requests( - p, - **req_kwargs, - ) - ) - logger.info(f"Prepared {len(requests)} requests for generation.") - return self._run_engine(requests) + if len(request_ids) < len(prompts): + request_ids.extend(f"{i + len(request_ids)}_{uuid.uuid4()}" for i in range(len(prompts) - len(request_ids))) + + request = OmniDiffusionRequest(prompts, sampling_params, request_ids) + return self._run_engine(request) - def _run_engine(self, requests: list[OmniDiffusionRequest]): - return self.engine.step(requests) + def _run_engine(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: + return self.engine.step(request) def close(self) -> None: self.engine.close() diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index e0e699c9c24..0a5ee55beea 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -15,9 +15,11 @@ import sys import time import traceback +from collections.abc import Sequence from dataclasses import fields -from typing import Any +from typing import Any, Literal, cast +from vllm import PromptType, RequestOutput from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -31,6 +33,7 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.distributed.omni_connectors import build_stage_connectors from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector +from vllm_omni.distributed.omni_connectors.connectors.base import OmniConnectorBase from vllm_omni.distributed.ray_utils.utils import kill_ray_actor, start_ray_actor from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion @@ -46,7 +49,8 @@ maybe_dump_to_shm, set_stage_devices, ) -from vllm_omni.inputs.data import OmniTokensPrompt +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams, OmniTokensPrompt +from vllm_omni.outputs import OmniRequestOutput from vllm_omni.utils import detect_device_type logger = init_logger(__name__) @@ -64,35 +68,6 @@ def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]: return od_config -def prepare_sampling_params(sampling_params: Any, stage_type: str) -> Any: - """Prepare sampling parameters for the given stage type. - - Args: - sampling_params: Raw sampling parameters (dict or SamplingParams) - stage_type: Either "llm" or "diffusion" - - Returns: - Processed sampling parameters ready for engine consumption - """ - if stage_type == "diffusion": - # For diffusion stages: extract kwargs, handling different input types - if isinstance(sampling_params, dict): - diffusion_kwargs = dict(sampling_params) - else: - diffusion_kwargs = getattr(sampling_params, "__dict__", {}) or {} - - # Remove 'prompt' and 'request_id' to avoid conflict with explicit arguments - diffusion_kwargs.pop("prompt", None) - diffusion_kwargs.pop("request_id", None) - return diffusion_kwargs - - else: # stage_type == "llm" - # For LLM stages: ensure we have a SamplingParams object - if isinstance(sampling_params, dict): - return SamplingParams(**sampling_params) - return sampling_params - - class OmniStage: """Stage manager for orchestrating a single stage in the omni pipeline. @@ -123,7 +98,7 @@ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): self.engine_outputs = None self.is_comprehension = getattr(stage_config, "is_comprehension", False) # Support for different stage types: "llm" (default) or "diffusion" - self.stage_type = getattr(stage_config, "stage_type", "llm") + self.stage_type: Literal["llm", "diffusion"] = getattr(stage_config, "stage_type", "llm") if hasattr(stage_config, "custom_process_input_func"): # Import the module specified in the config (already a full module path) module_path, func_name = stage_config.custom_process_input_func.rsplit(".", 1) @@ -137,7 +112,14 @@ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): default_sampling_params = getattr(stage_config, "default_sampling_params", {}) # For LLM stage, this can directly be a SamplingParams-compatible dict; # For diffusion stage, this only serves as default values for diffusion kwargs. - self.default_sampling_params = _to_dict(default_sampling_params) + default_sampling_params = _to_dict(default_sampling_params) + # Further convert it to dataclass to check fields + try: + self.default_sampling_params = ( + SamplingParams if self.stage_type == "llm" else OmniDiffusionSamplingParams + )(**default_sampling_params) + except TypeError as error: + raise TypeError(f"Invalid default_sampling_params for stage {self.stage_id}: {error}") from error # Runtime orchestration state (added) self._in_q: mp.Queue | None = None self._out_q: mp.Queue | None = None @@ -519,7 +501,7 @@ def _stage_worker( runtime_cfg = stage_payload.get("runtime", {}) shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) connectors_config = stage_payload.get("connectors_config", {}) - stage_type = stage_payload.get("stage_type", "llm") + stage_type: Literal["llm", "diffusion"] = stage_payload.get("stage_type", "llm") # Aggregates for running average _agg_total_tokens = 0 @@ -708,15 +690,14 @@ def _stage_worker( pass logger.debug("Engine initialized") # Initialize OmniConnectors if configured - connectors = {} + connectors: dict[tuple[str, str], OmniConnectorBase] | None = {} if connectors_config: - built_connectors = build_stage_connectors( + connectors = build_stage_connectors( stage_id=stage_id, connectors_config=connectors_config, ) - if built_connectors is None: + if connectors is None: return - connectors = built_connectors # Signal readiness to orchestrator try: @@ -785,6 +766,7 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: continue batch_tasks: list[dict[str, Any]] = [task] + tasks_failed_to_add_to_batch: list[dict[str, Any]] = [] start_time = _time.time() if max_batch_size > 1: while len(batch_tasks) < max_batch_size: @@ -800,7 +782,20 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: if extra_type == OmniStageTaskType.PROFILER_STOP: out_q.put({"type": "profiler_result", "data": p_data}) continue - batch_tasks.append(extra) + # Ensure that all tasks have the same sampling params + # If no, put them in a temporary container and add back to queue + # This should be always true, because user only calls omni.generate() once and it blocks + # User can only pass one sampling param object, but the list of prompts are separated. + if task.get("sampling_params") != extra.get("sampling_params"): + logger.warning( + """In offline mode, expect all prompts in one `omni.generate()` call to share same sampling params""" # noqa: E501 # line too long + f"""However, prompt {task.get("engine_inputs")} has sampling params {task.get("sampling_params")}, """ # noqa: E501 # line too long + f"""whereas the prompt {extra.get("engine_inputs")} has sampling params {extra.get("sampling_params")}.""" # noqa: E501 # line too long + """The two tasks cannot be combined in one batch request.""" + ) + tasks_failed_to_add_to_batch.append(extra) + else: + batch_tasks.append(extra) end_time = _time.time() duration = end_time - start_time if duration > batch_timeout: @@ -815,9 +810,13 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: break else: continue + for task_to_readd in tasks_failed_to_add_to_batch: + in_q.put(task_to_readd) + # Ensure that the popped tasks are with identical sampling params. Take one of them. + batch_engine_sampling_params: OmniSamplingParams = batch_tasks[0]["sampling_params"] batch_request_ids: list[Any] = [] - batch_engine_inputs: list[Any] = [] + batch_engine_inputs: list[OmniPromptType] = [] _rx_bytes_by_rid: dict[Any, int] = {} _rx_decode_ms_by_rid: dict[Any, float] = {} _in_flight_ms_by_rid: dict[Any, float] = {} @@ -839,6 +838,9 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: connectors=connectors, stage_id=stage_id, ) + # TODO: hack type annotation for now. + # A better way is to refine type annotation of connection and task/payloads, maybe using template types. + ein = cast(OmniPromptType | Sequence[OmniPromptType] | None, ein) if ein is None or _rx_metrics is None: raise RuntimeError( @@ -850,17 +852,14 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) batch_request_ids.append(rid) - if isinstance(ein, list): - batch_engine_inputs.extend(ein) - elif isinstance(ein, dict): - batch_engine_inputs.append(ein) - elif isinstance(ein, str): - # For diffusion stage-0, ein might be a string prompt directly + if isinstance(ein, (str, dict)): + # Types like OmniTextPrompt, TextPrompt are TypedDict, essentially dict and enters this branch batch_engine_inputs.append(ein) + elif isinstance(ein, Sequence): + batch_engine_inputs.extend(ein) else: - # For other types (e.g., OmniTokensPrompt, TextPrompt), append as-is + # Other unknown types, append as-is batch_engine_inputs.append(ein) - sampling_params = batch_tasks[0]["sampling_params"] logger.debug( "Received batch size=%d, request_ids=%s", len(batch_tasks), @@ -868,65 +867,30 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: ) try: _batch_seq += 1 - gen_outputs: list[Any] = [] + gen_outputs: list[OmniRequestOutput | RequestOutput] = [] _gen_t0 = _time.time() if stage_type == "diffusion": - # For diffusion, batch_engine_inputs should be prompts (strings) - # Convert to list of strings if needed - prompts = [] - for ein in batch_engine_inputs: - if isinstance(ein, str): - prompts.append(ein) - elif isinstance(ein, dict) and "prompt" in ein: - prompts.append(ein["prompt"]) - elif hasattr(ein, "prompt"): - prompts.append(ein.prompt) - else: - prompts.append(str(ein)) - # Prepare diffusion kwargs from sampling parameters - diffusion_kwargs = prepare_sampling_params(sampling_params, "diffusion") - - # Pass batch_request_ids to ensure correct ID mapping - diffusion_kwargs["request_ids"] = batch_request_ids - + stage_engine = cast(OmniDiffusion, stage_engine) + batch_engine_sampling_params = cast(OmniDiffusionSamplingParams, batch_engine_sampling_params) # Diffusion generate returns results directly, not an iterator - diffusion_results = stage_engine.generate(prompts, **diffusion_kwargs) - # Convert to list format compatible with LLM outputs - # Ensure each result has a request_id for proper mapping - if isinstance(diffusion_results, list): - gen_outputs = diffusion_results - # Assign request_ids if not present - for idx, result in enumerate(gen_outputs): - if not hasattr(result, "request_id") or result.request_id is None: - if idx < len(batch_request_ids): - if hasattr(result, "request_id"): - result.request_id = batch_request_ids[idx] - else: - # Create a wrapper object if result doesn't support request_id - from types import SimpleNamespace - - wrapped = SimpleNamespace() - wrapped.request_id = batch_request_ids[idx] - wrapped.output = result - gen_outputs[idx] = wrapped - else: - gen_outputs = [diffusion_results] - # Assign request_id to single result - if len(batch_request_ids) > 0: - if hasattr(gen_outputs[0], "request_id"): - gen_outputs[0].request_id = batch_request_ids[0] - else: - from types import SimpleNamespace - - wrapped = SimpleNamespace() - wrapped.request_id = batch_request_ids[0] - wrapped.output = gen_outputs[0] - gen_outputs[0] = wrapped + diffusion_results = stage_engine.generate( + batch_engine_inputs, batch_engine_sampling_params, batch_request_ids + ) + gen_outputs.extend(diffusion_results) + # Assign request_ids if not present + for idx, result in enumerate(gen_outputs): + if not hasattr(result, "request_id") or result.request_id is None: + if idx < len(batch_request_ids): + result.request_id = batch_request_ids[idx] else: - # LLM engine: use vLLM native SamplingParams - llm_sampling_params = prepare_sampling_params(sampling_params, "llm") - for ro in stage_engine.generate(batch_engine_inputs, llm_sampling_params, use_tqdm=False): - gen_outputs.append(ro) + stage_engine = cast(OmniLLM, stage_engine) + batch_engine_sampling_params = cast(SamplingParams, batch_engine_sampling_params) + results = stage_engine.generate( + batch_engine_inputs, # type: ignore # silent complaints about list of subclassed TypedDict + batch_engine_sampling_params, + use_tqdm=False, + ) + gen_outputs.extend(results) _gen_t1 = _time.time() _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 logger.debug(f"Generate done: batch={len(batch_tasks)}, req_ids={batch_request_ids}, gen_ms={_gen_ms:.1f}") @@ -935,7 +899,7 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: req_to_outputs: dict[Any, list[Any]] = {rid: [] for rid in batch_request_ids} unmapped: list[Any] = [] for ro in gen_outputs: - rid = getattr(ro, "request_id", None) + rid = ro.request_id if rid in req_to_outputs: req_to_outputs[rid].append(ro) else: @@ -1354,6 +1318,10 @@ async def generation_single_request(task: dict[str, Any]): connectors=connectors, stage_id=stage_id, ) + # TODO: hack type annotation for now. + # A better way is to refine type annotation of connection and task/payloads, maybe using template types. + ein = cast(OmniPromptType | Sequence[OmniPromptType] | None, ein) + if ein is None or _rx_metrics is None: raise RuntimeError( f"[Stage-{stage_id}] Missing connector payload for request {rid}. " @@ -1362,36 +1330,23 @@ async def generation_single_request(task: dict[str, Any]): _rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0)) _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) - sampling_params = task["sampling_params"] logger.debug("Received batch size=1, request_ids=%s", rid) _gen_t0 = _time.time() - if isinstance(ein, list): + if isinstance(ein, Sequence) and not isinstance(ein, str): ein = ein[0] if stage_type == "diffusion": - # For diffusion, ein should be prompts (strings) - # Convert to string if needed - if isinstance(ein, str): - prompt = ein - elif isinstance(ein, dict) and "prompt" in ein: - prompt = ein["prompt"] - elif hasattr(ein, "prompt"): - prompt = ein.prompt - else: - prompt = str(ein) - - # Prepare diffusion kwargs from sampling parameters - diffusion_kwargs = prepare_sampling_params(sampling_params, "diffusion") + diffusion_sampling_params = cast(OmniDiffusionSamplingParams, task["sampling_params"]) # AsyncOmniDiffusion.generate returns a single result, not an async generator - gen_output = await stage_engine.generate(prompt=prompt, request_id=rid, **diffusion_kwargs) + gen_output = await cast(AsyncOmniDiffusion, stage_engine).generate(ein, diffusion_sampling_params, rid) _gen_t1 = _time.time() _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 await generation_out_q.put((rid, gen_output, _gen_ms)) else: - # LLM stages: ensure using SamplingParams - llm_sampling_params = prepare_sampling_params(sampling_params, "llm") + ein = cast(PromptType, ein) + llm_sampling_params: SamplingParams = task["sampling_params"] gen_output = None - async for res in stage_engine.generate(ein, llm_sampling_params, rid): + async for res in cast(AsyncLLM, stage_engine).generate(ein, llm_sampling_params, rid): gen_output = res _gen_t1 = _time.time() _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 3c44da8b32b..e2cc56ddc75 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -10,13 +10,14 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Any +from typing import Any, cast import vllm.envs as envs from fastapi import Depends, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse from starlette.datastructures import State from starlette.routing import Route +from vllm import SamplingParams from vllm.engine.protocol import EngineClient from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages from vllm.entrypoints.launcher import serve_http @@ -80,6 +81,7 @@ ) from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt from vllm_omni.lora.request import LoRARequest from vllm_omni.lora.utils import stable_lora_int_id @@ -793,7 +795,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) HTTPException: For validation errors, missing engine, or generation failures """ # Get engine client (AsyncOmni) from app state - engine_client: EngineClient | None = getattr(raw_request.app.state, "engine_client", None) + engine_client: EngineClient | AsyncOmni | None = getattr(raw_request.app.state, "engine_client", None) if engine_client is None or not hasattr(engine_client, "stage_list"): raise HTTPException( status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, @@ -853,10 +855,8 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) try: # Build params - pass through user values directly - gen_params = { - "prompt": request.prompt, - "num_outputs_per_prompt": request.n, - } + prompt: OmniTextPrompt = {"prompt": request.prompt} + gen_params = OmniDiffusionSamplingParams(num_outputs_per_prompt=request.n) # Parse per-request LoRA (compatible with chat's extra_body.lora shape). if request.lora is not None: @@ -888,62 +888,72 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) detail="Invalid lora object: both name and path are required.", ) - gen_params["lora_request"] = LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)) + gen_params.lora_request = LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)) if lora_scale is not None: - gen_params["lora_scale"] = float(lora_scale) + gen_params.lora_scale = float(lora_scale) # Parse and add size if provided if request.size: width, height = parse_size(request.size) - gen_params["height"] = height - gen_params["width"] = width + gen_params.height = height + gen_params.width = width size_str = f"{width}x{height}" else: size_str = "model default" # Add optional parameters ONLY if provided if request.num_inference_steps is not None: - gen_params["num_inference_steps"] = request.num_inference_steps + gen_params.num_inference_steps = request.num_inference_steps if request.negative_prompt is not None: - gen_params["negative_prompt"] = request.negative_prompt + prompt["negative_prompt"] = request.negative_prompt if request.guidance_scale is not None: - gen_params["guidance_scale"] = request.guidance_scale + gen_params.guidance_scale = request.guidance_scale if request.true_cfg_scale is not None: - gen_params["true_cfg_scale"] = request.true_cfg_scale + gen_params.true_cfg_scale = request.true_cfg_scale if request.seed is not None: - gen_params["seed"] = request.seed - gen_params["request_id"] = f"img_gen_{int(time.time())}" + gen_params.seed = request.seed + request_id = f"img_gen_{int(time.time())}" logger.info(f"Generating {request.n} image(s) {size_str}") # Generate images using AsyncOmni (multi-stage mode) + engine_client = cast(AsyncOmni, engine_client) result = None stage_list = getattr(engine_client, "stage_list", None) if isinstance(stage_list, list): - default_params_list = getattr(engine_client, "default_sampling_params_list", None) + default_params_list: list[OmniSamplingParams] | None = getattr( + engine_client, "default_sampling_params_list", None + ) if not isinstance(default_params_list, list): - default_params_list = [{} for _ in stage_types] + default_params_list = [ + OmniDiffusionSamplingParams() if st == "diffusion" else SamplingParams() for st in stage_types + ] else: default_params_list = list(default_params_list) if len(default_params_list) != len(stage_types): - default_params_list = (default_params_list + [{} for _ in stage_types])[: len(stage_types)] + default_params_list = ( + default_params_list + + [OmniDiffusionSamplingParams() if st == "diffusion" else SamplingParams() for st in stage_types] + )[: len(stage_types)] - sampling_params_list: list[dict[str, Any]] = [] + sampling_params_list: list[OmniSamplingParams] = [] for idx, stage_type in enumerate(stage_types): if stage_type == "diffusion": sampling_params_list.append(gen_params) else: base_params = default_params_list[idx] - sampling_params_list.append(dict(base_params) if isinstance(base_params, dict) else base_params) + sampling_params_list.append(base_params) async for output in engine_client.generate( - prompt=gen_params["prompt"], - request_id=gen_params["request_id"], + prompt=prompt, + request_id=request_id, sampling_params_list=sampling_params_list, ): result = output else: - result = await engine_client.generate(**gen_params) + result = await engine_client.generate( + prompt=prompt, request_id=request_id, sampling_params_list=[gen_params] + ) if result is None: raise HTTPException( diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index f35008260ff..db8c972a1fd 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -6,13 +6,16 @@ from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence from datetime import datetime, timedelta, timezone from io import BytesIO -from typing import TYPE_CHECKING, Any, Final, Optional +from typing import TYPE_CHECKING, Any, Final, Optional, cast import jinja2 from fastapi import Request from PIL import Image from pydantic import TypeAdapter +from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt + try: import soundfile except ImportError: @@ -274,16 +277,11 @@ async def create_chat_completion( lora_request=lora_request, ) - trace_headers = None if raw_request is None else await self._get_trace_headers(raw_request.headers) - generator = self.engine_client.generate( prompt=engine_prompt, request_id=request_id, sampling_params_list=sampling_params_list, output_modalities=output_modalities, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, ) generators.append(generator) @@ -1904,29 +1902,30 @@ async def _create_diffusion_chat_completion( logger.warning("Failed to decode reference image: %s", e) # Build generation kwargs - gen_kwargs: dict[str, Any] = { + gen_prompt: OmniTextPrompt = { "prompt": prompt, - "request_id": request_id, - "num_inference_steps": num_inference_steps, - "height": height, - "width": width, "negative_prompt": negative_prompt, - "num_outputs_per_prompt": num_outputs_per_prompt, - "seed": seed, } + gen_params = OmniDiffusionSamplingParams( + num_inference_steps=num_inference_steps, + height=height, + width=width, + num_outputs_per_prompt=num_outputs_per_prompt, + seed=seed, + ) if guidance_scale is not None: - gen_kwargs["guidance_scale"] = guidance_scale + gen_params.guidance_scale = guidance_scale # Add Qwen-Image specific parameter if true_cfg_scale is not None: - gen_kwargs["true_cfg_scale"] = true_cfg_scale + gen_params.true_cfg_scale = true_cfg_scale # Add video generation parameters if set if num_frames is not None: - gen_kwargs["num_frames"] = num_frames + gen_params.num_frames = num_frames if guidance_scale_2 is not None: - gen_kwargs["guidance_scale_2"] = guidance_scale_2 + gen_params.guidance_scale_2 = guidance_scale_2 # Parse per-request LoRA (works for both AsyncOmniDiffusion and AsyncOmni). if lora_body and isinstance(lora_body, dict): @@ -1949,16 +1948,17 @@ async def _create_diffusion_chat_completion( lora_int_id = stable_lora_int_id(str(lora_path)) if lora_name and lora_path: lora_req = LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)) - gen_kwargs["lora_request"] = lora_req + gen_params.lora_request = lora_req if lora_scale is not None: - gen_kwargs["lora_scale"] = float(lora_scale) + gen_params.lora_scale = float(lora_scale) except Exception as e: # pragma: no cover - safeguard logger.warning("Failed to parse LoRA request: %s", e) # Add reference image if provided if pil_images: if len(pil_images) == 1: - gen_kwargs["pil_image"] = pil_images[0] + gen_prompt["multi_modal_data"] = {} + gen_prompt["multi_modal_data"]["image"] = pil_images[0] else: od_config = getattr(self._diffusion_engine, "od_config", None) supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) @@ -1966,7 +1966,8 @@ async def _create_diffusion_chat_completion( # TODO: entry is asyncOmni. We hack the od config here. supports_multimodal_inputs = True if supports_multimodal_inputs: - gen_kwargs["pil_image"] = pil_images + gen_prompt["multi_modal_data"] = {} + gen_prompt["multi_modal_data"]["image"] = pil_images else: return self._create_error_response( "Multiple input images are not supported by the current diffusion model. " @@ -1979,18 +1980,24 @@ async def _create_diffusion_chat_completion( # Handle both AsyncOmniDiffusion (returns OmniRequestOutput) and AsyncOmni (returns AsyncGenerator) if hasattr(self._diffusion_engine, "stage_list"): # AsyncOmni: iterate through async generator to get final output + diffusion_engine = cast(AsyncOmni, self._diffusion_engine) result = None - async for output in self._diffusion_engine.generate( - prompt=gen_kwargs["prompt"], - request_id=gen_kwargs.get("request_id"), - sampling_params_list=[gen_kwargs], # Pass as single-stage params + async for output in diffusion_engine.generate( + prompt=gen_prompt, + sampling_params_list=[gen_params], # Pass as single-stage params + request_id=request_id, ): result = output if result is None: return self._create_error_response("No output generated from AsyncOmni") else: # AsyncOmniDiffusion: direct call - result = await self._diffusion_engine.generate(**gen_kwargs) + diffusion_engine = cast(AsyncOmniDiffusion, self._diffusion_engine) + result = await diffusion_engine.generate( + prompt=gen_prompt, + sampling_params=gen_params, + request_id=request_id, + ) # Extract images from result # Handle nested OmniRequestOutput structure where images might be in request_output images = getattr(result.request_output, "images", []) diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index ec291e43e29..01de61b96db 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -1,4 +1,11 @@ -from typing import Any +import copy +import pprint +from dataclasses import asdict, dataclass, field +from typing import Any, TypeAlias + +from vllm import PromptType, SamplingParams + +from vllm_omni.lora.request import LoRARequest try: from typing import NotRequired @@ -7,7 +14,25 @@ from typing_extensions import NotRequired import torch -from vllm.inputs.data import EmbedsPrompt, TokenInputs, TokensPrompt +from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokenInputs, TokensPrompt + + +class OmniTextPrompt(TextPrompt): + """Text prompt with optional embeddings and additional information. + + Extends TextPrompt to support prompt embeddings and additional + information payloads for direct transfer between pipeline stages. + + Attributes: + prompt_embeds: Optional tensor containing prompt embeddings + additional_information: Optional dictionary containing additional + information (tensors or lists) to pass along with the prompt + """ + + negative_prompt: NotRequired[str] + prompt_embeds: NotRequired[torch.Tensor] + negative_prompt_embeds: NotRequired[torch.Tensor] + additional_information: NotRequired[dict[str, Any]] class OmniTokensPrompt(TokensPrompt): @@ -22,7 +47,9 @@ class OmniTokensPrompt(TokensPrompt): information (tensors or lists) to pass along with the prompt """ + negative_prompt: NotRequired[str] prompt_embeds: NotRequired[torch.Tensor] + negative_prompt_embeds: NotRequired[list[torch.Tensor] | None] """The embeddings of the prompt.""" # New: optional additional information dictionary @@ -44,7 +71,9 @@ class OmniTokenInputs(TokenInputs): """ # New: optional prompt embeddings aligned with token ids + negative_prompt: NotRequired[str] prompt_embeds: NotRequired[torch.Tensor] + negative_prompt_embeds: NotRequired[list[torch.Tensor] | None] # New: optional additional information dictionary # Values may be torch.Tensor or list @@ -65,12 +94,19 @@ class OmniEmbedsPrompt(EmbedsPrompt): # New: optional prompt embeddings aligned with token ids prompt_embeds: NotRequired[torch.Tensor] + negative_prompt_embeds: NotRequired[list[torch.Tensor] | None] # New: optional additional information dictionary # Values may be torch.Tensor or list additional_information: NotRequired[dict[str, Any]] +# Must ensure that all additional prompt types are inherited from vLLM prompt types +# Because TypedDict doesn't support isinstance and are dict. Cannot distinguish them in runtime. +# Inheritance ensure that there are only additional fields but not removing fields--safe to route to LLM.generate() +OmniPromptType: TypeAlias = PromptType | OmniTextPrompt | OmniTokensPrompt | OmniEmbedsPrompt + + def token_inputs_omni( prompt_token_ids: list[int], prompt: str | None = None, @@ -106,3 +142,140 @@ def token_inputs_omni( inputs["additional_information"] = additional_information return inputs + + +@dataclass +class OmniDiffusionSamplingParams: + """ + The collection of sampling parameters passed to diffusion pipelines. + + This dataclass contains all information needed during the diffusion pipeline + execution, allowing methods to update specific components without needing + to manage numerous individual parameters. + """ + + # Additional text-related parameters + max_sequence_length: int | None = None + prompt_template: dict[str, Any] | None = None + do_classifier_free_guidance: bool = False + + # Batch info + num_outputs_per_prompt: int = 1 + seed: int | None = None + generator: torch.Generator | list[torch.Generator] | None = None + + # layered info + layers: int = 4 + + # cfg info + cfg_normalize: bool = False + + # caption language + use_en_prompt: bool = False + + # different bucket in (640, 1024) to determine the condition and output resolution + resolution: int = 640 + + # Tracking if embeddings are already processed + is_prompt_processed: bool = False + + # Latent tensors + latents: torch.Tensor | None = None + raw_latent_shape: torch.Tensor | None = None + noise_pred: torch.Tensor | None = None + image_latent: torch.Tensor | None = None + + # Latent dimensions + height_latents: list[int] | int | None = None + width_latents: list[int] | int | None = None + num_frames: int = 1 # Default for image models + num_frames_round_down: bool = False # Whether to round down num_frames if it's not divisible by num_gpus + + # Original dimensions (before VAE scaling) + height: int | None = None + width: int | None = None + fps: int | None = None + height_not_provided: bool = False + width_not_provided: bool = False + + # Timesteps + timesteps: torch.Tensor | None = None + timestep: torch.Tensor | float | int | None = None + step_index: int | None = None + boundary_ratio: float | None = None + + # Scheduler parameters + num_inference_steps: int = 50 + guidance_scale: float = 0.0 + guidance_scale_provided: bool = False + guidance_scale_2: float | None = None + guidance_rescale: float = 0.0 + eta: float = 0.0 + sigmas: list[float] | None = None + + true_cfg_scale: float | None = None # qwen-image specific now + + n_tokens: int | None = None + extra_step_kwargs: dict[str, Any] = field(default_factory=dict) + + # [Omni] KV Cache Transfer, for bagel model now + past_key_values: Any | None = None # Injected KV Cache + kv_metadata: dict[str, Any] | None = None # Metadata for KV Cache (e.g., kv_lens, ropes) + need_kv_receive: bool = True # Flag to indicate if this request expects KV transfer + + # Component modules + modules: dict[str, Any] = field(default_factory=dict) + + return_trajectory_latents: bool = False + return_trajectory_decoded: bool = False + trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_latents: torch.Tensor | None = None + + # Extra parameters that might be needed by specific pipeline implementations + extra_args: dict[str, Any] = field(default_factory=dict) + + # Misc + save_output: bool = True + return_frames: bool = False + + # LoRA + lora_request: LoRARequest | None = None + lora_scale: float = 1.0 + + # STA parameters + STA_param: list | None = None + is_cfg_negative: bool = False + mask_search_final_result_pos: list[list] | None = None + mask_search_final_result_neg: list[list] | None = None + + # VSA parameters + VSA_sparsity: float = 0.0 + # perf_logger: PerformanceLogger | None = None + + # stage logging + # logging_info: PipelineLoggingInfo = field(default_factory=PipelineLoggingInfo) + + # profile + profile: bool = False + num_profiled_timesteps: int = 8 + + # debugging + debug: bool = False + + # results + output: torch.Tensor | None = None + + @property + def batch_size(self): + # This class is changed to only represent a single prompt request + # Only adjust batch size for number of videos per prompt + return self.num_outputs_per_prompt + + def __str__(self): + return pprint.pformat(asdict(self), indent=2, width=120) + + def clone(self) -> "OmniDiffusionSamplingParams": + return copy.deepcopy(self) + + +OmniSamplingParams: TypeAlias = SamplingParams | OmniDiffusionSamplingParams diff --git a/vllm_omni/model_executor/stage_configs/bagel.yaml b/vllm_omni/model_executor/stage_configs/bagel.yaml index cf5dff50d53..96bdc01b6ae 100644 --- a/vllm_omni/model_executor/stage_configs/bagel.yaml +++ b/vllm_omni/model_executor/stage_configs/bagel.yaml @@ -58,13 +58,7 @@ stage_args: final_output_type: image is_comprehension: false default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 seed: 52 - detokenize: True - repetition_penalty: 1.0 # Runtime edges runtime: diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 24b16828705..f47077677fa 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -6,6 +6,8 @@ from vllm.outputs import RequestOutput from vllm.v1.outputs import ModelRunnerOutput +from vllm_omni.inputs.data import OmniPromptType + class OmniModelRunnerOutput(ModelRunnerOutput): """Model runner output for omni models. @@ -54,7 +56,7 @@ class OmniRequestOutput: # Diffusion model fields images: list[Image.Image] = field(default_factory=list) - prompt: str | None = None + prompt: OmniPromptType | None = None latents: torch.Tensor | None = None metrics: dict[str, Any] = field(default_factory=dict) multimodal_output: dict[str, Any] = field(default_factory=dict) @@ -89,7 +91,7 @@ def from_diffusion( cls, request_id: str, images: list[Image.Image], - prompt: str | None = None, + prompt: OmniPromptType | None = None, metrics: dict[str, Any] | None = None, latents: torch.Tensor | None = None, multimodal_output: dict[str, Any] | None = None,