Skip to content

Reduce peak memory usage#248

Draft
GMNGeoffrey wants to merge 6 commits into
aqlaboratory:mainfrom
GMNGeoffrey:mem-reduction
Draft

Reduce peak memory usage#248
GMNGeoffrey wants to merge 6 commits into
aqlaboratory:mainfrom
GMNGeoffrey:mem-reduction

Conversation

@GMNGeoffrey

Copy link
Copy Markdown
Contributor

Summary
This builds on #244 to further reduce peak memory usage across a range of input sizes. The changes are most noticeable at 3k+ residues.

I had Claude Code iteratively look for the peak memory usage point and come up with a plan to reduce it. Then I cleaned up and extracted some of its commits (some I deemed excessively invasive)

Changes

Related Issues

Testing

  • I've done e2e runs showing no measurable impact on latency and identical outputs as well as graphed the memory usage.

Other Notes
I'm bundling these because the effects are sort of interdependent (e.g. one change might not reduce peak memory usage because there's another binding peak) and because thoroughly testing them is a bit time consuming, so I think it might be more efficient to show that the whole stack doesn't affect outputs or latency while reducing memory usage rather than doing it for each one individually. But I can break this up. A draft PR currently because I wanted something to point to.

Some of this might also change if we get rid of apply_per_sample as tentatively discussed in #213.

This is the memory peak for large complexes (around 5k tokens). The
pairwise distance computation here materializes an 3*N_atom^2
intermediate (I'm seeing 42GB of peak memory from this call alone).
`torch.cdist` uses the identity
`||a - b||^2 = ||a||^2 + ||b||^2 - 2 a.b` with more than 25 points (or
always if you set `use_mm_for_euclid_dist`). In floating point math this
has some negative consequences for precision as it ends up subtracting
large numbers to get a small result, but given that we then feed it into
topk, this shouldn't have that big an effect. It's also about 4x faster
(just for the distance computation), though I don't think that really
matters here. I tested across 14 monomer (49-4563 residues, 366-36,356
atoms), so we have realistic coordinates. But because distances are only
used for frame atoms when is_atomized is true, I artificially flipped
all of them to set them as atomized (would've been better to use actual
atomized inputs, but this is what I had on hand). `cdist` changes the
output of `phi` for four atoms total in the dataset (per input mean
0.05% of atoms, max single input 0.5% of atoms) and reduces the memory
usage of this computation for a 42k atoms from 42GB to 13GB.

There are some other locations that also do this distance calculation. I
think these aren't run during inference, so I haven't hit them, but it
might be worth investigating them as well. They don't get fed straight
into a topk though, so there may be more precision and differentiability
complications.
The masking line `d = d * pair_mask + inf * (1 - pair_mask)` momentarily
holds 4-5 simultaneous fp32 [N_atom, N_atom] tensors (the existing
pair_mask and d, plus three new intermediates). At N~42k this expression
alone accounts for ~27 GB of the function's ~33 GB peak after applying
the cdist optimization.

Switching to:
- pair_mask kept as bool (1.7 GB at N=41863 vs 6.7 GB as fp32)
- in-place AND for the chain-restriction step
- inplace masked_fill_ instead of multiply-and-add

drops the function's peak to ~10GB at N~42K.

This is bit-exact when pair_mask is 0/1-valued (which it is here, since
atom_mask is 0/1-float and atom_asym_id_mask is bool): d * 1 + 0 == d
and 0 + inf * 1 == inf in IEEE 754. Verified empirically 14 monomer
inputs that `phi` and `valid_frame_mask` are bit-identical to the previous
masking form.
Under offload_inference, pairformer_embedding returns zij on CPU. Even
if `apply_per_sample` was set, the previous code moved the whole
multi-sample pair representation back to GPU (S*N_tok^2*C_z) before
running the per-pair heads (PDE/PAE), which is the single largest
allocation in the run around 5k tokens and 5 samples.

Instead, plumb a compute_device kwarg through PDE/PAE's forward into
`_chunk`. The output buffer is still allocated on zij.device (CPU under
offload), but each iteration copies one sample's slice to
compute_device, runs the linear+layernorm, copies the result slice back
to zij.device, and writes it into the output buffer. The multi-sample
pair-rep tensor never lives on GPU.

Validated on 4967 residue 8uq5_A monomer on MI300X:
  - peak max_memory_allocated: 165 GB -> 132 GB (-20%)
  - runtime: 1840s -> 1824s (noise)
  - outputs match exactly when torch is run in deterministic mode

Note that to keep the API contract sensible, this adds a copy back to
zij.device even in the non `apply_per_sample` case (otherwise, the
return device would differ based on `apply_per_sample`, which would be a
bit weird). This means that if `apply_per_sample` is False and
`offload_inference` is True, whichever of PDE/PAE is called last will
unnecessarily copy the full zij to the CPU and then immediately back to
the GPU in the following dict comprehension (if it's not last then the
copy back is actually useful as it frees GPU memory for the subsequent
head). I played with having an output device parameter, but it made
things excessively complex. I think this is not really a sensible usage
scenario (I can't think of a scenario where offload threshold should be
lower than per-sample threshold) and the overhead of the copy is not
really very large (see benchmark numbers above), so I just left the
simpler thing.
Under offload_inference, aux_heads previously moved PDE, PAE, and
distogram logits back to GPU in its final dict comp, simultaneously
putting both per-pair head outputs on GPU `~2*S*N_tok^2*C_out*4` bytes.
This is ~64 GB at N_tok=5k, S=5, C_out=64, which is a binding memory
peak (after my previous memory reduction changes).

Instead, we leave them on CPU at `aux_heads` exit and have
`_get_confidence_scores` move each one onto GPU only while it is being
consumed (PDE for softmax->pde->gpde, PAE for softmax->pae and the
sample-ranking compute, distogram for gpde). Then drop the per-call GPU
reference.

PDE and PAE no longer coexist on GPU. This doesn't always lower the
program peak by itself (there are other points that hit the same memory
peak this is removing), but unblocks subsequent optimizations.

Verified outputs match exactly in torch deterministic mode and runtime
is within noise.
`_embed_feats` has a chain of 8 'a = a + self.<some>_linear(x)'
out-of-place additions. Each step has the accumulator a, the new linear
output, and a+linear(x) all live during the add, so the peak per step is
3x `[*, N_templ, N, N, C]` tensors (~32 GB for my example N_tok=5570
monomer). `add_` folds the linear output into `a` in-place, dropping the
3x to 2x and freeing the temporary linear output immediately after
fold-in.

Confirmed output is bit-identical in torch deterministic mode.
Unchunked, this cat is now the memory bottleneck. At ~5k tokens, this
drops peak memory by ~26GB.
@GMNGeoffrey

GMNGeoffrey commented Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

Here's a plot showing how this changes peak memory usage

image

Note that the changes especially with per-slice zij moves do introduce more memory fragmentation, which is why the end of the series with these changes is significantly below the full memory capacity

a = a + self.y_linear(y[..., None])
a = a + self.z_linear(z[..., None])
a = a + self.backbone_mask_linear(backbone_frame_pair_mask)
a.add_(self.pseudo_beta_mask_linear(pseudo_beta_pair_mask))

@christinaflo christinaflo Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should only run if inplace_safe=True, similar to other places in the code. we have a function
from openfold3.core.utils.tensor_utils import add that passes the flag like a = add(a, self.pseudo_beta_mask_linear(pseudo_beta_pair_mask), inplace=inplace_safe). it's always true for inference, not training

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Inference OOM from get_token_frame_atoms [BUG] Cuda Out of Memory on large structure even low_mem is set

2 participants