Skip to content

Conversation

@lxd-cumt
Copy link
Collaborator

@lxd-cumt lxd-cumt commented Nov 6, 2025

PR Category

Train

PR Types

New Features

PR Description

  • support flex attention with context parallel, which enables dynamic attention mask in training. Config flex_attention: true for using.
  • add a unified RingAttention interface, to support dynamic attention mechanisms, such as sparse attention, which can be implemented by simply replacing the forward_impl and backward_impl in RingAttention.
  • supports more optimization strategies, such as comp/comm overlapping by gpu multi-stream, pruning redundant computation for sparse attention, and load-balancing optimizations for CP-rank computation.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for custom attention mechanisms, specifically FlexAttention and RingAttention, to enable context parallelism with dynamic attention masks. The changes include new modules for these attention mechanisms, modifications to the GPT model builder to incorporate them, and new command-line arguments for configuration.

My review identifies a critical issue that would cause a runtime error, a high-risk dependency on internal PyTorch APIs, and a maintainability concern regarding a magic number. Addressing these points will improve the robustness and maintainability of the new features.

Comment on lines +117 to +119
+ assert HAVE_FA3 or is_fa_min_version(
+ "2.7.3"
+ ), "flash attn verion v2.7.3 and above is required for dynamic batching."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable HAVE_FA3 is used here but it's not defined or imported, which will lead to a NameError at runtime. It seems you intended to check for the availability of flash-attention v3.

You should define HAVE_FA3 at the top of this file, similar to how it's done in megatron/core/transformer/attention.py, by adding the following try-except block:

try:
    from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_fa3
    from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_func_fa3

    HAVE_FA3 = True
except ImportError:
    HAVE_FA3 = False

Comment on lines +16 to +19
+from torch._higher_order_ops.flex_attention import (
+ sdpa_dense_backward,
+ create_fw_bw_graph,
+)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation relies on internal PyTorch APIs (sdpa_dense_backward, create_fw_bw_graph from torch._higher_order_ops.flex_attention). These APIs are not part of the public contract, are undocumented, and can change or be removed without notice in future PyTorch releases. This poses a significant maintenance risk and could break the functionality unexpectedly upon a PyTorch upgrade.

It is highly recommended to either:

  1. Find an alternative implementation that uses public PyTorch APIs.
  2. If no alternative exists, add a prominent warning in the documentation and code about this dependency, and ideally add version checks for PyTorch to catch breaking changes early.

+ assert not args.multi_latent_attention, "FlexAttention with Multi Latent Attention is not supported"
+ assert args.transformer_impl == 'transformer_engine', \
+ 'FlexAttention is only supported with transformer_engine implementation'
+ assert args.seq_length % (args.context_parallel_size * 128) == 0, "seq_length must be divisible by context_parallel_size*flex_attn_block_size(128)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 128 is a magic number. Although the comment in the assertion message mentions it's flex_attn_block_size, it's better to define it as a constant to improve readability and maintainability. This avoids having to search for this hardcoded value if the block size ever needs to be updated.

Consider defining it as a constant at the top of the file or function and using the constant here.

        FLEX_ATTN_BLOCK_SIZE = 128
        assert args.seq_length % (args.context_parallel_size * FLEX_ATTN_BLOCK_SIZE) == 0, f"seq_length must be divisible by context_parallel_size*flex_attn_block_size({FLEX_ATTN_BLOCK_SIZE})"

@lxd-cumt lxd-cumt changed the title Support custom attention with context parallel [Train] Support custom attention with context parallel Nov 6, 2025
Copy link
Collaborator

@zhaoyinglia zhaoyinglia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add memory performance results in pr description

+ assert not args.multi_latent_attention, "FlexAttention with Multi Latent Attention is not supported"
+ assert args.transformer_impl == 'transformer_engine', \
+ 'FlexAttention is only supported with transformer_engine implementation'
+ assert args.seq_length % (args.context_parallel_size * 128) == 0, "seq_length must be divisible by context_parallel_size*flex_attn_block_size(128)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dose this assert need to consider tensor parallel?

+ return mask_mod
+
+
+def _flex_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Draw a sketch to show how you chunk q/k/v and how the communication works.

@CLAassistant
Copy link

CLAassistant commented Nov 18, 2025

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
0 out of 2 committers have signed the CLA.

❌ mslv
❌ lxd-cumt


mslv seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants