Make pairformer_emb_per_sample call into pairformer_emb#223
Conversation
This avoids duplicate code paths that slightly diverge. In particular, the batched path currently flattens the batch dimensions whereas the per_sample path does not. This makes some downstream things confused. Includes equivalence tests between the per_sample and batched versions (these passed with the previous code as well, but weren't tested). Also adds some invariant checking on the batch dimensions being passed into the pairformer embedding to avoid future confusion.
|
@jnwei @christinaflo PTAL |
I encountered an issue with this when doing aqlaboratory#213. That enables `batch > 1` and chunking to happen together. When you run inference with samples both above and below per_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path and flatten their batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which *doesn't* flatten its batch dimensions and results in higher-rank inputs (aqlaboratory#223 aims to fix this mismatch). `_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple subclass), so when the same `ChunkSizeTuner` instance is invoked with tensors of different ranks across calls, the inner `zip(..., strict=True)` raised. Instead treat any length mismatch as a cache miss so the caller re-tunes instead. The redundant top-level length assert in `tune_chunk_size` is now handled the same way and removed. Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor.
| c_s = config.architecture.shared.c_s | ||
| c_z = config.architecture.shared.c_z | ||
|
|
||
| pair_emb = PairformerEmbedding( |
There was a problem hiding this comment.
This should be a separate utility, but for asserting equal tensors you have to modify the model weight init scheme like here: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/tests/test_kernels.py#L411
Otherwise, there's a lot of "final" init layers that are zero weight/bias on the first pass and you won't see the diffs come up
There was a problem hiding this comment.
Oh thanks for flagging. I've hit this with other models before, but didn't think about it here. It would be really nice if Torch had some way to let the init function itself declare different options for different purposes
There was a problem hiding this comment.
Added non-zero initialization, although I actually found that the outputs weren't zero even without that.
There was a problem hiding this comment.
they wouldnt be zero bc they have the initial embeddings, but the updates should be zero if i remember correctly
There was a problem hiding this comment.
@jnwei we should modify tests to use this util after it's merged in (i.e. kernel and offloading tests)
There was a problem hiding this comment.
This is a nice addition, thanks for adding this utility and for all of the testing fixes @GMNGeoffrey !
Similar to in pairformer embedding, these assumed that num_samples was the 1st dimension rather than indexing backwards base on the number of feature dimensions.
I encountered an issue with this when doing aqlaboratory#213. That enables `batch > 1` and chunking to happen together. When you run inference with samples both above and below per_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path and flatten their batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which *doesn't* flatten its batch dimensions and results in higher-rank inputs (aqlaboratory#223 aims to fix this mismatch). `_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple subclass), so when the same `ChunkSizeTuner` instance is invoked with tensors of different ranks across calls, the inner `zip(..., strict=True)` raised. Instead treat any length mismatch as a cache miss so the caller re-tunes instead. The redundant top-level length assert in `tune_chunk_size` is now handled the same way and removed. Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor.
I encountered an issue with this when doing aqlaboratory#213. That enables `batch > 1` and chunking to happen together. When you run inference with samples both above and below per_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path and flatten their batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which *doesn't* flatten its batch dimensions and results in higher-rank inputs (aqlaboratory#223 aims to fix this mismatch). `_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple subclass), so when the same `ChunkSizeTuner` instance is invoked with tensors of different ranks across calls, the inner `zip(..., strict=True)` raised. Instead treat any length mismatch as a cache miss so the caller re-tunes instead. The redundant top-level length assert in `tune_chunk_size` is now handled the same way and removed. Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor.
Summary
The current code has some duplication that also causes the two paths to diverge. In particular, the batched path currently flattens the batch dimensions whereas the per_sample path does not. This makes some downstream things confused (the issue I ran into was the chunk tuner expecting the same function to always be called with the same rank inputs, which is probably a bug as well). But also, to the extent that the flattening needs to happen, then I think it should happen for all paths (there's a disadvantage in that it can force implicit broadcasts to become explicit).
Changes
per_sample_pairformer_emba thin wrapper aroundpairformer_embx_predto not assume that num_samples is the 1st dimensionzijin a couple other places that assumed num_samples is the first dimensionRelated Issues
Testing
ValueErrornow.