From 1486f39ef1be63f6aeb5c1cc3d1d15dfdb79a3ed Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 09:07:34 +0000 Subject: [PATCH 01/34] [Fix] enforce Qwen-Image max_sequence_length before encoding Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit adda9a6aefafe6260140b4b28e6db2f2920366b4) --- .../test_qwen_image_max_sequence_length.py | 208 ++++++++++++++++++ .../models/qwen_image/pipeline_qwen_image.py | 35 ++- .../qwen_image/pipeline_qwen_image_edit.py | 40 +++- .../pipeline_qwen_image_edit_plus.py | 39 +++- .../qwen_image/pipeline_qwen_image_layered.py | 30 ++- .../models/qwen_image/prompt_utils.py | 24 ++ 6 files changed, 356 insertions(+), 20 deletions(-) create mode 100644 tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py create mode 100644 vllm_omni/diffusion/models/qwen_image/prompt_utils.py diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py new file mode 100644 index 00000000000..f5491388e7a --- /dev/null +++ b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py @@ -0,0 +1,208 @@ +import inspect +from types import SimpleNamespace + +import pytest +import torch +from torch import nn + +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import ( + QwenImagePipeline, +) +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import ( + QwenImageEditPipeline, +) +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import ( + QwenImageEditPlusPipeline, +) +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_layered import ( + QwenImageLayeredPipeline, +) + + +class _RejectingTextEncoder: + dtype = torch.float32 + + def __call__(self, *args, **kwargs): + raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length") + + +class _FakeModelInputs: + def __init__(self, total_sequence_length: int): + attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long) + self.input_ids = attention_mask.clone() + self.attention_mask = attention_mask + self.pixel_values = None + self.image_grid_thw = None + + def to(self, device): + return self + + +class _FakeTokenizer: + def __init__(self, total_sequence_length: int): + self.total_sequence_length = total_sequence_length + + def __call__(self, *args, **kwargs): + return _FakeModelInputs(self.total_sequence_length) + + +class _FakeProcessor(_FakeTokenizer): + pass + + +class _FakeScheduler: + def __init__(self): + self.begin_index = None + + def set_begin_index(self, begin_index: int): + self.begin_index = begin_index + + +PIPELINE_CASES = [ + pytest.param(QwenImagePipeline, 34, "tokenizer", id="qwen-image"), + pytest.param(QwenImageLayeredPipeline, 34, "tokenizer", id="qwen-image-layered"), + pytest.param(QwenImageEditPipeline, 64, "processor", id="qwen-image-edit"), + pytest.param(QwenImageEditPlusPipeline, 64, "processor", id="qwen-image-edit-plus"), +] + + +def _make_pipeline( + pipeline_class: type, + *, + total_sequence_length: int, + drop_idx: int, + input_kind: str, +): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer_max_length = 1024 + pipeline.prompt_template_encode = "{}" + pipeline.prompt_template_encode_start_idx = drop_idx + pipeline.tokenizer = _FakeTokenizer(total_sequence_length) + if input_kind == "processor": + pipeline.processor = _FakeProcessor(total_sequence_length) + return pipeline + + +@pytest.mark.parametrize(("pipeline_class", "drop_idx", "input_kind"), PIPELINE_CASES) +def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length( + pipeline_class: type, + drop_idx: int, + input_kind: str, +): + pipeline = _make_pipeline( + pipeline_class, + total_sequence_length=drop_idx + 1025, + drop_idx=drop_idx, + input_kind=input_kind, + ) + + with pytest.raises(ValueError, match=r"got 1025 tokens, but `max_sequence_length` is 1024"): + pipeline.encode_prompt(prompt="prompt") + + +@pytest.mark.parametrize(("pipeline_class", "drop_idx", "input_kind"), PIPELINE_CASES) +def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length( + pipeline_class: type, + drop_idx: int, + input_kind: str, +): + pipeline = _make_pipeline( + pipeline_class, + total_sequence_length=drop_idx + 17, + drop_idx=drop_idx, + input_kind=input_kind, + ) + + with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): + pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) + + +def test_prepare_encode_defaults_to_tokenizer_max_length(): + pipeline = object.__new__(QwenImagePipeline) + nn.Module.__init__(pipeline) + pipeline.tokenizer_max_length = 1024 + pipeline.vae_scale_factor = 8 + pipeline.default_sample_size = 128 + pipeline.scheduler = _FakeScheduler() + pipeline._extract_prompts = lambda prompts: (["prompt"], None) + + captured = {} + + def _fake_prepare_generation_context(**kwargs): + captured["max_sequence_length"] = kwargs["max_sequence_length"] + embeds = torch.ones((1, 1, 1)) + mask = torch.ones((1, 1), dtype=torch.long) + return { + "prompt_embeds": embeds, + "prompt_embeds_mask": mask, + "negative_prompt_embeds": None, + "negative_prompt_embeds_mask": None, + "latents": embeds, + "timesteps": torch.tensor([1]), + "do_true_cfg": False, + "guidance": None, + "img_shapes": [[(1, 1, 1)]], + "txt_seq_lens": [1], + "negative_txt_seq_lens": None, + } + + pipeline._prepare_generation_context = _fake_prepare_generation_context + state = SimpleNamespace( + prompts=["prompt"], + sampling=SimpleNamespace( + height=None, + width=None, + num_inference_steps=None, + sigmas=None, + guidance_scale_provided=False, + num_outputs_per_prompt=0, + generator=None, + true_cfg_scale=None, + max_sequence_length=None, + ), + ) + + pipeline.prepare_encode(state) + + assert captured["max_sequence_length"] == 1024 + + +@pytest.mark.parametrize( + ("pipeline_class", "drop_idx"), + [ + pytest.param(QwenImageEditPipeline, 64, id="qwen-image-edit"), + pytest.param(QwenImageEditPlusPipeline, 64, id="qwen-image-edit-plus"), + ], +) +def test_edit_pipelines_validate_text_prompt_length_before_image_token_expansion( + pipeline_class: type, + drop_idx: int, +): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer_max_length = 1024 + pipeline.prompt_template_encode = "{}" + pipeline.prompt_template_encode_start_idx = drop_idx + pipeline.tokenizer = _FakeTokenizer(drop_idx + 8) + pipeline.processor = _FakeProcessor(drop_idx + 1500) + + with pytest.raises(AssertionError, match="text encoder should not run"): + pipeline.encode_prompt(prompt="short prompt") + + +@pytest.mark.parametrize( + "pipeline_class", + [ + QwenImagePipeline, + QwenImageLayeredPipeline, + QwenImageEditPipeline, + QwenImageEditPlusPipeline, + ], +) +def test_forward_max_sequence_length_default_is_1024(pipeline_class: type): + assert inspect.signature(pipeline_class.forward).parameters["max_sequence_length"].default == 1024 diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 8a8697ebbc4..f59c31528b2 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -29,6 +29,9 @@ from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, ) +from vllm_omni.diffusion.models.qwen_image.prompt_utils import ( + validate_qwen_prompt_sequence_lengths, +) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) @@ -363,8 +366,11 @@ def check_inputs( "that was used to generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " + f"{max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -378,6 +384,8 @@ def _get_qwen_prompt_embeds( self, prompt: str | list[str] = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): dtype = dtype or self.text_encoder.dtype @@ -388,12 +396,17 @@ def _get_qwen_prompt_embeds( txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( txt, - max_length=self.tokenizer_max_length + drop_idx, padding=True, - truncation=True, + truncation=False, return_tensors="pt", ).to(self.device) - # print(f"attention mask: {txt_tokens.attention_mask}") + validate_qwen_prompt_sequence_lengths( + txt_tokens.attention_mask, + drop_idx=drop_idx, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, @@ -422,6 +435,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -439,7 +453,11 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -632,6 +650,7 @@ def _prepare_generation_context( prompt_embeds_mask=negative_prompt_embeds_mask, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) else: negative_prompt_embeds = None @@ -703,7 +722,7 @@ def prepare_encode( num_images_per_prompt=sampling.num_outputs_per_prompt if sampling.num_outputs_per_prompt > 0 else 1, generator=sampling.generator, true_cfg_scale=sampling.true_cfg_scale or 4.0, - max_sequence_length=sampling.max_sequence_length or 512, + max_sequence_length=sampling.max_sequence_length or self.tokenizer_max_length, attention_kwargs=kwargs.get("attention_kwargs"), ) @@ -934,7 +953,7 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 1024, ) -> DiffusionOutput: extracted_prompt, negative_prompt = self._extract_prompts(req.prompts) prompt = extracted_prompt or prompt diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 59c63526663..056ef4ecdfa 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -32,6 +32,9 @@ QwenImageCFGParallelMixin, ) from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift +from vllm_omni.diffusion.models.qwen_image.prompt_utils import ( + validate_qwen_prompt_sequence_lengths, +) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) @@ -323,8 +326,11 @@ def check_inputs( "that was used to generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " + f"{max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -384,6 +390,8 @@ def _get_qwen_prompt_embeds( prompt: str | list[str] = None, image: PIL.Image.Image | torch.Tensor | None = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): """Get prompt embeddings with image support for editing.""" dtype = dtype or self.text_encoder.dtype @@ -393,6 +401,23 @@ def _get_qwen_prompt_embeds( template = self.prompt_template_encode drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # Qwen-Image-Edit expands image placeholders into many vision tokens + # inside the processor. `max_sequence_length` is meant to constrain the + # prompt text length, so validate on the text template before image + # token expansion. + validate_qwen_prompt_sequence_lengths( + txt_tokens.attention_mask, + drop_idx=drop_idx, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + ) # Use processor to handle both text and image inputs model_inputs = self.processor( @@ -434,6 +459,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -453,7 +479,12 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + image, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -624,7 +655,7 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 1024, ) -> DiffusionOutput: """Forward pass for image editing.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -739,6 +770,7 @@ def forward( prompt_embeds_mask=negative_prompt_embeds_mask, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) num_channels_latents = self.transformer.in_channels // 4 diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 59b9975436d..0523c51db75 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -35,6 +35,9 @@ retrieve_latents, retrieve_timesteps, ) +from vllm_omni.diffusion.models.qwen_image.prompt_utils import ( + validate_qwen_prompt_sequence_lengths, +) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) @@ -283,8 +286,11 @@ def check_inputs( "that was used to generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " + f"{max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -299,6 +305,8 @@ def _get_qwen_prompt_embeds( prompt: str | list[str], image: list[torch.Tensor] | torch.Tensor | None = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): """Get prompt embeddings with support for multiple images.""" dtype = dtype or self.text_encoder.dtype @@ -319,6 +327,22 @@ def _get_qwen_prompt_embeds( template = self.prompt_template_encode drop_idx = self.prompt_template_encode_start_idx txt = [template.format(base_img_prompt + e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # The processor expands image placeholders into many vision tokens. + # `max_sequence_length` should guard the prompt text length before that + # multimodal expansion happens. + validate_qwen_prompt_sequence_lengths( + txt_tokens.attention_mask, + drop_idx=drop_idx, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + ) # Use processor to handle both text and image inputs model_inputs = self.processor( @@ -360,6 +384,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -379,7 +404,12 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + image, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -557,7 +587,7 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 1024, ) -> DiffusionOutput: """Forward pass for image editing with support for multiple images.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -692,6 +722,7 @@ def forward( prompt_embeds_mask=negative_prompt_embeds_mask, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) num_channels_latents = self.transformer.in_channels // 4 diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index df59b463960..1d42224cb63 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -31,6 +31,9 @@ from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, ) +from vllm_omni.diffusion.models.qwen_image.prompt_utils import ( + validate_qwen_prompt_sequence_lengths, +) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) @@ -340,8 +343,11 @@ def check_inputs( "generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " + f"{max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -356,6 +362,8 @@ def _get_qwen_prompt_embeds( prompt: str | list[str] | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): device = device or self.device dtype = dtype or self.text_encoder.dtype @@ -368,8 +376,16 @@ def _get_qwen_prompt_embeds( txt_tokens = self.tokenizer( txt, padding=True, + truncation=False, return_tensors="pt", ).to(device) + validate_qwen_prompt_sequence_lengths( + txt_tokens.attention_mask, + drop_idx=drop_idx, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, @@ -399,6 +415,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -419,7 +436,11 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -603,7 +624,7 @@ def forward( negative_prompt_embeds_mask: torch.Tensor | None = None, output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, - max_sequence_length: int = 512, + max_sequence_length: int = 1024, resolution: int = 640, cfg_normalize: bool = False, use_en_prompt: bool = False, @@ -736,6 +757,7 @@ def forward( device=self.device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) # 4. Prepare latent variables diff --git a/vllm_omni/diffusion/models/qwen_image/prompt_utils.py b/vllm_omni/diffusion/models/qwen_image/prompt_utils.py new file mode 100644 index 00000000000..b7bdc7e091b --- /dev/null +++ b/vllm_omni/diffusion/models/qwen_image/prompt_utils.py @@ -0,0 +1,24 @@ +import torch + + +def validate_qwen_prompt_sequence_lengths( + attention_mask: torch.Tensor, + *, + drop_idx: int, + max_sequence_length: int, + supported_max_sequence_length: int, + prompt_name: str = "prompt", +) -> None: + sequence_lengths = torch.clamp(attention_mask.sum(dim=1) - drop_idx, min=0) + too_long = torch.nonzero(sequence_lengths > max_sequence_length, as_tuple=False) + if too_long.numel() == 0: + return + + batch_idx = int(too_long[0].item()) + actual_length = int(sequence_lengths[batch_idx].item()) + prompt_ref = f"`{prompt_name}` at batch index {batch_idx}" if attention_mask.shape[0] > 1 else f"`{prompt_name}`" + raise ValueError( + f"{prompt_ref} is too long after applying the Qwen prompt template: got {actual_length} tokens, but " + f"`max_sequence_length` is {max_sequence_length}. Shorten the prompt or increase " + f"`max_sequence_length` up to {supported_max_sequence_length}." + ) From efde2823dc096f7f5b47169dce30c0f298f90b90 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 09:34:28 +0000 Subject: [PATCH 02/34] [Fix] enforce Wan2.2 max_sequence_length before encoding Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 281e14a3e0c6a6d63de18c9cbf28a1f773130ad8) --- .../wan2_2/test_wan22_max_sequence_length.py | 145 ++++++++++++++++++ .../models/wan2_2/pipeline_wan2_2.py | 45 +++++- .../models/wan2_2/pipeline_wan2_2_i2v.py | 45 +++++- .../models/wan2_2/pipeline_wan2_2_ti2v.py | 45 +++++- .../diffusion/models/wan2_2/prompt_utils.py | 23 +++ 5 files changed, 297 insertions(+), 6 deletions(-) create mode 100644 tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py create mode 100644 vllm_omni/diffusion/models/wan2_2/prompt_utils.py diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py new file mode 100644 index 00000000000..04d39406d6d --- /dev/null +++ b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py @@ -0,0 +1,145 @@ +from types import SimpleNamespace + +import PIL.Image +import pytest +import torch +from torch import nn + +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + WAN22_MAX_SEQUENCE_LENGTH, + Wan22Pipeline, +) +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import ( + Wan22I2VPipeline, +) +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import ( + Wan22TI2VPipeline, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class _RejectingTextEncoder: + dtype = torch.float32 + + def __call__(self, *args, **kwargs): + raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length") + + +class _FakeTokenBatch: + def __init__(self, total_sequence_length: int): + attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long) + self.input_ids = attention_mask.clone() + self.attention_mask = attention_mask + + +class _FakeTokenizer: + def __init__(self, total_sequence_length: int): + self.total_sequence_length = total_sequence_length + + def __call__(self, *args, **kwargs): + return _FakeTokenBatch(self.total_sequence_length) + + +PIPELINE_CASES = [ + pytest.param(Wan22Pipeline, id="wan22-t2v"), + pytest.param(Wan22I2VPipeline, id="wan22-i2v"), + pytest.param(Wan22TI2VPipeline, id="wan22-ti2v"), +] + + +def _make_pipeline(pipeline_class: type, *, total_sequence_length: int): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer = _FakeTokenizer(total_sequence_length) + pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH + return pipeline + + +@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) +def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(pipeline_class: type): + pipeline = _make_pipeline(pipeline_class, total_sequence_length=WAN22_MAX_SEQUENCE_LENGTH + 1) + + with pytest.raises(ValueError, match=r"got 513 tokens, but `max_sequence_length` is 512"): + pipeline.encode_prompt(prompt="prompt") + + +@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) +def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(pipeline_class: type): + pipeline = _make_pipeline(pipeline_class, total_sequence_length=17) + + with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): + pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) + + +def _sampling_params(**overrides): + defaults = dict( + height=None, + width=None, + num_frames=None, + num_inference_steps=None, + generator=None, + guidance_scale_provided=False, + guidance_scale_2=None, + boundary_ratio=None, + num_outputs_per_prompt=0, + max_sequence_length=None, + seed=None, + extra_args={}, + prompt_embeds=None, + negative_prompt_embeds=None, + ) + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +@pytest.mark.parametrize( + ("pipeline_class", "prompt_value", "forward_kwargs"), + [ + pytest.param(Wan22Pipeline, "prompt", {}, id="wan22-t2v"), + pytest.param( + Wan22I2VPipeline, + {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, + {"image": PIL.Image.new("RGB", (64, 64))}, + id="wan22-i2v", + ), + pytest.param( + Wan22TI2VPipeline, + {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, + {"image": PIL.Image.new("RGB", (64, 64))}, + id="wan22-ti2v", + ), + ], +) +def test_forward_defaults_to_wan22_tokenizer_max_length( + pipeline_class: type, + prompt_value, + forward_kwargs, +): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH + pipeline.boundary_ratio = None + pipeline.vae_scale_factor_temporal = 4 + pipeline.vae_scale_factor_spatial = 8 + pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2)) + + captured = {} + + def _fake_check_inputs(*args, **kwargs): + captured["max_sequence_length"] = kwargs["max_sequence_length"] + raise RuntimeError("stop after capture") + + pipeline.check_inputs = _fake_check_inputs + + req = SimpleNamespace( + prompts=[prompt_value], + sampling_params=_sampling_params(), + ) + + with pytest.raises(RuntimeError, match="stop after capture"): + pipeline.forward(req, **forward_kwargs) + + assert captured["max_sequence_length"] == WAN22_MAX_SEQUENCE_LENGTH diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 4b8d01dfa9c..5116a8b6149 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -24,6 +24,9 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.wan2_2.prompt_utils import ( + validate_wan_prompt_sequence_lengths, +) from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin @@ -33,6 +36,7 @@ logger = logging.getLogger(__name__) DEBUG_PERF = False +WAN22_MAX_SEQUENCE_LENGTH = 512 def retrieve_latents( @@ -248,6 +252,7 @@ def __init__( pass self.boundary_ratio = od_config.boundary_ratio + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH # Determine which transformers to load based on boundary_ratio # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) @@ -423,6 +428,7 @@ def forward( negative_prompt_embeds=negative_prompt_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -456,7 +462,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -753,6 +759,19 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + text_inputs_untruncated = self.tokenizer( + prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_wan_prompt_sequence_lengths( + text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + ) text_inputs = self.tokenizer( prompt_clean, @@ -781,8 +800,23 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_wan_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + ) neg_text_inputs = self.tokenizer( - [self._prompt_clean(p) for p in negative_prompt], + negative_prompt_clean, padding="max_length", max_length=max_sequence_length, truncation=True, @@ -854,6 +888,7 @@ def check_inputs( negative_prompt_embeds=None, guidance_scale_2=None, boundary_ratio=None, + max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -880,5 +915,11 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " + f"{max_sequence_length}" + ) + if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 112dfdfdda1..8eb11265f43 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -26,10 +26,14 @@ from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + WAN22_MAX_SEQUENCE_LENGTH, create_transformer_from_config, load_transformer_config, retrieve_latents, ) +from vllm_omni.diffusion.models.wan2_2.prompt_utils import ( + validate_wan_prompt_sequence_lengths, +) from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -211,6 +215,7 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -394,6 +399,7 @@ def forward( image_embeds=image_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -422,7 +428,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -666,6 +672,19 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + text_inputs_untruncated = self.tokenizer( + prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_wan_prompt_sequence_lengths( + text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + ) text_inputs = self.tokenizer( prompt_clean, @@ -694,8 +713,23 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_wan_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + ) neg_text_inputs = self.tokenizer( - [self._prompt_clean(p) for p in negative_prompt], + negative_prompt_clean, padding="max_length", max_length=max_sequence_length, truncation=True, @@ -837,6 +871,7 @@ def check_inputs( image_embeds=None, guidance_scale_2=None, boundary_ratio=None, + max_sequence_length=None, ): if image is None and image_embeds is None: raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.") @@ -858,6 +893,12 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " + f"{max_sequence_length}" + ) + if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 4953a809d80..f0eccae99ba 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -38,10 +38,14 @@ from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + WAN22_MAX_SEQUENCE_LENGTH, create_transformer_from_config, load_transformer_config, retrieve_latents, ) +from vllm_omni.diffusion.models.wan2_2.prompt_utils import ( + validate_wan_prompt_sequence_lengths, +) from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniTextPrompt @@ -183,6 +187,7 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -303,6 +308,7 @@ def forward( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -326,7 +332,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -513,6 +519,19 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + text_inputs_untruncated = self.tokenizer( + prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_wan_prompt_sequence_lengths( + text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + ) text_inputs = self.tokenizer( prompt_clean, @@ -541,8 +560,23 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_wan_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + ) neg_text_inputs = self.tokenizer( - [self._prompt_clean(p) for p in negative_prompt], + negative_prompt_clean, padding="max_length", max_length=max_sequence_length, truncation=True, @@ -667,6 +701,7 @@ def check_inputs( width, prompt_embeds=None, negative_prompt_embeds=None, + max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -682,6 +717,12 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " + f"{max_sequence_length}" + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) diff --git a/vllm_omni/diffusion/models/wan2_2/prompt_utils.py b/vllm_omni/diffusion/models/wan2_2/prompt_utils.py new file mode 100644 index 00000000000..23aec7e76f8 --- /dev/null +++ b/vllm_omni/diffusion/models/wan2_2/prompt_utils.py @@ -0,0 +1,23 @@ +import torch + + +def validate_wan_prompt_sequence_lengths( + attention_mask: torch.Tensor, + *, + max_sequence_length: int, + supported_max_sequence_length: int, + prompt_name: str = "prompt", +) -> None: + sequence_lengths = attention_mask.sum(dim=1) + too_long = torch.nonzero(sequence_lengths > max_sequence_length, as_tuple=False) + if too_long.numel() == 0: + return + + batch_idx = int(too_long[0].item()) + actual_length = int(sequence_lengths[batch_idx].item()) + prompt_ref = f"`{prompt_name}` at batch index {batch_idx}" if attention_mask.shape[0] > 1 else f"`{prompt_name}`" + raise ValueError( + f"{prompt_ref} is too long for Wan2.2 text encoding: got {actual_length} tokens, but " + f"`max_sequence_length` is {max_sequence_length}. Shorten the prompt or increase " + f"`max_sequence_length` up to {supported_max_sequence_length}." + ) From cae196ae368247e106da5e94ef6bac5d15c83f35 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 09:39:30 +0000 Subject: [PATCH 03/34] [Fix] apply pre-commit formatting Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 66151f00b9290ad9875c0e9d86014aca63bca56d) --- vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py | 3 +-- .../diffusion/models/qwen_image/pipeline_qwen_image_edit.py | 3 +-- .../models/qwen_image/pipeline_qwen_image_edit_plus.py | 3 +-- .../diffusion/models/qwen_image/pipeline_qwen_image_layered.py | 3 +-- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 3 +-- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 3 +-- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py | 3 +-- 7 files changed, 7 insertions(+), 14 deletions(-) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index f59c31528b2..ed26bc799a4 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -368,8 +368,7 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " - f"{max_sequence_length}" + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 056ef4ecdfa..4ab65da9543 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -328,8 +328,7 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " - f"{max_sequence_length}" + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 0523c51db75..0b848196604 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -288,8 +288,7 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " - f"{max_sequence_length}" + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 1d42224cb63..bb974178bf4 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -345,8 +345,7 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " - f"{max_sequence_length}" + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 5116a8b6149..3c1379ca2ee 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -917,8 +917,7 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " - f"{max_sequence_length}" + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" ) if boundary_ratio is None and guidance_scale_2 is not None: diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 8eb11265f43..45ad7f87399 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -895,8 +895,7 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " - f"{max_sequence_length}" + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" ) if boundary_ratio is None and guidance_scale_2 is not None: diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index f0eccae99ba..91c40333c9a 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -719,8 +719,7 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is " - f"{max_sequence_length}" + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: From 0ec989b363976905bf7b50e04b0b1dbbde38ceb5 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 09:47:44 +0000 Subject: [PATCH 04/34] [Refactor] merge diffusion prompt length validators Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 1e8fa70ab1622a40ffec51466b713638c679ef24) --- .../models/qwen_image/pipeline_qwen_image.py | 11 +++++---- .../qwen_image/pipeline_qwen_image_edit.py | 11 +++++---- .../pipeline_qwen_image_edit_plus.py | 11 +++++---- .../qwen_image/pipeline_qwen_image_layered.py | 11 +++++---- .../models/wan2_2/pipeline_wan2_2.py | 12 ++++++---- .../models/wan2_2/pipeline_wan2_2_i2v.py | 12 ++++++---- .../models/wan2_2/pipeline_wan2_2_ti2v.py | 12 ++++++---- .../diffusion/models/wan2_2/prompt_utils.py | 23 ------------------- .../qwen_image => utils}/prompt_utils.py | 9 ++++---- 9 files changed, 50 insertions(+), 62 deletions(-) delete mode 100644 vllm_omni/diffusion/models/wan2_2/prompt_utils.py rename vllm_omni/diffusion/{models/qwen_image => utils}/prompt_utils.py (73%) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index ed26bc799a4..cc4e7108979 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -29,14 +29,14 @@ from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, ) -from vllm_omni.diffusion.models.qwen_image.prompt_utils import ( - validate_qwen_prompt_sequence_lengths, -) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -399,12 +399,13 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(self.device) - validate_qwen_prompt_sequence_lengths( + validate_prompt_sequence_lengths( txt_tokens.attention_mask, - drop_idx=drop_idx, max_sequence_length=max_sequence_length or self.tokenizer_max_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name=prompt_name, + length_offset=drop_idx, + error_context="after applying the Qwen prompt template", ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 4ab65da9543..9572e3c8ba6 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -32,14 +32,14 @@ QwenImageCFGParallelMixin, ) from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift -from vllm_omni.diffusion.models.qwen_image.prompt_utils import ( - validate_qwen_prompt_sequence_lengths, -) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -410,12 +410,13 @@ def _get_qwen_prompt_embeds( # inside the processor. `max_sequence_length` is meant to constrain the # prompt text length, so validate on the text template before image # token expansion. - validate_qwen_prompt_sequence_lengths( + validate_prompt_sequence_lengths( txt_tokens.attention_mask, - drop_idx=drop_idx, max_sequence_length=max_sequence_length or self.tokenizer_max_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name=prompt_name, + length_offset=drop_idx, + error_context="after applying the Qwen prompt template", ) # Use processor to handle both text and image inputs diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 0b848196604..9495ce4003d 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -35,14 +35,14 @@ retrieve_latents, retrieve_timesteps, ) -from vllm_omni.diffusion.models.qwen_image.prompt_utils import ( - validate_qwen_prompt_sequence_lengths, -) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -335,12 +335,13 @@ def _get_qwen_prompt_embeds( # The processor expands image placeholders into many vision tokens. # `max_sequence_length` should guard the prompt text length before that # multimodal expansion happens. - validate_qwen_prompt_sequence_lengths( + validate_prompt_sequence_lengths( txt_tokens.attention_mask, - drop_idx=drop_idx, max_sequence_length=max_sequence_length or self.tokenizer_max_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name=prompt_name, + length_offset=drop_idx, + error_context="after applying the Qwen prompt template", ) # Use processor to handle both text and image inputs diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index bb974178bf4..91e83128fe9 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -31,14 +31,14 @@ from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, ) -from vllm_omni.diffusion.models.qwen_image.prompt_utils import ( - validate_qwen_prompt_sequence_lengths, -) from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -378,12 +378,13 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(device) - validate_qwen_prompt_sequence_lengths( + validate_prompt_sequence_lengths( txt_tokens.attention_mask, - drop_idx=drop_idx, max_sequence_length=max_sequence_length or self.tokenizer_max_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name=prompt_name, + length_offset=drop_idx, + error_context="after applying the Qwen prompt template", ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 3c1379ca2ee..753081067b0 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -24,13 +24,13 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler -from vllm_omni.diffusion.models.wan2_2.prompt_utils import ( - validate_wan_prompt_sequence_lengths, -) from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -767,10 +767,11 @@ def encode_prompt( return_attention_mask=True, return_tensors="pt", ) - validate_wan_prompt_sequence_lengths( + validate_prompt_sequence_lengths( text_inputs_untruncated.attention_mask, max_sequence_length=max_sequence_length, supported_max_sequence_length=self.tokenizer_max_length, + error_context="for Wan2.2 text encoding", ) text_inputs = self.tokenizer( @@ -809,11 +810,12 @@ def encode_prompt( return_attention_mask=True, return_tensors="pt", ) - validate_wan_prompt_sequence_lengths( + validate_prompt_sequence_lengths( neg_text_inputs_untruncated.attention_mask, max_sequence_length=max_sequence_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", ) neg_text_inputs = self.tokenizer( negative_prompt_clean, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 45ad7f87399..9438bca96f0 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -31,12 +31,12 @@ load_transformer_config, retrieve_latents, ) -from vllm_omni.diffusion.models.wan2_2.prompt_utils import ( - validate_wan_prompt_sequence_lengths, -) from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -680,10 +680,11 @@ def encode_prompt( return_attention_mask=True, return_tensors="pt", ) - validate_wan_prompt_sequence_lengths( + validate_prompt_sequence_lengths( text_inputs_untruncated.attention_mask, max_sequence_length=max_sequence_length, supported_max_sequence_length=self.tokenizer_max_length, + error_context="for Wan2.2 text encoding", ) text_inputs = self.tokenizer( @@ -722,11 +723,12 @@ def encode_prompt( return_attention_mask=True, return_tensors="pt", ) - validate_wan_prompt_sequence_lengths( + validate_prompt_sequence_lengths( neg_text_inputs_untruncated.attention_mask, max_sequence_length=max_sequence_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", ) neg_text_inputs = self.tokenizer( negative_prompt_clean, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 91c40333c9a..b829678e2ad 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -43,11 +43,11 @@ load_transformer_config, retrieve_latents, ) -from vllm_omni.diffusion.models.wan2_2.prompt_utils import ( - validate_wan_prompt_sequence_lengths, -) from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -527,10 +527,11 @@ def encode_prompt( return_attention_mask=True, return_tensors="pt", ) - validate_wan_prompt_sequence_lengths( + validate_prompt_sequence_lengths( text_inputs_untruncated.attention_mask, max_sequence_length=max_sequence_length, supported_max_sequence_length=self.tokenizer_max_length, + error_context="for Wan2.2 text encoding", ) text_inputs = self.tokenizer( @@ -569,11 +570,12 @@ def encode_prompt( return_attention_mask=True, return_tensors="pt", ) - validate_wan_prompt_sequence_lengths( + validate_prompt_sequence_lengths( neg_text_inputs_untruncated.attention_mask, max_sequence_length=max_sequence_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", ) neg_text_inputs = self.tokenizer( negative_prompt_clean, diff --git a/vllm_omni/diffusion/models/wan2_2/prompt_utils.py b/vllm_omni/diffusion/models/wan2_2/prompt_utils.py deleted file mode 100644 index 23aec7e76f8..00000000000 --- a/vllm_omni/diffusion/models/wan2_2/prompt_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - - -def validate_wan_prompt_sequence_lengths( - attention_mask: torch.Tensor, - *, - max_sequence_length: int, - supported_max_sequence_length: int, - prompt_name: str = "prompt", -) -> None: - sequence_lengths = attention_mask.sum(dim=1) - too_long = torch.nonzero(sequence_lengths > max_sequence_length, as_tuple=False) - if too_long.numel() == 0: - return - - batch_idx = int(too_long[0].item()) - actual_length = int(sequence_lengths[batch_idx].item()) - prompt_ref = f"`{prompt_name}` at batch index {batch_idx}" if attention_mask.shape[0] > 1 else f"`{prompt_name}`" - raise ValueError( - f"{prompt_ref} is too long for Wan2.2 text encoding: got {actual_length} tokens, but " - f"`max_sequence_length` is {max_sequence_length}. Shorten the prompt or increase " - f"`max_sequence_length` up to {supported_max_sequence_length}." - ) diff --git a/vllm_omni/diffusion/models/qwen_image/prompt_utils.py b/vllm_omni/diffusion/utils/prompt_utils.py similarity index 73% rename from vllm_omni/diffusion/models/qwen_image/prompt_utils.py rename to vllm_omni/diffusion/utils/prompt_utils.py index b7bdc7e091b..dea28ce2374 100644 --- a/vllm_omni/diffusion/models/qwen_image/prompt_utils.py +++ b/vllm_omni/diffusion/utils/prompt_utils.py @@ -1,15 +1,16 @@ import torch -def validate_qwen_prompt_sequence_lengths( +def validate_prompt_sequence_lengths( attention_mask: torch.Tensor, *, - drop_idx: int, max_sequence_length: int, supported_max_sequence_length: int, prompt_name: str = "prompt", + length_offset: int = 0, + error_context: str, ) -> None: - sequence_lengths = torch.clamp(attention_mask.sum(dim=1) - drop_idx, min=0) + sequence_lengths = torch.clamp(attention_mask.sum(dim=1) - length_offset, min=0) too_long = torch.nonzero(sequence_lengths > max_sequence_length, as_tuple=False) if too_long.numel() == 0: return @@ -18,7 +19,7 @@ def validate_qwen_prompt_sequence_lengths( actual_length = int(sequence_lengths[batch_idx].item()) prompt_ref = f"`{prompt_name}` at batch index {batch_idx}" if attention_mask.shape[0] > 1 else f"`{prompt_name}`" raise ValueError( - f"{prompt_ref} is too long after applying the Qwen prompt template: got {actual_length} tokens, but " + f"{prompt_ref} is too long {error_context}: got {actual_length} tokens, but " f"`max_sequence_length` is {max_sequence_length}. Shorten the prompt or increase " f"`max_sequence_length` up to {supported_max_sequence_length}." ) From 973addeeaa60fbfdc110e32adf3e13912a895d0d Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 11:27:26 +0000 Subject: [PATCH 05/34] [Fix] exclude Qwen template overhead from prompt length checks Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 0a6d6180f0111778229a28739066c18448dcdf72) --- .../test_qwen_image_max_sequence_length.py | 64 +++++++++++++++++-- .../models/qwen_image/pipeline_qwen_image.py | 8 ++- .../qwen_image/pipeline_qwen_image_edit.py | 8 ++- .../pipeline_qwen_image_edit_plus.py | 8 ++- .../qwen_image/pipeline_qwen_image_layered.py | 8 ++- vllm_omni/diffusion/utils/prompt_utils.py | 11 +++- 6 files changed, 95 insertions(+), 12 deletions(-) diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py index f5491388e7a..6843f4d5a2b 100644 --- a/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py +++ b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py @@ -39,11 +39,18 @@ def to(self, device): class _FakeTokenizer: - def __init__(self, total_sequence_length: int): - self.total_sequence_length = total_sequence_length + def __init__(self, total_sequence_length: int | list[int]): + if isinstance(total_sequence_length, list): + self.total_sequence_lengths = list(total_sequence_length) + else: + self.total_sequence_lengths = [total_sequence_length] def __call__(self, *args, **kwargs): - return _FakeModelInputs(self.total_sequence_length) + if len(self.total_sequence_lengths) > 1: + total_sequence_length = self.total_sequence_lengths.pop(0) + else: + total_sequence_length = self.total_sequence_lengths[0] + return _FakeModelInputs(total_sequence_length) class _FakeProcessor(_FakeTokenizer): @@ -80,7 +87,7 @@ def _make_pipeline( pipeline.tokenizer_max_length = 1024 pipeline.prompt_template_encode = "{}" pipeline.prompt_template_encode_start_idx = drop_idx - pipeline.tokenizer = _FakeTokenizer(total_sequence_length) + pipeline.tokenizer = _FakeTokenizer([total_sequence_length, 0]) if input_kind == "processor": pipeline.processor = _FakeProcessor(total_sequence_length) return pipeline @@ -94,7 +101,7 @@ def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length( ): pipeline = _make_pipeline( pipeline_class, - total_sequence_length=drop_idx + 1025, + total_sequence_length=1025, drop_idx=drop_idx, input_kind=input_kind, ) @@ -111,7 +118,7 @@ def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length( ): pipeline = _make_pipeline( pipeline_class, - total_sequence_length=drop_idx + 17, + total_sequence_length=17, drop_idx=drop_idx, input_kind=input_kind, ) @@ -188,13 +195,56 @@ def test_edit_pipelines_validate_text_prompt_length_before_image_token_expansion pipeline.tokenizer_max_length = 1024 pipeline.prompt_template_encode = "{}" pipeline.prompt_template_encode_start_idx = drop_idx - pipeline.tokenizer = _FakeTokenizer(drop_idx + 8) + pipeline.tokenizer = _FakeTokenizer([8, 0]) pipeline.processor = _FakeProcessor(drop_idx + 1500) with pytest.raises(AssertionError, match="text encoder should not run"): pipeline.encode_prompt(prompt="short prompt") +@pytest.mark.parametrize( + "pipeline_class", + [ + pytest.param(QwenImagePipeline, id="qwen-image"), + pytest.param(QwenImageLayeredPipeline, id="qwen-image-layered"), + ], +) +def test_qwen_generation_validator_excludes_template_suffix_from_budget(pipeline_class: type): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer_max_length = 1024 + pipeline.prompt_template_encode = "{}" + pipeline.prompt_template_encode_start_idx = 34 + pipeline.tokenizer = _FakeTokenizer([1029, 5]) + + with pytest.raises(AssertionError, match="text encoder should not run"): + pipeline.encode_prompt(prompt="boundary prompt") + + +@pytest.mark.parametrize( + "pipeline_class", + [ + pytest.param(QwenImageEditPipeline, id="qwen-image-edit"), + pytest.param(QwenImageEditPlusPipeline, id="qwen-image-edit-plus"), + ], +) +def test_qwen_edit_validator_excludes_image_placeholders_from_budget(pipeline_class: type): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RejectingTextEncoder() + pipeline.tokenizer_max_length = 1024 + pipeline.prompt_template_encode = "{}" + pipeline.prompt_template_encode_start_idx = 64 + pipeline.tokenizer = _FakeTokenizer([30, 20]) + pipeline.processor = _FakeProcessor(1500) + + with pytest.raises(AssertionError, match="text encoder should not run"): + pipeline.encode_prompt(prompt="short prompt") + + @pytest.mark.parametrize( "pipeline_class", [ diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index cc4e7108979..8a6fd080322 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -399,12 +399,18 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(self.device) + template_tokens = self.tokenizer( + [template.format("")], + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) validate_prompt_sequence_lengths( txt_tokens.attention_mask, max_sequence_length=max_sequence_length or self.tokenizer_max_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name=prompt_name, - length_offset=drop_idx, + baseline_attention_mask=template_tokens.attention_mask, error_context="after applying the Qwen prompt template", ) encoder_hidden_states = self.text_encoder( diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index 9572e3c8ba6..def0916762c 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -406,6 +406,12 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(self.device) + template_tokens = self.tokenizer( + [template.format("")], + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) # Qwen-Image-Edit expands image placeholders into many vision tokens # inside the processor. `max_sequence_length` is meant to constrain the # prompt text length, so validate on the text template before image @@ -415,7 +421,7 @@ def _get_qwen_prompt_embeds( max_sequence_length=max_sequence_length or self.tokenizer_max_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name=prompt_name, - length_offset=drop_idx, + baseline_attention_mask=template_tokens.attention_mask, error_context="after applying the Qwen prompt template", ) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 9495ce4003d..7de2e3cdb60 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -332,6 +332,12 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(self.device) + template_tokens = self.tokenizer( + [template.format(base_img_prompt)], + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) # The processor expands image placeholders into many vision tokens. # `max_sequence_length` should guard the prompt text length before that # multimodal expansion happens. @@ -340,7 +346,7 @@ def _get_qwen_prompt_embeds( max_sequence_length=max_sequence_length or self.tokenizer_max_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name=prompt_name, - length_offset=drop_idx, + baseline_attention_mask=template_tokens.attention_mask, error_context="after applying the Qwen prompt template", ) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 91e83128fe9..07b564d2134 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -378,12 +378,18 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(device) + template_tokens = self.tokenizer( + [template.format("")], + padding=True, + truncation=False, + return_tensors="pt", + ).to(device) validate_prompt_sequence_lengths( txt_tokens.attention_mask, max_sequence_length=max_sequence_length or self.tokenizer_max_length, supported_max_sequence_length=self.tokenizer_max_length, prompt_name=prompt_name, - length_offset=drop_idx, + baseline_attention_mask=template_tokens.attention_mask, error_context="after applying the Qwen prompt template", ) encoder_hidden_states = self.text_encoder( diff --git a/vllm_omni/diffusion/utils/prompt_utils.py b/vllm_omni/diffusion/utils/prompt_utils.py index dea28ce2374..ab6bf558443 100644 --- a/vllm_omni/diffusion/utils/prompt_utils.py +++ b/vllm_omni/diffusion/utils/prompt_utils.py @@ -8,9 +8,18 @@ def validate_prompt_sequence_lengths( supported_max_sequence_length: int, prompt_name: str = "prompt", length_offset: int = 0, + baseline_attention_mask: torch.Tensor | None = None, error_context: str, ) -> None: - sequence_lengths = torch.clamp(attention_mask.sum(dim=1) - length_offset, min=0) + sequence_lengths = attention_mask.sum(dim=1) + if baseline_attention_mask is not None: + baseline_lengths = baseline_attention_mask.sum(dim=1) + if baseline_lengths.shape[0] == 1 and sequence_lengths.shape[0] > 1: + baseline_lengths = baseline_lengths.expand(sequence_lengths.shape[0]) + sequence_lengths = sequence_lengths - baseline_lengths + if length_offset: + sequence_lengths = sequence_lengths - length_offset + sequence_lengths = torch.clamp(sequence_lengths, min=0) too_long = torch.nonzero(sequence_lengths > max_sequence_length, as_tuple=False) if too_long.numel() == 0: return From 1f190d9ceca038cc221302cf83737689684b08c7 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 11:37:34 +0000 Subject: [PATCH 06/34] [Docs] annotate Qwen prompt length validation invariants Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit bd9bfaf28310db8df7d9a0de48b9e21c7fe5c9a4) --- vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py | 3 +++ .../diffusion/models/qwen_image/pipeline_qwen_image_edit.py | 3 +++ .../models/qwen_image/pipeline_qwen_image_edit_plus.py | 3 +++ .../models/qwen_image/pipeline_qwen_image_layered.py | 3 +++ vllm_omni/diffusion/utils/prompt_utils.py | 4 ++++ 5 files changed, 16 insertions(+) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 8a6fd080322..105445f31c7 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -399,6 +399,9 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(self.device) + # Validate only the user prompt contribution. The Qwen template also + # adds a fixed suffix after the user text, so subtracting only + # prompt_template_encode_start_idx would overcount near-limit prompts. template_tokens = self.tokenizer( [template.format("")], padding=True, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index def0916762c..f73cefade9b 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -406,6 +406,9 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(self.device) + # The edit template contains fixed multimodal scaffolding around the + # instruction. Validate against the empty-template baseline so image + # placeholder text does not consume the user's text budget. template_tokens = self.tokenizer( [template.format("")], padding=True, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 7de2e3cdb60..337d77c2e44 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -332,6 +332,9 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(self.device) + # Multi-image edit prepends "Picture N" placeholders before the user + # instruction. Subtract the placeholder-aware baseline so attached + # images do not reduce the remaining prompt budget. template_tokens = self.tokenizer( [template.format(base_img_prompt)], padding=True, diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 07b564d2134..f670d8044a1 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -378,6 +378,9 @@ def _get_qwen_prompt_embeds( truncation=False, return_tensors="pt", ).to(device) + # The layered template also appends fixed non-user tokens after the + # editable text, so use the empty-template tokenized baseline instead of + # counting everything after prompt_template_encode_start_idx. template_tokens = self.tokenizer( [template.format("")], padding=True, diff --git a/vllm_omni/diffusion/utils/prompt_utils.py b/vllm_omni/diffusion/utils/prompt_utils.py index ab6bf558443..fc1769f4d54 100644 --- a/vllm_omni/diffusion/utils/prompt_utils.py +++ b/vllm_omni/diffusion/utils/prompt_utils.py @@ -13,6 +13,10 @@ def validate_prompt_sequence_lengths( ) -> None: sequence_lengths = attention_mask.sum(dim=1) if baseline_attention_mask is not None: + # Some callers need to validate only the user-controlled portion of a + # templated prompt. In those cases we subtract the fully-tokenized + # template baseline instead of only removing a fixed prefix length, + # because the template may also contribute a suffix or image markers. baseline_lengths = baseline_attention_mask.sum(dim=1) if baseline_lengths.shape[0] == 1 and sequence_lengths.shape[0] > 1: baseline_lengths = baseline_lengths.expand(sequence_lengths.shape[0]) From 7eeafcfe34d28c55181c6e4fb5a1b216a07c2ce8 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 01:46:34 +0000 Subject: [PATCH 07/34] [Test] add pytest marks for max sequence UTs Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 21851d628fd52601fc37fcbb5a4f40dac860158e) --- .../models/qwen_image/test_qwen_image_max_sequence_length.py | 2 ++ tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py index 6843f4d5a2b..f5676a0056f 100644 --- a/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py +++ b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py @@ -18,6 +18,8 @@ QwenImageLayeredPipeline, ) +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + class _RejectingTextEncoder: dtype = torch.float32 diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py index 04d39406d6d..1b594b4b01e 100644 --- a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py +++ b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py @@ -18,6 +18,8 @@ pytestmark = [pytest.mark.core_model, pytest.mark.cpu] +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + class _RejectingTextEncoder: dtype = torch.float32 From b63a3b73e4e41480ecac8cb7577581268bad6d2c Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Tue, 14 Apr 2026 16:19:34 +0800 Subject: [PATCH 08/34] [Bugfix] Preserve default diffusion sampling params in default stage Signed-off-by: David Chen <530634352@qq.com> (cherry picked from commit 896b0b83e945c10672f0d65adb326f7b6c9a7b8e) --- tests/entrypoints/openai_api/test_image_server.py | 10 ++++++---- .../test_async_omni_diffusion_config.py | 14 ++++++++++++++ vllm_omni/engine/async_omni_engine.py | 11 +++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 8695311fb5b..dffa8440954 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -171,7 +171,7 @@ def test_client(mock_async_diffusion): app.state.stage_configs = [SimpleNamespace(stage_type="diffusion")] app.state.diffusion_model_name = "Qwen/Qwen-Image" # For models endpoint app.state.args = Namespace( - default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5}}', + default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1024 * 1792, ) @@ -194,7 +194,7 @@ def async_omni_test_client(): SimpleNamespace(stage_type="diffusion"), ] app.state.args = Namespace( - default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', + default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1048576, # 1024*1024 to support resolution tests ) return TestClient(app) @@ -216,7 +216,7 @@ def async_omni_rgba_test_client(): SimpleNamespace(stage_type="diffusion"), ] app.state.args = Namespace( - default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', + default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1048576, ) return TestClient(app) @@ -238,7 +238,7 @@ def async_omni_stage_configs_only_client(): # Intentionally do not populate app.state.stage_configs. Refactored # AsyncOmni exposes stage_configs on the engine instance. app.state.args = Namespace( - default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', + default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}', max_generated_image_size=1024 * 1792, ) return TestClient(app) @@ -954,6 +954,7 @@ def test_image_edit_parameter_default(async_omni_test_client): assert captured_sampling_params.num_outputs_per_prompt == 1 assert captured_sampling_params.num_inference_steps == 4 assert captured_sampling_params.guidance_scale == 7.5 + assert captured_sampling_params.generator_device == "cpu" # Test that a size exceeding max_generated_image_size returns 400 response = async_omni_test_client.post( @@ -987,6 +988,7 @@ def test_image_edit_parameter_default_single_stage(test_client): assert captured_sampling_params.num_outputs_per_prompt == 1 assert captured_sampling_params.num_inference_steps == 4 assert captured_sampling_params.guidance_scale == 7.5 + assert captured_sampling_params.generator_device == "cpu" # Size exceeding max_generated_image_size (1024*1792) returns 400 response = test_client.post( diff --git a/tests/entrypoints/test_async_omni_diffusion_config.py b/tests/entrypoints/test_async_omni_diffusion_config.py index ca5624f2d4c..1e38c4438f4 100644 --- a/tests/entrypoints/test_async_omni_diffusion_config.py +++ b/tests/entrypoints/test_async_omni_diffusion_config.py @@ -69,6 +69,20 @@ def test_default_stage_config_propagates_ulysses_mode(): assert parallel_config.ulysses_mode == "advanced_uaa" +def test_default_stage_config_includes_default_sampling_params(): + """Ensure default sampling params survive the default diffusion-stage builder.""" + stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg( + { + "default_sampling_params": '{"0": {"generator_device":"cpu", "guidance_scale":7.5}}', + } + )[0] + + assert stage_cfg["default_sampling_params"] == { + "generator_device": "cpu", + "guidance_scale": 7.5, + } + + def test_serve_cli_accepts_ulysses_mode(): """Ensure diffusion serve CLI exposes ulysses_mode and wires it to parallel_config.""" parser = FlexibleArgumentParser() diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 34d9d240441..2828b05d0cf 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -800,6 +800,16 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: # We temporally create a default config for diffusion stage. # In the future, we should merge the default config with the user-provided config. normalized_kwargs = dict(kwargs) + default_sampling_params = normalized_kwargs.get("default_sampling_params") + if isinstance(default_sampling_params, str): + try: + default_sampling_params = json.loads(default_sampling_params) + except json.JSONDecodeError: + logger.warning("Invalid default_sampling_params JSON, ignoring stage defaults.") + default_sampling_params = None + if not isinstance(default_sampling_params, dict): + default_sampling_params = None + stage_default_sampling_params = default_sampling_params.get("0", {}) if default_sampling_params else {} # TODO: hack, convert dtype to string to avoid non-premitive omegaconf create error. if "dtype" in normalized_kwargs and not isinstance(normalized_kwargs["dtype"], str): @@ -888,6 +898,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: else {} ), }, + "default_sampling_params": stage_default_sampling_params, "final_output": True, "final_output_type": "image", } From ee577d755ab473ed8325ddc4393ad529043cd955 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 07:13:26 +0000 Subject: [PATCH 09/34] Fix Qwen image edit plus prompt encoding memory Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 0e2f00989523f56d4eb19b948796f2c316a94220) --- .../qwen_image/test_qwen_image_edit_plus.py | 79 +++++++++++++++++++ .../pipeline_qwen_image_edit_plus.py | 13 +-- 2 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py new file mode 100644 index 00000000000..01644ada8b6 --- /dev/null +++ b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import ( + QwenImageEditPlusPipeline, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +class _FakeProcessorOutput(SimpleNamespace): + def to(self, _device: str): + return self + + +class _FakeProcessor: + def __call__(self, *, text, images, padding, return_tensors): + assert padding is True + assert return_tensors == "pt" + assert len(text) == 1 + assert len(images) == 2 + return _FakeProcessorOutput( + input_ids=torch.tensor([[1, 2, 3, 4]]), + attention_mask=torch.tensor([[1, 1, 1, 0]]), + pixel_values=torch.tensor([1.0]), + image_grid_thw=torch.tensor([[1, 1, 1]]), + ) + + +class _FakeTextEncoder: + dtype = torch.float32 + + def __init__(self) -> None: + self.model_calls = [] + + def __call__(self, *args, **kwargs): + raise AssertionError("full CausalLM forward should not be used for prompt encoding") + + def model(self, **kwargs): + self.model_calls.append(kwargs) + return SimpleNamespace( + last_hidden_state=torch.tensor( + [ + [ + [1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0], + [4.0, 4.0, 4.0], + ] + ] + ) + ) + + +def test_qwen_image_edit_plus_prompt_encoding_skips_lm_head(): + pipeline = QwenImageEditPlusPipeline.__new__(QwenImageEditPlusPipeline) + pipeline.device = "cpu" + pipeline.text_encoder = _FakeTextEncoder() + pipeline.processor = _FakeProcessor() + pipeline.prompt_template_encode = "{}" + pipeline.prompt_template_encode_start_idx = 1 + + prompt_embeds, attention_mask = pipeline._get_qwen_prompt_embeds( + prompt="combine", + image=[object(), object()], + ) + + assert len(pipeline.text_encoder.model_calls) == 1 + model_call = pipeline.text_encoder.model_calls[0] + assert model_call["output_hidden_states"] is False + assert model_call["return_dict"] is True + assert tuple(prompt_embeds.shape) == (1, 2, 3) + assert tuple(attention_mask.shape) == (1, 2) + assert torch.equal(prompt_embeds[0], torch.tensor([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])) + assert torch.equal(attention_mask[0], torch.tensor([1, 1])) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 337d77c2e44..548ed665b58 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -361,15 +361,18 @@ def _get_qwen_prompt_embeds( return_tensors="pt", ).to(self.device) - outputs = self.text_encoder( + # We only need the fused multimodal hidden states for diffusion conditioning. + # Calling the full CausalLM forward also materializes logits for the entire + # prompt, which becomes unnecessarily expensive for many-image edit prompts. + outputs = self.text_encoder.model( input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, - output_hidden_states=True, + output_hidden_states=False, + return_dict=True, ) - - hidden_states = outputs.hidden_states[-1] + hidden_states = outputs.last_hidden_state split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] @@ -381,7 +384,7 @@ def _get_qwen_prompt_embeds( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] ) - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device) return prompt_embeds, encoder_attention_mask From 4abda0cc7d187297cb5000cf909c0bb67259a7a9 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 07:17:30 +0000 Subject: [PATCH 10/34] Clamp Qwen image edit plus prompt length Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit eec0785a55371b79427b93a407353e679e9412f9) --- .../qwen_image/test_qwen_image_edit_plus.py | 18 ++++++++++++++++++ .../pipeline_qwen_image_edit_plus.py | 3 +++ 2 files changed, 21 insertions(+) diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py index 01644ada8b6..eae5c1b82ba 100644 --- a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py +++ b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py @@ -77,3 +77,21 @@ def test_qwen_image_edit_plus_prompt_encoding_skips_lm_head(): assert tuple(attention_mask.shape) == (1, 2) assert torch.equal(prompt_embeds[0], torch.tensor([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])) assert torch.equal(attention_mask[0], torch.tensor([1, 1])) + + +def test_qwen_image_edit_plus_encode_prompt_applies_max_sequence_length(): + pipeline = QwenImageEditPlusPipeline.__new__(QwenImageEditPlusPipeline) + pipeline._get_qwen_prompt_embeds = lambda *args, **kwargs: ( + torch.arange(24, dtype=torch.float32).view(1, 4, 6), + torch.tensor([[1, 1, 1, 1]]), + ) + + prompt_embeds, attention_mask = pipeline.encode_prompt( + prompt="combine", + image=[object()], + max_sequence_length=2, + ) + + assert tuple(prompt_embeds.shape) == (1, 2, 6) + assert tuple(attention_mask.shape) == (1, 2) + assert torch.equal(attention_mask[0], torch.tensor([1, 1])) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 548ed665b58..432916a7a6a 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -423,6 +423,9 @@ def encode_prompt( prompt_name=prompt_name, ) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) From d8f0ff95dc31925c36c717a6de9cd1f857940e8f Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 07:24:14 +0000 Subject: [PATCH 11/34] Limit Qwen image edit plus input images Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 72af6037cb30b13ba7918c470ba2ff9ab6069cc7) --- .../qwen_image/test_qwen_image_edit_plus.py | 105 ++++-------------- .../pipeline_qwen_image_edit_plus.py | 22 ++-- 2 files changed, 33 insertions(+), 94 deletions(-) diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py index eae5c1b82ba..7f17fa0da16 100644 --- a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py +++ b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py @@ -1,97 +1,36 @@ # SPDX-License-Identifier: Apache-2.0 +import json +from pathlib import Path from types import SimpleNamespace +import numpy as np import pytest -import torch +from PIL import Image from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import ( - QwenImageEditPlusPipeline, + get_qwen_image_edit_plus_pre_process_func, ) pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] -class _FakeProcessorOutput(SimpleNamespace): - def to(self, _device: str): - return self - - -class _FakeProcessor: - def __call__(self, *, text, images, padding, return_tensors): - assert padding is True - assert return_tensors == "pt" - assert len(text) == 1 - assert len(images) == 2 - return _FakeProcessorOutput( - input_ids=torch.tensor([[1, 2, 3, 4]]), - attention_mask=torch.tensor([[1, 1, 1, 0]]), - pixel_values=torch.tensor([1.0]), - image_grid_thw=torch.tensor([[1, 1, 1]]), - ) - - -class _FakeTextEncoder: - dtype = torch.float32 - - def __init__(self) -> None: - self.model_calls = [] - - def __call__(self, *args, **kwargs): - raise AssertionError("full CausalLM forward should not be used for prompt encoding") - - def model(self, **kwargs): - self.model_calls.append(kwargs) - return SimpleNamespace( - last_hidden_state=torch.tensor( - [ - [ - [1.0, 1.0, 1.0], - [2.0, 2.0, 2.0], - [3.0, 3.0, 3.0], - [4.0, 4.0, 4.0], - ] - ] - ) - ) - - -def test_qwen_image_edit_plus_prompt_encoding_skips_lm_head(): - pipeline = QwenImageEditPlusPipeline.__new__(QwenImageEditPlusPipeline) - pipeline.device = "cpu" - pipeline.text_encoder = _FakeTextEncoder() - pipeline.processor = _FakeProcessor() - pipeline.prompt_template_encode = "{}" - pipeline.prompt_template_encode_start_idx = 1 - - prompt_embeds, attention_mask = pipeline._get_qwen_prompt_embeds( - prompt="combine", - image=[object(), object()], - ) - - assert len(pipeline.text_encoder.model_calls) == 1 - model_call = pipeline.text_encoder.model_calls[0] - assert model_call["output_hidden_states"] is False - assert model_call["return_dict"] is True - assert tuple(prompt_embeds.shape) == (1, 2, 3) - assert tuple(attention_mask.shape) == (1, 2) - assert torch.equal(prompt_embeds[0], torch.tensor([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])) - assert torch.equal(attention_mask[0], torch.tensor([1, 1])) - - -def test_qwen_image_edit_plus_encode_prompt_applies_max_sequence_length(): - pipeline = QwenImageEditPlusPipeline.__new__(QwenImageEditPlusPipeline) - pipeline._get_qwen_prompt_embeds = lambda *args, **kwargs: ( - torch.arange(24, dtype=torch.float32).view(1, 4, 6), - torch.tensor([[1, 1, 1, 1]]), - ) - - prompt_embeds, attention_mask = pipeline.encode_prompt( - prompt="combine", - image=[object()], - max_sequence_length=2, +def test_qwen_image_edit_plus_rejects_too_many_input_images(tmp_path: Path): + vae_dir = tmp_path / "vae" + vae_dir.mkdir() + (vae_dir / "config.json").write_text(json.dumps({"z_dim": 16})) + + pre_process = get_qwen_image_edit_plus_pre_process_func(SimpleNamespace(model=str(tmp_path))) + image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)) + request = SimpleNamespace( + prompts=[ + { + "prompt": "combine", + "multi_modal_data": {"image": [image, image, image, image, image]}, + } + ], + sampling_params=SimpleNamespace(height=None, width=None), ) - assert tuple(prompt_embeds.shape) == (1, 2, 6) - assert tuple(attention_mask.shape) == (1, 2) - assert torch.equal(attention_mask[0], torch.tensor([1, 1])) + with pytest.raises(ValueError, match=r"At most 4 images are supported by this model"): + pre_process(request) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 432916a7a6a..bff5e11059e 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -56,6 +56,7 @@ CONDITION_IMAGE_SIZE = 384 * 384 VAE_IMAGE_SIZE = 1024 * 1024 +MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = 4 def get_qwen_image_edit_plus_pre_process_func( @@ -93,6 +94,11 @@ def pre_process_func( if not isinstance(raw_image, list): raw_image = [raw_image] + if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES: + raise ValueError( + f"Received {len(raw_image)} input images. " + f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model." + ) image = [ PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im) for im in raw_image @@ -361,18 +367,15 @@ def _get_qwen_prompt_embeds( return_tensors="pt", ).to(self.device) - # We only need the fused multimodal hidden states for diffusion conditioning. - # Calling the full CausalLM forward also materializes logits for the entire - # prompt, which becomes unnecessarily expensive for many-image edit prompts. - outputs = self.text_encoder.model( + outputs = self.text_encoder( input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, - output_hidden_states=False, - return_dict=True, + output_hidden_states=True, ) - hidden_states = outputs.last_hidden_state + + hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] @@ -384,7 +387,7 @@ def _get_qwen_prompt_embeds( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] ) - prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device) + prompt_embeds = prompt_embeds.to(dtype=dtype) return prompt_embeds, encoder_attention_mask @@ -423,9 +426,6 @@ def encode_prompt( prompt_name=prompt_name, ) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) From 2596b22549d299c7d8c017e53e53ffc1eb2c77fe Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 07:37:44 +0000 Subject: [PATCH 12/34] Translate omni image edit input errors Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit f1900fef6344e059dd3afa6b6db474c62e5859dd) --- vllm_omni/entrypoints/openai/api_server.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index cec9869d277..c2bdf4d932a 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1682,11 +1682,17 @@ async def _generate_with_async_omni( pass sampling_params_list.append(default_stage_params) - async for output in engine_client.generate( - sampling_params_list=sampling_params_list, - **kwargs, - ): - result = output + try: + async for output in engine_client.generate( + sampling_params_list=sampling_params_list, + **kwargs, + ): + result = output + except RuntimeError as e: + payload = e.args[0] if e.args else None + if isinstance(payload, dict) and "error" in payload: + raise ValueError(str(payload["error"])) from e + raise if result is None: raise HTTPException( From a1b3e50bfa2273e6ac3865f66fb4a54611dce28f Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 07:56:38 +0000 Subject: [PATCH 13/34] Validate image edit limits in API layer Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit c95d20ca881bc257d4b0420aa787d98a860958a5) --- .../openai_api/test_image_server.py | 24 ++++++++++++ vllm_omni/entrypoints/openai/api_server.py | 39 +++++++++++++------ 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index dffa8440954..c53396b6b38 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -776,6 +776,30 @@ def test_image_edit_rejects_multiple_images_when_model_does_not_support_them(asy assert engine.captured_prompt is None +def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511(async_omni_test_client): + engine = async_omni_test_client.app.state.engine_client + engine.get_diffusion_od_config = lambda: SimpleNamespace( + supports_multimodal_inputs=True, + model="Qwen/Qwen-Image-Edit-2511", + ) + + response = async_omni_test_client.post( + "/v1/images/edits", + files=[ + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ], + data={"prompt": "hello world."}, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model." + assert engine.captured_prompt is None + + def test_image_edit_parameter_pass(async_omni_test_client): img_bytes_1 = make_test_image_bytes((16, 16)) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index c2bdf4d932a..548492b4b05 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1438,6 +1438,12 @@ async def edit_images( status_code=HTTPStatus.BAD_REQUEST.value, detail="Received multiple input images. Only a single image is supported by this model.", ) + max_input_images = _get_max_edit_input_images(raw_request, engine_client, model_name) + if max_input_images is not None and len(pil_images) > max_input_images: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"Received {len(pil_images)} input images. At most {max_input_images} images are supported by this model.", + ) prompt["multi_modal_data"] = {} prompt["multi_modal_data"]["image"] = pil_images @@ -1615,6 +1621,23 @@ def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) return bool(getattr(od_config, "supports_multimodal_inputs", False)) +def _get_max_edit_input_images(raw_request: Request, engine_client: Any, model_name: str) -> int | None: + diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client + get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None) + od_config = ( + get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None) + ) + + model_identifiers = [model_name] + if od_config is not None: + model_identifiers.append(getattr(od_config, "model", None)) + + if any(isinstance(identifier, str) and "Qwen-Image-Edit-2511" in identifier for identifier in model_identifiers): + return 4 + + return None + + def _get_lora_from_json_str(lora_body): if lora_body is None: return None @@ -1682,17 +1705,11 @@ async def _generate_with_async_omni( pass sampling_params_list.append(default_stage_params) - try: - async for output in engine_client.generate( - sampling_params_list=sampling_params_list, - **kwargs, - ): - result = output - except RuntimeError as e: - payload = e.args[0] if e.args else None - if isinstance(payload, dict) and "error" in payload: - raise ValueError(str(payload["error"])) from e - raise + async for output in engine_client.generate( + sampling_params_list=sampling_params_list, + **kwargs, + ): + result = output if result is None: raise HTTPException( From a8763b37201d8716d01f2f01916cd1f94ce8d05b Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 08:03:08 +0000 Subject: [PATCH 14/34] Refine image edit input limit helpers Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 731c5361eb06915c1576f979351fa20dfa4bf02e) --- vllm_omni/entrypoints/openai/api_server.py | 32 +++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 548492b4b05..bd2c1566a64 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1433,16 +1433,16 @@ async def edit_images( if not input_images_list: raise HTTPException(status_code=422, detail="Field 'image' or 'url' is required") pil_images = await _load_input_images(input_images_list) - if len(pil_images) > 1 and not _supports_multimodal_image_inputs(raw_request, engine_client): - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Received multiple input images. Only a single image is supported by this model.", - ) max_input_images = _get_max_edit_input_images(raw_request, engine_client, model_name) if max_input_images is not None and len(pil_images) > max_input_images: + detail = ( + "Received multiple input images. Only a single image is supported by this model." + if max_input_images == 1 + else f"Received {len(pil_images)} input images. At most {max_input_images} images are supported by this model." + ) raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"Received {len(pil_images)} input images. At most {max_input_images} images are supported by this model.", + detail=detail, ) prompt["multi_modal_data"] = {} prompt["multi_modal_data"]["image"] = pil_images @@ -1608,11 +1608,7 @@ def _get_engine_and_model(raw_request: Request): def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) -> bool: - diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client - get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None) - od_config = ( - get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None) - ) + od_config = _get_diffusion_od_config(raw_request, engine_client) if od_config is None: # Preserve the existing compatibility behavior when the diffusion @@ -1621,16 +1617,20 @@ def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) return bool(getattr(od_config, "supports_multimodal_inputs", False)) -def _get_max_edit_input_images(raw_request: Request, engine_client: Any, model_name: str) -> int | None: +def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any: diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None) - od_config = ( + return ( get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None) ) - model_identifiers = [model_name] - if od_config is not None: - model_identifiers.append(getattr(od_config, "model", None)) + +def _get_max_edit_input_images(raw_request: Request, engine_client: Any, model_name: str) -> int | None: + if not _supports_multimodal_image_inputs(raw_request, engine_client): + return 1 + + od_config = _get_diffusion_od_config(raw_request, engine_client) + model_identifiers = [model_name, getattr(od_config, "model", None)] if any(isinstance(identifier, str) and "Qwen-Image-Edit-2511" in identifier for identifier in model_identifiers): return 4 From 09d8847c79b1a4e54e164d7f0d4e0f4481ea7c6f Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 08:08:42 +0000 Subject: [PATCH 15/34] Simplify image edit input limit helper Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 8c857c3f77596f96a9761881ed448588f9e3be2f) --- vllm_omni/entrypoints/openai/api_server.py | 24 +++++++--------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index bd2c1566a64..7d4737a8138 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1608,7 +1608,11 @@ def _get_engine_and_model(raw_request: Request): def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) -> bool: - od_config = _get_diffusion_od_config(raw_request, engine_client) + diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client + get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None) + od_config = ( + get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None) + ) if od_config is None: # Preserve the existing compatibility behavior when the diffusion @@ -1617,25 +1621,11 @@ def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) return bool(getattr(od_config, "supports_multimodal_inputs", False)) -def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any: - diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client - get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None) - return ( - get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None) - ) - - -def _get_max_edit_input_images(raw_request: Request, engine_client: Any, model_name: str) -> int | None: +def _get_max_edit_input_images(raw_request: Request, engine_client: Any, model_name: str) -> int: if not _supports_multimodal_image_inputs(raw_request, engine_client): return 1 - od_config = _get_diffusion_od_config(raw_request, engine_client) - model_identifiers = [model_name, getattr(od_config, "model", None)] - - if any(isinstance(identifier, str) and "Qwen-Image-Edit-2511" in identifier for identifier in model_identifiers): - return 4 - - return None + return 4 def _get_lora_from_json_str(lora_body): From 9acab70d431f35afb0e0942a0777799c09680675 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 08:12:56 +0000 Subject: [PATCH 16/34] Wrap image edit limit error message Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 0c25a06e8296e070ff079535d5cf2deac1579c14) --- vllm_omni/entrypoints/openai/api_server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 7d4737a8138..df3b0f0230a 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1438,7 +1438,10 @@ async def edit_images( detail = ( "Received multiple input images. Only a single image is supported by this model." if max_input_images == 1 - else f"Received {len(pil_images)} input images. At most {max_input_images} images are supported by this model." + else ( + f"Received {len(pil_images)} input images. " + f"At most {max_input_images} images are supported by this model." + ) ) raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, From 518640da8c2eb4aa014f766ca822b1b8f93a9ac0 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 09:40:54 +0000 Subject: [PATCH 17/34] Use model-specific image edit limits Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 826c74a001d881f2a05d5260b3b5f8c5b62370e1) --- .../pipeline_qwen_image_edit_plus.py | 3 ++ vllm_omni/entrypoints/openai/api_server.py | 30 ++++++++++++++----- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index bff5e11059e..2927b7cabf1 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -56,6 +56,9 @@ CONDITION_IMAGE_SIZE = 384 * 384 VAE_IMAGE_SIZE = 1024 * 1024 +# Keep this in sync with the practical conditioning-token budget for +# Qwen-Image-Edit-2511. Empirically, 4 images stays within the supported range +# while 5 images overflows the prompt/conditioning path and fails downstream. MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = 4 diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index df3b0f0230a..9eb02459ae1 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1611,11 +1611,7 @@ def _get_engine_and_model(raw_request: Request): def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) -> bool: - diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client - get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None) - od_config = ( - get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None) - ) + od_config = _get_diffusion_od_config(raw_request, engine_client) if od_config is None: # Preserve the existing compatibility behavior when the diffusion @@ -1624,11 +1620,31 @@ def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) return bool(getattr(od_config, "supports_multimodal_inputs", False)) -def _get_max_edit_input_images(raw_request: Request, engine_client: Any, model_name: str) -> int: +def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any: + diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client + get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None) + return ( + get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None) + ) + + +def _get_max_edit_input_images(raw_request: Request, engine_client: Any, model_name: str) -> int | None: if not _supports_multimodal_image_inputs(raw_request, engine_client): return 1 - return 4 + od_config = _get_diffusion_od_config(raw_request, engine_client) + model_identifiers = [model_name] + if od_config is not None: + model_identifiers.append(getattr(od_config, "model", None)) + + if any(isinstance(identifier, str) and "Qwen-Image-Edit-2511" in identifier for identifier in model_identifiers): + from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import ( + MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES, + ) + + return MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES + + return None def _get_lora_from_json_str(lora_body): From 4814a9a3be984ad2c740df78f67cf6bbad173c34 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 11:24:41 +0000 Subject: [PATCH 18/34] Reject over-limit image edits before loading Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 297d06b36b088df2ff05ee313060e7422cadc125) --- .../openai_api/test_image_server.py | 33 +++++++++++++++++++ vllm_omni/entrypoints/openai/api_server.py | 6 ++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index c53396b6b38..3c6516f082a 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -800,6 +800,39 @@ def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511(async_omni_ assert engine.captured_prompt is None +def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511_before_loading( + async_omni_test_client, monkeypatch: pytest.MonkeyPatch +): + import vllm_omni.entrypoints.openai.api_server as api_server_module + + engine = async_omni_test_client.app.state.engine_client + engine.get_diffusion_od_config = lambda: SimpleNamespace( + supports_multimodal_inputs=True, + model="Qwen/Qwen-Image-Edit-2511", + ) + + def _fail_load(*args, **kwargs): + raise AssertionError("_load_input_images should not run for over-limit requests") + + monkeypatch.setattr(api_server_module, "_load_input_images", _fail_load) + + response = async_omni_test_client.post( + "/v1/images/edits", + files=[ + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ("image", make_test_image_bytes((16, 16))), + ], + data={"prompt": "hello world."}, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model." + assert engine.captured_prompt is None + + def test_image_edit_parameter_pass(async_omni_test_client): img_bytes_1 = make_test_image_bytes((16, 16)) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 9eb02459ae1..6e903510fde 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1432,14 +1432,13 @@ async def edit_images( input_images_list.extend(urls) if not input_images_list: raise HTTPException(status_code=422, detail="Field 'image' or 'url' is required") - pil_images = await _load_input_images(input_images_list) max_input_images = _get_max_edit_input_images(raw_request, engine_client, model_name) - if max_input_images is not None and len(pil_images) > max_input_images: + if max_input_images is not None and len(input_images_list) > max_input_images: detail = ( "Received multiple input images. Only a single image is supported by this model." if max_input_images == 1 else ( - f"Received {len(pil_images)} input images. " + f"Received {len(input_images_list)} input images. " f"At most {max_input_images} images are supported by this model." ) ) @@ -1447,6 +1446,7 @@ async def edit_images( status_code=HTTPStatus.BAD_REQUEST.value, detail=detail, ) + pil_images = await _load_input_images(input_images_list) prompt["multi_modal_data"] = {} prompt["multi_modal_data"]["image"] = pil_images From 854fad6c838a1d400feb7f54c9e24692eb443ad6 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Thu, 16 Apr 2026 11:37:32 +0000 Subject: [PATCH 19/34] Document image edit limit handling Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 4ea227185b41c3a662dc835470e9138a58b3063f) --- vllm_omni/entrypoints/openai/api_server.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 6e903510fde..4aeb29e069e 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1432,6 +1432,9 @@ async def edit_images( input_images_list.extend(urls) if not input_images_list: raise HTTPException(status_code=422, detail="Field 'image' or 'url' is required") + # Reject oversized multi-image edit requests before fetching or decoding + # any inputs. This keeps over-limit URL requests from burning network, + # CPU, and memory on work that will be rejected anyway. max_input_images = _get_max_edit_input_images(raw_request, engine_client, model_name) if max_input_images is not None and len(input_images_list) > max_input_images: detail = ( @@ -1632,6 +1635,10 @@ def _get_max_edit_input_images(raw_request: Request, engine_client: Any, model_n if not _supports_multimodal_image_inputs(raw_request, engine_client): return 1 + # Keep the API-side limit model-specific: this helper should not hardcode a + # generic "multi-image means 4" rule because future edit pipelines may have + # different conditioning budgets. Query the serving model / OD config first, + # then defer to the owning pipeline constant. od_config = _get_diffusion_od_config(raw_request, engine_client) model_identifiers = [model_name] if od_config is not None: From 1a9dbd31851eb5743f2230cbf7bc32865251ca94 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 02:55:11 +0000 Subject: [PATCH 20/34] [Fix] Make image edit input limits config-driven Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit f414061fad9d74994a3a76645a954b5fa45bdd5e) --- .../qwen_image/test_qwen_image_edit_plus.py | 2 + .../openai_api/test_image_server.py | 4 +- tests/test_diffusion_config_propagation.py | 9 +++++ vllm_omni/diffusion/data.py | 12 +++++- vllm_omni/entrypoints/openai/api_server.py | 38 +++++-------------- 5 files changed, 33 insertions(+), 32 deletions(-) diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py index 7f17fa0da16..873b52bf7a6 100644 --- a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py +++ b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py @@ -18,6 +18,8 @@ def test_qwen_image_edit_plus_rejects_too_many_input_images(tmp_path: Path): vae_dir = tmp_path / "vae" vae_dir.mkdir() + # Keep the mock config intentionally minimal: this test only needs the + # fields touched during pre-process initialization. (vae_dir / "config.json").write_text(json.dumps({"z_dim": 16})) pre_process = get_qwen_image_edit_plus_pre_process_func(SimpleNamespace(model=str(tmp_path))) diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 3c6516f082a..e49571a3f38 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -780,7 +780,7 @@ def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511(async_omni_ engine = async_omni_test_client.app.state.engine_client engine.get_diffusion_od_config = lambda: SimpleNamespace( supports_multimodal_inputs=True, - model="Qwen/Qwen-Image-Edit-2511", + max_multimodal_image_inputs=4, ) response = async_omni_test_client.post( @@ -808,7 +808,7 @@ def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511_before_load engine = async_omni_test_client.app.state.engine_client engine.get_diffusion_od_config = lambda: SimpleNamespace( supports_multimodal_inputs=True, - model="Qwen/Qwen-Image-Edit-2511", + max_multimodal_image_inputs=4, ) def _fail_load(*args, **kwargs): diff --git a/tests/test_diffusion_config_propagation.py b/tests/test_diffusion_config_propagation.py index 58eb6097cad..85a6db35cc9 100644 --- a/tests/test_diffusion_config_propagation.py +++ b/tests/test_diffusion_config_propagation.py @@ -106,3 +106,12 @@ def test_extra_kwargs_forwarded(self): ea = stages[0]["engine_args"] assert ea["enforce_eager"] is True assert ea["lora_path"] == "/tmp/lora" + + +def test_qwen_image_edit_plus_sets_generic_multimodal_limit(): + od_config = OmniDiffusionConfig(model="Qwen/Qwen-Image-Edit-2511", model_class_name="QwenImageEditPlusPipeline") + + od_config.update_multimodal_support() + + assert od_config.supports_multimodal_inputs is True + assert od_config.max_multimodal_image_inputs == 4 diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 390231bf99a..eac1ede8795 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -468,8 +468,10 @@ class OmniDiffusionConfig: # Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p) flow_shift: float | None = None - # support multi images input + # Support multi-image inputs and expose any model-specific request limit + # through a generic config field so serving code stays model-agnostic. supports_multimodal_inputs: bool = False + max_multimodal_image_inputs: int | None = None log_level: str = "info" @@ -616,6 +618,14 @@ def __post_init__(self): def update_multimodal_support(self) -> None: self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"} + self.max_multimodal_image_inputs = None + + if self.model_class_name == "QwenImageEditPlusPipeline": + from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import ( + MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES, + ) + + self.max_multimodal_image_inputs = MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES def enrich_config(self) -> None: """Load model metadata from HuggingFace and populate config fields. diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 4aeb29e069e..cd896f2609e 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1435,7 +1435,7 @@ async def edit_images( # Reject oversized multi-image edit requests before fetching or decoding # any inputs. This keeps over-limit URL requests from burning network, # CPU, and memory on work that will be rejected anyway. - max_input_images = _get_max_edit_input_images(raw_request, engine_client, model_name) + max_input_images = _get_max_edit_input_images(raw_request, engine_client) if max_input_images is not None and len(input_images_list) > max_input_images: detail = ( "Received multiple input images. Only a single image is supported by this model." @@ -1613,16 +1613,6 @@ def _get_engine_and_model(raw_request: Request): return engine_client, model_name, normalized_stage_configs -def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) -> bool: - od_config = _get_diffusion_od_config(raw_request, engine_client) - - if od_config is None: - # Preserve the existing compatibility behavior when the diffusion - # config is not exposed on the serving surface. - return True - return bool(getattr(od_config, "supports_multimodal_inputs", False)) - - def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any: diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None) @@ -1631,27 +1621,17 @@ def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any: ) -def _get_max_edit_input_images(raw_request: Request, engine_client: Any, model_name: str) -> int | None: - if not _supports_multimodal_image_inputs(raw_request, engine_client): - return 1 - - # Keep the API-side limit model-specific: this helper should not hardcode a - # generic "multi-image means 4" rule because future edit pipelines may have - # different conditioning budgets. Query the serving model / OD config first, - # then defer to the owning pipeline constant. +def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int | None: od_config = _get_diffusion_od_config(raw_request, engine_client) - model_identifiers = [model_name] - if od_config is not None: - model_identifiers.append(getattr(od_config, "model", None)) - - if any(isinstance(identifier, str) and "Qwen-Image-Edit-2511" in identifier for identifier in model_identifiers): - from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import ( - MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES, - ) + if od_config is None: + # Preserve the existing compatibility behavior when the diffusion + # config is not exposed on the serving surface. + return None - return MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES + if not bool(getattr(od_config, "supports_multimodal_inputs", False)): + return 1 - return None + return getattr(od_config, "max_multimodal_image_inputs", None) def _get_lora_from_json_str(lora_body): From c2a0230c1497bc72a19a363966792815a424656c Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 03:14:10 +0000 Subject: [PATCH 21/34] [Refactor] Move diffusion image limits into shared metadata Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 05a7a5de43fae3391d40ffeb0998803975347b23) --- tests/test_diffusion_config_propagation.py | 3 ++- vllm_omni/diffusion/data.py | 13 +++------ vllm_omni/diffusion/model_metadata.py | 27 +++++++++++++++++++ .../pipeline_qwen_image_edit_plus.py | 3 ++- 4 files changed, 35 insertions(+), 11 deletions(-) create mode 100644 vllm_omni/diffusion/model_metadata.py diff --git a/tests/test_diffusion_config_propagation.py b/tests/test_diffusion_config_propagation.py index 85a6db35cc9..18c6135c02d 100644 --- a/tests/test_diffusion_config_propagation.py +++ b/tests/test_diffusion_config_propagation.py @@ -14,6 +14,7 @@ DiffusionParallelConfig, OmniDiffusionConfig, ) +from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES def _roundtrip_diffusion_config(**kwargs) -> OmniDiffusionConfig: @@ -114,4 +115,4 @@ def test_qwen_image_edit_plus_sets_generic_multimodal_limit(): od_config.update_multimodal_support() assert od_config.supports_multimodal_inputs is True - assert od_config.max_multimodal_image_inputs == 4 + assert od_config.max_multimodal_image_inputs == QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index eac1ede8795..a5281501cdb 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -17,6 +17,7 @@ QuantizationConfig, ) +from vllm_omni.diffusion.model_metadata import get_diffusion_model_metadata from vllm_omni.diffusion.utils.network_utils import is_port_available from vllm_omni.quantization import build_quant_config @@ -617,15 +618,9 @@ def __post_init__(self): raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") def update_multimodal_support(self) -> None: - self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"} - self.max_multimodal_image_inputs = None - - if self.model_class_name == "QwenImageEditPlusPipeline": - from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import ( - MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES, - ) - - self.max_multimodal_image_inputs = MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES + metadata = get_diffusion_model_metadata(self.model_class_name) + self.supports_multimodal_inputs = metadata.supports_multimodal_inputs + self.max_multimodal_image_inputs = metadata.max_multimodal_image_inputs def enrich_config(self) -> None: """Load model metadata from HuggingFace and populate config fields. diff --git a/vllm_omni/diffusion/model_metadata.py b/vllm_omni/diffusion/model_metadata.py new file mode 100644 index 00000000000..4fbfe7c6286 --- /dev/null +++ b/vllm_omni/diffusion/model_metadata.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class DiffusionModelMetadata: + supports_multimodal_inputs: bool = False + max_multimodal_image_inputs: int | None = None + + +QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES = 4 + + +_DIFFUSION_MODEL_METADATA: dict[str, DiffusionModelMetadata] = { + "QwenImageEditPlusPipeline": DiffusionModelMetadata( + supports_multimodal_inputs=True, + max_multimodal_image_inputs=QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES, + ), +} + + +def get_diffusion_model_metadata(model_class_name: str | None) -> DiffusionModelMetadata: + if model_class_name is None: + return DiffusionModelMetadata() + return _DIFFUSION_MODEL_METADATA.get(model_class_name, DiffusionModelMetadata()) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 2927b7cabf1..327ee823cd2 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -23,6 +23,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -59,7 +60,7 @@ # Keep this in sync with the practical conditioning-token budget for # Qwen-Image-Edit-2511. Empirically, 4 images stays within the supported range # while 5 images overflows the prompt/conditioning path and fails downstream. -MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = 4 +MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES def get_qwen_image_edit_plus_pre_process_func( From 125dc9de1825ec16d5f5c598c40ec17ecf7bcc30 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 03:22:24 +0000 Subject: [PATCH 22/34] [Fix] Apply pre-commit import ordering Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit f3e7ce99199b00e3cc7acd16f0de2d2cdfb4190d) --- .../models/qwen_image/pipeline_qwen_image_edit_plus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 327ee823cd2..3a444556ea9 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -23,9 +23,9 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, From 13ca70fabc6af383f1861e02f2b4550c315e5839 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 03:42:41 +0000 Subject: [PATCH 23/34] [Docs] Clarify shared diffusion image limit metadata Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 3015646878f1e188c7554535f26ed2c189b5f998) --- vllm_omni/diffusion/data.py | 2 ++ vllm_omni/diffusion/model_metadata.py | 4 ++++ .../models/qwen_image/pipeline_qwen_image_edit_plus.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index a5281501cdb..5a1121ca94c 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -618,6 +618,8 @@ def __post_init__(self): raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") def update_multimodal_support(self) -> None: + # Resolve serving-visible multimodal behavior from shared metadata + # instead of importing concrete pipeline modules into the config layer. metadata = get_diffusion_model_metadata(self.model_class_name) self.supports_multimodal_inputs = metadata.supports_multimodal_inputs self.max_multimodal_image_inputs = metadata.max_multimodal_image_inputs diff --git a/vllm_omni/diffusion/model_metadata.py b/vllm_omni/diffusion/model_metadata.py index 4fbfe7c6286..ec133e7380e 100644 --- a/vllm_omni/diffusion/model_metadata.py +++ b/vllm_omni/diffusion/model_metadata.py @@ -6,6 +6,8 @@ @dataclass(frozen=True) class DiffusionModelMetadata: + # Keep serving-facing capability metadata in a lightweight shared module so + # config/model plumbing can read it without importing concrete pipelines. supports_multimodal_inputs: bool = False max_multimodal_image_inputs: int | None = None @@ -22,6 +24,8 @@ class DiffusionModelMetadata: def get_diffusion_model_metadata(model_class_name: str | None) -> DiffusionModelMetadata: + # Unknown models fall back to "no special multimodal capabilities" so new + # pipelines do not accidentally inherit limits meant for other models. if model_class_name is None: return DiffusionModelMetadata() return _DIFFUSION_MODEL_METADATA.get(model_class_name, DiffusionModelMetadata()) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 3a444556ea9..beca8c8c7df 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -60,6 +60,8 @@ # Keep this in sync with the practical conditioning-token budget for # Qwen-Image-Edit-2511. Empirically, 4 images stays within the supported range # while 5 images overflows the prompt/conditioning path and fails downstream. +# Re-export the shared metadata value locally so this pipeline keeps a nearby, +# descriptive constant for validation and tests without becoming the source of truth. MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES From bbaf682070ca962dba98bc80fb9f3015f81e0bc9 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 08:06:52 +0000 Subject: [PATCH 24/34] Fix RIFE device selection for CPU transport Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit ecbb6d45fc8232bf00a5d976aabcafa89b585f06) --- tests/entrypoints/openai_api/test_video_api_utils.py | 5 +++-- vllm_omni/diffusion/postprocess/rife_interpolator.py | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai_api/test_video_api_utils.py b/tests/entrypoints/openai_api/test_video_api_utils.py index 3868b657370..ce51759e7b4 100644 --- a/tests/entrypoints/openai_api/test_video_api_utils.py +++ b/tests/entrypoints/openai_api/test_video_api_utils.py @@ -123,7 +123,7 @@ def test_frame_interpolator_runs_actual_torch_tensor_path(monkeypatch): assert torch.isfinite(output_video).all() -def test_frame_interpolator_prefers_input_tensor_device(monkeypatch): +def test_frame_interpolator_uses_platform_device_when_tensor_is_cpu(monkeypatch): chosen_devices = [] model = rife_interpolator.Model().eval() @@ -134,10 +134,11 @@ def _fake_ensure_model_loaded(*, preferred_device=None): interpolator = rife_interpolator.FrameInterpolator() monkeypatch.setattr(interpolator, "_ensure_model_loaded", _fake_ensure_model_loaded) monkeypatch.setattr(model.flownet, "to", lambda device: model.flownet) + monkeypatch.setattr(rife_interpolator, "_select_torch_device", lambda: torch.device("cuda")) video = torch.zeros(1, 3, 2, 32, 32) output_video, multiplier = interpolator.interpolate_tensor(video, exp=1, scale=1.0) - assert chosen_devices == [video.device] + assert chosen_devices == [torch.device("cuda")] assert multiplier == 2 assert output_video.shape == (1, 3, 3, 32, 32) diff --git a/vllm_omni/diffusion/postprocess/rife_interpolator.py b/vllm_omni/diffusion/postprocess/rife_interpolator.py index b2b4a931914..89297d0a446 100644 --- a/vllm_omni/diffusion/postprocess/rife_interpolator.py +++ b/vllm_omni/diffusion/postprocess/rife_interpolator.py @@ -412,9 +412,12 @@ def interpolate_tensor( return restore_layout(video), 1 video, restore_range = _normalize_video_tensor_range(video) - # Prefer the decoded video's current device so CPU-offloaded requests do - # not move the tensor back to GPU just for interpolation. - model = self._ensure_model_loaded(preferred_device=video.device) + # A CPU tensor may be transport/offload state rather than an execution + # choice, so only trust it when it is already on an accelerator. + preferred_device = video.device + if preferred_device.type == "cpu": + preferred_device = _select_torch_device() + model = self._ensure_model_loaded(preferred_device=preferred_device) video = video.to(model.device()) intermediates_per_pair = 2**exp // 2 From beeb333a1da8ca1be9d113265912cb2e707fc104 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 08:59:50 +0000 Subject: [PATCH 25/34] [Chore] run pre-commit formatting Signed-off-by: david6666666 <530634352@qq.com> --- tests/engine/test_async_omni_engine_stage_init.py | 1 + tests/entrypoints/openai_api/test_video_api_utils.py | 4 +++- tests/entrypoints/openai_api/test_video_server.py | 5 +---- vllm_omni/diffusion/diffusion_engine.py | 2 +- vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py | 8 +++----- vllm_omni/engine/stage_init_utils.py | 3 +-- 6 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/engine/test_async_omni_engine_stage_init.py b/tests/engine/test_async_omni_engine_stage_init.py index f95aa53af45..e5ba2f31d58 100644 --- a/tests/engine/test_async_omni_engine_stage_init.py +++ b/tests/engine/test_async_omni_engine_stage_init.py @@ -70,6 +70,7 @@ def _fake_setup_stage_devices(_stage_id, _runtime_cfg): else: os.environ[env_var] = old_env + def test_initialize_stages_uses_inline_diffusion_client_for_single_stage(monkeypatch): """Single-stage diffusion init should request the inline client path.""" import vllm_omni.engine.async_omni_engine as engine_mod diff --git a/tests/entrypoints/openai_api/test_video_api_utils.py b/tests/entrypoints/openai_api/test_video_api_utils.py index ce51759e7b4..7a0494226b9 100644 --- a/tests/entrypoints/openai_api/test_video_api_utils.py +++ b/tests/entrypoints/openai_api/test_video_api_utils.py @@ -87,7 +87,9 @@ def test_encode_video_bytes_without_audio_uses_diffusers_export(monkeypatch): _install_fake_export_to_video(monkeypatch, export_calls) monkeypatch.setattr( "vllm_omni.diffusion.utils.media_utils.mux_video_audio_bytes", - lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no-audio path should not use mux_video_audio_bytes")), + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("no-audio path should not use mux_video_audio_bytes") + ), ) video = np.linspace(0.0, 1.0, num=4 * 2 * 2 * 3, dtype=np.float32).reshape(4, 2, 2, 3) diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py index 9ebcb689726..258ea76f40a 100644 --- a/tests/entrypoints/openai_api/test_video_server.py +++ b/tests/entrypoints/openai_api/test_video_server.py @@ -14,8 +14,7 @@ from types import SimpleNamespace import pytest -from fastapi import FastAPI -from fastapi import HTTPException +from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient from PIL import Image from pytest_mock import MockerFixture @@ -981,5 +980,3 @@ async def test_run_generation_maps_omni_request_error_to_http_exception(): "error_type": "OutOfMemoryError", "detail": {"retryable": False}, } - - diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 17dfe6d3a3a..e4eee39de2a 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -308,7 +308,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: def make_engine( config: OmniDiffusionConfig, scheduler: SchedulerInterface | None = None, - ) -> "DiffusionEngine": + ) -> DiffusionEngine: """Factory method to create a DiffusionEngine instance. Args: diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index 833ce31124d..ea791501173 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -24,12 +24,12 @@ from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention -from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm from vllm_omni.diffusion.distributed.sp_plan import ( SequenceParallelInput, SequenceParallelOutput, ) from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm logger = init_logger(__name__) @@ -621,7 +621,7 @@ def __init__( # 1. Self-attention self.norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - + self.attn1 = WanSelfAttention( dim=dim, num_heads=num_heads, @@ -682,9 +682,7 @@ def forward( hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states, c_scale_msa, c_shift_msa).type_as( - hidden_states - ) + norm_hidden_states = self.norm3(hidden_states, c_scale_msa, c_shift_msa).type_as(hidden_states) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states) diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index ffe90cb3b71..e127a305014 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -450,9 +450,8 @@ def initialize_diffusion_stage( and ultimately to ``AsyncOmniDiffusion``. use_inline: If True, uses the inline diffusion client instead of subprocess. """ - from vllm_omni.diffusion.stage_diffusion_client import create_diffusion_client - from vllm_omni.diffusion.data import OmniDiffusionConfig + from vllm_omni.diffusion.stage_diffusion_client import create_diffusion_client od_config = OmniDiffusionConfig.from_kwargs( model=model, From d7233cbd7a7212f6e0b9fd61320ed69a5437e089 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 08:52:31 +0000 Subject: [PATCH 26/34] [Fix] align Wan2.2 max_sequence_length with model config Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit 072bfa26ac9d4f4b60c1156a2cd05b9a3dfaca35) --- .../models/wan2_2/test_wan22_max_sequence_length.py | 8 ++++---- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 12 ++++++++---- .../diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 7 +++++-- .../diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py | 6 ++++-- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py index 1b594b4b01e..b8e248f1efd 100644 --- a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py +++ b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py @@ -18,8 +18,6 @@ pytestmark = [pytest.mark.core_model, pytest.mark.cpu] -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - class _RejectingTextEncoder: dtype = torch.float32 @@ -64,7 +62,10 @@ def _make_pipeline(pipeline_class: type, *, total_sequence_length: int): def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(pipeline_class: type): pipeline = _make_pipeline(pipeline_class, total_sequence_length=WAN22_MAX_SEQUENCE_LENGTH + 1) - with pytest.raises(ValueError, match=r"got 513 tokens, but `max_sequence_length` is 512"): + with pytest.raises( + ValueError, + match=rf"got {WAN22_MAX_SEQUENCE_LENGTH + 1} tokens, but `max_sequence_length` is {WAN22_MAX_SEQUENCE_LENGTH}", + ): pipeline.encode_prompt(prompt="prompt") @@ -75,7 +76,6 @@ def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(p with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) - def _sampling_params(**overrides): defaults = dict( height=None, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 753081067b0..e98a59a8d18 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) DEBUG_PERF = False -WAN22_MAX_SEQUENCE_LENGTH = 512 +WAN22_MAX_SEQUENCE_LENGTH = 2048 def retrieve_latents( @@ -77,7 +77,6 @@ def load_transformer_config(model_path: str, subfolder: str = "transformer", loc pass return {} - def create_transformer_from_config(config: dict) -> WanTransformer3DModel: """Create WanTransformer3DModel from config dict.""" kwargs = {} @@ -252,7 +251,6 @@ def __init__( pass self.boundary_ratio = od_config.boundary_ratio - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH # Determine which transformers to load based on boundary_ratio # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) @@ -295,18 +293,22 @@ def __init__( ).to(self.device) # Initialize transformers with correct config (weights loaded via load_weights) + transformer_config: dict = {} if load_transformer: transformer_config = load_transformer_config(model, "transformer", local_files_only) self.transformer = create_transformer_from_config(transformer_config) else: self.transformer = None + transformer_2_config: dict = {} if load_transformer_2: transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) self.transformer_2 = create_transformer_from_config(transformer_2_config) else: self.transformer_2 = None + self.tokenizer_max_length = resolve_wan22_tokenizer_max_length(transformer_config, transformer_2_config) + # Store the active transformer config if load_transformer: self.transformer_config = self.transformer.config @@ -749,10 +751,12 @@ def encode_prompt( negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, + max_sequence_length: int | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, ): + if max_sequence_length is None: + max_sequence_length = self.tokenizer_max_length device = device or self.device dtype = dtype or self.text_encoder.dtype diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 9438bca96f0..d7b0dbee8b5 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -215,7 +215,6 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -247,7 +246,9 @@ def __init__( transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) self.transformer_2 = create_transformer_from_config(transformer_2_config) else: + transformer_2_config = None self.transformer_2 = None + self.tokenizer_max_length = resolve_wan22_tokenizer_max_length(transformer_config, transformer_2_config) # Initialize UniPC scheduler flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p @@ -661,11 +662,13 @@ def encode_prompt( negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, + max_sequence_length: int = WAN22_MAX_SEQUENCE_LENGTH, device: torch.device | None = None, dtype: torch.dtype | None = None, ): """Encode text prompts using T5 text encoder.""" + if max_sequence_length is None: + max_sequence_length = self.tokenizer_max_length device = device or self.device dtype = dtype or self.text_encoder.dtype diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index b829678e2ad..6fec41070b2 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -187,7 +187,6 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -201,6 +200,7 @@ def __init__( # Load config from model to get correct dimensions transformer_config = load_transformer_config(model, "transformer", local_files_only) self.transformer = create_transformer_from_config(transformer_config) + self.tokenizer_max_length = resolve_wan22_tokenizer_max_length(transformer_config) # Initialize UniPC scheduler flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p @@ -508,11 +508,13 @@ def encode_prompt( negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, + max_sequence_length: int = WAN22_MAX_SEQUENCE_LENGTH, device: torch.device | None = None, dtype: torch.dtype | None = None, ): """Encode text prompts using T5 text encoder.""" + if max_sequence_length is None: + max_sequence_length = self.tokenizer_max_length device = device or self.device dtype = dtype or self.text_encoder.dtype From 25ac7cd84e9c89220a6c0347e12d62cf7967414a Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 09:01:02 +0000 Subject: [PATCH 27/34] [Fix] raise Wan2.2 max_sequence_length to 2048 Signed-off-by: david6666666 <530634352@qq.com> (cherry picked from commit ea6ce237f91062bfe7d59591c7326766030d4d59) --- .../models/wan2_2/test_wan22_max_sequence_length.py | 1 - vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 9 ++------- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 5 +---- .../diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py | 4 +--- 4 files changed, 4 insertions(+), 15 deletions(-) diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py index b8e248f1efd..b4d8f9e9ec4 100644 --- a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py +++ b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py @@ -75,7 +75,6 @@ def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(p with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) - def _sampling_params(**overrides): defaults = dict( height=None, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index e98a59a8d18..56b3d12595b 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -251,6 +251,7 @@ def __init__( pass self.boundary_ratio = od_config.boundary_ratio + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH # Determine which transformers to load based on boundary_ratio # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) @@ -293,22 +294,18 @@ def __init__( ).to(self.device) # Initialize transformers with correct config (weights loaded via load_weights) - transformer_config: dict = {} if load_transformer: transformer_config = load_transformer_config(model, "transformer", local_files_only) self.transformer = create_transformer_from_config(transformer_config) else: self.transformer = None - transformer_2_config: dict = {} if load_transformer_2: transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) self.transformer_2 = create_transformer_from_config(transformer_2_config) else: self.transformer_2 = None - self.tokenizer_max_length = resolve_wan22_tokenizer_max_length(transformer_config, transformer_2_config) - # Store the active transformer config if load_transformer: self.transformer_config = self.transformer.config @@ -751,12 +748,10 @@ def encode_prompt( negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - max_sequence_length: int | None = None, + max_sequence_length: int = WAN22_MAX_SEQUENCE_LENGTH, device: torch.device | None = None, dtype: torch.dtype | None = None, ): - if max_sequence_length is None: - max_sequence_length = self.tokenizer_max_length device = device or self.device dtype = dtype or self.text_encoder.dtype diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index d7b0dbee8b5..71578138862 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -215,6 +215,7 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -246,9 +247,7 @@ def __init__( transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) self.transformer_2 = create_transformer_from_config(transformer_2_config) else: - transformer_2_config = None self.transformer_2 = None - self.tokenizer_max_length = resolve_wan22_tokenizer_max_length(transformer_config, transformer_2_config) # Initialize UniPC scheduler flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p @@ -667,8 +666,6 @@ def encode_prompt( dtype: torch.dtype | None = None, ): """Encode text prompts using T5 text encoder.""" - if max_sequence_length is None: - max_sequence_length = self.tokenizer_max_length device = device or self.device dtype = dtype or self.text_encoder.dtype diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 6fec41070b2..1ad2bd23276 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -187,6 +187,7 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -200,7 +201,6 @@ def __init__( # Load config from model to get correct dimensions transformer_config = load_transformer_config(model, "transformer", local_files_only) self.transformer = create_transformer_from_config(transformer_config) - self.tokenizer_max_length = resolve_wan22_tokenizer_max_length(transformer_config) # Initialize UniPC scheduler flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p @@ -513,8 +513,6 @@ def encode_prompt( dtype: torch.dtype | None = None, ): """Encode text prompts using T5 text encoder.""" - if max_sequence_length is None: - max_sequence_length = self.tokenizer_max_length device = device or self.device dtype = dtype or self.text_encoder.dtype From 67e52e8676dc769aea391138aa0b98e3ac2d1024 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 09:18:15 +0000 Subject: [PATCH 28/34] [Chore] run pre-commit after PR2877 backport Signed-off-by: david6666666 <530634352@qq.com> --- tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py | 2 ++ vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py index b4d8f9e9ec4..7ceed6a5cbc 100644 --- a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py +++ b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py @@ -75,6 +75,8 @@ def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(p with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) + + def _sampling_params(**overrides): defaults = dict( height=None, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 56b3d12595b..3b8fd697c45 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -77,6 +77,7 @@ def load_transformer_config(model_path: str, subfolder: str = "transformer", loc pass return {} + def create_transformer_from_config(config: dict) -> WanTransformer3DModel: """Create WanTransformer3DModel from config dict.""" kwargs = {} From 5041f7e078df1c9e4262b607a1dc09c9c72baa16 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 13:52:53 +0000 Subject: [PATCH 29/34] Fix async video job profiler metadata persistence --- vllm_omni/entrypoints/openai/protocol/videos.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py index d23bc1d2e4c..c7184989e4c 100644 --- a/vllm_omni/entrypoints/openai/protocol/videos.py +++ b/vllm_omni/entrypoints/openai/protocol/videos.py @@ -273,6 +273,14 @@ class VideoResponse(BaseModel): description="Filename of the saved output video files for this job.", ) inference_time_s: float | None = Field(default=None, description="End-to-end inference time in seconds.") + stage_durations: dict[str, float] = Field( + default_factory=dict, + description="Per-stage profiler timings captured during generation, when available.", + ) + peak_memory_mb: float | None = Field( + default=None, + description="Peak device memory used during generation in MiB, when reported by the pipeline.", + ) @property def file_extension(self) -> str: From 6e87f2071a5a9e22d066853c2c1fa8326095432c Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 14:00:20 +0000 Subject: [PATCH 30/34] Revert "Fix async video job profiler metadata persistence" This reverts commit 5041f7e078df1c9e4262b607a1dc09c9c72baa16. --- vllm_omni/entrypoints/openai/protocol/videos.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py index c7184989e4c..d23bc1d2e4c 100644 --- a/vllm_omni/entrypoints/openai/protocol/videos.py +++ b/vllm_omni/entrypoints/openai/protocol/videos.py @@ -273,14 +273,6 @@ class VideoResponse(BaseModel): description="Filename of the saved output video files for this job.", ) inference_time_s: float | None = Field(default=None, description="End-to-end inference time in seconds.") - stage_durations: dict[str, float] = Field( - default_factory=dict, - description="Per-stage profiler timings captured during generation, when available.", - ) - peak_memory_mb: float | None = Field( - default=None, - description="Peak device memory used during generation in MiB, when reported by the pipeline.", - ) @property def file_extension(self) -> str: From f0c4c1f8b101f7b82446374e392589ca9dff3f64 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 14:11:59 +0000 Subject: [PATCH 31/34] Fix release backport unit regressions Signed-off-by: david6666666 <530634352@qq.com> --- .../diffusion/test_multiproc_engine_concurrency.py | 11 ++++------- vllm_omni/diffusion/diffusion_engine.py | 6 ------ vllm_omni/entrypoints/openai/api_server.py | 14 ++++++++++++-- vllm_omni/entrypoints/openai/protocol/videos.py | 8 ++++++++ 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/tests/diffusion/test_multiproc_engine_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py index adb8dc338c6..61cf42777f7 100644 --- a/tests/diffusion/test_multiproc_engine_concurrency.py +++ b/tests/diffusion/test_multiproc_engine_concurrency.py @@ -8,7 +8,7 @@ import pytest import torch -from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.data import DiffusionOutput, OmniRequestError from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor from vllm_omni.diffusion.sched import RequestScheduler @@ -304,16 +304,13 @@ def test_serial_add_req_then_collective_rpc(self): assert rpc_out.error == "result_for_rpc" def test_serial_add_req_error_propagation(self): - """``add_req`` should raise when the worker reports an error.""" + """``add_req`` should surface worker failures as ``OmniRequestError``.""" engine, _, _, res_q = _make_engine() # Put an error response directly res_q.put({"status": "error", "error": "boom"}) - out = engine.add_req_and_wait_for_response(_mock_request("fail")) - - assert isinstance(out, DiffusionOutput) - assert out.error is not None - assert "boom" in out.error + with pytest.raises(OmniRequestError, match="boom"): + engine.add_req_and_wait_for_response(_mock_request("fail")) def test_serial_collective_rpc_error_propagation(self): """``collective_rpc`` should raise when the worker reports an error.""" diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index e4eee39de2a..6f992ec8594 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -356,12 +356,6 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus finished_req_ids = self.scheduler.update_from_output(sched_output, output) - if output.error: - raise OmniRequestError( - output.error, - status_code=500, - error_type="DiffusionExecutionError", - ) if target_sched_req_id in finished_req_ids: # self.scheduler.pop_request_state(target_sched_req_id) return output diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index cd896f2609e..928f96b5cba 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1628,10 +1628,20 @@ def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int # config is not exposed on the serving surface. return None - if not bool(getattr(od_config, "supports_multimodal_inputs", False)): + supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", None) + if not isinstance(supports_multimodal_inputs, bool): + # Serving-side mocks and older engine surfaces may not expose + # structured diffusion capability metadata yet. Keep the legacy + # "no hard limit" behavior instead of comparing against mock objects. + return None + + if not supports_multimodal_inputs: return 1 - return getattr(od_config, "max_multimodal_image_inputs", None) + max_input_images = getattr(od_config, "max_multimodal_image_inputs", None) + if isinstance(max_input_images, bool): + return None + return max_input_images if isinstance(max_input_images, int) else None def _get_lora_from_json_str(lora_body): diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py index d23bc1d2e4c..c7184989e4c 100644 --- a/vllm_omni/entrypoints/openai/protocol/videos.py +++ b/vllm_omni/entrypoints/openai/protocol/videos.py @@ -273,6 +273,14 @@ class VideoResponse(BaseModel): description="Filename of the saved output video files for this job.", ) inference_time_s: float | None = Field(default=None, description="End-to-end inference time in seconds.") + stage_durations: dict[str, float] = Field( + default_factory=dict, + description="Per-stage profiler timings captured during generation, when available.", + ) + peak_memory_mb: float | None = Field( + default=None, + description="Peak device memory used during generation in MiB, when reported by the pipeline.", + ) @property def file_extension(self) -> str: From 210bb37fcdca24a408c9f775467c028b780071f0 Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Fri, 17 Apr 2026 14:29:10 +0000 Subject: [PATCH 32/34] Revert "Fix release backport unit regressions" This reverts commit f0c4c1f8b101f7b82446374e392589ca9dff3f64. --- .../diffusion/test_multiproc_engine_concurrency.py | 11 +++++++---- vllm_omni/diffusion/diffusion_engine.py | 6 ++++++ vllm_omni/entrypoints/openai/api_server.py | 14 ++------------ vllm_omni/entrypoints/openai/protocol/videos.py | 8 -------- 4 files changed, 15 insertions(+), 24 deletions(-) diff --git a/tests/diffusion/test_multiproc_engine_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py index 61cf42777f7..adb8dc338c6 100644 --- a/tests/diffusion/test_multiproc_engine_concurrency.py +++ b/tests/diffusion/test_multiproc_engine_concurrency.py @@ -8,7 +8,7 @@ import pytest import torch -from vllm_omni.diffusion.data import DiffusionOutput, OmniRequestError +from vllm_omni.diffusion.data import DiffusionOutput from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor from vllm_omni.diffusion.sched import RequestScheduler @@ -304,13 +304,16 @@ def test_serial_add_req_then_collective_rpc(self): assert rpc_out.error == "result_for_rpc" def test_serial_add_req_error_propagation(self): - """``add_req`` should surface worker failures as ``OmniRequestError``.""" + """``add_req`` should raise when the worker reports an error.""" engine, _, _, res_q = _make_engine() # Put an error response directly res_q.put({"status": "error", "error": "boom"}) - with pytest.raises(OmniRequestError, match="boom"): - engine.add_req_and_wait_for_response(_mock_request("fail")) + out = engine.add_req_and_wait_for_response(_mock_request("fail")) + + assert isinstance(out, DiffusionOutput) + assert out.error is not None + assert "boom" in out.error def test_serial_collective_rpc_error_propagation(self): """``collective_rpc`` should raise when the worker reports an error.""" diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 6f992ec8594..e4eee39de2a 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -356,6 +356,12 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus finished_req_ids = self.scheduler.update_from_output(sched_output, output) + if output.error: + raise OmniRequestError( + output.error, + status_code=500, + error_type="DiffusionExecutionError", + ) if target_sched_req_id in finished_req_ids: # self.scheduler.pop_request_state(target_sched_req_id) return output diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 928f96b5cba..cd896f2609e 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1628,20 +1628,10 @@ def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int # config is not exposed on the serving surface. return None - supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", None) - if not isinstance(supports_multimodal_inputs, bool): - # Serving-side mocks and older engine surfaces may not expose - # structured diffusion capability metadata yet. Keep the legacy - # "no hard limit" behavior instead of comparing against mock objects. - return None - - if not supports_multimodal_inputs: + if not bool(getattr(od_config, "supports_multimodal_inputs", False)): return 1 - max_input_images = getattr(od_config, "max_multimodal_image_inputs", None) - if isinstance(max_input_images, bool): - return None - return max_input_images if isinstance(max_input_images, int) else None + return getattr(od_config, "max_multimodal_image_inputs", None) def _get_lora_from_json_str(lora_body): diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py index c7184989e4c..d23bc1d2e4c 100644 --- a/vllm_omni/entrypoints/openai/protocol/videos.py +++ b/vllm_omni/entrypoints/openai/protocol/videos.py @@ -273,14 +273,6 @@ class VideoResponse(BaseModel): description="Filename of the saved output video files for this job.", ) inference_time_s: float | None = Field(default=None, description="End-to-end inference time in seconds.") - stage_durations: dict[str, float] = Field( - default_factory=dict, - description="Per-stage profiler timings captured during generation, when available.", - ) - peak_memory_mb: float | None = Field( - default=None, - description="Peak device memory used during generation in MiB, when reported by the pipeline.", - ) @property def file_extension(self) -> str: From c3540ee7d98b12ded433080cdb1e058e0aec8e1d Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Sat, 18 Apr 2026 01:24:25 +0000 Subject: [PATCH 33/34] Fix 400 status for prompt length validation Signed-off-by: david6666666 <530634352@qq.com> --- vllm_omni/diffusion/diffusion_engine.py | 3 ++- vllm_omni/entrypoints/openai/api_server.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index e4eee39de2a..f78ec1f059c 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -357,9 +357,10 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus finished_req_ids = self.scheduler.update_from_output(sched_output, output) if output.error: + status_code = 400 if "`max_sequence_length`" in output.error else 500 raise OmniRequestError( output.error, - status_code=500, + status_code=status_code, error_type="DiffusionExecutionError", ) if target_sched_req_id in finished_req_ids: diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index cd896f2609e..df1960dfb83 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -81,6 +81,7 @@ from vllm.utils import random_uuid from vllm.utils.system_utils import decorate_logs +from vllm_omni.diffusion.data import OmniRequestError from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.errors import InvalidInputReferenceError from vllm_omni.entrypoints.openai.image_api_utils import ( @@ -1701,11 +1702,14 @@ async def _generate_with_async_omni( pass sampling_params_list.append(default_stage_params) - async for output in engine_client.generate( - sampling_params_list=sampling_params_list, - **kwargs, - ): - result = output + try: + async for output in engine_client.generate( + sampling_params_list=sampling_params_list, + **kwargs, + ): + result = output + except OmniRequestError as e: + raise HTTPException(status_code=e.status_code, detail=str(e)) from e if result is None: raise HTTPException( From 5be6ff566e816de60daee0461b38ce75a1926cea Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Sun, 19 Apr 2026 11:08:59 +0000 Subject: [PATCH 34/34] [Fix] avoid padding short Wan2.2 prompts to max_sequence_length Signed-off-by: david6666666 <530634352@qq.com> --- .../wan2_2/test_wan22_max_sequence_length.py | 52 ++++++++++++++++-- .../models/wan2_2/pipeline_wan2_2.py | 53 +++++++++++-------- .../models/wan2_2/pipeline_wan2_2_i2v.py | 53 +++++++++++-------- .../models/wan2_2/pipeline_wan2_2_ti2v.py | 53 +++++++++++-------- 4 files changed, 137 insertions(+), 74 deletions(-) diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py index 7ceed6a5cbc..17dec2ce06a 100644 --- a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py +++ b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py @@ -27,18 +27,38 @@ def __call__(self, *args, **kwargs): class _FakeTokenBatch: - def __init__(self, total_sequence_length: int): - attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long) + def __init__(self, total_sequence_length: int, padded_length: int | None = None): + padded_length = padded_length or total_sequence_length + attention_mask = torch.zeros((1, padded_length), dtype=torch.long) + attention_mask[:, :total_sequence_length] = 1 self.input_ids = attention_mask.clone() self.attention_mask = attention_mask class _FakeTokenizer: - def __init__(self, total_sequence_length: int): - self.total_sequence_length = total_sequence_length + def __init__(self, total_sequence_length: int | list[int]): + if isinstance(total_sequence_length, list): + self.total_sequence_lengths = list(total_sequence_length) + else: + self.total_sequence_lengths = [total_sequence_length] def __call__(self, *args, **kwargs): - return _FakeTokenBatch(self.total_sequence_length) + if len(self.total_sequence_lengths) > 1: + total_sequence_length = self.total_sequence_lengths.pop(0) + else: + total_sequence_length = self.total_sequence_lengths[0] + return _FakeTokenBatch(total_sequence_length, kwargs.get("max_length")) + + +class _RecordingTextEncoder: + dtype = torch.float32 + + def __init__(self): + self.input_lengths: list[int] = [] + + def __call__(self, input_ids, attention_mask): + self.input_lengths.append(input_ids.shape[1]) + return SimpleNamespace(last_hidden_state=torch.zeros((1, input_ids.shape[1], 4), dtype=torch.float32)) PIPELINE_CASES = [ @@ -77,6 +97,28 @@ def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(p pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) +@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) +def test_encode_prompt_uses_actual_prompt_length_for_text_encoder_padding(pipeline_class: type): + pipeline = object.__new__(pipeline_class) + nn.Module.__init__(pipeline) + pipeline.device = torch.device("cpu") + pipeline.text_encoder = _RecordingTextEncoder() + # prompt validation, prompt encode, negative validation, negative encode + pipeline.tokenizer = _FakeTokenizer([17, 17, 9, 9]) + pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH + + prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( + prompt="prompt", + negative_prompt="neg", + do_classifier_free_guidance=True, + ) + + assert pipeline.text_encoder.input_lengths == [17, 17] + assert prompt_embeds.shape[1] == 17 + assert negative_prompt_embeds is not None + assert negative_prompt_embeds.shape[1] == 17 + + def _sampling_params(**overrides): defaults = dict( height=None, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 3b8fd697c45..016c5177a22 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -773,11 +773,37 @@ def encode_prompt( supported_max_sequence_length=self.tokenizer_max_length, error_context="for Wan2.2 text encoding", ) + prompt_encode_length = max(int(text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), 1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", + ) + prompt_encode_length = max( + prompt_encode_length, + int(neg_text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), + ) text_inputs = self.tokenizer( prompt_clean, padding="max_length", - max_length=max_sequence_length, + max_length=prompt_encode_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -790,37 +816,18 @@ def encode_prompt( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + [torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - negative_prompt_embeds = None if do_classifier_free_guidance: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) neg_text_inputs = self.tokenizer( negative_prompt_clean, padding="max_length", - max_length=max_sequence_length, + max_length=prompt_encode_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -833,7 +840,7 @@ def encode_prompt( negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] negative_prompt_embeds = torch.stack( [ - torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) + torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in negative_prompt_embeds ], dim=0, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 71578138862..6166df78ec8 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -686,11 +686,37 @@ def encode_prompt( supported_max_sequence_length=self.tokenizer_max_length, error_context="for Wan2.2 text encoding", ) + prompt_encode_length = max(int(text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), 1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", + ) + prompt_encode_length = max( + prompt_encode_length, + int(neg_text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), + ) text_inputs = self.tokenizer( prompt_clean, padding="max_length", - max_length=max_sequence_length, + max_length=prompt_encode_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -703,37 +729,18 @@ def encode_prompt( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + [torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - negative_prompt_embeds = None if do_classifier_free_guidance: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) neg_text_inputs = self.tokenizer( negative_prompt_clean, padding="max_length", - max_length=max_sequence_length, + max_length=prompt_encode_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -746,7 +753,7 @@ def encode_prompt( negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] negative_prompt_embeds = torch.stack( [ - torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) + torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in negative_prompt_embeds ], dim=0, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 1ad2bd23276..0f4caf9fd16 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -533,11 +533,37 @@ def encode_prompt( supported_max_sequence_length=self.tokenizer_max_length, error_context="for Wan2.2 text encoding", ) + prompt_encode_length = max(int(text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), 1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", + ) + prompt_encode_length = max( + prompt_encode_length, + int(neg_text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), + ) text_inputs = self.tokenizer( prompt_clean, padding="max_length", - max_length=max_sequence_length, + max_length=prompt_encode_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -550,37 +576,18 @@ def encode_prompt( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + [torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - negative_prompt_embeds = None if do_classifier_free_guidance: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) neg_text_inputs = self.tokenizer( negative_prompt_clean, padding="max_length", - max_length=max_sequence_length, + max_length=prompt_encode_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -593,7 +600,7 @@ def encode_prompt( negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] negative_prompt_embeds = torch.stack( [ - torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) + torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in negative_prompt_embeds ], dim=0,