Skip to content

Treat arg-cache length mismatch as a cache miss in ChunkSizeTuner#224

Merged
christinaflo merged 2 commits into
aqlaboratory:mainfrom
GMNGeoffrey:chunk-tune-rank-change
Jun 3, 2026
Merged

Treat arg-cache length mismatch as a cache miss in ChunkSizeTuner#224
christinaflo merged 2 commits into
aqlaboratory:mainfrom
GMNGeoffrey:chunk-tune-rank-change

Conversation

@GMNGeoffrey

@GMNGeoffrey GMNGeoffrey commented May 18, 2026

Copy link
Copy Markdown
Contributor

Summary
The chunk size tuner currently throws if the function it's tuning for receives arguments with different ranks in different calls. Instead, treat this as a cache miss. I don't think the tuner should be forcing invariants on the functions it's tuning for.

I encountered an issue with this when doing #213. That enables batch > 1 and chunking to happen together with Triton kernels (you hit the same bug with kernels off without that change). When you run inference with samples both above and below per_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path which flattens the batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which doesn't flatten their batch dimensions and results in higher-rank inputs (#223 aims to fix this mismatch).

_compare_arg_caches recurses on tensor shapes (torch.Size is a tuple subclass), so when the same ChunkSizeTuner instance was invoked with tensors of different ranks across calls, the inner zip(..., strict=True) raised. Instead in this change we treat any length mismatch as a cache miss so the caller re-tunes instead. The top-level length assert in tune_chunk_size is now handled the same way and removed.

Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor.

Changes

  • Remove assert that tuned function always called with same number of args
  • Add dtype.itemsize to tensor cache key along with shape
  • Record a cache miss when cache key lengths differ rather than raising from zip(...,strict=True)
  • Unit tests for the same

Related Issues

Testing

  • Unit tests verifying cache miss on arg count, arg rank, and arg dtype size changes.

@GMNGeoffrey

Copy link
Copy Markdown
Contributor Author

@christinaflo @jnwei PTAL

I encountered an issue with this when doing
aqlaboratory#213. That enables
`batch > 1` and chunking to happen together. When you run inference with
samples both above and below per_sample_token_cutoff, the small inputs
use the batched/normal pairformer embedding path and flatten their batch
dimensions (i.e. batch and num_samples), but the larger inputs follow
the per-sample path, which *doesn't* flatten its batch dimensions and
results in higher-rank inputs
(aqlaboratory#223 aims to fix this
mismatch).

`_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple
subclass), so when the same `ChunkSizeTuner` instance is invoked with
tensors of different ranks across calls, the inner
`zip(..., strict=True)` raised. Instead treat any length mismatch as a
cache miss so the caller re-tunes instead. The redundant top-level
length assert in `tune_chunk_size` is now handled the same way and
removed.

Also adds dtype element size to the cache key for a tensor argument. We
could use the entire dtype, but I think in terms of what matters for
chunking, the element size is the key factor.
@GMNGeoffrey GMNGeoffrey force-pushed the chunk-tune-rank-change branch from 0baf5ad to f943d7c Compare May 28, 2026 16:53
@christinaflo christinaflo self-requested a review May 29, 2026 01:18
@christinaflo christinaflo added the safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. label May 29, 2026

@jandom jandom left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are all good – all 3 test for the negative case, would it make sense to test for the positive as well, or do we think that's overkill?

Meta-comment, and outside of scope, we should really migrate all these to pytest just to be consistent

@GMNGeoffrey

Copy link
Copy Markdown
Contributor Author

Those are all good – all 3 test for the negative case, would it make sense to test for the positive as well

Done

@christinaflo christinaflo left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@GMNGeoffrey

Copy link
Copy Markdown
Contributor Author

LGTM!

Can you mark safe to test and merge? Or were you waiting for @jandom?

@christinaflo christinaflo added safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. and removed safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. labels Jun 3, 2026
@christinaflo

Copy link
Copy Markdown
Collaborator

LGTM!

Can you mark safe to test and merge? Or were you waiting for @jandom?

Yeah I was going to see if @jandom had other comments but I think it's fine to merge after the tests finish running.

@christinaflo christinaflo merged commit cf2c00c into aqlaboratory:main Jun 3, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants