Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 0 additions & 150 deletions tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,6 @@ def test_wan22_i2v_diffusers_offline_generates_video(
@pytest.mark.benchmark
@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"}, num_cards=2)
@pytest.mark.skip(reason="issue: #2874")
@pytest.mark.parametrize("omni_server", SERVER_CASES, indirect=True)
def test_wan22_i2v_online_serving_generates_video(
omni_server,
Expand Down
46 changes: 2 additions & 44 deletions vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,12 @@
from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.prompt_utils import (
validate_prompt_sequence_lengths,
)
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform

logger = logging.getLogger(__name__)
DEBUG_PERF = False
WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"}
WAN22_MAX_SEQUENCE_LENGTH = 512


def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any:
Expand Down Expand Up @@ -293,7 +289,6 @@ def __init__(
pass

self.boundary_ratio = od_config.boundary_ratio
self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH

# Determine which transformers to load based on boundary_ratio
# boundary_ratio=1.0: only load transformer_2 (low-noise stage only)
Expand Down Expand Up @@ -565,7 +560,6 @@ def forward(
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale_2=guidance_high if boundary_ratio is not None else None,
boundary_ratio=boundary_ratio,
max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)

if num_frames % self.vae_scale_factor_temporal != 1:
Expand Down Expand Up @@ -599,7 +593,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
max_sequence_length=req.sampling_params.max_sequence_length or 512,
device=device,
dtype=dtype,
)
Expand Down Expand Up @@ -832,20 +826,6 @@ def encode_prompt(
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)
text_inputs_untruncated = self.tokenizer(
prompt_clean,
padding=True,
truncation=False,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
validate_prompt_sequence_lengths(
text_inputs_untruncated.attention_mask,
max_sequence_length=max_sequence_length,
supported_max_sequence_length=self.tokenizer_max_length,
error_context="for Wan2.2 text encoding",
)

text_inputs = self.tokenizer(
prompt_clean,
Expand Down Expand Up @@ -874,24 +854,8 @@ def encode_prompt(
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt]
neg_text_inputs_untruncated = self.tokenizer(
negative_prompt_clean,
padding=True,
truncation=False,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
validate_prompt_sequence_lengths(
neg_text_inputs_untruncated.attention_mask,
max_sequence_length=max_sequence_length,
supported_max_sequence_length=self.tokenizer_max_length,
prompt_name="negative_prompt",
error_context="for Wan2.2 text encoding",
)
neg_text_inputs = self.tokenizer(
negative_prompt_clean,
[self._prompt_clean(p) for p in negative_prompt],
padding="max_length",
max_length=max_sequence_length,
truncation=True,
Expand Down Expand Up @@ -963,7 +927,6 @@ def check_inputs(
negative_prompt_embeds=None,
guidance_scale_2=None,
boundary_ratio=None,
max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
Expand All @@ -990,10 +953,5 @@ def check_inputs(
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")

if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
raise ValueError(
f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
)

if boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.")
46 changes: 2 additions & 44 deletions vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
WAN22_MAX_SEQUENCE_LENGTH,
build_wan_scheduler,
create_transformer_from_config,
load_transformer_config,
Expand All @@ -37,9 +36,6 @@
from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.prompt_utils import (
validate_prompt_sequence_lengths,
)
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform

Expand Down Expand Up @@ -218,7 +214,6 @@ def __init__(

# Text encoder
self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
self.text_encoder = UMT5EncoderModel.from_pretrained(
model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only
).to(self.device)
Expand Down Expand Up @@ -474,7 +469,6 @@ def forward(
image_embeds=image_embeds,
guidance_scale_2=guidance_high if boundary_ratio is not None else None,
boundary_ratio=boundary_ratio,
max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)

# Adjust num_frames to be compatible with VAE temporal scaling
Expand Down Expand Up @@ -503,7 +497,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
max_sequence_length=req.sampling_params.max_sequence_length or 512,
device=device,
dtype=dtype,
)
Expand Down Expand Up @@ -708,20 +702,6 @@ def encode_prompt(
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)
text_inputs_untruncated = self.tokenizer(
prompt_clean,
padding=True,
truncation=False,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
validate_prompt_sequence_lengths(
text_inputs_untruncated.attention_mask,
max_sequence_length=max_sequence_length,
supported_max_sequence_length=self.tokenizer_max_length,
error_context="for Wan2.2 text encoding",
)

text_inputs = self.tokenizer(
prompt_clean,
Expand Down Expand Up @@ -750,24 +730,8 @@ def encode_prompt(
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt]
neg_text_inputs_untruncated = self.tokenizer(
negative_prompt_clean,
padding=True,
truncation=False,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
validate_prompt_sequence_lengths(
neg_text_inputs_untruncated.attention_mask,
max_sequence_length=max_sequence_length,
supported_max_sequence_length=self.tokenizer_max_length,
prompt_name="negative_prompt",
error_context="for Wan2.2 text encoding",
)
neg_text_inputs = self.tokenizer(
negative_prompt_clean,
[self._prompt_clean(p) for p in negative_prompt],
padding="max_length",
max_length=max_sequence_length,
truncation=True,
Expand Down Expand Up @@ -911,7 +875,6 @@ def check_inputs(
image_embeds=None,
guidance_scale_2=None,
boundary_ratio=None,
max_sequence_length=None,
):
if image is None and image_embeds is None:
raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.")
Expand All @@ -933,11 +896,6 @@ def check_inputs(
if prompt is None and prompt_embeds is None:
raise ValueError("Provide either `prompt` or `prompt_embeds`.")

if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
raise ValueError(
f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
)

if boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.")

Expand Down
Loading
Loading