Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Key arguments:
- `--num_frames`: Number of frames (Wan default is 81).
- `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high).
- `--negative_prompt`: optional list of artifacts to suppress (the PR demo used a long Chinese string).
- `--boundary_ratio`: Boundary split ratio for low/high DiT.
- `--boundary_ratio`: Boundary split ratio for low/high DiT. Default `0.875` uses both transformers for best quality. Set to `1.0` to load only the low-noise transformer (saves noticeable memory with good quality, recommended if memory is limited). Set to `0.0` loads only the high-noise transformer (not recommended, lower quality).
- `--fps`: frames per second for the saved MP4 (requires `diffusers` export_to_video).
- `--output`: path to save the generated video.
- `--vae_use_slicing`: enable VAE slicing for memory optimization.
Expand Down
7 changes: 6 additions & 1 deletion examples/offline_inference/text_to_video/text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--width", type=int, default=1280, help="Video width.")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames (Wan default is 81).")
parser.add_argument("--num_inference_steps", type=int, default=40, help="Sampling steps.")
parser.add_argument("--boundary_ratio", type=float, default=0.875, help="Boundary split ratio for low/high DiT.")
parser.add_argument(
"--boundary_ratio",
type=float,
default=0.875,
help="Boundary split ratio for low/high DiT. Default 0.875 uses both transformers for best quality. Set to 1.0 to load only the low-noise transformer (saves noticeable memory with good quality, recommended if memory is limited).",
)
parser.add_argument(
"--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)."
)
Expand Down
96 changes: 73 additions & 23 deletions vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,30 @@ def __init__(
except Exception:
pass

self.boundary_ratio = od_config.boundary_ratio

# Determine which transformers to load based on boundary_ratio
# boundary_ratio=1.0: only load transformer_2 (low-noise stage only)
# boundary_ratio=0.0: only load transformer (high-noise stage only)
# otherwise: load both transformers
load_transformer = self.boundary_ratio != 1.0 if self.boundary_ratio is not None else True
load_transformer_2 = self.has_transformer_2 and (
self.boundary_ratio != 0.0 if self.boundary_ratio is not None else True
)
Comment thread
faaany marked this conversation as resolved.

# Set up weights sources for transformer(s)
self.weights_sources = [
DiffusersPipelineLoader.ComponentSource(
model_or_path=od_config.model,
subfolder="transformer",
revision=None,
prefix="transformer.",
fall_back_to_pt=True,
),
]
if self.has_transformer_2:
self.weights_sources = []
if load_transformer:
self.weights_sources.append(
DiffusersPipelineLoader.ComponentSource(
model_or_path=od_config.model,
subfolder="transformer",
revision=None,
prefix="transformer.",
fall_back_to_pt=True,
)
)
if load_transformer_2:
self.weights_sources.append(
DiffusersPipelineLoader.ComponentSource(
model_or_path=od_config.model,
Expand All @@ -257,14 +270,26 @@ def __init__(
).to(self.device)

# Initialize transformers with correct config (weights loaded via load_weights)
transformer_config = load_transformer_config(model, "transformer", local_files_only)
self.transformer = create_transformer_from_config(transformer_config)
if self.has_transformer_2:
if load_transformer:
transformer_config = load_transformer_config(model, "transformer", local_files_only)
self.transformer = create_transformer_from_config(transformer_config)
else:
self.transformer = None

if load_transformer_2:
transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only)
self.transformer_2 = create_transformer_from_config(transformer_2_config)
else:
self.transformer_2 = None

# Store the active transformer config
if load_transformer:
self.transformer_config = self.transformer.config
elif load_transformer_2:
self.transformer_config = self.transformer_2.config
else:
raise RuntimeError("No transformer loaded")

# Initialize UniPC scheduler
flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
self.scheduler = FlowUniPCMultistepScheduler(
Expand All @@ -275,7 +300,6 @@ def __init__(

self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.boundary_ratio = od_config.boundary_ratio

self._guidance_scale = None
self._guidance_scale_2 = None
Expand Down Expand Up @@ -333,7 +357,7 @@ def forward(

# Ensure dimensions are compatible with VAE and patch size
# For expand_timesteps mode, we need latent dims to be even (divisible by patch_size)
patch_size = self.transformer.config.patch_size
patch_size = self.transformer_config.patch_size
Comment thread
faaany marked this conversation as resolved.
mod_value = self.vae_scale_factor_spatial * patch_size[1] # 16*2=32 for TI2V, 8*2=16 for I2V
height = (height // mod_value) * mod_value
width = (width // mod_value) * mod_value
Expand Down Expand Up @@ -374,7 +398,14 @@ def forward(
num_frames = max(num_frames, 1)

device = self.device
dtype = self.transformer.dtype
# Get dtype from whichever transformer is loaded
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.transformer_2 is not None:
dtype = self.transformer_2.dtype
else:
# Fallback to text_encoder dtype if no transformer loaded
dtype = self.text_encoder.dtype

# Seed / generator
if generator is None:
Expand Down Expand Up @@ -444,7 +475,7 @@ def forward(
image_tensor = image

# Use out_channels for noise latents (not in_channels which includes condition)
num_channels_latents = self.transformer.config.out_channels
num_channels_latents = self.transformer_config.out_channels
batch_size = prompt_embeds.shape[0]

# Prepare noise latents
Expand Down Expand Up @@ -489,7 +520,7 @@ def forward(
first_frame_mask[:, :, 0] = 0
else:
# T2V mode: standard latent preparation
num_channels_latents = self.transformer.config.in_channels
num_channels_latents = self.transformer_config.in_channels
latents = self.prepare_latents(
batch_size=prompt_embeds.shape[0],
num_channels_latents=num_channels_latents,
Expand All @@ -508,19 +539,38 @@ def forward(
# Denoising
for t in timesteps:
self._current_timestep = t
current_model = self.transformer
current_guidance_scale = guidance_low
if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None:
current_model = self.transformer_2

# Select model based on timestep and boundary_ratio
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't understand this. When you don't need the first transformer, how to offload it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When boundary_ratio is set to 1.0, self.transformer will be None. In our current offload logic, None module will be skipped (see

), leaving only self.transformer_2 in the dit_modules list. In this case, the memory-saving strategy still works, because DiT modules (no matter 1 or 2) and encoders are mutual exclusive.

# High noise stage (t >= boundary_timestep): use transformer
# Low noise stage (t < boundary_timestep): use transformer_2
if boundary_timestep is not None and t < boundary_timestep:
# Low noise stage - always use guidance_high for this stage
current_guidance_scale = guidance_high
if self.transformer_2 is not None:
current_model = self.transformer_2
elif self.transformer is not None:
# Fallback to transformer if transformer_2 not loaded
current_model = self.transformer
else:
raise RuntimeError("No transformer available for low-noise stage")
else:
# High noise stage - always use guidance_low for this stage
current_guidance_scale = guidance_low
if self.transformer is not None:
current_model = self.transformer
elif self.transformer_2 is not None:
# Fallback to transformer_2 if transformer not loaded
current_model = self.transformer_2
else:
raise RuntimeError("No transformer available for high-noise stage")

if self.expand_timesteps and latent_condition is not None:
# I2V mode: blend condition with latents using mask
latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
latent_model_input = latent_model_input.to(dtype)

# Expand timesteps per patch - use floor division to match patch embedding
patch_size = self.transformer.config.patch_size
patch_size = self.transformer_config.patch_size
num_latent_frames = latents.shape[2]
patch_height = latents.shape[3] // patch_size[1]
patch_width = latents.shape[4] // patch_size[2]
Expand Down