Chunk large LayerNorm inputs to work around torch bug#231
Open
GMNGeoffrey wants to merge 2 commits into
Open
Conversation
Contributor
Author
|
@christinaflo PTAL |
There are at least two bugs in the Torch layer norm kernel for very large inputs: - pytorch/pytorch#181555: when numel > 2^32 and the channel dimension is a multiple of 4 (so the vectorized kernel is used) there is a uint32 overflow in the row offset. Fixed in pytorch/pytorch#181600, but that is not yet in the stable release (but confirmed fixed in nightly). - pytorch/pytorch#184826 (which I just reported): when combined batch dim is >= 2^23 and the channel dimension is not a multiple of 4 (so the non-vectorized kernel is used), there's some failure not yet root caused, that means that some outputs are never addressed and are filled with garbage values. I ran into the second bug while trying to run inference on three large monomer sequences with 4563, 4967, and 5247 residues. The layer norm call in diffusion conditioning is `(1, 1, N_tok, N_tok, 267)` (c_z=128 + relpos=139 -> channels=267). I also ran into the first bug when trying to triangulate the exact trigger for the second. Since it appears that the channel dim will always be odd by construction (sum of 3 even numbers plus 1 for same-entity features), I think we would never hit this bug in this layernorm, but we would hit it for the layernorm in pairformer where the call is `(1, 1, N_tok, N_tok, 128)` so it would be, triggered by any sequence >= 2^12.5 ~ 5793 tokens `sqrt(2^32/2^7)`. Fix: use our existing chunking infrastructure to chunk the layer norm kernel when it would exceed either of these thresholds. I did this unconditionally, since at least the first bug affects CUDA as well as ROCm. I also didn't gate it on the channel parity, since that seemed like unnecessary extra complication. In my tests this chunking did not introduce any noticeable additional latency. I also didn't hook this in to any of the chunk size tuning or configuration, since the purpose here is different. Verified that this resolves the issue on these sequences.
017ca67 to
cbb394d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
There are at least two bugs in the Torch layer norm kernel for very large inputs:
>2^32and the channel dimension is a multiple of 4 (so the vectorized kernel is used) there is a uint32 overflow in the row offset. Fixed in [CUDA] Fix int32 overflow in layer_norm for tensors with >2^32 elements pytorch/pytorch#181600, but that is not yet in the stable release.F.layer_normsilently produces wrong output forM >= 2^23, C%4!=0pytorch/pytorch#184826 (which I just reported): when combined batch dim is>2^23and the channel dimension is not a multiple of 4 (so the non-vectorized kernel is used), there's some failure not yet root caused, that means that some outputs are never addressed and are filled with garbage values. At exactly2^23I actually get an "invalid configuration argument" error from HIP.I ran into the second bug while trying to run inference on three large monomer sequences with 4563, 4967, and 5247 residues. The layer norm call in diffusion conditioning is
(1, 1, N_tok, N_tok, 267)(c_z=128 + relpos=139 -> channels=267). I also ran into the first bug when trying to triangulate the exact trigger for the second. Since it appears that the channel dim will always be odd by construction (sum of 3 even numbers plus 1 for same-entity features), I think we would never hit this bug in this layernorm, but we could hit it for the layernorm in pairformer where the call is(1, 1, N_tok, N_tok, 128), so it would be triggered by any sequence >= 2^12.5 ~ 5793 tokens (sqrt(2^32/2^7)).Fix: use our existing chunking infrastructure to chunk the layer norm kernel when it would exceed either of these thresholds. I did this unconditionally, since at least the first bug affects CUDA as well as ROCm and these chunks are so large that I think they don't add any meaningful overhead. I also didn't gate it on the channel parity, since that seemed like unnecessary extra complication. In my tests this chunking did not introduce any noticeable additional latency. I also didn't hook this in to any of the chunk size tuning or configuration, since the purpose here is different.
Verified that this resolves the issue on these sequences.
Changes
Testing
Other notes
addopts = "-m 'not slow'"in the pytest pyproject.toml configuration.