Skip to content
Merged
2 changes: 1 addition & 1 deletion DeepSpeed/deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
# we only need num_heads once
num_heads = input.shape[2]

if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and not scatter_idx < 2):
Comment thread
mengker33 marked this conversation as resolved.
# Assuming here that the number of heads for q is consistent with kv
# If not, additional logic is required for cases like GQA
if get_num_kv_heads() is None:
Expand Down
40 changes: 40 additions & 0 deletions examples/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,25 @@ python image_to_video_generation.py \
--bf16
```

For multi-cards inference, we support both traditional sequence parallelism (SP) and DeepSpeed Ulysses to accelerate the inference. Traditional SP is the default one, if you want to apply DeepSpeed Ulysses, please set USE_SP=0.
When using DeepSpeed Ulysses, please set PT_HPU_SYNC_LAUNCH=1 to reduce memory consumption. Besides, to slightly improve the accuracy in DeepSpeed Ulysses case, you can enable the mask by setting CP_USE_MASK=1.

```bash
PT_HPU_LAZY_MODE=1 \
python ../gaudi_spawn.py --world_size 2 image_to_video_generation.py \
--model_name_or_path "Wan-AI/Wan2.2-TI2V-5B-Diffusers" \
--image_path "https://raw.githubusercontent.com/Wan-Video/Wan2.2/main/examples/i2v_input.JPG" \
--video_save_dir ./wan2.2-output \
--prompts "The cat removes the glasses from its eyes." \
--use_habana \
--height 1088 \
--width 800 \
--fps 24 \
--num_frames 121 \
--bf16 \
--context_parallel_size 2
```

### Text-to-Video with Wan 2.2
Wan2.2 is a comprehensive and open suite of video foundation models. Please refer to [Huggingface Wan2.2 doc](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B)

Expand All @@ -502,6 +521,27 @@ python text_to_video_generation.py \
--dtype bf16
```

For multi-cards inference, we support both traditional sequence parallelism (SP) and DeepSpeed Ulysses to accelerate the inference. Traditional SP is the default one, if you want to apply DeepSpeed Ulysses, please set USE_SP=0.
When using DeepSpeed Ulysses, please set PT_HPU_SYNC_LAUNCH=1 to reduce memory consumption. Besides, to slightly improve the accuracy in DeepSpeed Ulysses case, you can enable the mask by setting CP_USE_MASK=1.

```bash
PT_HPU_LAZY_MODE=1 \
python ../gaudi_spawn.py --world_size 2 text_to_video_generation.py \
--model_name_or_path "Wan-AI/Wan2.2-TI2V-5B-Diffusers" \
--prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--pipeline_type wan \
--num_videos_per_prompt 1 \
--use_habana \
--height 704 \
--width 1280 \
--num_frames 121 \
--num_inference_steps 50 \
--guidance_scale 5.0 \
--output_type mp4 \
--dtype bf16 \
--context_parallel_size 2
```

### Text-to-Video with CogvideoX

CogVideoX is an open-source version of the video generation model originating from QingYing, unveiled in https://huggingface.co/THUDM/CogVideoX-5b.
Expand Down
17 changes: 17 additions & 0 deletions examples/stable-diffusion/image_to_video_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
GaudiStableVideoDiffusionPipeline,
GaudiWanImageToVideoPipeline,
)
from optimum.habana.distributed import parallel_state
from optimum.habana.transformers.gaudi_configuration import GaudiConfig
from optimum.habana.utils import set_seed

Expand Down Expand Up @@ -224,6 +225,14 @@ def main():
default=None,
help="Number of steps to ignore for throughput calculation.",
)

parser.add_argument(
"--context_parallel_size",
type=int,
default=1,
help="Determines how many ranks are divided into context parallel group.",
)

args = parser.parse_args()

# Setup logging
Expand Down Expand Up @@ -295,6 +304,14 @@ def main():
if args.bf16:
kwargs["torch_dtype"] = torch.bfloat16

if args.context_parallel_size > 1 and parallel_state.is_unitialized():
Comment thread
mengker33 marked this conversation as resolved.
if not torch.distributed.is_initialized():
import deepspeed

torch.distributed.init_process_group(backend="hccl")
deepspeed.init_distributed(dist_backend="hccl")
parallel_state.initialize_model_parallel(sequence_parallel_size=args.context_parallel_size, use_fp8=False)

if args.control_image_path is not None:
from optimum.habana.diffusers import GaudiStableVideoDiffusionControlNetPipeline
from optimum.habana.diffusers.models import ControlNetSDVModel, UNetSpatioTemporalConditionControlNetModel
Expand Down
15 changes: 15 additions & 0 deletions examples/stable-diffusion/text_to_video_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from diffusers.utils.export_utils import export_to_video

from optimum.habana.diffusers import GaudiCogVideoXPipeline, GaudiTextToVideoSDPipeline, GaudiWanPipeline
from optimum.habana.distributed import parallel_state
from optimum.habana.transformers.gaudi_configuration import GaudiConfig
from optimum.habana.utils import set_seed

Expand Down Expand Up @@ -135,6 +136,12 @@ def main():
default="./generated-videos",
help="The directory where videos will be saved.",
)
parser.add_argument(
"--context_parallel_size",
type=int,
default=1,
help="Determines how many ranks are divided into context parallel group.",
)

parser.add_argument("--seed", type=int, default=42, help="Random seed for initialization.")

