Skip to content

Avoid flattening and expanding broadcast dimensions when chunking#213

Open
GMNGeoffrey wants to merge 6 commits into
aqlaboratory:mainfrom
GMNGeoffrey:no-flatten-chunking
Open

Avoid flattening and expanding broadcast dimensions when chunking#213
GMNGeoffrey wants to merge 6 commits into
aqlaboratory:mainfrom
GMNGeoffrey:no-flatten-chunking

Conversation

@GMNGeoffrey

@GMNGeoffrey GMNGeoffrey commented May 5, 2026

Copy link
Copy Markdown
Contributor

Summary
This is an alternative, more complicated fix for #206 instead of #207. To avoid the expansion of an implicit broadcast dimension (e.g. in pair bias), if one exists it chunks across each batch dimension discretely. If the innermost batch dimension (usually N_seq) is larger than chunk_size then this is equivalent to splitting per sample. If it's smaller though, then we might fit multiple samples in a single chunk.

So as an example when we have diffusion_samples 5 and sequence length 76 in the pairformer embedding we have batch dims [5, 76] for QKV but for the pair bias it's [5, 1]. Rather than flattening to [380], we chunk each dimension separately. If chunk size were 128 then we'd split into 5 chunks of [1, 76]; if it were 64, then we'd split into 5 chunks of [1, 64] and 5 chunks of [1, 12]; if chunk size were 256, then we'd split into chunks of [3, 76] and [2, 76], etc.

Doing this in the chunk_layer function allows us to be chunk-size-aware, which we can't be at higher levels when chunk size has potentially not yet been determined due to chunk size tuning. I think this could completely replace apply_per_sample and per_sample_token_cutoff as well and is more robust.

Changes

  • Add new function for dimension-aware chunking employed when inputs would be broadcast in their batch dimensions
  • Remove workarounds that avoid chunking when using optimized kernels and batch size > 1
  • Allow kernel tests to use batch_size > 1

Related Issues

Testing

  • Unit tests for chunking logic
  • Integration tests for kernels now work with batch==2

Other Notes

- Avoid weird addition of 4 to power-of-two chunk sizes. This was added
  in aqlaboratory@a9a12890d without
  explanation. We can hypothesize that it was related to adding 4 to an
  input dimension in trace_utils.py (trying to get a test case to fit in
  one chunk?), but that file was long ago deleted. This just looks like
  a bug and makes us hit unhappy paths all over the place. Fixes
  aqlaboratory#203
- Enable chunking for AuxiliaryHeadsAllAtom pairformer embedding when
  using optimized kernels. Without chunking, this is the first call to
  cause OOMs because its `diffusion_samples*sequence_length` batches.
  Chunking gets turned off in prediction_heads.py due to batch size > 1
  and use of optimized kernels because cross-sample chunking requires
  expanding out pair bias and they all require it to have size 1 in the
  second dimension with implicit broadcasting. So we turn on
  `apply_per_sample` when optimized kernels are in use. This splits the
  > 1 batch dimension, which avoids this problematic path and then we
  can do normal chunking for the rest if it's still too large. We could
  do something more elaborate (see suggestions in linked issue), but
  this is an improvement for now. Fixes
  aqlaboratory#206
With the introduction of dim-aware chunking, these can work fine.
GMNGeoffrey added a commit to GMNGeoffrey/openfold-3 that referenced this pull request May 18, 2026
I encountered an issue with this when doing
aqlaboratory#213. That enables
`batch > 1` and chunking to happen together. When you run inference with
samples both above and below per_sample_token_cutoff, the small inputs
use the batched/normal pairformer embedding path and flatten their batch
dimensions (i.e. batch and num_samples), but the larger inputs follow
the per-sample path, which *doesn't* flatten its batch dimensions and
results in higher-rank inputs
(aqlaboratory#223 aims to fix this
mismatch).

