Avoid flattening and expanding broadcast dimensions when chunking#213
Avoid flattening and expanding broadcast dimensions when chunking#213GMNGeoffrey wants to merge 6 commits into
Conversation
- Avoid weird addition of 4 to power-of-two chunk sizes. This was added in aqlaboratory@a9a12890d without explanation. We can hypothesize that it was related to adding 4 to an input dimension in trace_utils.py (trying to get a test case to fit in one chunk?), but that file was long ago deleted. This just looks like a bug and makes us hit unhappy paths all over the place. Fixes aqlaboratory#203 - Enable chunking for AuxiliaryHeadsAllAtom pairformer embedding when using optimized kernels. Without chunking, this is the first call to cause OOMs because its `diffusion_samples*sequence_length` batches. Chunking gets turned off in prediction_heads.py due to batch size > 1 and use of optimized kernels because cross-sample chunking requires expanding out pair bias and they all require it to have size 1 in the second dimension with implicit broadcasting. So we turn on `apply_per_sample` when optimized kernels are in use. This splits the > 1 batch dimension, which avoids this problematic path and then we can do normal chunking for the rest if it's still too large. We could do something more elaborate (see suggestions in linked issue), but this is an improvement for now. Fixes aqlaboratory#206
With the introduction of dim-aware chunking, these can work fine.
I encountered an issue with this when doing aqlaboratory#213. That enables `batch > 1` and chunking to happen together. When you run inference with samples both above and below per_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path and flatten their batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which *doesn't* flatten its batch dimensions and results in higher-rank inputs (aqlaboratory#223 aims to fix this mismatch). `_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple subclass), so when the same `ChunkSizeTuner` instance is invoked with tensors of different ranks across calls, the inner `zip(..., strict=True)` raised. Instead treat any length mismatch as a cache miss so the caller re-tunes instead. The redundant top-level length assert in `tune_chunk_size` is now handled the same way and removed. Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor.
|
I am in favor of this kind of fix vs #207, I always wanted something like this to avoid the workarounds but never got to it lol, thank you for this! Would you be able to do the performance testing? I'm wondering how high it could go compared to apply_per_sample + offloading. |
|
Yeah I'm happy to do some performance testing :-) Would be helpful to merge my other chunking PRs so I can get this on a stable base. This solution still does leave something to be desired because a chunk of 1024 when it includes the batch dimension is not the same as one where it's just the sequence count, as pair bias scales with the former but not the latter. But I guess that's ok because we're caching chunk size based on the argument shapes anyways. |
I encountered an issue with this when doing aqlaboratory#213. That enables `batch > 1` and chunking to happen together. When you run inference with samples both above and below per_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path and flatten their batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which *doesn't* flatten its batch dimensions and results in higher-rank inputs (aqlaboratory#223 aims to fix this mismatch). `_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple subclass), so when the same `ChunkSizeTuner` instance is invoked with tensors of different ranks across calls, the inner `zip(..., strict=True)` raised. Instead treat any length mismatch as a cache miss so the caller re-tunes instead. The redundant top-level length assert in `tune_chunk_size` is now handled the same way and removed. Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor.
I encountered an issue with this when doing aqlaboratory#213. That enables `batch > 1` and chunking to happen together. When you run inference with samples both above and below per_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path and flatten their batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which *doesn't* flatten its batch dimensions and results in higher-rank inputs (aqlaboratory#223 aims to fix this mismatch). `_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple subclass), so when the same `ChunkSizeTuner` instance is invoked with tensors of different ranks across calls, the inner `zip(..., strict=True)` raised. Instead treat any length mismatch as a cache miss so the caller re-tunes instead. The redundant top-level length assert in `tune_chunk_size` is now handled the same way and removed. Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor.
|
Sorry @GMNGeoffrey it looks like some of my changes gave conflicts here – hopefully won't be a nightmare to resolve |
NBD. I think most of the merge conflicts are with my own PRs :-) |
Summary
This is an alternative, more complicated fix for #206 instead of #207. To avoid the expansion of an implicit broadcast dimension (e.g. in pair bias), if one exists it chunks across each batch dimension discretely. If the innermost batch dimension (usually N_seq) is larger than chunk_size then this is equivalent to splitting per sample. If it's smaller though, then we might fit multiple samples in a single chunk.
So as an example when we have diffusion_samples 5 and sequence length 76 in the pairformer embedding we have batch dims [5, 76] for QKV but for the pair bias it's [5, 1]. Rather than flattening to [380], we chunk each dimension separately. If chunk size were 128 then we'd split into 5 chunks of [1, 76]; if it were 64, then we'd split into 5 chunks of [1, 64] and 5 chunks of [1, 12]; if chunk size were 256, then we'd split into chunks of [3, 76] and [2, 76], etc.
Doing this in the
chunk_layerfunction allows us to be chunk-size-aware, which we can't be at higher levels when chunk size has potentially not yet been determined due to chunk size tuning. I think this could completely replaceapply_per_sampleandper_sample_token_cutoffas well and is more robust.Changes
Related Issues
Testing
Other Notes