-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[GPT-NeoX] Add SDPA support
#31031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPT-NeoX] Add SDPA support
#31031
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bigger Models needed to match in generation. Same as in Llama, not sure if that's an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exchanged the old attention mask implementation here. Can revert but thought it kept things cleaner.
|
Ran [
{
"task_env": "padding_side=left, use_mask=True, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.02271,
"std_of_mean_differences": 0.03539,
"min_of_mean_differences": 0.00616,
"max_of_mean_differences": 0.3262,
"total_fails": 87
},
{
"task_env": "padding_side=right, use_mask=True, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.02072,
"std_of_mean_differences": 0.01623,
"min_of_mean_differences": 0.00693,
"max_of_mean_differences": 0.1079,
"total_fails": 86
},
{
"task_env": "padding_side=left, use_mask=False, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.02172,
"std_of_mean_differences": 0.01623,
"min_of_mean_differences": 0.01025,
"max_of_mean_differences": 0.09912,
"total_fails": 80
},
{
"task_env": "padding_side=right, use_mask=False, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.02172,
"std_of_mean_differences": 0.01623,
"min_of_mean_differences": 0.01025,
"max_of_mean_differences": 0.09912,
"total_fails": 80
},
{
"task_env": "padding_side=left, use_mask=False, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.03034,
"std_of_mean_differences": 0.05975,
"min_of_mean_differences": 0.00812,
"max_of_mean_differences": 0.2949,
"total_fails": 21
},
{
"task_env": "padding_side=right, use_mask=False, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.03034,
"std_of_mean_differences": 0.05975,
"min_of_mean_differences": 0.00812,
"max_of_mean_differences": 0.2949,
"total_fails": 21
},
{
"task_env": "padding_side=left, use_mask=False, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00205,
"std_of_mean_differences": 0.0008,
"min_of_mean_differences": 0.00095,
"max_of_mean_differences": 0.00397,
"total_fails": 17
},
{
"task_env": "padding_side=right, use_mask=False, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00205,
"std_of_mean_differences": 0.0008,
"min_of_mean_differences": 0.00095,
"max_of_mean_differences": 0.00397,
"total_fails": 17
},
{
"task_env": "padding_side=right, use_mask=True, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.01624,
"std_of_mean_differences": 0.00679,
"min_of_mean_differences": 0.00653,
"max_of_mean_differences": 0.02954,
"total_fails": 17
},
{
"task_env": "padding_side=left, use_mask=True, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.0026,
"std_of_mean_differences": 0.0014,
"min_of_mean_differences": 0.00119,
"max_of_mean_differences": 0.00635,
"total_fails": 15
},
{
"task_env": "padding_side=right, use_mask=True, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00245,
"std_of_mean_differences": 0.00108,
"min_of_mean_differences": 0.00119,
"max_of_mean_differences": 0.00473,
"total_fails": 15
},
{
"task_env": "padding_side=left, use_mask=True, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
"mean_of_mean_differences": 0.01135,
"std_of_mean_differences": 0.00861,
"min_of_mean_differences": 0.00629,
"max_of_mean_differences": 0.03516,
"total_fails": 9
},
{
"task_env": "padding_side=left, use_mask=False, batch_size=1, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00454,
"std_of_mean_differences": 0.00205,
"min_of_mean_differences": 0.002,
"max_of_mean_differences": 0.00702,
"total_fails": 3
},
{
"task_env": "padding_side=right, use_mask=False, batch_size=1, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.00454,
"std_of_mean_differences": 0.00205,
"min_of_mean_differences": 0.002,
"max_of_mean_differences": 0.00702,
"total_fails": 3
},
{
"task_env": "padding_side=left, use_mask=True, batch_size=1, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
"mean_of_mean_differences": 0.002,
"std_of_mean_differences": 0.0,
"min_of_mean_differences": 0.002,
"max_of_mean_differences": 0.002,
"total_fails": 1
}
]The tests seem pretty flaky but |
|
@fxmarty Could you do a first review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've generalised the projections and rope application. A lot of stuff is being repeated among all the attn implementations. Can revert if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
padding_mask is never used in any attn implementation but to keep things common between the different attns, I added it to the method signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed it from the function signature 👍
|
The new commits are all cosmetic/refactoring stuff which doesn't affect the logic except for the additional |
|
Further removed a wrong artefact in the code copied from llama.
|
fxmarty
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
added head mask check to sdpa mask creation handle sdpa memory backend bug via own version flag
fix flash_attn_2 stuff
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great - thanks for all the work adding this!
All LGTM, would just like a second look from @ArthurZucker regarding the ROPE scaling logic
|
Before merge, we'll need to do a run on all the slow tests for the model. Pushing a commit with the message |
|
Should be committed with the msg. |
|
@vasqu Thanks for pushing and again for adding this capability. All looks good - let's merge 🤗 |
What does this PR do?
Adds torch's SDPA to the
GPT-NeoXmodel architecture (as another attention module). Possibly relevant #28005.Added benchmarks based on @fxmarty scripts @ training and inference. Setup: rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04 using
float16with pythia-410m-deduped.Training results:
Inference results:
Remaining relevant issues:
RUN_SLOW=True pytest tests/models/gpt_neox -k "test_eager_matches_sdpa_inference" -s -vvvvvleads to failures onbf16--> not sure if it's an implementation issue on my side or because the models' rope is naturally stored in fp32? Help/Advice appreciated. Edit: See comment below for a summary of captured tests that fail.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@fxmarty @ArthurZucker @amyeroberts