-
Notifications
You must be signed in to change notification settings - Fork 124
[Train] Support custom attention with context parallel #910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for custom attention mechanisms, specifically FlexAttention and RingAttention, to enable context parallelism with dynamic attention masks. The changes include new modules for these attention mechanisms, modifications to the GPT model builder to incorporate them, and new command-line arguments for configuration.
My review identifies a critical issue that would cause a runtime error, a high-risk dependency on internal PyTorch APIs, and a maintainability concern regarding a magic number. Addressing these points will improve the robustness and maintainability of the new features.
| + assert HAVE_FA3 or is_fa_min_version( | ||
| + "2.7.3" | ||
| + ), "flash attn verion v2.7.3 and above is required for dynamic batching." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable HAVE_FA3 is used here but it's not defined or imported, which will lead to a NameError at runtime. It seems you intended to check for the availability of flash-attention v3.
You should define HAVE_FA3 at the top of this file, similar to how it's done in megatron/core/transformer/attention.py, by adding the following try-except block:
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_fa3
from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_func_fa3
HAVE_FA3 = True
except ImportError:
HAVE_FA3 = False| +from torch._higher_order_ops.flex_attention import ( | ||
| + sdpa_dense_backward, | ||
| + create_fw_bw_graph, | ||
| +) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation relies on internal PyTorch APIs (sdpa_dense_backward, create_fw_bw_graph from torch._higher_order_ops.flex_attention). These APIs are not part of the public contract, are undocumented, and can change or be removed without notice in future PyTorch releases. This poses a significant maintenance risk and could break the functionality unexpectedly upon a PyTorch upgrade.
It is highly recommended to either:
- Find an alternative implementation that uses public PyTorch APIs.
- If no alternative exists, add a prominent warning in the documentation and code about this dependency, and ideally add version checks for PyTorch to catch breaking changes early.
| + assert not args.multi_latent_attention, "FlexAttention with Multi Latent Attention is not supported" | ||
| + assert args.transformer_impl == 'transformer_engine', \ | ||
| + 'FlexAttention is only supported with transformer_engine implementation' | ||
| + assert args.seq_length % (args.context_parallel_size * 128) == 0, "seq_length must be divisible by context_parallel_size*flex_attn_block_size(128)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value 128 is a magic number. Although the comment in the assertion message mentions it's flex_attn_block_size, it's better to define it as a constant to improve readability and maintainability. This avoids having to search for this hardcoded value if the block size ever needs to be updated.
Consider defining it as a constant at the top of the file or function and using the constant here.
FLEX_ATTN_BLOCK_SIZE = 128
assert args.seq_length % (args.context_parallel_size * FLEX_ATTN_BLOCK_SIZE) == 0, f"seq_length must be divisible by context_parallel_size*flex_attn_block_size({FLEX_ATTN_BLOCK_SIZE})"
zhaoyinglia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add memory performance results in pr description
| + assert not args.multi_latent_attention, "FlexAttention with Multi Latent Attention is not supported" | ||
| + assert args.transformer_impl == 'transformer_engine', \ | ||
| + 'FlexAttention is only supported with transformer_engine implementation' | ||
| + assert args.seq_length % (args.context_parallel_size * 128) == 0, "seq_length must be divisible by context_parallel_size*flex_attn_block_size(128)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dose this assert need to consider tensor parallel?
| + return mask_mod | ||
| + | ||
| + | ||
| +def _flex_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Draw a sketch to show how you chunk q/k/v and how the communication works.
|
mslv seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
PR Category
Train
PR Types
New Features
PR Description