Make config chunk_size maximum rather than minimum#227
Conversation
|
@christinaflo PTAL |
This converts the configured chunk size to the *maximum* rather than the *minimum*. I don't think there's a clear use case for this to be the minimum and users were confused. This also allows us to remove the hardcoded values for the max with default and flash-style kernels: these instead are just config settings. I removed the minimum chunk size entirely, so the tuner will search all powers of 2 up to the maximum. I'm struggling to think of a use case where someone would want to limit chunks size to a minimum of 4 (I think this number just came from the original suggested chunk size in the AlphaFold 3 paper). Note that this changes the hardcoded different chunk size for diffusion conditioning. I'm not sure how important this is to preserve and whether we saw actual latency improvements from raising the chunk size here vs just observing that it was possible to do so. If we want different chunk sizes for different modules, I think we should probably thread through a config setting per module because having a hardcoded value here is a bit surprising IMO (and note that it only gets applied when tuning is turned on). Probably the best analog in the config would be `offload_inference` as a per-module setting that isn't an init arg. If more things got added here, IDK whether we'd want to keep this structure or perhaps mirror the structure of architecture.
This would have caught the issue with the chunk size for attention getting divided by 4. I think for only that bug, it's not really carrying its weight, but testing this boundary case still seems useful.
| ) | ||
| if chunk_size is not None: | ||
| config.settings.memory.eval.chunk_size = chunk_size | ||
| config.settings.memory.train.chunk_size = chunk_size |
There was a problem hiding this comment.
We never really run chunking with training because the activations stack up anyway in the backward pass so it doesnt save you much. I havent actually run this but i think it may fail some assert. Diffusion conditioning at least has a assert not self.training for its chunking function
There was a problem hiding this comment.
ah okay i see now you only have a test for eval mode anyway
There was a problem hiding this comment.
can we just delete this line anyway since it cant run
There was a problem hiding this comment.
In that case, shouldn't we remove chunk_size from the train settings entirely? I set it here because if chunk size is overridden for the test it would be surprising if it weren't overridden for training as well IMO. I can delete it if you prefer though
There was a problem hiding this comment.
I had it there originally because I needed chunking enabled during validation for some large samples but not for training, so it'll pick what to use in the model like: chunk_size=mode_mem_settings.chunk_size depending on the stage it's in. I guess there's nothing stopping you from doing it during training, it's just not really worth it ever.
We could fix the assert in diffusion conditioning to match the other modules so it runs:
if chunk_size is not None and self.chunk_size_tuner is not None:
assert not self.training
or change model.py to always set it to None for training and not reference the config, i thought it was easier to distinguish in the config though
There was a problem hiding this comment.
Oh yeah thats a lot of memory. on 80gb gpus with no kernels I can see the chunking take effect with seq lengths > 1500 tokens. I have a really old and messy script for benchmarking inference speed + mem using random_of3_features. I can clean it up and share it, but I have a bit of a backlog this week so I can just send it for testing #213 so I don't block this PR. As you said, this is really just a config change. Btw realistic values are n_msa=16384 (this will get subsampled to 1024 each recycle) and n_templ=4.
There was a problem hiding this comment.
Perfect, thanks! I think that gives me enough to do testing without being limited by what I can find in the pdb. A standardized script in-repo would be nice, but you don't need to rush to clean up yours 🙂
So those don't vary much with input or scale with n_tok? I found that homoers used significantly less memory at very large n_tok I think due to msa reuse for the shared sequences, but I didn't dig in.
I can test for smaller memory caps and the chunk tuner behavior by limiting torch mem_fraction. The only problem is the combinatorial explosion of options. If 80gb is of particular interest (H100?) I can test it as well. The other option is to test fixed chunk sizes and just track memory. I've got a memory snapshotting callback (happy to clean up and upstream if it would be generally useful). The only issues there are that tuning itself can affect peak due to some clones and then diffusion conditioning getting chunk size 2048 is guarded by the tuner getting on (is the diff there worth it?)
There was a problem hiding this comment.
It's just the max allowable input n_msa and n_templ, there could be less but it'll get capped at those numbers. I'm not sure why large homomers use less memory either off the top of my head, I'd have to look into that as well, since msas should still be capped at 1024 per recycle I'd expect they would still reach that threshold.
Yeah 80 gb h100 is generally what we have access to so that is of particular interest. Even on mi300A I try to limit model gpu memory quite a bit anyway since our data loader processes take up a lot of memory, so I kind of treat it like an h100.
For the memory snapshotting, I have a callback for this also internally that will get merged in at some point, I assume they're probably the same.
For the diffusion conditioning chunk size, this PR changes it to the global max but I think it's fine, I highly doubt it's that much slower and I eyeballed that number to begin with :).
There was a problem hiding this comment.
Ah I looked closer and found the issue. One of the homomers I had Claude dig up for me (6R7M-1) is a 40-chain homomer with an MSA depth of only 122. This was throwing things off and I thought it would be an issue with all homomers to a lesser extent, but it's really just this one that is weird. If n_msa gets subsampled to 1024 every recycle, does it actually make a difference if the random input has n_msa=1024 or n_msa=16384?
For the diffusion conditioning chunk size, this PR changes it to the global max but I think it's fine, I highly doubt it's that much slower and I eyeballed that number to begin with :).
Oh right, I forgot that I did that already :-D
There was a problem hiding this comment.
No actually it doesn't make a difference, it's only if you want to exercise the subsampling logic which doesn't matter here.
I added an assert to guard against this so there isn't some surprising behavior.
b9d20a1 to
22edd67
Compare
Summary
The
chunk_sizespecified in the config is currently the minimum value that the tuner will go down to. I don't think there's a clear use case for this to be the minimum and users were confused. Instead, we can make this the maximum. This also allows us to remove the hardcoded values for the max with default and flash-style kernels: these instead are just config settings. I removed the minimum chunk size entirely, so the tuner will search all powers of 2 up to the maximum. I'm struggling to think of a use case where someone would want to limit chunks size to a minimum of 4 (I think this number just came from the original suggested chunk size in the AlphaFold 3 paper).Note that this changes the hardcoded different chunk size for diffusion conditioning. I'm not sure how important this is to preserve and whether we saw actual latency improvements from raising the chunk size here vs just observing that it was possible to do so. If we want different chunk sizes for different modules, I think we should probably thread through a config setting per module because having a hardcoded value here is a bit surprising IMO (and note that it only gets applied when tuning is turned on). Probably the best analog in the config would be
offload_inferenceas a per-module setting that isn't an init arg. If more things got added here, IDK whether we'd want to keep this structure or perhaps mirror the structure of architecture.Changes
chunk_sizeindicate the maximum chunk size when tuning is activeTesting