diff --git a/examples/offline_inference/text_to_video/text_to_video.md b/examples/offline_inference/text_to_video/text_to_video.md index cd5bd2be44..b5d0b2adc2 100644 --- a/examples/offline_inference/text_to_video/text_to_video.md +++ b/examples/offline_inference/text_to_video/text_to_video.md @@ -26,7 +26,7 @@ Key arguments: - `--num_frames`: Number of frames (Wan default is 81). - `--guidance_scale` and `--guidance_scale_high`: CFG scale (applied to low/high). - `--negative_prompt`: optional list of artifacts to suppress (the PR demo used a long Chinese string). -- `--boundary_ratio`: Boundary split ratio for low/high DiT. +- `--boundary_ratio`: Boundary split ratio for low/high DiT. Default `0.875` uses both transformers for best quality. Set to `1.0` to load only the low-noise transformer (saves noticeable memory with good quality, recommended if memory is limited). Set to `0.0` loads only the high-noise transformer (not recommended, lower quality). - `--fps`: frames per second for the saved MP4 (requires `diffusers` export_to_video). - `--output`: path to save the generated video. - `--vae_use_slicing`: enable VAE slicing for memory optimization. diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 464ff9ccad..b5c003dfb2 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -32,7 +32,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--width", type=int, default=1280, help="Video width.") parser.add_argument("--num_frames", type=int, default=81, help="Number of frames (Wan default is 81).") parser.add_argument("--num_inference_steps", type=int, default=40, help="Sampling steps.") - parser.add_argument("--boundary_ratio", type=float, default=0.875, help="Boundary split ratio for low/high DiT.") + parser.add_argument( + "--boundary_ratio", + type=float, + default=0.875, + help="Boundary split ratio for low/high DiT. Default 0.875 uses both transformers for best quality. Set to 1.0 to load only the low-noise transformer (saves noticeable memory with good quality, recommended if memory is limited).", + ) parser.add_argument( "--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." ) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 712c054001..562d8eec51 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -227,17 +227,30 @@ def __init__( except Exception: pass + self.boundary_ratio = od_config.boundary_ratio + + # Determine which transformers to load based on boundary_ratio + # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) + # boundary_ratio=0.0: only load transformer (high-noise stage only) + # otherwise: load both transformers + load_transformer = self.boundary_ratio != 1.0 if self.boundary_ratio is not None else True + load_transformer_2 = self.has_transformer_2 and ( + self.boundary_ratio != 0.0 if self.boundary_ratio is not None else True + ) + # Set up weights sources for transformer(s) - self.weights_sources = [ - DiffusersPipelineLoader.ComponentSource( - model_or_path=od_config.model, - subfolder="transformer", - revision=None, - prefix="transformer.", - fall_back_to_pt=True, - ), - ] - if self.has_transformer_2: + self.weights_sources = [] + if load_transformer: + self.weights_sources.append( + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ) + if load_transformer_2: self.weights_sources.append( DiffusersPipelineLoader.ComponentSource( model_or_path=od_config.model, @@ -257,14 +270,26 @@ def __init__( ).to(self.device) # Initialize transformers with correct config (weights loaded via load_weights) - transformer_config = load_transformer_config(model, "transformer", local_files_only) - self.transformer = create_transformer_from_config(transformer_config) - if self.has_transformer_2: + if load_transformer: + transformer_config = load_transformer_config(model, "transformer", local_files_only) + self.transformer = create_transformer_from_config(transformer_config) + else: + self.transformer = None + + if load_transformer_2: transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) self.transformer_2 = create_transformer_from_config(transformer_2_config) else: self.transformer_2 = None + # Store the active transformer config + if load_transformer: + self.transformer_config = self.transformer.config + elif load_transformer_2: + self.transformer_config = self.transformer_2.config + else: + raise RuntimeError("No transformer loaded") + # Initialize UniPC scheduler flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p self.scheduler = FlowUniPCMultistepScheduler( @@ -275,7 +300,6 @@ def __init__( self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 - self.boundary_ratio = od_config.boundary_ratio self._guidance_scale = None self._guidance_scale_2 = None @@ -333,7 +357,7 @@ def forward( # Ensure dimensions are compatible with VAE and patch size # For expand_timesteps mode, we need latent dims to be even (divisible by patch_size) - patch_size = self.transformer.config.patch_size + patch_size = self.transformer_config.patch_size mod_value = self.vae_scale_factor_spatial * patch_size[1] # 16*2=32 for TI2V, 8*2=16 for I2V height = (height // mod_value) * mod_value width = (width // mod_value) * mod_value @@ -374,7 +398,14 @@ def forward( num_frames = max(num_frames, 1) device = self.device - dtype = self.transformer.dtype + # Get dtype from whichever transformer is loaded + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.transformer_2 is not None: + dtype = self.transformer_2.dtype + else: + # Fallback to text_encoder dtype if no transformer loaded + dtype = self.text_encoder.dtype # Seed / generator if generator is None: @@ -444,7 +475,7 @@ def forward( image_tensor = image # Use out_channels for noise latents (not in_channels which includes condition) - num_channels_latents = self.transformer.config.out_channels + num_channels_latents = self.transformer_config.out_channels batch_size = prompt_embeds.shape[0] # Prepare noise latents @@ -489,7 +520,7 @@ def forward( first_frame_mask[:, :, 0] = 0 else: # T2V mode: standard latent preparation - num_channels_latents = self.transformer.config.in_channels + num_channels_latents = self.transformer_config.in_channels latents = self.prepare_latents( batch_size=prompt_embeds.shape[0], num_channels_latents=num_channels_latents, @@ -508,11 +539,30 @@ def forward( # Denoising for t in timesteps: self._current_timestep = t - current_model = self.transformer - current_guidance_scale = guidance_low - if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None: - current_model = self.transformer_2 + + # Select model based on timestep and boundary_ratio + # High noise stage (t >= boundary_timestep): use transformer + # Low noise stage (t < boundary_timestep): use transformer_2 + if boundary_timestep is not None and t < boundary_timestep: + # Low noise stage - always use guidance_high for this stage current_guidance_scale = guidance_high + if self.transformer_2 is not None: + current_model = self.transformer_2 + elif self.transformer is not None: + # Fallback to transformer if transformer_2 not loaded + current_model = self.transformer + else: + raise RuntimeError("No transformer available for low-noise stage") + else: + # High noise stage - always use guidance_low for this stage + current_guidance_scale = guidance_low + if self.transformer is not None: + current_model = self.transformer + elif self.transformer_2 is not None: + # Fallback to transformer_2 if transformer not loaded + current_model = self.transformer_2 + else: + raise RuntimeError("No transformer available for high-noise stage") if self.expand_timesteps and latent_condition is not None: # I2V mode: blend condition with latents using mask @@ -520,7 +570,7 @@ def forward( latent_model_input = latent_model_input.to(dtype) # Expand timesteps per patch - use floor division to match patch embedding - patch_size = self.transformer.config.patch_size + patch_size = self.transformer_config.patch_size num_latent_frames = latents.shape[2] patch_height = latents.shape[3] // patch_size[1] patch_width = latents.shape[4] // patch_size[2]