Make chunking for Triton kernels closer to cueq#209
Conversation
The max chunk size when using the cuequivariance kernels is double the default (1024 instead of 512). Make the Triton kernels use the same. There are a few places where the already tuned chunk size is divided by 4, but that subdivision is skipped when using the cuequivariance kernels. Make the Triton kernels the same. Rationale: the Triton kernels are very similar to the cuequivariance kernels and both use a memory efficient flash-attention-style algorithm. I think it makes sense to default to having them use the same code paths for chunk size. Some of the chunking subdivisions had explicit comments about being to work around errors specific to native PyTorch ops. We might want to further tune the chunking based on more detailed benchmarking, but assuming the Triton kernels behave similar to the cuequivariance kernels seems like a better starting point than assuming they behave like native PyTorch.
|
@christinaflo and @jnwei since you were responding to my other issues about chunking :-) Thanks for taking a look |
|
I was going to mention this in the other PR, chunking is too aggressive for all these kernels including deepspeed so we should have a single variable for this as you mentioned, or just make it more configurable in general instead of a hardcoded param. I did notice that when modifying these settings for large sequences cueq ran out of mem faster than the other kernels which is why I lowered the max in this commit, but I don't know if this is optimal. For the purposes of this PR we can just make these two match, but I'd like to tune this better in a future PR. |
|
@christinaflo makes sense. I think the best thing would be to have chunk size something that's configured from the config files rather than having a hardcoded maximum in the code. Then different settings for the kernels would just be a preset. Not sure how we would handle the further chunking subdivision in that case though. The fact that that's necessary suggests that the chunk size tuner isn't really working as intended. For this PR, would you prefer I leave it like this with a new variable for the Triton kernels, also add a variable for the deepspeed kernels (and turn off the chunking subdivision), make this a single variable for all 3 types of optimized kernels, or something else? I can look into threading this through as a config parameter in a separate PR, if you'd like. |
|
I was thinking about this more and what makes the most sense to me would be using the configured |
|
FWIW, I've benchmarked this and #207 and played with the chunk size a bit across a range of input sizes (~500 to ~5000 residues). It seems like 1024 is actually the sweet spot. Giving attention a smaller chunk size does also seem to be slightly better, but it's unclear if it's significant and this way of doing it is a bit odd IMO. But it seems like overall these changes are an improvement. |
Just wondering do you know for your 5k residue example how much this increased peak memory? Re minimum chunk size, I've used it before when I wanted a set chunk size that was above the max, basically disabling the tuning but keeping the chunk_size // 4 for the attention. But that's just a hacky thing, and I think most people are confused by the fact that it's the minimum as well. So I agree that it makes more sense to change to maximum, but I guess that can be a separate PR.
Oh oops I missed this, I think originally I was thinking about having a single variable and adding deepspeed to this as well. They're all pretty similar with inference memory usage, so it doesn't seem necessary to separate them. Could we change to that instead actually? |
|
Converted to a single variable
I can send a follow-up with a PR for that. If we're doing that though, I don't think it makes sense to add a
I tested at baseline and then without the |
|
Reran using the torch mem tracking. rocminfo was just giving me the numbers torch had reserved, so basically useless (always 187.3 GiB according to Torch itself). All the chunking configurations had max 167.6 GiB allocated except the ones that tried chunk size of 2048 which had 173.7 GiB (and then presumably tried to allocate another 14+ GiB and failed). Timing ordering totally changed in the second run. This is also all on k8s, so it has overhead and different jobs might be getting different nodes and GPUs. I think the only thing I can confidently say from this data is that the proposed configuration here is very likely not worse than the current mainline and that raising the max chunk size further to 2048 is likely bad (although that is probably dependent on the sequence length). |
There was a problem hiding this comment.
Thanks for making the changes, lgtm! We can tune this later if needed, but from my own runs (up to 3k residues) this also seems like an good default chunk config. I do think reducing the attention chunk size probably would help, maybe //2 instead of //4 I dont know. But for most use cases this is fine.
Summary
Make the chunking values used for Triton kernels the same as those used for cuequivariance kernels
Rationale: the Triton kernels are very similar to the cuequivariance kernels and both use a memory efficient flash-attention-style algorithm. I think it makes sense to default to having them use the same code paths for chunk size. Some of the chunking subdivisions had explicit comments about being to work around errors specific to native PyTorch ops (I think maybe when I was looking through history: not now). We might want to further tune the chunking based on more detailed benchmarking, but assuming the Triton kernels behave similar to the cuequivariance kernels seems like a better starting point than assuming they behave like native PyTorch. This also reduces the number of unique call signatures the attention kernel sees which makes it easier to tune.
Changes
Related Issues
Sort of related to #203 and #206. It was noticing the weird calling pattern of chunks of 129 (
(512+4)/4) that caused me to dive into the chunking logic in the first placeTesting
Other Notes
We could have a single "optimized attention kernels" chunk size variable rather than having separate ones.