Skip to content

Make pairformer_emb_per_sample call into pairformer_emb#223

Merged
christinaflo merged 4 commits into
aqlaboratory:mainfrom
GMNGeoffrey:apply-per-sample
May 20, 2026
Merged

Make pairformer_emb_per_sample call into pairformer_emb#223
christinaflo merged 4 commits into
aqlaboratory:mainfrom
GMNGeoffrey:apply-per-sample

Conversation

@GMNGeoffrey

@GMNGeoffrey GMNGeoffrey commented May 18, 2026

Copy link
Copy Markdown
Contributor

Summary
The current code has some duplication that also causes the two paths to diverge. In particular, the batched path currently flattens the batch dimensions whereas the per_sample path does not. This makes some downstream things confused (the issue I ran into was the chunk tuner expecting the same function to always be called with the same rank inputs, which is probably a bug as well). But also, to the extent that the flattening needs to happen, then I think it should happen for all paths (there's a disadvantage in that it can force implicit broadcasts to become explicit).

Changes

  • Make per_sample_pairformer_emb a thin wrapper around pairformer_emb
  • Add validation that all inputs to PairformerEmbedding have the same number of batch dimensions
  • Fix slicing of x_pred to not assume that num_samples is the 1st dimension
  • Fix slicing of zij in a couple other places that assumed num_samples is the first dimension
  • Add tests for equivalence of per_sample and batched pairformer embedding paths

Related Issues

Testing

This avoids duplicate code paths that slightly diverge. In particular,
the batched path currently flattens the batch dimensions whereas the
per_sample path does not. This makes some downstream things confused.

Includes equivalence tests between the per_sample and batched versions
(these passed with the previous code as well, but weren't tested).

Also adds some invariant checking on the batch dimensions being passed
into the pairformer embedding to avoid future confusion.
@GMNGeoffrey

Copy link
Copy Markdown
Contributor Author

@jnwei @christinaflo PTAL

GMNGeoffrey added a commit to GMNGeoffrey/openfold-3 that referenced this pull request May 18, 2026
I encountered an issue with this when doing
aqlaboratory#213. That enables
`batch > 1` and chunking to happen together. When you run inference with
samples both above and below per_sample_token_cutoff, the small inputs
use the batched/normal pairformer embedding path and flatten their batch
dimensions (i.e. batch and num_samples), but the larger inputs follow
the per-sample path, which *doesn't* flatten its batch dimensions and
results in higher-rank inputs
(aqlaboratory#223 aims to fix this
mismatch).

`_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple
subclass), so when the same `ChunkSizeTuner` instance is invoked with
tensors of different ranks across calls, the inner
`zip(..., strict=True)` raised. Instead treat any length mismatch as a
cache miss so the caller re-tunes instead. The redundant top-level
length assert in `tune_chunk_size` is now handled the same way and
removed.

Also adds dtype element size to the cache key for a tensor argument. We
could use the entire dtype, but I think in terms of what matters for
chunking, the element size is the key factor.
c_s = config.architecture.shared.c_s
c_z = config.architecture.shared.c_z

pair_emb = PairformerEmbedding(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a separate utility, but for asserting equal tensors you have to modify the model weight init scheme like here: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/tests/test_kernels.py#L411
Otherwise, there's a lot of "final" init layers that are zero weight/bias on the first pass and you won't see the diffs come up

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thanks for flagging. I've hit this with other models before, but didn't think about it here. It would be really nice if Torch had some way to let the init function itself declare different options for different purposes

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added non-zero initialization, although I actually found that the outputs weren't zero even without that.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they wouldnt be zero bc they have the initial embeddings, but the updates should be zero if i remember correctly

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnwei we should modify tests to use this util after it's merged in (i.e. kernel and offloading tests)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice addition, thanks for adding this utility and for all of the testing fixes @GMNGeoffrey !

Comment thread openfold3/core/model/heads/prediction_heads.py
Similar to in pairformer embedding, these assumed that num_samples was
the 1st dimension rather than indexing backwards base on the number of
feature dimensions.
@GMNGeoffrey GMNGeoffrey requested a review from christinaflo May 19, 2026 03:56
@christinaflo christinaflo added the safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. label May 20, 2026

@christinaflo christinaflo left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@christinaflo christinaflo merged commit 7305d96 into aqlaboratory:main May 20, 2026
11 checks passed
@GMNGeoffrey GMNGeoffrey deleted the apply-per-sample branch May 21, 2026 21:48
GMNGeoffrey added a commit to GMNGeoffrey/openfold-3 that referenced this pull request May 28, 2026
I encountered an issue with this when doing
aqlaboratory#213. That enables
`batch > 1` and chunking to happen together. When you run inference with
samples both above and below per_sample_token_cutoff, the small inputs
use the batched/normal pairformer embedding path and flatten their batch
dimensions (i.e. batch and num_samples), but the larger inputs follow
the per-sample path, which *doesn't* flatten its batch dimensions and
results in higher-rank inputs
(aqlaboratory#223 aims to fix this
mismatch).

`_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple
subclass), so when the same `ChunkSizeTuner` instance is invoked with
tensors of different ranks across calls, the inner
`zip(..., strict=True)` raised. Instead treat any length mismatch as a
cache miss so the caller re-tunes instead. The redundant top-level
length assert in `tune_chunk_size` is now handled the same way and
removed.

Also adds dtype element size to the cache key for a tensor argument. We
could use the entire dtype, but I think in terms of what matters for
chunking, the element size is the key factor.
GMNGeoffrey added a commit to GMNGeoffrey/openfold-3 that referenced this pull request May 28, 2026
I encountered an issue with this when doing
aqlaboratory#213. That enables
`batch > 1` and chunking to happen together. When you run inference with
samples both above and below per_sample_token_cutoff, the small inputs
use the batched/normal pairformer embedding path and flatten their batch
dimensions (i.e. batch and num_samples), but the larger inputs follow
the per-sample path, which *doesn't* flatten its batch dimensions and
results in higher-rank inputs
(aqlaboratory#223 aims to fix this
mismatch).

`_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple
subclass), so when the same `ChunkSizeTuner` instance is invoked with
tensors of different ranks across calls, the inner
`zip(..., strict=True)` raised. Instead treat any length mismatch as a
cache miss so the caller re-tunes instead. The redundant top-level
length assert in `tune_chunk_size` is now handled the same way and
removed.

Also adds dtype element size to the cache key for a tensor argument. We
could use the entire dtype, but I think in terms of what matters for
chunking, the element size is the key factor.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] per_sample_pairformer_emb passes wrong-shaped x_pred to embed_zij

3 participants