Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 0 additions & 190 deletions tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py

This file was deleted.

1 change: 0 additions & 1 deletion vllm_omni/diffusion/layers/adalayernorm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from importlib.util import find_spec

import torch
import torch.nn as nn
from vllm.logger import init_logger

from vllm_omni.diffusion.layers.custom_op import CustomOp
Expand Down
69 changes: 10 additions & 59 deletions vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,11 @@
from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.prompt_utils import (
validate_prompt_sequence_lengths,
)
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform

logger = logging.getLogger(__name__)
DEBUG_PERF = False
WAN22_MAX_SEQUENCE_LENGTH = 2048


def retrieve_latents(
Expand Down Expand Up @@ -252,7 +248,6 @@ def __init__(
pass

self.boundary_ratio = od_config.boundary_ratio
self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH

# Determine which transformers to load based on boundary_ratio
# boundary_ratio=1.0: only load transformer_2 (low-noise stage only)
Expand Down Expand Up @@ -428,7 +423,6 @@ def forward(
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale_2=guidance_high if boundary_ratio is not None else None,
boundary_ratio=boundary_ratio,
max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)

if num_frames % self.vae_scale_factor_temporal != 1:
Expand Down Expand Up @@ -462,7 +456,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
max_sequence_length=req.sampling_params.max_sequence_length or 512,
device=device,
dtype=dtype,
)
Expand Down Expand Up @@ -749,7 +743,7 @@ def encode_prompt(
negative_prompt: str | list[str] | None = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
max_sequence_length: int = WAN22_MAX_SEQUENCE_LENGTH,
max_sequence_length: int = 512,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
Expand All @@ -759,51 +753,11 @@ def encode_prompt(
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)
text_inputs_untruncated = self.tokenizer(
prompt_clean,
padding=True,
truncation=False,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
validate_prompt_sequence_lengths(
text_inputs_untruncated.attention_mask,
max_sequence_length=max_sequence_length,
supported_max_sequence_length=self.tokenizer_max_length,
error_context="for Wan2.2 text encoding",
)
prompt_encode_length = max(int(text_inputs_untruncated.attention_mask.sum(dim=1).max().item()), 1)

negative_prompt_embeds = None
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt]
neg_text_inputs_untruncated = self.tokenizer(
negative_prompt_clean,
padding=True,
truncation=False,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
validate_prompt_sequence_lengths(
neg_text_inputs_untruncated.attention_mask,
max_sequence_length=max_sequence_length,
supported_max_sequence_length=self.tokenizer_max_length,
prompt_name="negative_prompt",
error_context="for Wan2.2 text encoding",
)
prompt_encode_length = max(
prompt_encode_length,
int(neg_text_inputs_untruncated.attention_mask.sum(dim=1).max().item()),
)

text_inputs = self.tokenizer(
prompt_clean,
padding="max_length",
max_length=prompt_encode_length,
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
Expand All @@ -816,18 +770,21 @@ def encode_prompt(
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)

negative_prompt_embeds = None
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
neg_text_inputs = self.tokenizer(
negative_prompt_clean,
[self._prompt_clean(p) for p in negative_prompt],
padding="max_length",
max_length=prompt_encode_length,
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
Expand All @@ -840,7 +797,7 @@ def encode_prompt(
negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)]
negative_prompt_embeds = torch.stack(
[
torch.cat([u, u.new_zeros(prompt_encode_length - u.size(0), u.size(1))])
torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))])
for u in negative_prompt_embeds
],
dim=0,
Expand Down Expand Up @@ -897,7 +854,6 @@ def check_inputs(
negative_prompt_embeds=None,
guidance_scale_2=None,
boundary_ratio=None,
max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
Expand All @@ -924,10 +880,5 @@ def check_inputs(
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")

if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
raise ValueError(
f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
)

if boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.")
Loading