Add checks against expanded pair bias#228
Merged
christinaflo merged 2 commits intoMay 28, 2026
Merged
Conversation
Assert on bias shapes in Triton triangle attention. Notably that the second dim of pair bias has size 1. This was implicitly assumed with `pair_bias.stride(1)` not even passed into the kernel. Document that this is why chunking gets turned off in prediction_heads.py when using optimized kernels. The current chunking expands the broadcast dimension. Update tests for all optimized kernels in test_kernels.py to avoid batch_size > 1 + chunking (currently only set for deepspeed).
Contributor
Author
|
@christinaflo PTAL |
There's a workaround in place for the cueq kernel.
This was referenced May 28, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The Triton kernel assumes that pair bias has size 1 in the N_seq dimension and does an implicit broadcast (
pair_bias.stride(1)is not even passed to the kernel), but this isn't actually checked. This is the reason chunking is turned off when batch size > 1 and we're using optimized kernels, because the chunker expands the implicit broadcast dimensions. Added asserts for this and documented the issue.Changes
Related Issues
Testing
Other Notes
This, along with #226, contains the parts of #207 that I think we still want even if the approach in #213 is preferred overall. If we merge this, then I'll close #207 and iterate on #213 instead.