diff --git a/openfold3/core/utils/atomize_utils.py b/openfold3/core/utils/atomize_utils.py index 79436981..077923ad 100644 --- a/openfold3/core/utils/atomize_utils.py +++ b/openfold3/core/utils/atomize_utils.py @@ -762,23 +762,23 @@ def get_token_frame_atoms( valid_frame_mask: [*, N_token] Mask denoting valid frames """ - # Create pairwise atom mask - pair_mask = atom_mask[..., None] * atom_mask[..., None, :] + # Pairwise atom mask, kept as bool throughout to avoid materializing + # large fp32 [N_atom, N_atom] intermediates. + am_bool = atom_mask.bool() + pair_mask = am_bool[..., None] & am_bool[..., None, :] - # Update pairwise atom mask # Restrict to atoms within the same chain atom_asym_id = broadcast_token_feat_to_atoms( token_mask=batch["token_mask"], num_atoms_per_token=batch["num_atoms_per_token"], token_feat=batch["asym_id"], ) - atom_asym_id_mask = atom_asym_id[..., None] == atom_asym_id[..., None, :] - pair_mask = pair_mask * atom_asym_id_mask + pair_mask &= atom_asym_id[..., None] == atom_asym_id[..., None, :] - # Compute distance matrix + # Compute distance matrix. Use cdist to avoid materializing N*N*3 intermediate # [*, N_atom, N_atom] - d = torch.sum(eps + (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1) ** 0.5 - d = d * pair_mask + inf * (1 - pair_mask) + d = torch.cdist(x, x) + d.masked_fill_(~pair_mask, inf) # Find indices of two closest atoms for start atoms # [*, N_token]