Reduce peak memory usage#248
Draft
GMNGeoffrey wants to merge 6 commits into
Draft
Conversation
This is the memory peak for large complexes (around 5k tokens). The pairwise distance computation here materializes an 3*N_atom^2 intermediate (I'm seeing 42GB of peak memory from this call alone). `torch.cdist` uses the identity `||a - b||^2 = ||a||^2 + ||b||^2 - 2 a.b` with more than 25 points (or always if you set `use_mm_for_euclid_dist`). In floating point math this has some negative consequences for precision as it ends up subtracting large numbers to get a small result, but given that we then feed it into topk, this shouldn't have that big an effect. It's also about 4x faster (just for the distance computation), though I don't think that really matters here. I tested across 14 monomer (49-4563 residues, 366-36,356 atoms), so we have realistic coordinates. But because distances are only used for frame atoms when is_atomized is true, I artificially flipped all of them to set them as atomized (would've been better to use actual atomized inputs, but this is what I had on hand). `cdist` changes the output of `phi` for four atoms total in the dataset (per input mean 0.05% of atoms, max single input 0.5% of atoms) and reduces the memory usage of this computation for a 42k atoms from 42GB to 13GB. There are some other locations that also do this distance calculation. I think these aren't run during inference, so I haven't hit them, but it might be worth investigating them as well. They don't get fed straight into a topk though, so there may be more precision and differentiability complications.
The masking line `d = d * pair_mask + inf * (1 - pair_mask)` momentarily holds 4-5 simultaneous fp32 [N_atom, N_atom] tensors (the existing pair_mask and d, plus three new intermediates). At N~42k this expression alone accounts for ~27 GB of the function's ~33 GB peak after applying the cdist optimization. Switching to: - pair_mask kept as bool (1.7 GB at N=41863 vs 6.7 GB as fp32) - in-place AND for the chain-restriction step - inplace masked_fill_ instead of multiply-and-add drops the function's peak to ~10GB at N~42K. This is bit-exact when pair_mask is 0/1-valued (which it is here, since atom_mask is 0/1-float and atom_asym_id_mask is bool): d * 1 + 0 == d and 0 + inf * 1 == inf in IEEE 754. Verified empirically 14 monomer inputs that `phi` and `valid_frame_mask` are bit-identical to the previous masking form.
Under offload_inference, pairformer_embedding returns zij on CPU. Even if `apply_per_sample` was set, the previous code moved the whole multi-sample pair representation back to GPU (S*N_tok^2*C_z) before running the per-pair heads (PDE/PAE), which is the single largest allocation in the run around 5k tokens and 5 samples. Instead, plumb a compute_device kwarg through PDE/PAE's forward into `_chunk`. The output buffer is still allocated on zij.device (CPU under offload), but each iteration copies one sample's slice to compute_device, runs the linear+layernorm, copies the result slice back to zij.device, and writes it into the output buffer. The multi-sample pair-rep tensor never lives on GPU. Validated on 4967 residue 8uq5_A monomer on MI300X: - peak max_memory_allocated: 165 GB -> 132 GB (-20%) - runtime: 1840s -> 1824s (noise) - outputs match exactly when torch is run in deterministic mode Note that to keep the API contract sensible, this adds a copy back to zij.device even in the non `apply_per_sample` case (otherwise, the return device would differ based on `apply_per_sample`, which would be a bit weird). This means that if `apply_per_sample` is False and `offload_inference` is True, whichever of PDE/PAE is called last will unnecessarily copy the full zij to the CPU and then immediately back to the GPU in the following dict comprehension (if it's not last then the copy back is actually useful as it frees GPU memory for the subsequent head). I played with having an output device parameter, but it made things excessively complex. I think this is not really a sensible usage scenario (I can't think of a scenario where offload threshold should be lower than per-sample threshold) and the overhead of the copy is not really very large (see benchmark numbers above), so I just left the simpler thing.
Under offload_inference, aux_heads previously moved PDE, PAE, and distogram logits back to GPU in its final dict comp, simultaneously putting both per-pair head outputs on GPU `~2*S*N_tok^2*C_out*4` bytes. This is ~64 GB at N_tok=5k, S=5, C_out=64, which is a binding memory peak (after my previous memory reduction changes). Instead, we leave them on CPU at `aux_heads` exit and have `_get_confidence_scores` move each one onto GPU only while it is being consumed (PDE for softmax->pde->gpde, PAE for softmax->pae and the sample-ranking compute, distogram for gpde). Then drop the per-call GPU reference. PDE and PAE no longer coexist on GPU. This doesn't always lower the program peak by itself (there are other points that hit the same memory peak this is removing), but unblocks subsequent optimizations. Verified outputs match exactly in torch deterministic mode and runtime is within noise.
`_embed_feats` has a chain of 8 'a = a + self.<some>_linear(x)' out-of-place additions. Each step has the accumulator a, the new linear output, and a+linear(x) all live during the add, so the peak per step is 3x `[*, N_templ, N, N, C]` tensors (~32 GB for my example N_tok=5570 monomer). `add_` folds the linear output into `a` in-place, dropping the 3x to 2x and freeing the temporary linear output immediately after fold-in. Confirmed output is bit-identical in torch deterministic mode.
Unchunked, this cat is now the memory bottleneck. At ~5k tokens, this drops peak memory by ~26GB.
Contributor
Author
christinaflo
reviewed
Jun 9, 2026
| a = a + self.y_linear(y[..., None]) | ||
| a = a + self.z_linear(z[..., None]) | ||
| a = a + self.backbone_mask_linear(backbone_frame_pair_mask) | ||
| a.add_(self.pseudo_beta_mask_linear(pseudo_beta_pair_mask)) |
Collaborator
There was a problem hiding this comment.
this should only run if inplace_safe=True, similar to other places in the code. we have a function
from openfold3.core.utils.tensor_utils import add that passes the flag like a = add(a, self.pseudo_beta_mask_linear(pseudo_beta_pair_mask), inplace=inplace_safe). it's always true for inference, not training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Summary
This builds on #244 to further reduce peak memory usage across a range of input sizes. The changes are most noticeable at 3k+ residues.
I had Claude Code iteratively look for the peak memory usage point and come up with a plan to reduce it. Then I cleaned up and extracted some of its commits (some I deemed excessively invasive)
Changes
get_token_frame_atoms(from Reduceget_token_frame_atomsmemory usage #244)get_token_frame_atoms(from Reduceget_token_frame_atomsmemory usage #244)Related Issues
get_token_frame_atoms#225. My report about an OOM that motivated the cdist change.Testing
Other Notes
I'm bundling these because the effects are sort of interdependent (e.g. one change might not reduce peak memory usage because there's another binding peak) and because thoroughly testing them is a bit time consuming, so I think it might be more efficient to show that the whole stack doesn't affect outputs or latency while reducing memory usage rather than doing it for each one individually. But I can break this up. A draft PR currently because I wanted something to point to.
Some of this might also change if we get rid of
apply_per_sampleas tentatively discussed in #213.