Expand Down Expand Up @@ -183,6 +190,14 @@ def main():
elif args.dtype == "fp32":
kwargs["torch_dtype"] = torch.float32

if args.context_parallel_size > 1 and parallel_state.is_unitialized():
if not torch.distributed.is_initialized():
import deepspeed

torch.distributed.init_process_group(backend="hccl")
deepspeed.init_distributed(dist_backend="hccl")
parallel_state.initialize_model_parallel(sequence_parallel_size=args.context_parallel_size, use_fp8=False)

# Generate images
if args.pipeline_type == "stable_diffusion":
pipeline: GaudiTextToVideoSDPipeline = GaudiTextToVideoSDPipeline.from_pretrained(
Expand Down
180 changes: 163 additions & 17 deletions optimum/habana/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from diffusers.models.transformers.transformer_wan import WanAttention, _get_added_kv_projections, _get_qkv_projections
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_xformers_available
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from torch import nn

from ...distributed import parallel_state
from .embeddings import RotaryPosEmbedding


Expand Down Expand Up @@ -206,8 +208,92 @@ def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)
def forward(
self,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
out = self._hpu_kernel_fsdpa.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)
return out.permute(0, 2, 1, 3)


class GaudiDistributedAttention(torch.nn.Module):
def __init__(self, hpu_module_fsdpa: ModuleFusedSDPA):
super().__init__()
self._hpu_module_fsdpa = hpu_module_fsdpa
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
from deepspeed.sequence.layer import DistributedAttention

self._hpu_module_fsdpa_distributed = DistributedAttention(
self._hpu_module_fsdpa, parallel_state.get_sequence_parallel_group(), 2, 1
)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor,
dropout_p: float,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
return self._hpu_module_fsdpa_distributed(
query,
key,
value,
0, # As the shape for inputs is [B, S, N, H]
None,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)
else:
return self._hpu_module_fsdpa(
query,
key,
value,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)


class CogVideoXAttnProcessorGaudi:
Expand Down Expand Up @@ -262,17 +348,20 @@ def __call__(

softmax_mode = "None" if attn.training else "fast"
hidden_states = self.fused_scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_casual=False,
scale=None,
softmax_mode=softmax_mode,
query.transpose(1, 2).contiguous(),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why CogVideoXAttnProcessorGaudi need to change? Since you put qkv transpose in ModuleFusedSDPA, need to check other function which not use sp or ulysses.

key.transpose(1, 2).contiguous(),
value.transpose(1, 2).contiguous(),
attention_mask,
0.0,
False,
None,
softmax_mode,
False,
None,
"None",
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down Expand Up @@ -553,6 +642,18 @@ def __init__(self, is_training=False):
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
self.is_training = is_training
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
self.fused_scaled_dot_product_attention_distributed = None
self.use_sp = os.getenv("USE_SP", "True").lower() not in ("0", "false", "False")
self.cp_size = parallel_state.get_sequence_parallel_world_size()

if not self.use_sp and parallel_state.sequence_parallel_is_initialized() \
and self.cp_size > 1:
self.fused_scaled_dot_product_attention_distributed = (
GaudiDistributedAttention(self.fused_scaled_dot_product_attention)
if FusedSDPA
else None
)

def _native_attention(
self,
Expand All @@ -565,14 +666,28 @@ def _native_attention(
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
# apply gaudi fused SDPA
from habana_frameworks.torch.hpex.kernels import FusedSDPA

# Fast FSDPA is not supported in training mode
fsdpa_mode = "None" if self.is_training else "fast"
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = FusedSDPA.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, fsdpa_mode, None)
out = out.permute(0, 2, 1, 3)

if self.fused_scaled_dot_product_attention_distributed:
out = self.fused_scaled_dot_product_attention_distributed(
Comment thread
mengker33 marked this conversation as resolved.
query,
key,
value,
attn_mask,
0.0,
False,
None,
fsdpa_mode,
False,
None,
"None",
)
else:
out = self.fused_scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal, scale, fsdpa_mode,
False,
None,
"None",)
return out

def __call__(
Expand Down Expand Up @@ -634,8 +749,39 @@ def apply_rotary_emb(
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)

# Add traditional SP:
if self.use_sp and self.cp_size > 1:
bs, kv_seq, num_head, head_dim = key.shape
key = key.reshape(bs, kv_seq, -1)
value = value.reshape(bs, kv_seq, -1)
full_key = torch.empty(bs, kv_seq * self.cp_size, num_head * head_dim, dtype=key.dtype, device=key.device)
full_value = torch.empty(bs, kv_seq * self.cp_size, num_head * head_dim, dtype=value.dtype, device=value.device)
gather1 = torch.distributed.all_gather_into_tensor(
full_key,
key,
group=parallel_state.get_sequence_parallel_group(),
async_op=True,
)
torch.distributed.all_gather_into_tensor(
full_value,
value,
group=parallel_state.get_sequence_parallel_group(),
async_op=False,
)
gather1.wait()
key = full_key.reshape(bs, kv_seq * self.cp_size, num_head, head_dim)
value = full_value.reshape(bs, kv_seq * self.cp_size, num_head, head_dim)

if attention_mask is not None:
logger.warning(f"Applying attention_mask in SP is not well supported, set it as None.")
attention_mask = None

hidden_states = self._native_attention(query, key, value, attention_mask, 0.0, False, None)

if self.use_sp and self.cp_size > 1:
torch.hpu.synchronize()


hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)

Expand Down
Loading