Skip to content

Conversation

@vasqu
Copy link
Contributor

@vasqu vasqu commented May 25, 2024

What does this PR do?

Adds torch's SDPA to the GPT-NeoX model architecture (as another attention module). Possibly relevant #28005.

Added benchmarks based on @fxmarty scripts @ training and inference. Setup: rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04 using float16 with pythia-410m-deduped.

Training results:

batch_size seq_len Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
1 128 0.024 0.019 28.945 1789.95 1789.95 0
1 256 0.039 0.031 23.18 1845.83 1844.84 0.053
1 512 0.08 0.055 45.524 2278.38 1953.76 16.615
1 1024 0.19 0.102 86.777 4772.36 2408.35 98.159
1 2048 0.565 0.204 177.098 13484.1 3882.01 247.348
2 128 0.037 0.032 15.121 1843.86 1844.78 -0.05
2 256 0.067 0.055 21.706 1999.72 1951.67 2.462
2 512 0.144 0.096 50.046 3613.16 2406.77 50.125
2 1024 0.366 0.193 89.666 8707.55 3878.86 124.487
2 2048 OOM 0.379 / OOM 6825.13 SDPA does not OOM
4 128 0.06 0.054 11.539 1947.6 1952.06 -0.228
4 256 0.119 0.093 28.072 3008.39 2405.99 25.038
4 512 0.275 0.187 47.145 6290.58 3877.29 62.242
4 1024 OOM 0.36 / OOM 6821.98 SDPA does not OOM
4 2048 OOM 0.731 / OOM 12705.1 SDPA does not OOM

Inference results:

batch_size seq_len Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
1 128 6.569 5.858 12.14 974.831 974.826 0
1 256 7.009 5.863 19.542 1029.01 1028.08 0.09
1 512 7.157 5.965 19.983 1137.54 1137.52 0.001
1 1024 7.523 6.506 15.637 1329.3 1329.26 0.003
1 2048 9.271 9.205 0.713 1752.47 1734.51 1.036
2 128 7.239 5.959 21.493 1044.8 1028.37 1.597
2 256 7.228 6.036 19.757 1167.32 1137.73 2.601
2 512 7.538 6.693 12.628 1352.93 1329.55 1.758
2 1024 8.916 8.632 3.291 1752.56 1734.62 1.034
2 2048 12.628 12.606 0.181 2558.72 2545.8 0.508
4 128 7.278 6.046 20.373 1168.41 1137.79 2.691
4 256 7.614 6.588 15.574 1353.1 1329.79 1.753
4 512 8.798 8.144 8.028 1752.76 1734.85 1.032
4 1024 11.765 11.303 4.09 2558.96 2546.04 0.508
4 2048 19.568 17.735 10.33 4175.5 4165.26 0.246

Remaining relevant issues:

  • RUN_SLOW=True pytest tests/models/gpt_neox -k "test_eager_matches_sdpa_inference" -s -vvvvv leads to failures on bf16 --> not sure if it's an implementation issue on my side or because the models' rope is naturally stored in fp32? Help/Advice appreciated. Edit: See comment below for a summary of captured tests that fail.
  • Might've missed some docs :D just went ahead with it.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@fxmarty @ArthurZucker @amyeroberts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bigger Models needed to match in generation. Same as in Llama, not sure if that's an issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exchanged the old attention mask implementation here. Can revert but thought it kept things cleaner.

@vasqu
Copy link
Contributor Author

vasqu commented May 26, 2024

