From deb7edb3a4ee92f3c5e0344fcb8657d646df5628 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 30 Oct 2025 13:07:11 +0800 Subject: [PATCH 1/3] fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled Signed-off-by: Wang, Yi --- .../models/transformers/transformer_wan.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index dd75fb124f1a..38ba7d64c424 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...hooks.context_parallel import EquipartitionSharder from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph @@ -660,6 +661,15 @@ def forward( timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len ) if ts_seq_len is not None: + # Check if running under context parallel and split along seq_len dimension + if hasattr(self, '_parallel_config') and self._parallel_config is not None: + cp_config = getattr(self._parallel_config, 'context_parallel_config', None) + if cp_config is not None and cp_config._world_size > 1: + timestep_proj = EquipartitionSharder.shard( + timestep_proj, + dim=1, + mesh=cp_config._flattened_mesh + ) # batch_size, seq_len, 6, inner_dim timestep_proj = timestep_proj.unflatten(2, (6, -1)) else: @@ -681,6 +691,15 @@ def forward( # 5. Output norm, projection & unpatchify if temb.ndim == 3: + # Check if running under context parallel and split along seq_len dimension + if hasattr(self, '_parallel_config') and self._parallel_config is not None: + cp_config = getattr(self._parallel_config, 'context_parallel_config', None) + if cp_config is not None and cp_config._world_size > 1: + temb = EquipartitionSharder.shard( + temb, + dim=1, + mesh=cp_config._flattened_mesh + ) # batch_size, seq_len, inner_dim (wan 2.2 ti2v) shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) shift = shift.squeeze(2) From a732b15855a5963b924f54154fd117c5d17c30ff Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 6 Nov 2025 02:12:41 -0800 Subject: [PATCH 2/3] address review comment Signed-off-by: Wang, Yi A --- src/diffusers/hooks/context_parallel.py | 8 ++++--- .../models/transformers/transformer_wan.py | 22 +++---------------- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 915fe453b90b..26aea9267e23 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -203,10 +203,12 @@ def post_forward(self, module, output): def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: - raise ValueError( - f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." + logger.warning_once( + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions. split will not be applied" ) - return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + return x + else: + return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) class ContextParallelGatherHook(ModelHook): diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 38ba7d64c424..6f3993eb3f64 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -20,7 +20,6 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...hooks.context_parallel import EquipartitionSharder from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph @@ -556,6 +555,9 @@ class WanTransformer3DModel( "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + "": { + "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, } @register_to_config @@ -661,15 +663,6 @@ def forward( timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len ) if ts_seq_len is not None: - # Check if running under context parallel and split along seq_len dimension - if hasattr(self, '_parallel_config') and self._parallel_config is not None: - cp_config = getattr(self._parallel_config, 'context_parallel_config', None) - if cp_config is not None and cp_config._world_size > 1: - timestep_proj = EquipartitionSharder.shard( - timestep_proj, - dim=1, - mesh=cp_config._flattened_mesh - ) # batch_size, seq_len, 6, inner_dim timestep_proj = timestep_proj.unflatten(2, (6, -1)) else: @@ -691,15 +684,6 @@ def forward( # 5. Output norm, projection & unpatchify if temb.ndim == 3: - # Check if running under context parallel and split along seq_len dimension - if hasattr(self, '_parallel_config') and self._parallel_config is not None: - cp_config = getattr(self._parallel_config, 'context_parallel_config', None) - if cp_config is not None and cp_config._world_size > 1: - temb = EquipartitionSharder.shard( - temb, - dim=1, - mesh=cp_config._flattened_mesh - ) # batch_size, seq_len, inner_dim (wan 2.2 ti2v) shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) shift = shift.squeeze(2) From b3a0f80acf13dcd51737e2c623886e653c9d7279 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 6 Nov 2025 02:15:45 -0800 Subject: [PATCH 3/3] refine Signed-off-by: Wang, Yi A --- src/diffusers/hooks/context_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 26aea9267e23..6491d17b4f46 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -204,7 +204,7 @@ def post_forward(self, module, output): def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: logger.warning_once( - f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions. split will not be applied" + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied." ) return x else: