diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 24e8965a39e..95d1e08bbc7 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -12,6 +12,7 @@ import numpy as np import PIL.Image import torch +import torchvision.transforms.functional as TF from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel @@ -560,6 +561,7 @@ def forward( video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) if isinstance(image, PIL.Image.Image): + image = TF.to_tensor(image).to(device) image_tensor = video_processor.preprocess(image, height=height, width=width) else: image_tensor = image @@ -568,6 +570,7 @@ def forward( # Handle last_image if provided if last_image is not None: if isinstance(last_image, PIL.Image.Image): + image = TF.to_tensor(last_image).to(device) last_image_tensor = video_processor.preprocess(last_image, height=height, width=width) else: last_image_tensor = last_image @@ -872,12 +875,14 @@ def prepare_latents( return latents, latent_condition, first_frame_mask # Wan2.1 style: create mask and concatenate with condition - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + mask_lat_size = torch.ones( + batch_size, 1, num_frames, latent_height, latent_width, device=latent_condition.device + ) if last_image is None: - mask_lat_size[:, :, list(range(1, num_frames))] = 0 + mask_lat_size[:, :, 1:] = 0 else: - mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + mask_lat_size[:, :, 1 : num_frames - 1] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)