Ran RUN_SLOW=True pytest tests/models/gpt_neox -k "test_eager_matches_sdpa_inference" -s -vvvvv 100 times (setup's the same as above) and captured the failures caused at bf16 (fp16 and fp32 stay consistent without any failures):

[
    {
        "task_env": "padding_side=left, use_mask=True, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
        "mean_of_mean_differences": 0.02271,
        "std_of_mean_differences": 0.03539,
        "min_of_mean_differences": 0.00616,
        "max_of_mean_differences": 0.3262,
        "total_fails": 87
    },
    {
        "task_env": "padding_side=right, use_mask=True, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
        "mean_of_mean_differences": 0.02072,
        "std_of_mean_differences": 0.01623,
        "min_of_mean_differences": 0.00693,
        "max_of_mean_differences": 0.1079,
        "total_fails": 86
    },
    {
        "task_env": "padding_side=left, use_mask=False, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
        "mean_of_mean_differences": 0.02172,
        "std_of_mean_differences": 0.01623,
        "min_of_mean_differences": 0.01025,
        "max_of_mean_differences": 0.09912,
        "total_fails": 80
    },
    {
        "task_env": "padding_side=right, use_mask=False, batch_size=5, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
        "mean_of_mean_differences": 0.02172,
        "std_of_mean_differences": 0.01623,
        "min_of_mean_differences": 0.01025,
        "max_of_mean_differences": 0.09912,
        "total_fails": 80
    },
    {
        "task_env": "padding_side=left, use_mask=False, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
        "mean_of_mean_differences": 0.03034,
        "std_of_mean_differences": 0.05975,
        "min_of_mean_differences": 0.00812,
        "max_of_mean_differences": 0.2949,
        "total_fails": 21
    },
    {
        "task_env": "padding_side=right, use_mask=False, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
        "mean_of_mean_differences": 0.03034,
        "std_of_mean_differences": 0.05975,
        "min_of_mean_differences": 0.00812,
        "max_of_mean_differences": 0.2949,
        "total_fails": 21
    },
    {
        "task_env": "padding_side=left, use_mask=False, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
        "mean_of_mean_differences": 0.00205,
        "std_of_mean_differences": 0.0008,
        "min_of_mean_differences": 0.00095,
        "max_of_mean_differences": 0.00397,
        "total_fails": 17
    },
    {
        "task_env": "padding_side=right, use_mask=False, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
        "mean_of_mean_differences": 0.00205,
        "std_of_mean_differences": 0.0008,
        "min_of_mean_differences": 0.00095,
        "max_of_mean_differences": 0.00397,
        "total_fails": 17
    },
    {
        "task_env": "padding_side=right, use_mask=True, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
        "mean_of_mean_differences": 0.01624,
        "std_of_mean_differences": 0.00679,
        "min_of_mean_differences": 0.00653,
        "max_of_mean_differences": 0.02954,
        "total_fails": 17
    },
    {
        "task_env": "padding_side=left, use_mask=True, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
        "mean_of_mean_differences": 0.0026,
        "std_of_mean_differences": 0.0014,
        "min_of_mean_differences": 0.00119,
        "max_of_mean_differences": 0.00635,
        "total_fails": 15
    },
    {
        "task_env": "padding_side=right, use_mask=True, batch_size=5, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
        "mean_of_mean_differences": 0.00245,
        "std_of_mean_differences": 0.00108,
        "min_of_mean_differences": 0.00119,
        "max_of_mean_differences": 0.00473,
        "total_fails": 15
    },
    {
        "task_env": "padding_side=left, use_mask=True, batch_size=1, enable_kernels=True, torch_atol=0.01, torch_rtol=0.03, cuda=Yes",
        "mean_of_mean_differences": 0.01135,
        "std_of_mean_differences": 0.00861,
        "min_of_mean_differences": 0.00629,
        "max_of_mean_differences": 0.03516,
        "total_fails": 9
    },
    {
        "task_env": "padding_side=left, use_mask=False, batch_size=1, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
        "mean_of_mean_differences": 0.00454,
        "std_of_mean_differences": 0.00205,
        "min_of_mean_differences": 0.002,
        "max_of_mean_differences": 0.00702,
        "total_fails": 3
    },
    {
        "task_env": "padding_side=right, use_mask=False, batch_size=1, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
        "mean_of_mean_differences": 0.00454,
        "std_of_mean_differences": 0.00205,
        "min_of_mean_differences": 0.002,
        "max_of_mean_differences": 0.00702,
        "total_fails": 3
    },
    {
        "task_env": "padding_side=left, use_mask=True, batch_size=1, enable_kernels=False, torch_atol=0.01, torch_rtol=0.01, cuda=Maybe",
        "mean_of_mean_differences": 0.002,
        "std_of_mean_differences": 0.0,
        "min_of_mean_differences": 0.002,
        "max_of_mean_differences": 0.002,
        "total_fails": 1
    }
]

The tests seem pretty flaky but bf16 is consistent regarding failing in at least one case. Possibly, needs bigger bounds and/or an overwritten test in GPT-NeoX. But maybe I also messed up somewhere.

@amyeroberts
Copy link
Contributor

@fxmarty Could you do a first review?

Comment on lines 171 to 172
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've generalised the projections and rope application. A lot of stuff is being repeated among all the attn implementations. Can revert if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

padding_mask is never used in any attn implementation but to keep things common between the different attns, I added it to the method signature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu I think you can actually remove the padding_mask argument from the GPTNeoXAttention.forward method. This one is a leftover from the previous flash attention implementation in Transformers. It is not used anymore (and has actually never been in 9270ab0), so it is fine to remove.

Copy link
Contributor Author

@vasqu vasqu Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed it from the function signature 👍

@vasqu
Copy link
Contributor Author

vasqu commented May 28, 2024

The new commits are all cosmetic/refactoring stuff which doesn't affect the logic except for the additional head_mask check for sdpa (in the attn mask creation).

@vasqu vasqu mentioned this pull request Jun 1, 2024
5 tasks
@vasqu
Copy link
Contributor Author

vasqu commented Jun 1, 2024

Further removed a wrong artefact in the code copied from llama.

  • Updated failing tests summary above (bf16 issue persists).
  • Other tests pass (i.e. via RUN_SLOW=True pytest tests/models/gpt_neox -s -vvvvv).
  • Benchmarks should stay around the same so not updated.

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu I think you can actually remove the padding_mask argument from the GPTNeoXAttention.forward method. This one is a leftover from the previous flash attention implementation in Transformers. It is not used anymore (and has actually never been in 9270ab0), so it is fine to remove.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great - thanks for all the work adding this!

All LGTM, would just like a second look from @ArthurZucker regarding the ROPE scaling logic

@amyeroberts
Copy link
Contributor

Before merge, we'll need to do a run on all the slow tests for the model. Pushing a commit with the message [run_slow] gpt_neox should trigger a run (the workflow will need to be approved by someone from HF)

@vasqu
Copy link
Contributor Author

vasqu commented Jun 25, 2024

Should be committed with the msg.

@amyeroberts
Copy link
Contributor

@vasqu Thanks for pushing and again for adding this capability. All looks good - let's merge 🤗

@amyeroberts amyeroberts merged commit b07770c into huggingface:main Jun 26, 2024
@vasqu vasqu deleted the gptneox-sdpa branch June 26, 2024 14:41
@avishaiElmakies avishaiElmakies mentioned this pull request Sep 25, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants