From 9528def208a3dddeda9328bb4163c9a2fde64d5c Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Mon, 20 Apr 2026 08:16:53 +0000 Subject: [PATCH 1/2] [Revert] drop Wan2.2 prompt-length enforcement from #2847 Signed-off-by: david6666666 <530634352@qq.com> --- .../wan2_2/test_wan22_max_sequence_length.py | 150 ------------------ .../models/wan2_2/pipeline_wan2_2.py | 46 +----- .../models/wan2_2/pipeline_wan2_2_i2v.py | 46 +----- .../models/wan2_2/pipeline_wan2_2_ti2v.py | 46 +----- .../models/wan2_2/pipeline_wan2_2_vace.py | 5 +- 5 files changed, 7 insertions(+), 286 deletions(-) delete mode 100644 tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py deleted file mode 100644 index 64c2b271c9c..00000000000 --- a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py +++ /dev/null @@ -1,150 +0,0 @@ -from types import SimpleNamespace - -import PIL.Image -import pytest -import torch -from torch import nn - -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - WAN22_MAX_SEQUENCE_LENGTH, - Wan22Pipeline, -) -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import ( - Wan22I2VPipeline, -) -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import ( - Wan22TI2VPipeline, -) -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace import ( - Wan22VACEPipeline, -) - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -class _RejectingTextEncoder: - dtype = torch.float32 - - def __call__(self, *args, **kwargs): - raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length") - - -class _FakeTokenBatch: - def __init__(self, total_sequence_length: int): - attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long) - self.input_ids = attention_mask.clone() - self.attention_mask = attention_mask - - -class _FakeTokenizer: - def __init__(self, total_sequence_length: int): - self.total_sequence_length = total_sequence_length - - def __call__(self, *args, **kwargs): - return _FakeTokenBatch(self.total_sequence_length) - - -PIPELINE_CASES = [ - pytest.param(Wan22Pipeline, id="wan22-t2v"), - pytest.param(Wan22I2VPipeline, id="wan22-i2v"), - pytest.param(Wan22TI2VPipeline, id="wan22-ti2v"), - pytest.param(Wan22VACEPipeline, id="wan22-vace"), -] - - -def _make_pipeline(pipeline_class: type, *, total_sequence_length: int): - pipeline = object.__new__(pipeline_class) - nn.Module.__init__(pipeline) - pipeline.device = torch.device("cpu") - pipeline.text_encoder = _RejectingTextEncoder() - pipeline.tokenizer = _FakeTokenizer(total_sequence_length) - pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH - return pipeline - - -@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) -def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(pipeline_class: type): - pipeline = _make_pipeline(pipeline_class, total_sequence_length=WAN22_MAX_SEQUENCE_LENGTH + 1) - - with pytest.raises(ValueError, match=r"got 513 tokens, but `max_sequence_length` is 512"): - pipeline.encode_prompt(prompt="prompt") - - -@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) -def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(pipeline_class: type): - pipeline = _make_pipeline(pipeline_class, total_sequence_length=17) - - with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): - pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) - - -def _sampling_params(**overrides): - defaults = dict( - height=None, - width=None, - num_frames=None, - num_inference_steps=None, - generator=None, - guidance_scale_provided=False, - guidance_scale_2=None, - boundary_ratio=None, - num_outputs_per_prompt=0, - max_sequence_length=None, - seed=None, - extra_args={}, - prompt_embeds=None, - negative_prompt_embeds=None, - ) - defaults.update(overrides) - return SimpleNamespace(**defaults) - - -@pytest.mark.parametrize( - ("pipeline_class", "prompt_value", "forward_kwargs"), - [ - pytest.param(Wan22Pipeline, "prompt", {}, id="wan22-t2v"), - pytest.param( - Wan22I2VPipeline, - {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, - {"image": PIL.Image.new("RGB", (64, 64))}, - id="wan22-i2v", - ), - pytest.param( - Wan22TI2VPipeline, - {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, - {"image": PIL.Image.new("RGB", (64, 64))}, - id="wan22-ti2v", - ), - pytest.param(Wan22VACEPipeline, "prompt", {}, id="wan22-vace"), - ], -) -def test_forward_defaults_to_wan22_tokenizer_max_length( - pipeline_class: type, - prompt_value, - forward_kwargs, -): - pipeline = object.__new__(pipeline_class) - nn.Module.__init__(pipeline) - pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH - pipeline.boundary_ratio = None - pipeline.vae_scale_factor_temporal = 4 - pipeline.vae_scale_factor_spatial = 8 - pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2)) - - captured = {} - - def _fake_check_inputs(*args, **kwargs): - captured["max_sequence_length"] = kwargs["max_sequence_length"] - raise RuntimeError("stop after capture") - - pipeline.check_inputs = _fake_check_inputs - - req = SimpleNamespace( - prompts=[prompt_value], - sampling_params=_sampling_params(), - ) - - with pytest.raises(RuntimeError, match="stop after capture"): - pipeline.forward(req, **forward_kwargs) - - assert captured["max_sequence_length"] == WAN22_MAX_SEQUENCE_LENGTH diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 652425d5097..5b067614446 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -29,16 +29,12 @@ from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.utils.prompt_utils import ( - validate_prompt_sequence_lengths, -) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) DEBUG_PERF = False WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"} -WAN22_MAX_SEQUENCE_LENGTH = 512 def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any: @@ -293,7 +289,6 @@ def __init__( pass self.boundary_ratio = od_config.boundary_ratio - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH # Determine which transformers to load based on boundary_ratio # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) @@ -565,7 +560,6 @@ def forward( negative_prompt_embeds=negative_prompt_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -599,7 +593,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -832,20 +826,6 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) - text_inputs_untruncated = self.tokenizer( - prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - error_context="for Wan2.2 text encoding", - ) text_inputs = self.tokenizer( prompt_clean, @@ -874,24 +854,8 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) neg_text_inputs = self.tokenizer( - negative_prompt_clean, + [self._prompt_clean(p) for p in negative_prompt], padding="max_length", max_length=max_sequence_length, truncation=True, @@ -963,7 +927,6 @@ def check_inputs( negative_prompt_embeds=None, guidance_scale_2=None, boundary_ratio=None, - max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -990,10 +953,5 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: - raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" - ) - if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 95d1e08bbc7..4eedd9d207f 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -26,7 +26,6 @@ from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - WAN22_MAX_SEQUENCE_LENGTH, build_wan_scheduler, create_transformer_from_config, load_transformer_config, @@ -37,9 +36,6 @@ from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.utils.prompt_utils import ( - validate_prompt_sequence_lengths, -) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -218,7 +214,6 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -474,7 +469,6 @@ def forward( image_embeds=image_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -503,7 +497,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -708,20 +702,6 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) - text_inputs_untruncated = self.tokenizer( - prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - error_context="for Wan2.2 text encoding", - ) text_inputs = self.tokenizer( prompt_clean, @@ -750,24 +730,8 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) neg_text_inputs = self.tokenizer( - negative_prompt_clean, + [self._prompt_clean(p) for p in negative_prompt], padding="max_length", max_length=max_sequence_length, truncation=True, @@ -911,7 +875,6 @@ def check_inputs( image_embeds=None, guidance_scale_2=None, boundary_ratio=None, - max_sequence_length=None, ): if image is None and image_embeds is None: raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.") @@ -933,11 +896,6 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") - if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: - raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" - ) - if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index dba76ba8af8..8170a8f5ab5 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -37,7 +37,6 @@ from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - WAN22_MAX_SEQUENCE_LENGTH, build_wan_scheduler, create_transformer_from_config, load_transformer_config, @@ -47,9 +46,6 @@ ) from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.utils.prompt_utils import ( - validate_prompt_sequence_lengths, -) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -189,7 +185,6 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -377,7 +372,6 @@ def forward( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -401,7 +395,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -551,20 +545,6 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) - text_inputs_untruncated = self.tokenizer( - prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - error_context="for Wan2.2 text encoding", - ) text_inputs = self.tokenizer( prompt_clean, @@ -593,24 +573,8 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) neg_text_inputs = self.tokenizer( - negative_prompt_clean, + [self._prompt_clean(p) for p in negative_prompt], padding="max_length", max_length=max_sequence_length, truncation=True, @@ -735,7 +699,6 @@ def check_inputs( width, prompt_embeds=None, negative_prompt_embeds=None, - max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -751,11 +714,6 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") - if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: - raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py index 11408e2d24b..0458f88597e 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py @@ -243,7 +243,6 @@ def check_inputs( video=None, mask=None, reference_images=None, - max_sequence_length=None, ): super().check_inputs( prompt=prompt, @@ -252,7 +251,6 @@ def check_inputs( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=max_sequence_length, ) # VACE-specific: validate video/mask/reference_images consistency @@ -549,7 +547,6 @@ def forward( video=source_video, mask=source_mask, reference_images=reference_images, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) device = self.device @@ -568,7 +565,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) From d4ac17c4ba1a7b94a4413aa0418cbc79e31f13cf Mon Sep 17 00:00:00 2001 From: david6666666 <530634352@qq.com> Date: Mon, 20 Apr 2026 08:25:57 +0000 Subject: [PATCH 2/2] [CI] re-enable Wan2.2 I2V online serving test Signed-off-by: david6666666 <530634352@qq.com> --- tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py index 24daa8ccf54..3aa5da85c24 100644 --- a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py +++ b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py @@ -567,7 +567,6 @@ def test_wan22_i2v_diffusers_offline_generates_video( @pytest.mark.benchmark @pytest.mark.diffusion @hardware_test(res={"cuda": "H100"}, num_cards=2) -@pytest.mark.skip(reason="issue: #2874") @pytest.mark.parametrize("omni_server", SERVER_CASES, indirect=True) def test_wan22_i2v_online_serving_generates_video( omni_server,