-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Support fused masking in Attention #1924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Adding a whole sludge of numbers here: |
|
Key highlights are in the longer sequences, for example at head dim 128 The |
|
Do you mind sharing labels for those columns? The numbers look amazing 🚀 but I'm trying to understand the nuances a bit more. |
angeloskath
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀🚀🚀
Looks great and the results are sweet!
Yes, I thought they copied over So the table would be |
|
Hm the test failure is weird, can you check it is a numerical tolerance issue and maybe set a fixed seed ? After that can't wait for you to merge :-) |
It goes away after re runs - I probably just need to a numerical seed to fix it up |
* Update API to allow mask='causal' in fast::sdpa * Add fallback * Update steel::AttnParams * Fix typo * WIP, basic causal * Update tests * Update benchmarking * Update masking loop limits * Add bool masking and update tests * Update additive mask * Update benchmarks * Update benchmarks * Update tests * Update for bfloat error * Update early exit * Add random seed to tests
Proposed changes
fast::scaled_dot_product_attentionto take variant for maskChecklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes