Treat arg-cache length mismatch as a cache miss in ChunkSizeTuner#224
Merged
christinaflo merged 2 commits intoJun 3, 2026
Merged
Conversation
Contributor
Author
|
@christinaflo @jnwei PTAL |
9da1180 to
6d8e73b
Compare
6d8e73b to
0baf5ad
Compare
I encountered an issue with this when doing aqlaboratory#213. That enables `batch > 1` and chunking to happen together. When you run inference with samples both above and below per_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path and flatten their batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which *doesn't* flatten its batch dimensions and results in higher-rank inputs (aqlaboratory#223 aims to fix this mismatch). `_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple subclass), so when the same `ChunkSizeTuner` instance is invoked with tensors of different ranks across calls, the inner `zip(..., strict=True)` raised. Instead treat any length mismatch as a cache miss so the caller re-tunes instead. The redundant top-level length assert in `tune_chunk_size` is now handled the same way and removed. Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor.
0baf5ad to
f943d7c
Compare
jandom
reviewed
Jun 1, 2026
Contributor
Author
Done |
Contributor
Author
Can you mark safe to test and merge? Or were you waiting for @jandom? |
Collaborator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The chunk size tuner currently throws if the function it's tuning for receives arguments with different ranks in different calls. Instead, treat this as a cache miss. I don't think the tuner should be forcing invariants on the functions it's tuning for.
I encountered an issue with this when doing #213. That enables
batch > 1and chunking to happen together with Triton kernels (you hit the same bug with kernels off without that change). When you run inference with samples both above and belowper_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path which flattens the batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which doesn't flatten their batch dimensions and results in higher-rank inputs (#223 aims to fix this mismatch)._compare_arg_cachesrecurses on tensor shapes (torch.Sizeis a tuple subclass), so when the sameChunkSizeTunerinstance was invoked with tensors of different ranks across calls, the innerzip(..., strict=True)raised. Instead in this change we treat any length mismatch as a cache miss so the caller re-tunes instead. The top-level length assert intune_chunk_sizeis now handled the same way and removed.Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor.
Changes
dtype.itemsizeto tensor cache key along with shapezip(...,strict=True)Related Issues
Testing