-
Notifications
You must be signed in to change notification settings - Fork 836
[Wan2.2] Optimize memory usage with conditional transformer loading #980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b7cc39e
56e9f56
0e69f80
481e75e
639c8d6
138cbe5
e314487
6230103
2e6ad08
b1dcb3d
ae45aa3
c1902c1
d24283b
b60b95d
a4cc3ac
9c49cce
67045dc
109ad84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -227,17 +227,30 @@ def __init__( | |||
| except Exception: | ||||
| pass | ||||
|
|
||||
| self.boundary_ratio = od_config.boundary_ratio | ||||
|
|
||||
| # Determine which transformers to load based on boundary_ratio | ||||
| # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) | ||||
| # boundary_ratio=0.0: only load transformer (high-noise stage only) | ||||
| # otherwise: load both transformers | ||||
| load_transformer = self.boundary_ratio != 1.0 if self.boundary_ratio is not None else True | ||||
| load_transformer_2 = self.has_transformer_2 and ( | ||||
| self.boundary_ratio != 0.0 if self.boundary_ratio is not None else True | ||||
| ) | ||||
|
|
||||
| # Set up weights sources for transformer(s) | ||||
| self.weights_sources = [ | ||||
| DiffusersPipelineLoader.ComponentSource( | ||||
| model_or_path=od_config.model, | ||||
| subfolder="transformer", | ||||
| revision=None, | ||||
| prefix="transformer.", | ||||
| fall_back_to_pt=True, | ||||
| ), | ||||
| ] | ||||
| if self.has_transformer_2: | ||||
| self.weights_sources = [] | ||||
| if load_transformer: | ||||
| self.weights_sources.append( | ||||
| DiffusersPipelineLoader.ComponentSource( | ||||
| model_or_path=od_config.model, | ||||
| subfolder="transformer", | ||||
| revision=None, | ||||
| prefix="transformer.", | ||||
| fall_back_to_pt=True, | ||||
| ) | ||||
| ) | ||||
| if load_transformer_2: | ||||
| self.weights_sources.append( | ||||
| DiffusersPipelineLoader.ComponentSource( | ||||
| model_or_path=od_config.model, | ||||
|
|
@@ -257,14 +270,26 @@ def __init__( | |||
| ).to(self.device) | ||||
|
|
||||
| # Initialize transformers with correct config (weights loaded via load_weights) | ||||
| transformer_config = load_transformer_config(model, "transformer", local_files_only) | ||||
| self.transformer = create_transformer_from_config(transformer_config) | ||||
| if self.has_transformer_2: | ||||
| if load_transformer: | ||||
| transformer_config = load_transformer_config(model, "transformer", local_files_only) | ||||
| self.transformer = create_transformer_from_config(transformer_config) | ||||
| else: | ||||
| self.transformer = None | ||||
|
|
||||
| if load_transformer_2: | ||||
| transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) | ||||
| self.transformer_2 = create_transformer_from_config(transformer_2_config) | ||||
| else: | ||||
| self.transformer_2 = None | ||||
|
|
||||
| # Store the active transformer config | ||||
| if load_transformer: | ||||
| self.transformer_config = self.transformer.config | ||||
| elif load_transformer_2: | ||||
| self.transformer_config = self.transformer_2.config | ||||
| else: | ||||
| raise RuntimeError("No transformer loaded") | ||||
|
|
||||
| # Initialize UniPC scheduler | ||||
| flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p | ||||
| self.scheduler = FlowUniPCMultistepScheduler( | ||||
|
|
@@ -275,7 +300,6 @@ def __init__( | |||
|
|
||||
| self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 | ||||
| self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 | ||||
| self.boundary_ratio = od_config.boundary_ratio | ||||
|
|
||||
| self._guidance_scale = None | ||||
| self._guidance_scale_2 = None | ||||
|
|
@@ -333,7 +357,7 @@ def forward( | |||
|
|
||||
| # Ensure dimensions are compatible with VAE and patch size | ||||
| # For expand_timesteps mode, we need latent dims to be even (divisible by patch_size) | ||||
| patch_size = self.transformer.config.patch_size | ||||
| patch_size = self.transformer_config.patch_size | ||||
|
faaany marked this conversation as resolved.
|
||||
| mod_value = self.vae_scale_factor_spatial * patch_size[1] # 16*2=32 for TI2V, 8*2=16 for I2V | ||||
| height = (height // mod_value) * mod_value | ||||
| width = (width // mod_value) * mod_value | ||||
|
|
@@ -374,7 +398,14 @@ def forward( | |||
| num_frames = max(num_frames, 1) | ||||
|
|
||||
| device = self.device | ||||
| dtype = self.transformer.dtype | ||||
| # Get dtype from whichever transformer is loaded | ||||
| if self.transformer is not None: | ||||
| dtype = self.transformer.dtype | ||||
| elif self.transformer_2 is not None: | ||||
| dtype = self.transformer_2.dtype | ||||
| else: | ||||
| # Fallback to text_encoder dtype if no transformer loaded | ||||
| dtype = self.text_encoder.dtype | ||||
|
|
||||
| # Seed / generator | ||||
| if generator is None: | ||||
|
|
@@ -444,7 +475,7 @@ def forward( | |||
| image_tensor = image | ||||
|
|
||||
| # Use out_channels for noise latents (not in_channels which includes condition) | ||||
| num_channels_latents = self.transformer.config.out_channels | ||||
| num_channels_latents = self.transformer_config.out_channels | ||||
| batch_size = prompt_embeds.shape[0] | ||||
|
|
||||
| # Prepare noise latents | ||||
|
|
@@ -489,7 +520,7 @@ def forward( | |||
| first_frame_mask[:, :, 0] = 0 | ||||
| else: | ||||
| # T2V mode: standard latent preparation | ||||
| num_channels_latents = self.transformer.config.in_channels | ||||
| num_channels_latents = self.transformer_config.in_channels | ||||
| latents = self.prepare_latents( | ||||
| batch_size=prompt_embeds.shape[0], | ||||
| num_channels_latents=num_channels_latents, | ||||
|
|
@@ -508,19 +539,38 @@ def forward( | |||
| # Denoising | ||||
| for t in timesteps: | ||||
| self._current_timestep = t | ||||
| current_model = self.transformer | ||||
| current_guidance_scale = guidance_low | ||||
| if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None: | ||||
| current_model = self.transformer_2 | ||||
|
|
||||
| # Select model based on timestep and boundary_ratio | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I didn't understand this. When you don't need the first transformer, how to offload it?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When vllm-omni/vllm_omni/diffusion/offload.py Line 151 in c4220f0
self.transformer_2 in the dit_modules list. In this case, the memory-saving strategy still works, because DiT modules (no matter 1 or 2) and encoders are mutual exclusive.
|
||||
| # High noise stage (t >= boundary_timestep): use transformer | ||||
| # Low noise stage (t < boundary_timestep): use transformer_2 | ||||
| if boundary_timestep is not None and t < boundary_timestep: | ||||
| # Low noise stage - always use guidance_high for this stage | ||||
| current_guidance_scale = guidance_high | ||||
| if self.transformer_2 is not None: | ||||
| current_model = self.transformer_2 | ||||
| elif self.transformer is not None: | ||||
| # Fallback to transformer if transformer_2 not loaded | ||||
| current_model = self.transformer | ||||
| else: | ||||
| raise RuntimeError("No transformer available for low-noise stage") | ||||
| else: | ||||
| # High noise stage - always use guidance_low for this stage | ||||
| current_guidance_scale = guidance_low | ||||
| if self.transformer is not None: | ||||
| current_model = self.transformer | ||||
| elif self.transformer_2 is not None: | ||||
| # Fallback to transformer_2 if transformer not loaded | ||||
| current_model = self.transformer_2 | ||||
| else: | ||||
| raise RuntimeError("No transformer available for high-noise stage") | ||||
|
|
||||
| if self.expand_timesteps and latent_condition is not None: | ||||
| # I2V mode: blend condition with latents using mask | ||||
| latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents | ||||
| latent_model_input = latent_model_input.to(dtype) | ||||
|
|
||||
| # Expand timesteps per patch - use floor division to match patch embedding | ||||
| patch_size = self.transformer.config.patch_size | ||||
| patch_size = self.transformer_config.patch_size | ||||
| num_latent_frames = latents.shape[2] | ||||
| patch_height = latents.shape[3] // patch_size[1] | ||||
| patch_width = latents.shape[4] // patch_size[2] | ||||
|
|
||||
Uh oh!
There was an error while loading. Please reload this page.