Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion flagscale/backends/Megatron-LM/gpt_builders.py.patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/gpt_builders.py b/gpt_builders.py
index 89b228815..804676ba1 100644
index 89b228815..9e48bace4 100644
--- a/gpt_builders.py
+++ b/gpt_builders.py
@@ -26,7 +26,14 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None):
Expand All @@ -18,3 +18,11 @@ index 89b228815..804676ba1 100644
if args.use_legacy_models:
model = megatron.legacy.model.GPTModel(
config,
@@ -115,6 +122,7 @@ def _get_transformer_layer_spec(use_te, config):
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
qk_l2_norm=args.qk_l2_norm,
use_kitchen=config.use_kitchen,
+ flex_attention=args.flex_attention,
)
else:
return get_gpt_layer_local_spec(
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py
old mode 100755
new mode 100644
index 68c1eb8c9..1b60a4a13
index 68c1eb8c9..2e3c5966d
--- a/megatron/core/models/gpt/gpt_layer_specs.py
+++ b/megatron/core/models/gpt/gpt_layer_specs.py
@@ -408,6 +408,7 @@ def get_gpt_decoder_block_spec(
@@ -32,6 +32,7 @@ from megatron.core.transformer.transformer_layer import (
TransformerLayerSubmodules,
get_transformer_layer_offset,
)
+from megatron.core.transformer.flex_attention import FlexAttention

try:
import transformer_engine as te # pylint: disable=unused-import
@@ -80,6 +81,7 @@ def get_gpt_layer_with_transformer_engine_spec(
use_te_op_fuser: Optional[bool] = False,
use_kitchen: bool = False,
use_te_activation_func: bool = False,
+ flex_attention: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).

@@ -162,12 +164,13 @@ def get_gpt_layer_with_transformer_engine_spec(
),
)
else:
+ attention_module = FlexAttention if flex_attention else SelfAttention
qk_norm = backend.layer_norm(for_qk=True)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
- module=SelfAttention,
+ module=attention_module,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=backend.column_parallel_layer_norm_linear(),
@@ -408,6 +411,7 @@ def get_gpt_decoder_block_spec(
qk_l2_norm: Optional[bool] = False,
vp_stage: Optional[int] = None,
pp_rank: Optional[int] = None,
+ is_dualpipev_first_chunk: Optional[bool] = False,
) -> TransformerBlockSubmodules:
"""GPT block spec."""
if use_transformer_engine:
@@ -487,7 +488,8 @@ def get_gpt_decoder_block_spec(
@@ -487,7 +491,8 @@ def get_gpt_decoder_block_spec(

# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
Expand All @@ -22,7 +53,7 @@ index 68c1eb8c9..1b60a4a13

if config.pipeline_model_parallel_layout is not None:
local_layer_specs = [
@@ -497,7 +499,8 @@ def get_gpt_decoder_block_spec(
@@ -497,7 +502,8 @@ def get_gpt_decoder_block_spec(
)
]
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
diff --git a/megatron/core/transformer/flex_attention.py b/megatron/core/transformer/flex_attention.py
new file mode 100644
index 000000000..7f51015b9
--- /dev/null
+++ b/megatron/core/transformer/flex_attention.py
@@ -0,0 +1,207 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import NoReturn, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+
+from megatron.core import tensor_parallel
+from megatron.core.inference.contexts import BaseInferenceContext
+from megatron.core.models.common.embeddings.rope_utils import (
+ apply_rotary_pos_emb,
+ apply_rotary_pos_emb_with_cos_sin,
+)
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.parallel_state import (
+ get_context_parallel_group,
+)
+from megatron.core.utils import (
+ deprecate_inference_params,
+ divide,
+ get_pg_size,
+ is_fa_min_version,
+ is_te_min_version,
+ nvtx_range_pop,
+ nvtx_range_push,
+)
+from .enums import AttnMaskType
+from .transformer_config import TransformerConfig
+from megatron.core.process_groups_config import ProcessGroupCollection
+from megatron.core.transformer.attention import SelfAttention
+from megatron.core.transformer.attention import SelfAttentionSubmodules
+
+from megatron.core.transformer.ring_attention import ring_attn
+
+
+class FlexAttention(SelfAttention):
+ def __init__(
+ self,
+ config: TransformerConfig,
+ submodules: SelfAttentionSubmodules,
+ layer_number: int,
+ attn_mask_type=AttnMaskType.padding,
+ cp_comm_type: str = None,
+ pg_collection: ProcessGroupCollection = None,
+ ):
+ super().__init__(
+ config=config,
+ submodules=submodules,
+ layer_number=layer_number,
+ attn_mask_type=attn_mask_type,
+ cp_comm_type=cp_comm_type,
+ pg_collection=pg_collection,
+ )
+
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states: Tensor,
+ attention_mask: Tensor,
+ key_value_states: Optional[Tensor] = None,
+ inference_context: Optional[BaseInferenceContext] = None,
+ rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None,
+ rotary_pos_cos: Optional[Tensor] = None,
+ rotary_pos_sin: Optional[Tensor] = None,
+ attention_bias: Optional[Tensor] = None,
+ packed_seq_params: Optional[PackedSeqParams] = None,
+ sequence_len_offset: Optional[int] = None,
+ *,
+ inference_params: Optional[BaseInferenceContext] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Perform a forward pass through the attention module.
+
+ Args:
+ hidden_states (Tensor): Hidden states.
+ attention_mask (Tensor): Attention mask.
+ key_value_states (Optional[Tensor]): Key/value states (for cross attention).
+ inference_context (Optional[BaseInferenceContext]): Inference context that manages
+ KV cache.
+ rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary
+ embedding tensor(s).
+ rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine.
+ rotary_pos_sin (Optional[Tensor]): Rotary embedding sine.
+ attention_bias (Optional[Tensor]): Attention bias.
+ packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format.
+ sequence_len_offset (Optional[int]): Sequence length offset used for
+ inference CUDA graphs.
+
+ Return:
+ (Tuple[Tensor, Tensor]) Attention output and bias.
+
+ """
+ # Check if we need to skip RoPE
+ # no_rope is 0-indexed array and self.layer_number is 1-indexed
+ assert (
+ attention_mask is not None
+ ), "Flex attention is used for customed attention mask, which must be provided"
+
+ no_rope = (
+ self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False
+ )
+ if no_rope:
+ rotary_pos_emb = None
+
+ inference_context = deprecate_inference_params(inference_context, inference_params)
+
+ if inference_context and inference_context.is_dynamic_batching():
+ assert HAVE_FA3 or is_fa_min_version(
+ "2.7.3"
+ ), "flash attn verion v2.7.3 and above is required for dynamic batching."
Comment on lines +117 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable HAVE_FA3 is used here but it's not defined or imported, which will lead to a NameError at runtime. It seems you intended to check for the availability of flash-attention v3.

You should define HAVE_FA3 at the top of this file, similar to how it's done in megatron/core/transformer/attention.py, by adding the following try-except block:

try:
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_fa3
    from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_func_fa3

    HAVE_FA3 = True
except ImportError:
    HAVE_FA3 = False

+
+ # hidden_states: [sq, b, h]
+ if self.config.flash_decode and not self.training and inference_context is not None:
+ rotary_pos_emb = None
+ else:
+ assert rotary_pos_cos is None and rotary_pos_sin is None
+
+ # For self attention we just duplicate the rotary_pos_emb if it isn't already
+ if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
+ rotary_pos_emb = (rotary_pos_emb,) * 2
+
+ # =====================
+ # Query, Key, and Value
+ # =====================
+ # Get the query, key and value tensors based on the type of attention -
+ # self or cross attn.
+ nvtx_range_push(suffix="qkv")
+ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
+ nvtx_range_pop(suffix="qkv")
+
+
+ # ================================================
+ # relative positional embedding (rotary embedding)
+ # ================================================
+ nvtx_range_push(suffix="rotary_pos_emb")
+ if rotary_pos_emb is not None and not self.config.flash_decode:
+ q_pos_emb, k_pos_emb = rotary_pos_emb
+
+ if packed_seq_params is not None:
+ if packed_seq_params.cu_seqlens_q_padded is not None:
+ cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded
+ else:
+ cu_seqlens_q = packed_seq_params.cu_seqlens_q
+ if packed_seq_params.cu_seqlens_kv_padded is not None:
+ cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded
+ else:
+ cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
+ else:
+ cu_seqlens_q = cu_seqlens_kv = None
+
+ if q_pos_emb is not None:
+ # TODO VIJAY: simplify
+ if inference_context is None or inference_context.is_static_batching():
+ query = apply_rotary_pos_emb(
+ query,
+ q_pos_emb,
+ config=self.config,
+ cu_seqlens=cu_seqlens_q,
+ cp_group=self.pg_collection.cp,
+ )
+ else:
+ query = inference_context.apply_rotary_emb_query(
+ query, q_pos_emb, self.config, cu_seqlens_q, self.pg_collection.cp
+ )
+ if k_pos_emb is not None:
+ key = apply_rotary_pos_emb(
+ key,
+ k_pos_emb,
+ config=self.config,
+ cu_seqlens=cu_seqlens_kv,
+ cp_group=self.pg_collection.cp,
+ )
+
+ # TODO, can apply positional embedding to value_layer so it has
+ # absolute positional embedding.
+ # otherwise, only relative positional embedding takes effect
+ # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
+ nvtx_range_pop(suffix="rotary_pos_emb")
+
+ # ==================================
+ # flex attention computation
+ # ==================================
+ query = query.permute(
+ 1, 2, 0, 3
+ ).contiguous() # [seq_len, batch_size, num_heads/tp, head_dim] -> [batch_size, num_heads/tp, seq_len, head_dim]
+ key = key.permute(1, 2, 0, 3).contiguous()
+ value = value.permute(1, 2, 0, 3).contiguous()
+ attention_mask = attention_mask.contiguous()
+ core_attn_out, _ = ring_attn(
+ query, key, value, attention_mask, group=get_context_parallel_group(),
+ )
+ core_attn_out = core_attn_out.permute(
+ 2, 0, 1, 3
+ ).contiguous()
+ core_attn_out = core_attn_out.view(core_attn_out.shape[0], core_attn_out.shape[1], -1)
+
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output, bias = self.linear_proj(core_attn_out)
+
+ return output, bias
Loading
Loading