diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 20f0ece635e..f82dabcb618 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2385,6 +2385,7 @@ def fused_apply_rotary_pos_emb_thd( freqs: torch.Tensor, cp_size: int = 1, cp_rank: int = 0, + interleaved: bool = False, ) -> torch.Tensor: """ Apply rotary positional embedding to input tensor T in `thd` format with CP support. @@ -2398,6 +2399,7 @@ def fused_apply_rotary_pos_emb_thd( cu_seqlens=cu_seqlens, cp_size=cp_size, cp_rank=cp_rank, + interleaved=interleaved, ) else: assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP." diff --git a/megatron/core/models/common/embeddings/rope_utils.py b/megatron/core/models/common/embeddings/rope_utils.py index 0e00c6340ed..b990615da29 100644 --- a/megatron/core/models/common/embeddings/rope_utils.py +++ b/megatron/core/models/common/embeddings/rope_utils.py @@ -313,7 +313,12 @@ def apply_rotary_pos_emb( else: assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available." return fused_apply_rotary_pos_emb_thd( - t, cu_seqlens, freqs, cp_size=cp_group.size(), cp_rank=cp_group.rank() + t, + cu_seqlens, + freqs, + cp_size=cp_group.size(), + cp_rank=cp_group.rank(), + interleaved=config.rotary_interleaved, ) # use unfused implementation if cu_seqlens is None: