forked from huggingface/optimum-habana
-
Notifications
You must be signed in to change notification settings - Fork 20
Enable Deepspeed Ulysses for Wan #353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Wei-Lin-Intel
merged 9 commits into
HabanaAI:aice/v1.22.0
from
mengker33:oh_fork_wan_enable_cp
Nov 6, 2025
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
d1dcfdd
Wan: Enable deepspeed ulysses for ti2v pipeline
mengker33 1fd46b8
Deepspeed: Fix uneven head sequence parallelism bug
mengker33 21f1007
README: Add Wan i2v example using deepspeed ulysses
mengker33 9a5fc12
Enable deepspeed ulysses for Wan t2v pipeline
mengker33 4dced85
Add attn_mask when there is padding in cp case
mengker33 86a67e5
README: Add Wan2.2 t2v example
mengker33 6c46e38
Wan: Add traditional SP in wan attention
mengker33 8372385
Wan: Replace vae WanBlockAttention with FusedSDPA
mengker33 c5aeffc
README: Update readme after adding SP support
mengker33 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,8 +22,10 @@ | |
| from diffusers.models.transformers.transformer_wan import WanAttention, _get_added_kv_projections, _get_qkv_projections | ||
| from diffusers.utils import deprecate, logging | ||
| from diffusers.utils.import_utils import is_xformers_available | ||
| from habana_frameworks.torch.hpex.kernels import FusedSDPA | ||
| from torch import nn | ||
|
|
||
| from ...distributed import parallel_state | ||
| from .embeddings import RotaryPosEmbedding | ||
|
|
||
|
|
||
|
|
@@ -206,8 +208,92 @@ def __init__(self, fusedSDPA): | |
| super().__init__() | ||
| self._hpu_kernel_fsdpa = fusedSDPA | ||
|
|
||
| def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): | ||
| return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) | ||
| def forward( | ||
| self, | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| softmax_mode, | ||
| recompute_mode, | ||
| valid_sequence_lengths, | ||
| padding_side="left", | ||
| ): | ||
| query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) | ||
| out = self._hpu_kernel_fsdpa.apply( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| softmax_mode, | ||
| recompute_mode, | ||
| valid_sequence_lengths, | ||
| padding_side, | ||
| ) | ||
| return out.permute(0, 2, 1, 3) | ||
|
|
||
|
|
||
| class GaudiDistributedAttention(torch.nn.Module): | ||
| def __init__(self, hpu_module_fsdpa: ModuleFusedSDPA): | ||
| super().__init__() | ||
| self._hpu_module_fsdpa = hpu_module_fsdpa | ||
| if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: | ||
| from deepspeed.sequence.layer import DistributedAttention | ||
|
|
||
| self._hpu_module_fsdpa_distributed = DistributedAttention( | ||
| self._hpu_module_fsdpa, parallel_state.get_sequence_parallel_group(), 2, 1 | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: torch.Tensor, | ||
| dropout_p: float, | ||
| is_casual, | ||
| scale, | ||
| softmax_mode, | ||
| recompute_mode, | ||
| valid_sequence_lengths, | ||
| padding_side="left", | ||
| ): | ||
| if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: | ||
| return self._hpu_module_fsdpa_distributed( | ||
| query, | ||
| key, | ||
| value, | ||
| 0, # As the shape for inputs is [B, S, N, H] | ||
| None, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_casual, | ||
| scale, | ||
| softmax_mode, | ||
| recompute_mode, | ||
| valid_sequence_lengths, | ||
| padding_side, | ||
| ) | ||
| else: | ||
| return self._hpu_module_fsdpa( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_casual, | ||
| scale, | ||
| softmax_mode, | ||
| recompute_mode, | ||
| valid_sequence_lengths, | ||
| padding_side, | ||
| ) | ||
|
|
||
|
|
||
| class CogVideoXAttnProcessorGaudi: | ||
|
|
@@ -262,17 +348,20 @@ def __call__( | |
|
|
||
| softmax_mode = "None" if attn.training else "fast" | ||
| hidden_states = self.fused_scaled_dot_product_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask=attention_mask, | ||
| dropout_p=0.0, | ||
| is_casual=False, | ||
| scale=None, | ||
| softmax_mode=softmax_mode, | ||
| query.transpose(1, 2).contiguous(), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why CogVideoXAttnProcessorGaudi need to change? Since you put qkv transpose in ModuleFusedSDPA, need to check other function which not use sp or ulysses. |
||
| key.transpose(1, 2).contiguous(), | ||
| value.transpose(1, 2).contiguous(), | ||
| attention_mask, | ||
| 0.0, | ||
| False, | ||
| None, | ||
| softmax_mode, | ||
| False, | ||
| None, | ||
| "None", | ||
| ) | ||
|
|
||
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||
| hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) | ||
|
|
||
| # linear proj | ||
| hidden_states = attn.to_out[0](hidden_states) | ||
|
|
@@ -553,6 +642,18 @@ def __init__(self, is_training=False): | |
| "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." | ||
| ) | ||
| self.is_training = is_training | ||
| self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None | ||
| self.fused_scaled_dot_product_attention_distributed = None | ||
| self.use_sp = os.getenv("USE_SP", "True").lower() not in ("0", "false", "False") | ||
| self.cp_size = parallel_state.get_sequence_parallel_world_size() | ||
|
|
||
| if not self.use_sp and parallel_state.sequence_parallel_is_initialized() \ | ||
| and self.cp_size > 1: | ||
| self.fused_scaled_dot_product_attention_distributed = ( | ||
| GaudiDistributedAttention(self.fused_scaled_dot_product_attention) | ||
| if FusedSDPA | ||
| else None | ||
| ) | ||
|
|
||
| def _native_attention( | ||
| self, | ||
|
|
@@ -565,14 +666,28 @@ def _native_attention( | |
| scale: Optional[float] = None, | ||
| enable_gqa: bool = False, | ||
| ) -> torch.Tensor: | ||
| # apply gaudi fused SDPA | ||
| from habana_frameworks.torch.hpex.kernels import FusedSDPA | ||
|
|
||
| # Fast FSDPA is not supported in training mode | ||
| fsdpa_mode = "None" if self.is_training else "fast" | ||
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) | ||
| out = FusedSDPA.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, fsdpa_mode, None) | ||
| out = out.permute(0, 2, 1, 3) | ||
|
|
||
| if self.fused_scaled_dot_product_attention_distributed: | ||
| out = self.fused_scaled_dot_product_attention_distributed( | ||
|
mengker33 marked this conversation as resolved.
|
||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| 0.0, | ||
| False, | ||
| None, | ||
| fsdpa_mode, | ||
| False, | ||
| None, | ||
| "None", | ||
| ) | ||
| else: | ||
| out = self.fused_scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal, scale, fsdpa_mode, | ||
| False, | ||
| None, | ||
| "None",) | ||
| return out | ||
|
|
||
| def __call__( | ||
|
|
@@ -634,8 +749,39 @@ def apply_rotary_emb( | |
| hidden_states_img = hidden_states_img.flatten(2, 3) | ||
| hidden_states_img = hidden_states_img.type_as(query) | ||
|
|
||
| # Add traditional SP: | ||
| if self.use_sp and self.cp_size > 1: | ||
| bs, kv_seq, num_head, head_dim = key.shape | ||
| key = key.reshape(bs, kv_seq, -1) | ||
| value = value.reshape(bs, kv_seq, -1) | ||
| full_key = torch.empty(bs, kv_seq * self.cp_size, num_head * head_dim, dtype=key.dtype, device=key.device) | ||
| full_value = torch.empty(bs, kv_seq * self.cp_size, num_head * head_dim, dtype=value.dtype, device=value.device) | ||
| gather1 = torch.distributed.all_gather_into_tensor( | ||
| full_key, | ||
| key, | ||
| group=parallel_state.get_sequence_parallel_group(), | ||
| async_op=True, | ||
| ) | ||
| torch.distributed.all_gather_into_tensor( | ||
| full_value, | ||
| value, | ||
| group=parallel_state.get_sequence_parallel_group(), | ||
| async_op=False, | ||
| ) | ||
| gather1.wait() | ||
| key = full_key.reshape(bs, kv_seq * self.cp_size, num_head, head_dim) | ||
| value = full_value.reshape(bs, kv_seq * self.cp_size, num_head, head_dim) | ||
|
|
||
| if attention_mask is not None: | ||
| logger.warning(f"Applying attention_mask in SP is not well supported, set it as None.") | ||
| attention_mask = None | ||
|
|
||
| hidden_states = self._native_attention(query, key, value, attention_mask, 0.0, False, None) | ||
|
|
||
| if self.use_sp and self.cp_size > 1: | ||
| torch.hpu.synchronize() | ||
|
|
||
|
|
||
| hidden_states = hidden_states.flatten(2, 3) | ||
| hidden_states = hidden_states.type_as(query) | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.