Reduce get_token_frame_atoms memory usage#244
Open
GMNGeoffrey wants to merge 2 commits into
Open
Conversation
This is the memory peak for large complexes (around 5k tokens). The pairwise distance computation here materializes an 3*N_atom^2 intermediate (I'm seeing 42GB of peak memory from this call alone). `torch.cdist` uses the identity `||a - b||^2 = ||a||^2 + ||b||^2 - 2 a.b` with more than 25 points (or always if you set `use_mm_for_euclid_dist`). In floating point math this has some negative consequences for precision as it ends up subtracting large numbers to get a small result, but given that we then feed it into topk, this shouldn't have that big an effect. It's also about 4x faster (just for the distance computation), though I don't think that really matters here. I tested across 14 monomer (49-4563 residues, 366-36,356 atoms), so we have realistic coordinates. But because distances are only used for frame atoms when is_atomized is true, I artificially flipped all of them to set them as atomized (would've been better to use actual atomized inputs, but this is what I had on hand). `cdist` changes the output of `phi` for four atoms total in the dataset (per input mean 0.05% of atoms, max single input 0.5% of atoms) and reduces the memory usage of this computation for a 42k atoms from 42GB to 13GB. There are some other locations that also do this distance calculation. I think these aren't run during inference, so I haven't hit them, but it might be worth investigating them as well. They don't get fed straight into a topk though, so there may be more precision and differentiability complications.
The masking line `d = d * pair_mask + inf * (1 - pair_mask)` momentarily holds 4-5 simultaneous fp32 [N_atom, N_atom] tensors (the existing pair_mask and d, plus three new intermediates). At N~42k this expression alone accounts for ~27 GB of the function's ~33 GB peak after applying the cdist optimization. Switching to: - pair_mask kept as bool (1.7 GB at N=41863 vs 6.7 GB as fp32) - in-place AND for the chain-restriction step - inplace masked_fill_ instead of multiply-and-add drops the function's peak to ~10GB at N~42K. This is bit-exact when pair_mask is 0/1-valued (which it is here, since atom_mask is 0/1-float and atom_asym_id_mask is bool): d * 1 + 0 == d and 0 + inf * 1 == inf in IEEE 754. Verified empirically 14 monomer inputs that `phi` and `valid_frame_mask` are bit-identical to the previous masking form.
Collaborator
|
And donot_use_mm_for_euclid_dist would remove the memory gains with this? I'm a little nervous since this is also used in training for the pae loss, I'd want a better idea of how much this changes outputs for a variety of different inputs. Separately, I was also thinking of moving this and other related indexing logic to the data pipeline and passing it in as a feature instead of computing it on the fly. |
Contributor
Author
|
Yeah I think basically all the gains are from using matmul. One of the alternatives I mentioned in the linked issue is chunking here if you're worried about the effect on the outputs. I think feeding into topk here makes this not a huge concern, but totally understand if you're leary of this sort of change without more thorough testing |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
I am hitting OOMs in
get_token_frame_atomsfor large monomer inputs around 5k residues (42K atoms). This function contributes ~47GB to the peak memory usage. We can bring that down to ~10GB by avoiding intermediate materializations and operating on boolean rather than fp32 masks.Changes
torch.cdistRelated Issues
get_token_frame_atoms#225Testing
Verified outputs change only marginally across 14 monomer inputs with a range of sizes (49-4563 residues, 366-36,356 atoms). cdist changes the output of phi for four atoms total in the dataset (per input mean 0.05%, max 0.5%). The mask changes don't affect the output at all.
Other Notes
Using
torch.cdistcan create numerics issues with floating point math, as it computes a small number as the subtraction of two large ones. Since we are feeding the result directly into topk, I think this doesn't end up mattering too much. In my testing, there were only a few cases where the output differed at all.