From 02ddf241482b70b2964281a1b291d34dc72d4bf8 Mon Sep 17 00:00:00 2001 From: fan2956 Date: Thu, 16 Apr 2026 19:40:51 +0800 Subject: [PATCH 1/3] [Perf] Optimize Wan2.2 device free on image preprocess Signed-off-by: fan2956 --- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index ddc6e0bc2b9..7d8973058c7 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -12,6 +12,7 @@ import numpy as np import PIL.Image import torch +import torchvision.transforms.functional as TF from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel @@ -478,6 +479,7 @@ def forward( video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) if isinstance(image, PIL.Image.Image): + image = TF.to_tensor(image).to(device) image_tensor = video_processor.preprocess(image, height=height, width=width) else: image_tensor = image @@ -486,6 +488,7 @@ def forward( # Handle last_image if provided if last_image is not None: if isinstance(last_image, PIL.Image.Image): + image = TF.to_tensor(image).to(device) last_image_tensor = video_processor.preprocess(last_image, height=height, width=width) else: last_image_tensor = last_image @@ -808,12 +811,12 @@ def prepare_latents( return latents, latent_condition, first_frame_mask # Wan2.1 style: create mask and concatenate with condition - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width, device=latent_condition.device) if last_image is None: - mask_lat_size[:, :, list(range(1, num_frames))] = 0 + mask_lat_size[:, :, 1:] = 0 else: - mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + mask_lat_size[:, :, 1 : num_frames - 1] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) From 0033df9a05122238c3179e75548ce431f84e05ac Mon Sep 17 00:00:00 2001 From: fan2956 Date: Thu, 16 Apr 2026 19:48:07 +0800 Subject: [PATCH 2/3] [Bugfix] fix pre-commit Signed-off-by: fan2956 --- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 7d8973058c7..35e00884ab5 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -811,7 +811,9 @@ def prepare_latents( return latents, latent_condition, first_frame_mask # Wan2.1 style: create mask and concatenate with condition - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width, device=latent_condition.device) + mask_lat_size = torch.ones( + batch_size, 1, num_frames, latent_height, latent_width, device=latent_condition.device + ) if last_image is None: mask_lat_size[:, :, 1:] = 0 From 37770a4013fe623dce96a8ac0495c39b14efe94e Mon Sep 17 00:00:00 2001 From: fan2956 Date: Fri, 17 Apr 2026 09:47:45 +0800 Subject: [PATCH 3/3] [Bugfix] fix last image to_tensor error Signed-off-by: fan2956 --- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 35e00884ab5..f912199d0f0 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -488,7 +488,7 @@ def forward( # Handle last_image if provided if last_image is not None: if isinstance(last_image, PIL.Image.Image): - image = TF.to_tensor(image).to(device) + image = TF.to_tensor(last_image).to(device) last_image_tensor = video_processor.preprocess(last_image, height=height, width=width) else: last_image_tensor = last_image