Skip to content

Add checks against expanded pair bias#228

Merged
christinaflo merged 2 commits into
aqlaboratory:mainfrom
GMNGeoffrey:pair-bias-broadcast-asserts
May 28, 2026
Merged

Add checks against expanded pair bias#228
christinaflo merged 2 commits into
aqlaboratory:mainfrom
GMNGeoffrey:pair-bias-broadcast-asserts

Conversation

@GMNGeoffrey

@GMNGeoffrey GMNGeoffrey commented May 21, 2026

Copy link
Copy Markdown
Contributor

Summary
The Triton kernel assumes that pair bias has size 1 in the N_seq dimension and does an implicit broadcast (pair_bias.stride(1) is not even passed to the kernel), but this isn't actually checked. This is the reason chunking is turned off when batch size > 1 and we're using optimized kernels, because the chunker expands the implicit broadcast dimensions. Added asserts for this and documented the issue.

Changes

  • Assert on expected bias shapes in Triton triangle attention
  • document why chunking gets turned off in prediction_heads.py when using optimized kernels
  • Update tests for Triton kernels in test_kernels.py to avoid batch_size > 1 + chunking (was only set for deepspeed). Cuequivariance has a workaround, so its tests still use the larger batch.

Related Issues

Testing

  • Updated the tests to not use batch size > 1 (existing tests would have thrown)
  • I don't think a test checking for the assert is very useful

Other Notes
This, along with #226, contains the parts of #207 that I think we still want even if the approach in #213 is preferred overall. If we merge this, then I'll close #207 and iterate on #213 instead.

Assert on bias shapes in Triton triangle attention. Notably that the
second dim of pair bias has size 1. This was implicitly assumed with
`pair_bias.stride(1)` not even passed into the kernel.

Document that this is why chunking gets turned off in
prediction_heads.py when using optimized kernels. The current chunking
expands the broadcast dimension.

Update tests for all optimized kernels in test_kernels.py to avoid
batch_size > 1 + chunking (currently only set for deepspeed).
@GMNGeoffrey

Copy link
Copy Markdown
Contributor Author

@christinaflo PTAL

@christinaflo christinaflo added the safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. label May 21, 2026
Comment thread openfold3/tests/test_kernels.py Outdated
There's a workaround in place for the cueq kernel.
@GMNGeoffrey GMNGeoffrey requested a review from christinaflo May 21, 2026 04:33

@christinaflo christinaflo left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@christinaflo christinaflo merged commit 4e1add7 into aqlaboratory:main May 28, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants