From 80b784294557289acd8db572ee327d212cc0fc07 Mon Sep 17 00:00:00 2001 From: Huy Vu <86480512+huvunvidia@users.noreply.github.com> Date: Tue, 10 Mar 2026 23:12:41 -0400 Subject: [PATCH 1/4] Update transformer_engine.py --- megatron/core/extensions/transformer_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index e901e40597a..544f2fc092e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2575,6 +2575,7 @@ def fused_apply_rotary_pos_emb_thd( freqs: torch.Tensor, cp_size: int = 1, cp_rank: int = 0, + interleaved: bool = False, ) -> torch.Tensor: """ Apply rotary positional embedding to input tensor T in `thd` format with CP support. @@ -2588,6 +2589,7 @@ def fused_apply_rotary_pos_emb_thd( cu_seqlens=cu_seqlens, cp_size=cp_size, cp_rank=cp_rank, + interleaved=interleaved, ) else: assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP." From 37add724acf5851dbc7d7a811aae3450ddfd1c60 Mon Sep 17 00:00:00 2001 From: Huy Vu <86480512+huvunvidia@users.noreply.github.com> Date: Tue, 10 Mar 2026 23:13:36 -0400 Subject: [PATCH 2/4] Add interleaved parameter to fused_apply_rotary_pos_emb_thd --- megatron/core/models/common/embeddings/rope_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/megatron/core/models/common/embeddings/rope_utils.py b/megatron/core/models/common/embeddings/rope_utils.py index e39540eb1d1..2fd19194813 100644 --- a/megatron/core/models/common/embeddings/rope_utils.py +++ b/megatron/core/models/common/embeddings/rope_utils.py @@ -288,7 +288,12 @@ def apply_rotary_pos_emb( else: assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available." return fused_apply_rotary_pos_emb_thd( - t, cu_seqlens, freqs, cp_size=cp_group.size(), cp_rank=cp_group.rank() + t, + cu_seqlens, + freqs, + cp_size=cp_group.size(), + cp_rank=cp_group.rank(), + interleaved=config.rotary_interleaved, ) # use unfused implementation if cu_seqlens is None: From b8d4e9aaa9c082d223273c06f093481eea38d984 Mon Sep 17 00:00:00 2001 From: Huy Vu <86480512+huvunvidia@users.noreply.github.com> Date: Thu, 12 Mar 2026 23:11:24 -0400 Subject: [PATCH 3/4] Add interleaved support for fused RoPE in TE --- megatron/core/extensions/transformer_engine.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 544f2fc092e..89313b0aaae 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2580,7 +2580,10 @@ def fused_apply_rotary_pos_emb_thd( """ Apply rotary positional embedding to input tensor T in `thd` format with CP support. """ - if is_te_min_version("1.12.0", check_equality=True): + if interleaved: + assert is_te_min_version("2.3.0"), "Only TE >= 2.3.0 supports interleaved fused RoPE." + + if is_te_min_version("2.3.0", check_equality=True): return apply_rotary_pos_emb( t, freqs, @@ -2590,6 +2593,16 @@ def fused_apply_rotary_pos_emb_thd( cp_size=cp_size, cp_rank=cp_rank, interleaved=interleaved, + ) + elif is_te_min_version("1.12.0", check_equality=True): + return apply_rotary_pos_emb( + t, + freqs, + tensor_format="thd", + fused=True, + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, ) else: assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP." From 9a96cf975ad48cc11cfe4d04592aa07eeb38188a Mon Sep 17 00:00:00 2001 From: Huy Vu <86480512+huvunvidia@users.noreply.github.com> Date: Thu, 12 Mar 2026 23:15:06 -0400 Subject: [PATCH 4/4] Update transformer_engine.py --- megatron/core/extensions/transformer_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 89313b0aaae..78454a6fbf4 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2593,7 +2593,7 @@ def fused_apply_rotary_pos_emb_thd( cp_size=cp_size, cp_rank=cp_rank, interleaved=interleaved, - ) + ) elif is_te_min_version("1.12.0", check_equality=True): return apply_rotary_pos_emb( t,