`_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple
subclass), so when the same `ChunkSizeTuner` instance is invoked with
tensors of different ranks across calls, the inner
`zip(..., strict=True)` raised. Instead treat any length mismatch as a
cache miss so the caller re-tunes instead. The redundant top-level
length assert in `tune_chunk_size` is now handled the same way and
removed.

Also adds dtype element size to the cache key for a tensor argument. We
could use the entire dtype, but I think in terms of what matters for
chunking, the element size is the key factor.
@christinaflo

christinaflo commented May 20, 2026

Copy link
Copy Markdown
Collaborator

I am in favor of this kind of fix vs #207, I always wanted something like this to avoid the workarounds but never got to it lol, thank you for this! Would you be able to do the performance testing? I'm wondering how high it could go compared to apply_per_sample + offloading.

@GMNGeoffrey

Copy link
Copy Markdown
Contributor Author

Yeah I'm happy to do some performance testing :-) Would be helpful to merge my other chunking PRs so I can get this on a stable base.

This solution still does leave something to be desired because a chunk of 1024 when it includes the batch dimension is not the same as one where it's just the sequence count, as pair bias scales with the former but not the latter. But I guess that's ok because we're caching chunk size based on the argument shapes anyways.

GMNGeoffrey added a commit to GMNGeoffrey/openfold-3 that referenced this pull request May 28, 2026
I encountered an issue with this when doing
aqlaboratory#213. That enables
`batch > 1` and chunking to happen together. When you run inference with
samples both above and below per_sample_token_cutoff, the small inputs
use the batched/normal pairformer embedding path and flatten their batch
dimensions (i.e. batch and num_samples), but the larger inputs follow
the per-sample path, which *doesn't* flatten its batch dimensions and
results in higher-rank inputs
(aqlaboratory#223 aims to fix this
mismatch).

`_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple
subclass), so when the same `ChunkSizeTuner` instance is invoked with
tensors of different ranks across calls, the inner
`zip(..., strict=True)` raised. Instead treat any length mismatch as a
cache miss so the caller re-tunes instead. The redundant top-level
length assert in `tune_chunk_size` is now handled the same way and
removed.

Also adds dtype element size to the cache key for a tensor argument. We
could use the entire dtype, but I think in terms of what matters for
chunking, the element size is the key factor.
GMNGeoffrey added a commit to GMNGeoffrey/openfold-3 that referenced this pull request May 28, 2026
I encountered an issue with this when doing
aqlaboratory#213. That enables
`batch > 1` and chunking to happen together. When you run inference with
samples both above and below per_sample_token_cutoff, the small inputs
use the batched/normal pairformer embedding path and flatten their batch
dimensions (i.e. batch and num_samples), but the larger inputs follow
the per-sample path, which *doesn't* flatten its batch dimensions and
results in higher-rank inputs
(aqlaboratory#223 aims to fix this
mismatch).

`_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple
subclass), so when the same `ChunkSizeTuner` instance is invoked with
tensors of different ranks across calls, the inner
`zip(..., strict=True)` raised. Instead treat any length mismatch as a
cache miss so the caller re-tunes instead. The redundant top-level
length assert in `tune_chunk_size` is now handled the same way and
removed.

Also adds dtype element size to the cache key for a tensor argument. We
could use the entire dtype, but I think in terms of what matters for
chunking, the element size is the key factor.
@jandom

jandom commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator

Sorry @GMNGeoffrey it looks like some of my changes gave conflicts here – hopefully won't be a nightmare to resolve

@GMNGeoffrey

Copy link
Copy Markdown
Contributor Author

Sorry @GMNGeoffrey it looks like some of my changes gave conflicts here – hopefully won't be a nightmare to resolve

NBD. I think most of the merge conflicts are with my own PRs :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Restoring chunking for batch_size > 1 with optimized kernels [BUG/QUESTION] Chunk size incremented by 4 off largest power of 2

3 participants