Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions openfold3/core/metrics/aggregate_confidence_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@


def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> dict:
# Under inference offload, aux_heads returns pde_logits, pae_logits, and
# distogram_logits on CPU. Move each one onto the compute device only while
# it's being consumed and drop the local reference afterwards (or just use a
# temporary in the first place), so that PDE and PAE are never both
# device-resident at the same time. Use atom_positions_predicted as the device
# anchor: it's always produced on the compute device by the diffusion
# sampler.
compute_device = outputs["atom_positions_predicted"].device

confidence_scores = {}
confidence_scores["plddt"] = (
probs_to_expected_error(
Expand All @@ -39,7 +48,7 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di
* 100.0
)

pde_probs = torch.softmax(outputs["pde_logits"], dim=-1)
pde_probs = torch.softmax(outputs["pde_logits"].to(device=compute_device), dim=-1)
confidence_scores["pde"] = probs_to_expected_error(
pde_probs, **config.confidence.pde
)
Expand All @@ -50,7 +59,7 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di

confidence_scores["gpde"], contact_probs = compute_global_predicted_distance_error(
pde=confidence_scores["pde"],
logits=outputs["distogram_logits"],
logits=outputs["distogram_logits"].to(device=compute_device),
**config.confidence.distogram,
)
if config.confidence.distogram.return_contact_probs:
Expand All @@ -59,7 +68,8 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di
del contact_probs

if config.architecture.heads.pae.enabled:
pae_probs = torch.softmax(outputs["pae_logits"], dim=-1)
pae_logits_on_device = outputs["pae_logits"].to(device=compute_device)
pae_probs = torch.softmax(pae_logits_on_device, dim=-1)
confidence_scores["pae"] = probs_to_expected_error(
pae_probs, **config.confidence.pae
)
Expand All @@ -76,10 +86,16 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di

valid_frame_mask = valid_frame_mask.bool()

# Patch outputs locally so downstream sample-ranking sees the
# device-resident pae_logits without us having to thread it through
# every callee signature.
outputs_for_ranking = dict(outputs)
outputs_for_ranking["pae_logits"] = pae_logits_on_device

confidence_scores.update(
full_complex_sample_ranking_metric(
batch=batch,
output=outputs,
output=outputs_for_ranking,
has_frame=valid_frame_mask,
**config.confidence.sample_ranking.full_complex,
**config.confidence.ptm,
Expand All @@ -90,7 +106,7 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di
confidence_scores.update(
compute_chain_pair_iptm(
batch=batch,
logits=outputs["pae_logits"],
logits=pae_logits_on_device,
has_frame=valid_frame_mask,
**config.confidence.ptm,
)
Expand All @@ -100,12 +116,15 @@ def _get_confidence_scores(batch: dict, outputs: dict, config: ConfigDict) -> di
confidence_scores.update(
compute_chain_ptm(
batch=batch,
outputs=outputs,
outputs=outputs_for_ranking,
has_frame=valid_frame_mask,
**config.confidence.ptm,
)
)

del outputs_for_ranking
del pae_logits_on_device

return confidence_scores


Expand Down
33 changes: 23 additions & 10 deletions openfold3/core/model/feature_embedders/template_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import openfold3.core.config.default_linear_init_config as lin_init
from openfold3.core.model.primitives import LayerNorm, Linear
from openfold3.core.utils.tensor_utils import add


class TemplatePairEmbedderAllAtom(nn.Module):
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(
self.layer_norm_z = LayerNorm(c_in)
self.linear_z = Linear(c_in, c_out, **linear_init_params.linear_z)

def _embed_feats(self, batch: dict):
def _embed_feats(self, batch: dict, inplace_safe: bool = False):
dtype = batch["template_unit_vector"].dtype

# [*, N_token, N_token]
Expand Down Expand Up @@ -103,17 +104,29 @@ def _embed_feats(self, batch: dict):
)

a = self.dgram_linear(template_distogram)
a = a + self.pseudo_beta_mask_linear(pseudo_beta_pair_mask)
a = a + self.aatype_linear_1(template_restype_ti.to(dtype=dtype))
a = a + self.aatype_linear_2(template_restype_tj.to(dtype=dtype))
a = a + self.x_linear(x[..., None])
a = a + self.y_linear(y[..., None])
a = a + self.z_linear(z[..., None])
a = a + self.backbone_mask_linear(backbone_frame_pair_mask)
a = add(
a, self.pseudo_beta_mask_linear(pseudo_beta_pair_mask), inplace=inplace_safe
)
a = add(
a,
self.aatype_linear_1(template_restype_ti.to(dtype=dtype)),
inplace=inplace_safe,
)
a = add(
a,
self.aatype_linear_2(template_restype_tj.to(dtype=dtype)),
inplace=inplace_safe,
)
a = add(a, self.x_linear(x[..., None]), inplace=inplace_safe)
a = add(a, self.y_linear(y[..., None]), inplace=inplace_safe)
a = add(a, self.z_linear(z[..., None]), inplace=inplace_safe)
a = add(
a, self.backbone_mask_linear(backbone_frame_pair_mask), inplace=inplace_safe
)

return a

def forward(self, batch, z):
def forward(self, batch, z, inplace_safe: bool = False):
"""
Args:
batch:
Expand All @@ -123,7 +136,7 @@ def forward(self, batch, z):
Returns:
# [*, N_templ, N_token, N_token, C_out] Template pair feature embedding
"""
a = self._embed_feats(batch=batch)
a = self._embed_feats(batch=batch, inplace_safe=inplace_safe)

# [*, N_templ, N_token, N_token, C_out]
z = self.linear_z(self.layer_norm_z(z))
Expand Down
31 changes: 19 additions & 12 deletions openfold3/core/model/heads/head_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ def forward(

# Distogram head: Main loop (Algorithm 1), line 17
distogram_logits = self.distogram(z=zij)
# Under offload_inference, move distogram off GPU now; downstream
# confidence scoring consumes it once at the very end (gpde) and
# can pull it back on demand. Saves ~S*N^2*C_out*4 bytes of GPU
# peak during all of the per-pair-head compute that follows.
if offload_inference:
distogram_logits = distogram_logits.to(device="cpu")
aux_out["distogram_logits"] = distogram_logits

# Stop grad
Expand Down Expand Up @@ -239,23 +245,24 @@ def forward(
)
aux_out["experimentally_resolved_logits"] = experimentally_resolved_logits

# zij is moved back to GPU after the single rep confidence heads
# because building the max_atom_per_token_mask uses a lot of memory
zij = zij.to(device=out_device)

pde_logits = self.pde(zij, apply_per_sample=apply_per_sample)
# We leave zij on CPU here and let the PDE/PAE heads pull what they
# need. This enables moving only a single sample onto the GPU at a time
# if running with apply_per_sample.
aux_out["pde_logits"] = self.pde(
zij,
apply_per_sample=apply_per_sample,
compute_device=out_device,
)

if self.config.pae.enabled:
# Offload pde logits to not keep all three pairwise tensors
# in GPU memory at once
offload_device = "cpu" if offload_inference else out_device
pde_logits = pde_logits.to(device=offload_device)
aux_out["pae_logits"] = self.pae(zij, apply_per_sample=apply_per_sample)
aux_out["pae_logits"] = self.pae(
zij,
apply_per_sample=apply_per_sample,
compute_device=out_device,
)

del zij

aux_out["pde_logits"] = pde_logits.to(device=out_device)

aux_out = {k: v.to(dtype=out_dtype) for k, v in aux_out.items()}

return aux_out
55 changes: 43 additions & 12 deletions openfold3/core/model/heads/prediction_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,19 +399,28 @@ def _compute_logits(self, zij: torch.Tensor):
def _chunk(
self,
zij: torch.Tensor,
compute_device: torch.device | None = None,
) -> torch.Tensor:
# ``zij`` will be moved in slices to ``compute_device`` for the layer
# norm + linear, and the output logits will be moved afterwards to
# the original ``zij.device``
zij_out = torch.zeros(
(*zij.shape[:-1], self.c_out), device=zij.device, dtype=zij.dtype
)
no_samples = zij.shape[-4]
for i in range(no_samples):
zij_out[..., i : i + 1, :, :, :] = self._compute_logits(
zij[..., i : i + 1, :, :, :]
)
slice_in = zij[..., i : i + 1, :, :, :].to(device=compute_device)
slice_out = self._compute_logits(slice_in).to(device=zij.device)
zij_out[..., i : i + 1, :, :, :] = slice_out

return zij_out

def forward(self, zij, apply_per_sample: bool = False):
def forward(
self,
zij,
apply_per_sample: bool = False,
compute_device: torch.device | None = None,
):
"""
Args:
zij:
Expand All @@ -421,14 +430,22 @@ def forward(self, zij, apply_per_sample: bool = False):
This is a memory optimization which is only used during
validation/inference and will depend on the number of samples
in the full rollout.
compute_device:
Device on which to run computation. zij will be moved here
before doing any computation. When apply_per_sample is true,
each per-sample slice of ``zij`` is moved onto this device
separately for the computation and the output is moved to
``zij.device`` before processing the next slice.
Returns:
logits:
[*, N, N, C_out] Logits
"""
if apply_per_sample:
logits = self._chunk(zij=zij)
logits = self._chunk(zij=zij, compute_device=compute_device)
else:
logits = self._compute_logits(zij=zij)
logits = self._compute_logits(zij=zij.to(device=compute_device)).to(
device=zij.device
)

return logits

Expand Down Expand Up @@ -471,19 +488,25 @@ def _compute_logits(self, zij: torch.Tensor):
def _chunk(
self,
zij: torch.Tensor,
compute_device: torch.device | None = None,
) -> torch.Tensor:
zij_out = torch.zeros(
(*zij.shape[:-1], self.c_out), device=zij.device, dtype=zij.dtype
)
no_samples = zij.shape[-4]
for i in range(no_samples):
zij_out[..., i : i + 1, :, :, :] = self._compute_logits(
zij[..., i : i + 1, :, :, :]
)
slice_in = zij[..., i : i + 1, :, :, :].to(device=compute_device)
slice_out = self._compute_logits(slice_in).to(device=zij.device)
zij_out[..., i : i + 1, :, :, :] = slice_out

return zij_out

def forward(self, zij, apply_per_sample: bool = False):
def forward(
self,
zij,
apply_per_sample: bool = False,
compute_device: torch.device | None = None,
):
"""
Args:
zij:
Expand All @@ -493,14 +516,22 @@ def forward(self, zij, apply_per_sample: bool = False):
This is a memory optimization which is only used during
validation/inference and will depend on the number of samples
in the full rollout.
compute_device:
Device on which to run computation. zij will be moved here
before doing any computation. When apply_per_sample is true,
each per-sample slice of ``zij`` is moved onto this device
separately for the computation and the output is moved to
``zij.device`` before processing the next slice.
Returns:
logits:
[*, N, N, C_out] Logits
"""
if apply_per_sample:
logits = self._chunk(zij=zij)
logits = self._chunk(zij=zij, compute_device=compute_device)
else:
logits = self._compute_logits(zij=zij)
logits = self._compute_logits(zij=zij.to(device=compute_device)).to(
device=zij.device
)

return logits

Expand Down
3 changes: 2 additions & 1 deletion openfold3/core/model/latent/template_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def _forward_offload(
t = self.template_pair_embedder(
batch=batch_templ,
z=z,
inplace_safe=inplace_safe,
)

# [*, N_templ, N_token, N_token, C_z]
Expand Down Expand Up @@ -608,7 +609,7 @@ def _forward(
inplace_safe: bool = False,
) -> torch.Tensor:
# [*, N_templ, N_token, N_token, C_t]
t = self.template_pair_embedder(batch, z)
t = self.template_pair_embedder(batch, z, inplace_safe=inplace_safe)

# [*, 1, N_token, N_token]
pair_mask = pair_mask[..., None, :, :].to(dtype=z.dtype)
Expand Down
27 changes: 22 additions & 5 deletions openfold3/core/model/layers/diffusion_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from openfold3.core.model.layers.transition import SwiGLUTransition
from openfold3.core.model.primitives.linear import Linear
from openfold3.core.model.primitives.normalization import LayerNorm
from openfold3.core.utils.chunk_utils import ChunkSizeTuner
from openfold3.core.utils.chunk_utils import ChunkSizeTuner, chunk_layer
from openfold3.core.utils.relpos import relpos_complex


Expand Down Expand Up @@ -137,16 +137,28 @@ def _embed_trunk_inputs(
si_input: torch.Tensor,
si_trunk: torch.Tensor,
zij_trunk: torch.Tensor,
chunk_size: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Pair conditioning
relpos_zij = relpos_complex(
batch=batch,
max_relative_idx=self.max_relative_idx,
max_relative_chain=self.max_relative_chain,
).to(dtype=zij_trunk.dtype)

zij = torch.cat([zij_trunk, relpos_zij], dim=-1)
zij = self.linear_z(self.layer_norm_z(zij))
def _proj_zij(zij_trunk_in, relpos_in):
return self.linear_z(
self.layer_norm_z(torch.cat([zij_trunk_in, relpos_in], dim=-1))
)

if chunk_size is not None:
zij = chunk_layer(
layer=_proj_zij,
inputs={"zij_trunk_in": zij_trunk, "relpos_in": relpos_zij},
chunk_size=chunk_size,
no_batch_dims=zij_trunk.dim() - 2,
)
else:
zij = _proj_zij(zij_trunk, relpos_zij)

# Single conditioning
si = torch.cat([si_trunk, si_input], dim=-1)
Expand Down Expand Up @@ -246,7 +258,12 @@ def forward(
zij_trunk = zij_trunk * 0

si, zij = self._embed_trunk_inputs(
batch=batch, t=t, si_input=si_input, si_trunk=si_trunk, zij_trunk=zij_trunk
batch=batch,
t=t,
si_input=si_input,
si_trunk=si_trunk,
zij_trunk=zij_trunk,
chunk_size=chunk_size,
)

if chunk_size is not None:
Expand Down
Loading