Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ORTModule] ATen Efficient Attention and Triton Flash Attention #17959

Merged
merged 4 commits into from
Oct 27, 2023

Conversation

centwang
Copy link
Contributor

This PR is to support efficient attention and flash attention in ORTModule, including:

  • Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable.
  • Integrate Triton Flash attention, which requires triton==2.0.0.dev20221202. Need A100 or H100. ORTMODULE_USE_FLASH_ATTENTION=1 to enable.
  • A python transformer tool to match sub-graph by config and write transformer quickly.

Current transformers supports attention mask for both efficient attn and flash attn, and dropout for efficient attn only. To support more training scenarios (such as causal mask in GPT2), more transformers need to be added.

The feature is guarded by system environment variables, it won't effect any current behavior if not enabled. Since it requires specific PyTorch/Triton versions, related tests is not added for now.

@centwang centwang requested a review from askhade October 16, 2023 08:27
Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner found more than 10 potential problems in the proposed changes. Check the Files changed tab for more details.

@centwang centwang merged commit b7408f7 into main Oct 27, 2023
87 of 90 checks passed
@centwang centwang deleted the weicwang/attn branch October 27, 2023 02:29
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
…osoft#17959)

This PR is to support efficient attention and flash attention in
ORTModule, including:
- Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev
or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable.
- Integrate Triton Flash attention, which requires
triton==2.0.0.dev20221202. Need A100 or H100.
ORTMODULE_USE_FLASH_ATTENTION=1 to enable.
- A python transformer tool to match sub-graph by config and write
transformer quickly.

Current transformers supports attention mask for both efficient attn and
flash attn, and dropout for efficient attn only. To support more
training scenarios (such as causal mask in GPT2), more transformers need
to be added.

The feature is guarded by system environment variables, it won't effect
any current behavior if not enabled. Since it requires specific
PyTorch/Triton versions, related tests is not added for now.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants