Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...hooks.context_parallel import EquipartitionSharder
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
Expand Down Expand Up @@ -660,6 +661,15 @@ def forward(
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
if ts_seq_len is not None:
# Check if running under context parallel and split along seq_len dimension
if hasattr(self, '_parallel_config') and self._parallel_config is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate why this is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when cp is enabled, seq_len is split, timestep_shape is [batch_size, seq_len, 6, inner_dim], so should be split in dim_1 as well since hidden state is split in seq_len dim as well. or else shape miss match will occur

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@sywangyi sywangyi Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean split timestep in forward? adding
"": {
"timestep": ContextParallelInput(split_dim=1, split_output=False)
}, to _cp_plan will make 5B work, but 14B fail since 5B timestep dims is 2. 14 timestep dims is 1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is an interesting situation. To tackle these, I think we might have to revisit the ContextParallelInput and ContextParallelOutput definitions a bit more.

If we had a way to tell the partitioner that the input might have "dynamic" dimensions depending on the model configs (like in this case), and what it should do if that's the case, it might be more flexible as a solution.

@DN6 curious to know what you think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm since this appears to be a bit of an edge case, I think we add something like this to the init

 if out_channels == 48:
     self._cp_plan[""] = {"timestep": ContextParallelInput(split_dim=1, split_output=False)}

Alternatively, we could add a callback to ContextParallelInput that splits the tensors based on a condition.

e.g.

"timestep": ContextParallelInput(split_dim=1, split_condition_fn: lambda x: x.dim == 2, split_output=False)

and then in _prepare_cp_input

    def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
        if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
            raise ValueError(
                f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
            )
        if cp_input.split_condition_fn(x):
            return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
        else:
            return x

potentially, we can add cp_plan directly as a config, allow model owner to overridee it I think (in this case, we would send a PR into wan repo, i think it'd be ok)

I'm not so sure about including it in the model config. You can currently override it through the enable_parallelism method

cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so sure about including it in the model config. You can currently override it through the enable_parallelism method

WDYT about allowing cp_plan from the config itself? The same config can also be passed through from_pretrained().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah cp_plan included in ContextParallelConfig is fine 👍🏽

Copy link
Contributor Author

@sywangyi sywangyi Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about change like this

    def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
        if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
            logger.warning(
                f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
            )
            return x
        else:
            return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)

than add
"": {
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False)
},

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sywangyi that change seems a little inexplicit to me, which is maybe not desirable as it can have silent consequences.

What do you think of passing the latter bit (timestep related definition) as an input to the ContextParallelConfig and then passing it to from_pretrained()? Because if we can make it work that way, it would be more flexible, as we have been discussing.

cp_config = getattr(self._parallel_config, 'context_parallel_config', None)
if cp_config is not None and cp_config._world_size > 1:
timestep_proj = EquipartitionSharder.shard(
timestep_proj,
dim=1,
mesh=cp_config._flattened_mesh
)
# batch_size, seq_len, 6, inner_dim
timestep_proj = timestep_proj.unflatten(2, (6, -1))
else:
Expand All @@ -681,6 +691,15 @@ def forward(

# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
# Check if running under context parallel and split along seq_len dimension
if hasattr(self, '_parallel_config') and self._parallel_config is not None:
cp_config = getattr(self._parallel_config, 'context_parallel_config', None)
if cp_config is not None and cp_config._world_size > 1:
temb = EquipartitionSharder.shard(
temb,
dim=1,
mesh=cp_config._flattened_mesh
)
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
Expand Down