Skip to content

Chunk large LayerNorm inputs to work around torch bug#231

Open
GMNGeoffrey wants to merge 2 commits into
aqlaboratory:mainfrom
GMNGeoffrey:ln-chunklayer-fix
Open

Chunk large LayerNorm inputs to work around torch bug#231
GMNGeoffrey wants to merge 2 commits into
aqlaboratory:mainfrom
GMNGeoffrey:ln-chunklayer-fix

Conversation

@GMNGeoffrey

Copy link
Copy Markdown
Contributor

Summary
There are at least two bugs in the Torch layer norm kernel for very large inputs:

I ran into the second bug while trying to run inference on three large monomer sequences with 4563, 4967, and 5247 residues. The layer norm call in diffusion conditioning is (1, 1, N_tok, N_tok, 267) (c_z=128 + relpos=139 -> channels=267). I also ran into the first bug when trying to triangulate the exact trigger for the second. Since it appears that the channel dim will always be odd by construction (sum of 3 even numbers plus 1 for same-entity features), I think we would never hit this bug in this layernorm, but we could hit it for the layernorm in pairformer where the call is (1, 1, N_tok, N_tok, 128), so it would be triggered by any sequence >= 2^12.5 ~ 5793 tokens (sqrt(2^32/2^7)).

Fix: use our existing chunking infrastructure to chunk the layer norm kernel when it would exceed either of these thresholds. I did this unconditionally, since at least the first bug affects CUDA as well as ROCm and these chunks are so large that I think they don't add any meaningful overhead. I also didn't gate it on the channel parity, since that seemed like unnecessary extra complication. In my tests this chunking did not introduce any noticeable additional latency. I also didn't hook this in to any of the chunk size tuning or configuration, since the purpose here is different.

Verified that this resolves the issue on these sequences.

Changes

  • Chunk very large layernorm calls
  • Add tests for layernorm at these sizes

Testing

  • Unit tests for these bugs. Confirmed they failed without this fix.
  • Confirmed I no longer got NaN outputs for any of my large monomer inputs with this fix.

Other notes

  • The 2^32 element test can be quite slow. I marked it as slow, but not sure if we want to exclude it by default or something (we could exclude all slow tests by default by adding addopts = "-m 'not slow'" in the pytest pyproject.toml configuration.

@GMNGeoffrey

Copy link
Copy Markdown
Contributor Author

@christinaflo PTAL

There are at least two bugs in the Torch layer norm kernel for very
large inputs:

- pytorch/pytorch#181555: when numel > 2^32
  and the channel dimension is a multiple of 4 (so the vectorized kernel
  is used) there is a uint32 overflow in the row offset. Fixed in
  pytorch/pytorch#181600, but that is not yet in
  the stable release (but confirmed fixed in nightly).
- pytorch/pytorch#184826 (which I just
  reported): when combined batch dim is >= 2^23 and the channel
  dimension is not a multiple of 4 (so the non-vectorized kernel is
  used), there's some failure not yet root caused, that means that some
  outputs are never addressed and are filled with garbage values.

I ran into the second bug while trying to run inference on three large
monomer sequences with 4563, 4967, and 5247 residues. The layer norm
call in diffusion conditioning is `(1, 1, N_tok, N_tok, 267)` (c_z=128 +
relpos=139 -> channels=267). I also ran into the first bug when trying
to triangulate the exact trigger for the second. Since it appears that
the channel dim will always be odd by construction (sum of 3 even
numbers plus 1 for same-entity features), I think we would never hit
this bug in this layernorm, but we would hit it for the layernorm in
pairformer where the call is `(1, 1, N_tok, N_tok, 128)` so it would be,
triggered by any sequence >= 2^12.5 ~ 5793 tokens `sqrt(2^32/2^7)`.

Fix: use our existing chunking infrastructure to chunk the layer norm
kernel when it would exceed either of these thresholds. I did this
unconditionally, since at least the first bug affects CUDA as well as
ROCm. I also didn't gate it on the channel parity, since that seemed
like unnecessary extra complication. In my tests this chunking did not
introduce any noticeable additional latency. I also didn't hook this in
to any of the chunk size tuning or configuration, since the purpose here
is different.

Verified that this resolves the issue on these sequences.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant