Skip to content

Reduce get_token_frame_atoms memory usage#244

Open
GMNGeoffrey wants to merge 2 commits into
aqlaboratory:mainfrom
GMNGeoffrey:atomize-pdist-cdist
Open

Reduce get_token_frame_atoms memory usage#244
GMNGeoffrey wants to merge 2 commits into
aqlaboratory:mainfrom
GMNGeoffrey:atomize-pdist-cdist

Conversation

@GMNGeoffrey

Copy link
Copy Markdown
Contributor

Summary
I am hitting OOMs in get_token_frame_atoms for large monomer inputs around 5k residues (42K atoms). This function contributes ~47GB to the peak memory usage. We can bring that down to ~10GB by avoiding intermediate materializations and operating on boolean rather than fp32 masks.

Changes

  • Replace manual pairwise distance computation with torch.cdist
  • Convert masks to booleans
  • Use in-place operations

Related Issues

Testing
Verified outputs change only marginally across 14 monomer inputs with a range of sizes (49-4563 residues, 366-36,356 atoms). cdist changes the output of phi for four atoms total in the dataset (per input mean 0.05%, max 0.5%). The mask changes don't affect the output at all.

Other Notes
Using torch.cdist can create numerics issues with floating point math, as it computes a small number as the subtraction of two large ones. Since we are feeding the result directly into topk, I think this doesn't end up mattering too much. In my testing, there were only a few cases where the output differed at all.

This is the memory peak for large complexes (around 5k tokens). The
pairwise distance computation here materializes an 3*N_atom^2
intermediate (I'm seeing 42GB of peak memory from this call alone).
`torch.cdist` uses the identity
`||a - b||^2 = ||a||^2 + ||b||^2 - 2 a.b` with more than 25 points (or
always if you set `use_mm_for_euclid_dist`). In floating point math this
has some negative consequences for precision as it ends up subtracting
large numbers to get a small result, but given that we then feed it into
topk, this shouldn't have that big an effect. It's also about 4x faster
(just for the distance computation), though I don't think that really
matters here. I tested across 14 monomer (49-4563 residues, 366-36,356
atoms), so we have realistic coordinates. But because distances are only
used for frame atoms when is_atomized is true, I artificially flipped
all of them to set them as atomized (would've been better to use actual
atomized inputs, but this is what I had on hand). `cdist` changes the
output of `phi` for four atoms total in the dataset (per input mean
0.05% of atoms, max single input 0.5% of atoms) and reduces the memory
usage of this computation for a 42k atoms from 42GB to 13GB.

There are some other locations that also do this distance calculation. I
think these aren't run during inference, so I haven't hit them, but it
might be worth investigating them as well. They don't get fed straight
into a topk though, so there may be more precision and differentiability
complications.
The masking line `d = d * pair_mask + inf * (1 - pair_mask)` momentarily
holds 4-5 simultaneous fp32 [N_atom, N_atom] tensors (the existing
pair_mask and d, plus three new intermediates). At N~42k this expression
alone accounts for ~27 GB of the function's ~33 GB peak after applying
the cdist optimization.

Switching to:
- pair_mask kept as bool (1.7 GB at N=41863 vs 6.7 GB as fp32)
- in-place AND for the chain-restriction step
- inplace masked_fill_ instead of multiply-and-add

drops the function's peak to ~10GB at N~42K.

This is bit-exact when pair_mask is 0/1-valued (which it is here, since
atom_mask is 0/1-float and atom_asym_id_mask is bool): d * 1 + 0 == d
and 0 + inf * 1 == inf in IEEE 754. Verified empirically 14 monomer
inputs that `phi` and `valid_frame_mask` are bit-identical to the previous
masking form.
@christinaflo

Copy link
Copy Markdown
Collaborator

And donot_use_mm_for_euclid_dist would remove the memory gains with this? I'm a little nervous since this is also used in training for the pae loss, I'd want a better idea of how much this changes outputs for a variety of different inputs. Separately, I was also thinking of moving this and other related indexing logic to the data pipeline and passing it in as a feature instead of computing it on the fly.

@GMNGeoffrey

Copy link
Copy Markdown
Contributor Author

Yeah I think basically all the gains are from using matmul. One of the alternatives I mentioned in the linked issue is chunking here if you're worried about the effect on the outputs. I think feeding into topk here makes this not a huge concern, but totally understand if you're leary of this sort of change without more thorough testing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Inference OOM from get_token_frame_atoms

2 participants