diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py deleted file mode 100644 index 17dec2ce06a..00000000000 --- a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py +++ /dev/null @@ -1,190 +0,0 @@ -from types import SimpleNamespace - -import PIL.Image -import pytest -import torch -from torch import nn - -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - WAN22_MAX_SEQUENCE_LENGTH, - Wan22Pipeline, -) -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import ( - Wan22I2VPipeline, -) -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import ( - Wan22TI2VPipeline, -) - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -class _RejectingTextEncoder: - dtype = torch.float32 - - def __call__(self, *args, **kwargs): - raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length") - - -class _FakeTokenBatch: - def __init__(self, total_sequence_length: int, padded_length: int | None = None): - padded_length = padded_length or total_sequence_length - attention_mask = torch.zeros((1, padded_length), dtype=torch.long) - attention_mask[:, :total_sequence_length] = 1 - self.input_ids = attention_mask.clone() - self.attention_mask = attention_mask - - -class _FakeTokenizer: - def __init__(self, total_sequence_length: int | list[int]): - if isinstance(total_sequence_length, list): - self.total_sequence_lengths = list(total_sequence_length) - else: - self.total_sequence_lengths = [total_sequence_length] - - def __call__(self, *args, **kwargs): - if len(self.total_sequence_lengths) > 1: - total_sequence_length = self.total_sequence_lengths.pop(0) - else: - total_sequence_length = self.total_sequence_lengths[0] - return _FakeTokenBatch(total_sequence_length, kwargs.get("max_length")) - - -class _RecordingTextEncoder: - dtype = torch.float32 - - def __init__(self): - self.input_lengths: list[int] = [] - - def __call__(self, input_ids, attention_mask): - self.input_lengths.append(input_ids.shape[1]) - return SimpleNamespace(last_hidden_state=torch.zeros((1, input_ids.shape[1], 4), dtype=torch.float32)) - - -PIPELINE_CASES = [ - pytest.param(Wan22Pipeline, id="wan22-t2v"), - pytest.param(Wan22I2VPipeline, id="wan22-i2v"), - pytest.param(Wan22TI2VPipeline, id="wan22-ti2v"), -] - - -def _make_pipeline(pipeline_class: type, *, total_sequence_length: int): - pipeline = object.__new__(pipeline_class) - nn.Module.__init__(pipeline) - pipeline.device = torch.device("cpu") - pipeline.text_encoder = _RejectingTextEncoder() - pipeline.tokenizer = _FakeTokenizer(total_sequence_length) - pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH - return pipeline - - -@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) -def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(pipeline_class: type): - pipeline = _make_pipeline(pipeline_class, total_sequence_length=WAN22_MAX_SEQUENCE_LENGTH + 1) - - with pytest.raises( - ValueError, - match=rf"got {WAN22_MAX_SEQUENCE_LENGTH + 1} tokens, but `max_sequence_length` is {WAN22_MAX_SEQUENCE_LENGTH}", - ): - pipeline.encode_prompt(prompt="prompt") - - -@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) -def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(pipeline_class: type): - pipeline = _make_pipeline(pipeline_class, total_sequence_length=17) - - with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): - pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) - - -@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) -def test_encode_prompt_uses_actual_prompt_length_for_text_encoder_padding(pipeline_class: type): - pipeline = object.__new__(pipeline_class) - nn.Module.__init__(pipeline) - pipeline.device = torch.device("cpu") - pipeline.text_encoder = _RecordingTextEncoder() - # prompt validation, prompt encode, negative validation, negative encode - pipeline.tokenizer = _FakeTokenizer([17, 17, 9, 9]) - pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH - - prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( - prompt="prompt", - negative_prompt="neg", - do_classifier_free_guidance=True, - ) - - assert pipeline.text_encoder.input_lengths == [17, 17] - assert prompt_embeds.shape[1] == 17 - assert negative_prompt_embeds is not None - assert negative_prompt_embeds.shape[1] == 17 - - -def _sampling_params(**overrides): - defaults = dict( - height=None, - width=None, - num_frames=None, - num_inference_steps=None, - generator=None, - guidance_scale_provided=False, - guidance_scale_2=None, - boundary_ratio=None, - num_outputs_per_prompt=0, - max_sequence_length=None, - seed=None, - extra_args={}, - prompt_embeds=None, - negative_prompt_embeds=None, - ) - defaults.update(overrides) - return SimpleNamespace(**defaults) - - -@pytest.mark.parametrize( - ("pipeline_class", "prompt_value", "forward_kwargs"), - [ - pytest.param(Wan22Pipeline, "prompt", {}, id="wan22-t2v"), - pytest.param( - Wan22I2VPipeline, - {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, - {"image": PIL.Image.new("RGB", (64, 64))}, - id="wan22-i2v", - ), - pytest.param( - Wan22TI2VPipeline, - {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, - {"image": PIL.Image.new("RGB", (64, 64))}, - id="wan22-ti2v", - ), - ], -) -def test_forward_defaults_to_wan22_tokenizer_max_length( - pipeline_class: type, - prompt_value, - forward_kwargs, -): - pipeline = object.__new__(pipeline_class) - nn.Module.__init__(pipeline) - pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH - pipeline.boundary_ratio = None - pipeline.vae_scale_factor_temporal = 4 - pipeline.vae_scale_factor_spatial = 8 - pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2)) - - captured = {} - - def _fake_check_inputs(*args, **kwargs): - captured["max_sequence_length"] = kwargs["max_sequence_length"] - raise RuntimeError("stop after capture") - - pipeline.check_inputs = _fake_check_inputs - - req = SimpleNamespace( - prompts=[prompt_value], - sampling_params=_sampling_params(), - ) - - with pytest.raises(RuntimeError, match="stop after capture"): - pipeline.forward(req, **forward_kwargs) - - assert captured["max_sequence_length"] == WAN22_MAX_SEQUENCE_LENGTH diff --git a/vllm_omni/diffusion/layers/adalayernorm.py b/vllm_omni/diffusion/layers/adalayernorm.py index 7f963a01d9b..e2ff041f6af 100644 --- a/vllm_omni/diffusion/layers/adalayernorm.py +++ b/vllm_omni/diffusion/layers/adalayernorm.py @@ -1,7 +1,6 @@ from importlib.util import find_spec import torch -import torch.nn as nn from vllm.logger import init_logger from vllm_omni.diffusion.layers.custom_op import CustomOp diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 016c5177a22..4b8d01dfa9c 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -28,15 +28,11 @@ from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.utils.prompt_utils import ( - validate_prompt_sequence_lengths, -) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) DEBUG_PERF = False -WAN22_MAX_SEQUENCE_LENGTH = 2048 def retrieve_latents( @@ -252,7 +248,6 @@ def __init__( pass self.boundary_ratio = od_config.boundary_ratio - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH # Determine which transformers to load based on boundary_ratio # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) @@ -428,7 +423,6 @@ def forward( negative_prompt_embeds=negative_prompt_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -462,7 +456,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -749,7 +743,7 @@ def encode_prompt( negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - max_sequence_length: int = WAN22_MAX_SEQUENCE_LENGTH, + max_sequence_length: int = 512, device: torch.device | None = None, dtype: torch.dtype | None = None, ): @@ -759,51 +753,11 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) - text_inputs_untruncated = self.tokenizer( - prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - error_context="for Wan2.2 text encoding", - ) - prompt_encode_length = max(int(text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), 1) - - negative_prompt_embeds = None - if do_classifier_free_guidance: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) - prompt_encode_length = max( - prompt_encode_length, - int(neg_text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), - ) text_inputs = self.tokenizer( prompt_clean, padding="max_length", - max_length=prompt_encode_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -816,18 +770,21 @@ def encode_prompt( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + negative_prompt_embeds = None if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt neg_text_inputs = self.tokenizer( - negative_prompt_clean, + [self._prompt_clean(p) for p in negative_prompt], padding="max_length", - max_length=prompt_encode_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -840,7 +797,7 @@ def encode_prompt( negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] negative_prompt_embeds = torch.stack( [ - torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) + torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in negative_prompt_embeds ], dim=0, @@ -897,7 +854,6 @@ def check_inputs( negative_prompt_embeds=None, guidance_scale_2=None, boundary_ratio=None, - max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -924,10 +880,5 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: - raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" - ) - if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 96814299ec6..f9d063d6dee 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -27,7 +27,6 @@ from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - WAN22_MAX_SEQUENCE_LENGTH, create_transformer_from_config, load_transformer_config, retrieve_latents, @@ -35,9 +34,6 @@ from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.utils.prompt_utils import ( - validate_prompt_sequence_lengths, -) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -216,7 +212,6 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -400,7 +395,6 @@ def forward( image_embeds=image_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -429,7 +423,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -489,7 +483,7 @@ def forward( # Handle last_image if provided if last_image is not None: if isinstance(last_image, PIL.Image.Image): - image = TF.to_tensor(last_image).to(device) + last_image = TF.to_tensor(last_image).to(device) last_image_tensor = video_processor.preprocess(last_image, height=height, width=width) else: last_image_tensor = last_image @@ -664,7 +658,7 @@ def encode_prompt( negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - max_sequence_length: int = WAN22_MAX_SEQUENCE_LENGTH, + max_sequence_length: int = 512, device: torch.device | None = None, dtype: torch.dtype | None = None, ): @@ -675,51 +669,11 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) - text_inputs_untruncated = self.tokenizer( - prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - error_context="for Wan2.2 text encoding", - ) - prompt_encode_length = max(int(text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), 1) - - negative_prompt_embeds = None - if do_classifier_free_guidance: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) - prompt_encode_length = max( - prompt_encode_length, - int(neg_text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), - ) text_inputs = self.tokenizer( prompt_clean, padding="max_length", - max_length=prompt_encode_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -732,18 +686,21 @@ def encode_prompt( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + negative_prompt_embeds = None if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt neg_text_inputs = self.tokenizer( - negative_prompt_clean, + [self._prompt_clean(p) for p in negative_prompt], padding="max_length", - max_length=prompt_encode_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -756,7 +713,7 @@ def encode_prompt( negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] negative_prompt_embeds = torch.stack( [ - torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) + torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in negative_prompt_embeds ], dim=0, @@ -885,7 +842,6 @@ def check_inputs( image_embeds=None, guidance_scale_2=None, boundary_ratio=None, - max_sequence_length=None, ): if image is None and image_embeds is None: raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.") @@ -907,11 +863,6 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") - if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: - raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" - ) - if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 0f4caf9fd16..4953a809d80 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -38,16 +38,12 @@ from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - WAN22_MAX_SEQUENCE_LENGTH, create_transformer_from_config, load_transformer_config, retrieve_latents, ) from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.utils.prompt_utils import ( - validate_prompt_sequence_lengths, -) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -187,7 +183,6 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -308,7 +303,6 @@ def forward( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -332,7 +326,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -508,7 +502,7 @@ def encode_prompt( negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - max_sequence_length: int = WAN22_MAX_SEQUENCE_LENGTH, + max_sequence_length: int = 512, device: torch.device | None = None, dtype: torch.dtype | None = None, ): @@ -519,51 +513,11 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) - text_inputs_untruncated = self.tokenizer( - prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - error_context="for Wan2.2 text encoding", - ) - prompt_encode_length = max(int(text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), 1) - - negative_prompt_embeds = None - if do_classifier_free_guidance: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) - prompt_encode_length = max( - prompt_encode_length, - int(neg_text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), - ) text_inputs = self.tokenizer( prompt_clean, padding="max_length", - max_length=prompt_encode_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -576,18 +530,21 @@ def encode_prompt( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + negative_prompt_embeds = None if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt neg_text_inputs = self.tokenizer( - negative_prompt_clean, + [self._prompt_clean(p) for p in negative_prompt], padding="max_length", - max_length=prompt_encode_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, @@ -600,7 +557,7 @@ def encode_prompt( negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] negative_prompt_embeds = torch.stack( [ - torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) + torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in negative_prompt_embeds ], dim=0, @@ -710,7 +667,6 @@ def check_inputs( width, prompt_embeds=None, negative_prompt_embeds=None, - max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -726,11 +682,6 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") - if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: - raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index 006f2d6717b..946192882fc 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -11,7 +11,6 @@ from diffusers.models.attention import FeedForward from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.models.normalization import FP32LayerNorm from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -24,14 +23,13 @@ from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention -from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm -from vllm_omni.diffusion.layers.norm import LayerNorm, RMSNorm from vllm_omni.diffusion.distributed.sp_plan import ( SequenceParallelInput, SequenceParallelOutput, ) from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm +from vllm_omni.diffusion.layers.norm import LayerNorm, RMSNorm logger = init_logger(__name__)