diff --git a/src/transformers/models/esm/openfold_utils/__init__.py b/src/transformers/models/esm/openfold_utils/__init__.py index 5273860260c1..4a0d932a05c4 100644 --- a/src/transformers/models/esm/openfold_utils/__init__.py +++ b/src/transformers/models/esm/openfold_utils/__init__.py @@ -6,3 +6,4 @@ from .protein import Protein as OFProtein from .protein import to_pdb from .rigid_utils import Rigid, Rotation +from .tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims diff --git a/src/transformers/models/esm/openfold_utils/chunk_utils.py b/src/transformers/models/esm/openfold_utils/chunk_utils.py index 11e5fff929b2..4f68503e99bb 100644 --- a/src/transformers/models/esm/openfold_utils/chunk_utils.py +++ b/src/transformers/models/esm/openfold_utils/chunk_utils.py @@ -14,23 +14,22 @@ import logging import math from functools import partial -from typing import Any, Callable, Dict, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from .tensor_utils import tensor_tree_map, tree_map -def _fetch_dims(tree): +def _fetch_dims(tree: Union[dict, list, tuple, torch.Tensor]) -> List[Tuple[int, ...]]: shapes = [] - tree_type = type(tree) - if tree_type is dict: + if isinstance(tree, dict): for v in tree.values(): shapes.extend(_fetch_dims(v)) - elif tree_type is list or tree_type is tuple: + elif isinstance(tree, (list, tuple)): for t in tree: shapes.extend(_fetch_dims(t)) - elif tree_type is torch.Tensor: + elif isinstance(tree, torch.Tensor): shapes.append(tree.shape) else: raise ValueError("Not supported") @@ -39,10 +38,7 @@ def _fetch_dims(tree): @torch.jit.ignore -def _flat_idx_to_idx( - flat_idx: int, - dims: Tuple[int], -) -> Tuple[int]: +def _flat_idx_to_idx(flat_idx: int, dims: Tuple[int, ...]) -> Tuple[int, ...]: idx = [] for d in reversed(dims): idx.append(flat_idx % d) @@ -55,10 +51,10 @@ def _flat_idx_to_idx( def _get_minimal_slice_set( start: Sequence[int], end: Sequence[int], - dims: int, + dims: Sequence[int], start_edges: Optional[Sequence[bool]] = None, end_edges: Optional[Sequence[bool]] = None, -) -> Sequence[Tuple[int]]: +) -> List[Tuple[slice, ...]]: """ Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of @@ -69,11 +65,11 @@ def _get_minimal_slice_set( # start_edges and end_edges both indicate whether, starting from any given # dimension, the start/end index is at the top/bottom edge of the # corresponding tensor, modeled as a tree - def reduce_edge_list(l): - tally = 1 + def reduce_edge_list(l: List[bool]) -> None: + tally = True for i in range(len(l)): reversed_idx = -1 * (i + 1) - l[reversed_idx] *= tally + l[reversed_idx] &= tally tally = l[reversed_idx] if start_edges is None: @@ -90,48 +86,54 @@ def reduce_edge_list(l): elif len(start) == 1: return [(slice(start[0], end[0] + 1),)] - slices = [] - path = [] + slices: List[Tuple[slice, ...]] = [] + path_list: List[slice] = [] # Dimensions common to start and end can be selected directly for s, e in zip(start, end): if s == e: - path.append(slice(s, s + 1)) + path_list.append(slice(s, s + 1)) else: break - path = tuple(path) + path: Tuple[slice, ...] = tuple(path_list) divergence_idx = len(path) # start == end, and we're done if divergence_idx == len(dims): - return [tuple(path)] + return [path] + + def upper() -> Tuple[Tuple[slice, ...], ...]: + assert start_edges is not None + assert end_edges is not None - def upper(): sdi = start[divergence_idx] - return [ + return tuple( path + (slice(sdi, sdi + 1),) + s for s in _get_minimal_slice_set( start[divergence_idx + 1 :], [d - 1 for d in dims[divergence_idx + 1 :]], dims[divergence_idx + 1 :], start_edges=start_edges[divergence_idx + 1 :], - end_edges=[1 for _ in end_edges[divergence_idx + 1 :]], + end_edges=[True for _ in end_edges[divergence_idx + 1 :]], ) - ] + ) + + def lower() -> Tuple[Tuple[slice, ...], ...]: + assert start_edges is not None + assert end_edges is not None - def lower(): edi = end[divergence_idx] - return [ + return tuple( path + (slice(edi, edi + 1),) + s for s in _get_minimal_slice_set( [0 for _ in start[divergence_idx + 1 :]], end[divergence_idx + 1 :], dims[divergence_idx + 1 :], - start_edges=[1 for _ in start_edges[divergence_idx + 1 :]], + start_edges=[True for _ in start_edges[divergence_idx + 1 :]], end_edges=end_edges[divergence_idx + 1 :], ) - ] + ) # If both start and end are at the edges of the subtree rooted at # divergence_idx, we can just select the whole subtree at once @@ -156,16 +158,11 @@ def lower(): slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)) slices.extend(lower()) - return [tuple(s) for s in slices] + return slices @torch.jit.ignore -def _chunk_slice( - t: torch.Tensor, - flat_start: int, - flat_end: int, - no_batch_dims: int, -) -> torch.Tensor: +def _chunk_slice(t: torch.Tensor, flat_start: int, flat_end: int, no_batch_dims: int) -> torch.Tensor: """ Equivalent to @@ -232,7 +229,7 @@ def chunk_layer( initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) - def _prep_inputs(t): + def _prep_inputs(t: torch.Tensor) -> torch.Tensor: if not low_mem: if not sum(t.shape[:no_batch_dims]) == no_batch_dims: t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) @@ -241,7 +238,7 @@ def _prep_inputs(t): t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) return t - prepped_inputs = tensor_tree_map(_prep_inputs, inputs) + prepped_inputs: Dict[str, Any] = tensor_tree_map(_prep_inputs, inputs) prepped_outputs = None if _out is not None: prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out) @@ -252,7 +249,7 @@ def _prep_inputs(t): no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0) - def _select_chunk(t): + def _select_chunk(t: torch.Tensor) -> torch.Tensor: return t[i : i + chunk_size] if t.shape[0] != 1 else t i = 0 @@ -269,7 +266,7 @@ def _select_chunk(t): no_batch_dims=len(orig_batch_dims), ) - chunks = tensor_tree_map(select_chunk, prepped_inputs) + chunks: Dict[str, Any] = tensor_tree_map(select_chunk, prepped_inputs) # Run the layer on the chunk output_chunk = layer(**chunks) @@ -279,12 +276,11 @@ def _select_chunk(t): out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk) # Put the chunk in its pre-allocated space - out_type = type(output_chunk) - if out_type is dict: + if isinstance(output_chunk, dict): - def assign(d1, d2): + def assign(d1: dict, d2: dict) -> None: for k, v in d1.items(): - if type(v) is dict: + if isinstance(v, dict): assign(v, d2[k]) else: if _add_into_out: @@ -293,13 +289,13 @@ def assign(d1, d2): v[i : i + chunk_size] = d2[k] assign(out, output_chunk) - elif out_type is tuple: + elif isinstance(output_chunk, tuple): for x1, x2 in zip(out, output_chunk): if _add_into_out: x1[i : i + chunk_size] += x2 else: x1[i : i + chunk_size] = x2 - elif out_type is torch.Tensor: + elif isinstance(output_chunk, torch.Tensor): if _add_into_out: out[i : i + chunk_size] += output_chunk else: @@ -319,24 +315,24 @@ def __init__( self, # Heuristically, runtimes for most of the modules in the network # plateau earlier than this on all GPUs I've run the model on. - max_chunk_size=512, + max_chunk_size: int = 512, ): self.max_chunk_size = max_chunk_size - self.cached_chunk_size = None - self.cached_arg_data = None + self.cached_chunk_size: Optional[int] = None + self.cached_arg_data: Optional[tuple] = None - def _determine_favorable_chunk_size(self, fn, args, min_chunk_size): + def _determine_favorable_chunk_size(self, fn: Callable, args: tuple, min_chunk_size: int) -> int: logging.info("Tuning chunk size...") if min_chunk_size >= self.max_chunk_size: return min_chunk_size - candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)] + candidates: List[int] = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)] candidates = [c for c in candidates if c > min_chunk_size] candidates = [min_chunk_size] + candidates candidates[-1] += 4 - def test_chunk_size(chunk_size): + def test_chunk_size(chunk_size: int) -> bool: try: with torch.no_grad(): fn(*args, chunk_size=chunk_size) @@ -356,13 +352,13 @@ def test_chunk_size(chunk_size): return candidates[min_viable_chunk_size_index] - def _compare_arg_caches(self, ac1, ac2): + def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool: consistent = True for a1, a2 in zip(ac1, ac2): assert type(ac1) == type(ac2) - if type(ac1) is list or type(ac1) is tuple: + if isinstance(ac1, (list, tuple)): consistent &= self._compare_arg_caches(a1, a2) - elif type(ac1) is dict: + elif isinstance(ac1, dict): a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])] a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])] consistent &= self._compare_arg_caches(a1_items, a2_items) @@ -374,11 +370,11 @@ def _compare_arg_caches(self, ac1, ac2): def tune_chunk_size( self, representative_fn: Callable, - args: Tuple[Any], + args: tuple, min_chunk_size: int, ) -> int: consistent = True - arg_data = tree_map(lambda a: a.shape if type(a) is torch.Tensor else a, args, object) + arg_data: tuple = tree_map(lambda a: a.shape if isinstance(a, torch.Tensor) else a, args, object) if self.cached_arg_data is not None: # If args have changed shape/value, we need to re-tune assert len(self.cached_arg_data) == len(arg_data) @@ -395,4 +391,6 @@ def tune_chunk_size( ) self.cached_arg_data = arg_data + assert self.cached_chunk_size is not None + return self.cached_chunk_size diff --git a/src/transformers/models/esm/openfold_utils/data_transforms.py b/src/transformers/models/esm/openfold_utils/data_transforms.py index e9e5d693169d..8d4c17589ae6 100644 --- a/src/transformers/models/esm/openfold_utils/data_transforms.py +++ b/src/transformers/models/esm/openfold_utils/data_transforms.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict + import numpy as np import torch @@ -20,39 +22,39 @@ from .tensor_utils import tensor_tree_map, tree_map -def make_atom14_masks(protein): +def make_atom14_masks(protein: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Construct denser atom positions (14 dimensions instead of 37).""" - restype_atom14_to_atom37 = [] - restype_atom37_to_atom14 = [] - restype_atom14_mask = [] + restype_atom14_to_atom37_list = [] + restype_atom37_to_atom14_list = [] + restype_atom14_mask_list = [] for rt in rc.restypes: atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] - restype_atom14_to_atom37.append([(rc.atom_order[name] if name else 0) for name in atom_names]) + restype_atom14_to_atom37_list.append([(rc.atom_order[name] if name else 0) for name in atom_names]) atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} - restype_atom37_to_atom14.append( + restype_atom37_to_atom14_list.append( [(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types] ) - restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names]) + restype_atom14_mask_list.append([(1.0 if name else 0.0) for name in atom_names]) # Add dummy mapping for restype 'UNK' - restype_atom14_to_atom37.append([0] * 14) - restype_atom37_to_atom14.append([0] * 37) - restype_atom14_mask.append([0.0] * 14) + restype_atom14_to_atom37_list.append([0] * 14) + restype_atom37_to_atom14_list.append([0] * 37) + restype_atom14_mask_list.append([0.0] * 14) restype_atom14_to_atom37 = torch.tensor( - restype_atom14_to_atom37, + restype_atom14_to_atom37_list, dtype=torch.int32, device=protein["aatype"].device, ) restype_atom37_to_atom14 = torch.tensor( - restype_atom37_to_atom14, + restype_atom37_to_atom14_list, dtype=torch.int32, device=protein["aatype"].device, ) restype_atom14_mask = torch.tensor( - restype_atom14_mask, + restype_atom14_mask_list, dtype=torch.float32, device=protein["aatype"].device, ) @@ -85,8 +87,7 @@ def make_atom14_masks(protein): return protein -def make_atom14_masks_np(batch): +def make_atom14_masks_np(batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]: batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray) - out = make_atom14_masks(batch) - out = tensor_tree_map(lambda t: np.array(t), out) + out = tensor_tree_map(lambda t: np.array(t), make_atom14_masks(batch)) return out diff --git a/src/transformers/models/esm/openfold_utils/feats.py b/src/transformers/models/esm/openfold_utils/feats.py index dbfda88805f7..18b01a1fecac 100644 --- a/src/transformers/models/esm/openfold_utils/feats.py +++ b/src/transformers/models/esm/openfold_utils/feats.py @@ -13,14 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Tuple, overload + import torch -import torch.nn as nn +import torch.types +from torch import nn from . import residue_constants as rc from .rigid_utils import Rigid, Rotation from .tensor_utils import batched_gather +@overload +def pseudo_beta_fn(aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: None) -> torch.Tensor: + ... + + +@overload +def pseudo_beta_fn( + aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + ... + + def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): is_gly = aatype == rc.restype_order["G"] ca_idx = rc.atom_order["CA"] @@ -42,7 +57,7 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): return pseudo_beta -def atom14_to_atom37(atom14, batch): +def atom14_to_atom37(atom14: torch.Tensor, batch: Dict[str, torch.Tensor]) -> torch.Tensor: atom37_data = batched_gather( atom14, batch["residx_atom37_to_atom14"], @@ -55,7 +70,7 @@ def atom14_to_atom37(atom14, batch): return atom37_data -def build_template_angle_feat(template_feats): +def build_template_angle_feat(template_feats: Dict[str, torch.Tensor]) -> torch.Tensor: template_aatype = template_feats["template_aatype"] torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"] @@ -73,7 +88,15 @@ def build_template_angle_feat(template_feats): return template_angle_feat -def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8): +def build_template_pair_feat( + batch: Dict[str, torch.Tensor], + min_bin: torch.types.Number, + max_bin: torch.types.Number, + no_bins: int, + use_unit_vector: bool = False, + eps: float = 1e-20, + inf: float = 1e8, +) -> torch.Tensor: template_mask = batch["template_pseudo_beta_mask"] template_mask_2d = template_mask[..., None] * template_mask[..., None, :] @@ -86,7 +109,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=F to_concat = [dgram, template_mask_2d[..., None]] - aatype_one_hot = nn.functional.one_hot( + aatype_one_hot: torch.LongTensor = nn.functional.one_hot( batch["template_aatype"], rc.restype_num + 2, ) @@ -126,8 +149,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=F return act -def build_extra_msa_feat(batch): - msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23) +def build_extra_msa_feat(batch: Dict[str, torch.Tensor]) -> torch.Tensor: + msa_1hot: torch.LongTensor = nn.functional.one_hot(batch["extra_msa"], 23) msa_feat = [ msa_1hot, batch["extra_has_deletion"].unsqueeze(-1), @@ -141,7 +164,7 @@ def torsion_angles_to_frames( alpha: torch.Tensor, aatype: torch.Tensor, rrgdf: torch.Tensor, -): +) -> Rigid: # [*, N, 8, 4, 4] default_4x4 = rrgdf[aatype, ...] @@ -172,9 +195,7 @@ def torsion_angles_to_frames( all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 2, 1:] = alpha - all_rots = Rigid(Rotation(rot_mats=all_rots), None) - - all_frames = default_r.compose(all_rots) + all_frames = default_r.compose(Rigid(Rotation(rot_mats=all_rots), None)) chi2_frame_to_frame = all_frames[..., 5] chi3_frame_to_frame = all_frames[..., 6] @@ -203,22 +224,22 @@ def torsion_angles_to_frames( def frames_and_literature_positions_to_atom14_pos( r: Rigid, aatype: torch.Tensor, - default_frames, - group_idx, - atom_mask, - lit_positions, -): + default_frames: torch.Tensor, + group_idx: torch.Tensor, + atom_mask: torch.Tensor, + lit_positions: torch.Tensor, +) -> torch.Tensor: # [*, N, 14] group_mask = group_idx[aatype, ...] # [*, N, 14, 8] - group_mask = nn.functional.one_hot( + group_mask_one_hot: torch.LongTensor = nn.functional.one_hot( group_mask, num_classes=default_frames.shape[-3], ) # [*, N, 14, 8] - t_atoms_to_global = r[..., None, :] * group_mask + t_atoms_to_global = r[..., None, :] * group_mask_one_hot # [*, N, 14] t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) diff --git a/src/transformers/models/esm/openfold_utils/loss.py b/src/transformers/models/esm/openfold_utils/loss.py index 4d60b1049137..e9523491d519 100644 --- a/src/transformers/models/esm/openfold_utils/loss.py +++ b/src/transformers/models/esm/openfold_utils/loss.py @@ -18,7 +18,7 @@ import torch -def _calculate_bin_centers(boundaries: torch.Tensor): +def _calculate_bin_centers(boundaries: torch.Tensor) -> torch.Tensor: step = boundaries[1] - boundaries[0] bin_centers = boundaries + step / 2 bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0) diff --git a/src/transformers/models/esm/openfold_utils/protein.py b/src/transformers/models/esm/openfold_utils/protein.py index 750027117a81..32e01571715c 100644 --- a/src/transformers/models/esm/openfold_utils/protein.py +++ b/src/transformers/models/esm/openfold_utils/protein.py @@ -17,7 +17,7 @@ import dataclasses import re import string -from typing import Any, Mapping, Optional, Sequence +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple import numpy as np @@ -69,10 +69,10 @@ class Protein: def from_proteinnet_string(proteinnet_str: str) -> Protein: tag_re = r"(\[[A-Z]+\]\n)" - tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0] - groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]]) + tags: List[str] = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0] + groups: Iterator[Tuple[str, List[str]]] = zip(tags[0::2], [l.split("\n") for l in tags[1::2]]) - atoms = ["N", "CA", "C"] + atoms: List[str] = ["N", "CA", "C"] aatype = None atom_positions = None atom_mask = None @@ -81,12 +81,12 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: seq = g[1][0].strip() for i in range(len(seq)): if seq[i] not in residue_constants.restypes: - seq[i] = "X" + seq[i] = "X" # FIXME: strings are immutable aatype = np.array( [residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq] ) elif "[TERTIARY]" == g[0]: - tertiary = [] + tertiary: List[List[float]] = [] for axis in range(3): tertiary.append(list(map(float, g[1][axis].split()))) tertiary_np = np.array(tertiary) @@ -106,6 +106,8 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: atom_mask[:, residue_constants.atom_order[atom]] = 1 atom_mask *= mask[..., None] + assert aatype is not None + return Protein( atom_positions=atom_positions, atom_mask=atom_mask, @@ -115,8 +117,8 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: ) -def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]: - pdb_headers = [] +def get_pdb_headers(prot: Protein, chain_id: int = 0) -> List[str]: + pdb_headers: List[str] = [] remark = prot.remark if remark is not None: @@ -124,7 +126,7 @@ def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]: parents = prot.parents parents_chain_index = prot.parents_chain_index - if parents_chain_index is not None: + if parents is not None and parents_chain_index is not None: parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id] if parents is None or len(parents) == 0: @@ -139,18 +141,18 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str: """Add pdb headers to an existing PDB string. Useful during multi-chain recycling """ - out_pdb_lines = [] + out_pdb_lines: List[str] = [] lines = pdb_str.split("\n") remark = prot.remark if remark is not None: out_pdb_lines.append(f"REMARK {remark}") - parents_per_chain = None + parents_per_chain: List[List[str]] if prot.parents is not None and len(prot.parents) > 0: parents_per_chain = [] if prot.parents_chain_index is not None: - parent_dict = {} + parent_dict: Dict[str, List[str]] = {} for p, i in zip(prot.parents, prot.parents_chain_index): parent_dict.setdefault(str(i), []) parent_dict[str(i)].append(p) @@ -160,11 +162,11 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str: chain_parents = parent_dict.get(str(i), ["N/A"]) parents_per_chain.append(chain_parents) else: - parents_per_chain.append(prot.parents) + parents_per_chain.append(list(prot.parents)) else: parents_per_chain = [["N/A"]] - def make_parent_line(p): + def make_parent_line(p: Sequence[str]) -> str: return f"PARENT {' '.join(p)}" out_pdb_lines.append(make_parent_line(parents_per_chain[0])) @@ -196,12 +198,12 @@ def to_pdb(prot: Protein) -> str: """ restypes = residue_constants.restypes + ["X"] - def res_1to3(r): + def res_1to3(r: int) -> str: return residue_constants.restype_1to3.get(restypes[r], "UNK") atom_types = residue_constants.atom_types - pdb_lines = [] + pdb_lines: List[str] = [] atom_mask = prot.atom_mask aatype = prot.aatype @@ -221,6 +223,7 @@ def res_1to3(r): atom_index = 1 prev_chain_index = 0 chain_tags = string.ascii_uppercase + chain_tag = None # Add all atom sites. for i in range(n): res_name_3 = res_1to3(aatype[i]) @@ -313,15 +316,12 @@ def from_prediction( Returns: A protein instance. """ - if b_factors is None: - b_factors = np.zeros_like(result["final_atom_mask"]) - return Protein( aatype=features["aatype"], atom_positions=result["final_atom_positions"], atom_mask=result["final_atom_mask"], residue_index=features["residue_index"] + 1, - b_factors=b_factors, + b_factors=b_factors if b_factors is not None else np.zeros_like(result["final_atom_mask"]), chain_index=chain_index, remark=remark, parents=parents, diff --git a/src/transformers/models/esm/openfold_utils/residue_constants.py b/src/transformers/models/esm/openfold_utils/residue_constants.py index 37f0fc081d06..6cab95652c63 100644 --- a/src/transformers/models/esm/openfold_utils/residue_constants.py +++ b/src/transformers/models/esm/openfold_utils/residue_constants.py @@ -19,7 +19,7 @@ import copy import functools from importlib import resources -from typing import List, Mapping, Tuple +from typing import Dict, List, Mapping, Sequence, Tuple import numpy as np @@ -33,43 +33,21 @@ # Format: The list for each AA type contains chi1, chi2, chi3, chi4 in # this order (or a relevant subset from chi1 onwards). ALA and GLY don't have # chi angles so their chi angle lists are empty. -chi_angles_atoms = { +chi_angles_atoms: Dict[str, List[List[str]]] = { "ALA": [], # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. - "ARG": [ - ["N", "CA", "CB", "CG"], - ["CA", "CB", "CG", "CD"], - ["CB", "CG", "CD", "NE"], - ["CG", "CD", "NE", "CZ"], - ], + "ARG": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "NE"], ["CG", "CD", "NE", "CZ"]], "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], "CYS": [["N", "CA", "CB", "SG"]], - "GLN": [ - ["N", "CA", "CB", "CG"], - ["CA", "CB", "CG", "CD"], - ["CB", "CG", "CD", "OE1"], - ], - "GLU": [ - ["N", "CA", "CB", "CG"], - ["CA", "CB", "CG", "CD"], - ["CB", "CG", "CD", "OE1"], - ], + "GLN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]], + "GLU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]], "GLY": [], "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]], "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]], "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], - "LYS": [ - ["N", "CA", "CB", "CG"], - ["CA", "CB", "CG", "CD"], - ["CB", "CG", "CD", "CE"], - ["CG", "CD", "CE", "NZ"], - ], - "MET": [ - ["N", "CA", "CB", "CG"], - ["CA", "CB", "CG", "SD"], - ["CB", "CG", "SD", "CE"], - ], + "LYS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "CE"], ["CG", "CD", "CE", "NZ"]], + "MET": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "SD"], ["CB", "CG", "SD", "CE"]], "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]], "SER": [["N", "CA", "CB", "OG"]], @@ -81,7 +59,7 @@ # If chi angles given in fixed-length array, this matrix determines how to mask # them for each AA type. The order is as per restype_order (see below). -chi_angles_mask = [ +chi_angles_mask: List[List[float]] = [ [0.0, 0.0, 0.0, 0.0], # ALA [1.0, 1.0, 1.0, 1.0], # ARG [1.0, 1.0, 0.0, 0.0], # ASN @@ -106,7 +84,7 @@ # The following chi angles are pi periodic: they can be rotated by a multiple # of pi without affecting the structure. -chi_pi_periodic = [ +chi_pi_periodic: List[List[float]] = [ [0.0, 0.0, 0.0, 0.0], # ALA [0.0, 0.0, 0.0, 0.0], # ARG [0.0, 0.0, 0.0, 0.0], # ASN @@ -142,219 +120,219 @@ # is defined such that the dihedral-angle-definiting atom (the last entry in # chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). # format: [atomname, group_idx, rel_position] -rigid_group_atom_positions = { +rigid_group_atom_positions: Dict[str, List[Tuple[str, int, Tuple[float, float, float]]]] = { "ALA": [ - ["N", 0, (-0.525, 1.363, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.526, -0.000, -0.000)], - ["CB", 0, (-0.529, -0.774, -1.205)], - ["O", 3, (0.627, 1.062, 0.000)], + ("N", 0, (-0.525, 1.363, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, -0.000, -0.000)), + ("CB", 0, (-0.529, -0.774, -1.205)), + ("O", 3, (0.627, 1.062, 0.000)), ], "ARG": [ - ["N", 0, (-0.524, 1.362, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.525, -0.000, -0.000)], - ["CB", 0, (-0.524, -0.778, -1.209)], - ["O", 3, (0.626, 1.062, 0.000)], - ["CG", 4, (0.616, 1.390, -0.000)], - ["CD", 5, (0.564, 1.414, 0.000)], - ["NE", 6, (0.539, 1.357, -0.000)], - ["NH1", 7, (0.206, 2.301, 0.000)], - ["NH2", 7, (2.078, 0.978, -0.000)], - ["CZ", 7, (0.758, 1.093, -0.000)], + ("N", 0, (-0.524, 1.362, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, -0.000, -0.000)), + ("CB", 0, (-0.524, -0.778, -1.209)), + ("O", 3, (0.626, 1.062, 0.000)), + ("CG", 4, (0.616, 1.390, -0.000)), + ("CD", 5, (0.564, 1.414, 0.000)), + ("NE", 6, (0.539, 1.357, -0.000)), + ("NH1", 7, (0.206, 2.301, 0.000)), + ("NH2", 7, (2.078, 0.978, -0.000)), + ("CZ", 7, (0.758, 1.093, -0.000)), ], "ASN": [ - ["N", 0, (-0.536, 1.357, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.526, -0.000, -0.000)], - ["CB", 0, (-0.531, -0.787, -1.200)], - ["O", 3, (0.625, 1.062, 0.000)], - ["CG", 4, (0.584, 1.399, 0.000)], - ["ND2", 5, (0.593, -1.188, 0.001)], - ["OD1", 5, (0.633, 1.059, 0.000)], + ("N", 0, (-0.536, 1.357, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, -0.000, -0.000)), + ("CB", 0, (-0.531, -0.787, -1.200)), + ("O", 3, (0.625, 1.062, 0.000)), + ("CG", 4, (0.584, 1.399, 0.000)), + ("ND2", 5, (0.593, -1.188, 0.001)), + ("OD1", 5, (0.633, 1.059, 0.000)), ], "ASP": [ - ["N", 0, (-0.525, 1.362, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.527, 0.000, -0.000)], - ["CB", 0, (-0.526, -0.778, -1.208)], - ["O", 3, (0.626, 1.062, -0.000)], - ["CG", 4, (0.593, 1.398, -0.000)], - ["OD1", 5, (0.610, 1.091, 0.000)], - ["OD2", 5, (0.592, -1.101, -0.003)], + ("N", 0, (-0.525, 1.362, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.527, 0.000, -0.000)), + ("CB", 0, (-0.526, -0.778, -1.208)), + ("O", 3, (0.626, 1.062, -0.000)), + ("CG", 4, (0.593, 1.398, -0.000)), + ("OD1", 5, (0.610, 1.091, 0.000)), + ("OD2", 5, (0.592, -1.101, -0.003)), ], "CYS": [ - ["N", 0, (-0.522, 1.362, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.524, 0.000, 0.000)], - ["CB", 0, (-0.519, -0.773, -1.212)], - ["O", 3, (0.625, 1.062, -0.000)], - ["SG", 4, (0.728, 1.653, 0.000)], + ("N", 0, (-0.522, 1.362, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.524, 0.000, 0.000)), + ("CB", 0, (-0.519, -0.773, -1.212)), + ("O", 3, (0.625, 1.062, -0.000)), + ("SG", 4, (0.728, 1.653, 0.000)), ], "GLN": [ - ["N", 0, (-0.526, 1.361, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.526, 0.000, 0.000)], - ["CB", 0, (-0.525, -0.779, -1.207)], - ["O", 3, (0.626, 1.062, -0.000)], - ["CG", 4, (0.615, 1.393, 0.000)], - ["CD", 5, (0.587, 1.399, -0.000)], - ["NE2", 6, (0.593, -1.189, -0.001)], - ["OE1", 6, (0.634, 1.060, 0.000)], + ("N", 0, (-0.526, 1.361, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, 0.000, 0.000)), + ("CB", 0, (-0.525, -0.779, -1.207)), + ("O", 3, (0.626, 1.062, -0.000)), + ("CG", 4, (0.615, 1.393, 0.000)), + ("CD", 5, (0.587, 1.399, -0.000)), + ("NE2", 6, (0.593, -1.189, -0.001)), + ("OE1", 6, (0.634, 1.060, 0.000)), ], "GLU": [ - ["N", 0, (-0.528, 1.361, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.526, -0.000, -0.000)], - ["CB", 0, (-0.526, -0.781, -1.207)], - ["O", 3, (0.626, 1.062, 0.000)], - ["CG", 4, (0.615, 1.392, 0.000)], - ["CD", 5, (0.600, 1.397, 0.000)], - ["OE1", 6, (0.607, 1.095, -0.000)], - ["OE2", 6, (0.589, -1.104, -0.001)], + ("N", 0, (-0.528, 1.361, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, -0.000, -0.000)), + ("CB", 0, (-0.526, -0.781, -1.207)), + ("O", 3, (0.626, 1.062, 0.000)), + ("CG", 4, (0.615, 1.392, 0.000)), + ("CD", 5, (0.600, 1.397, 0.000)), + ("OE1", 6, (0.607, 1.095, -0.000)), + ("OE2", 6, (0.589, -1.104, -0.001)), ], "GLY": [ - ["N", 0, (-0.572, 1.337, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.517, -0.000, -0.000)], - ["O", 3, (0.626, 1.062, -0.000)], + ("N", 0, (-0.572, 1.337, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.517, -0.000, -0.000)), + ("O", 3, (0.626, 1.062, -0.000)), ], "HIS": [ - ["N", 0, (-0.527, 1.360, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.525, 0.000, 0.000)], - ["CB", 0, (-0.525, -0.778, -1.208)], - ["O", 3, (0.625, 1.063, 0.000)], - ["CG", 4, (0.600, 1.370, -0.000)], - ["CD2", 5, (0.889, -1.021, 0.003)], - ["ND1", 5, (0.744, 1.160, -0.000)], - ["CE1", 5, (2.030, 0.851, 0.002)], - ["NE2", 5, (2.145, -0.466, 0.004)], + ("N", 0, (-0.527, 1.360, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, 0.000, 0.000)), + ("CB", 0, (-0.525, -0.778, -1.208)), + ("O", 3, (0.625, 1.063, 0.000)), + ("CG", 4, (0.600, 1.370, -0.000)), + ("CD2", 5, (0.889, -1.021, 0.003)), + ("ND1", 5, (0.744, 1.160, -0.000)), + ("CE1", 5, (2.030, 0.851, 0.002)), + ("NE2", 5, (2.145, -0.466, 0.004)), ], "ILE": [ - ["N", 0, (-0.493, 1.373, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.527, -0.000, -0.000)], - ["CB", 0, (-0.536, -0.793, -1.213)], - ["O", 3, (0.627, 1.062, -0.000)], - ["CG1", 4, (0.534, 1.437, -0.000)], - ["CG2", 4, (0.540, -0.785, -1.199)], - ["CD1", 5, (0.619, 1.391, 0.000)], + ("N", 0, (-0.493, 1.373, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.527, -0.000, -0.000)), + ("CB", 0, (-0.536, -0.793, -1.213)), + ("O", 3, (0.627, 1.062, -0.000)), + ("CG1", 4, (0.534, 1.437, -0.000)), + ("CG2", 4, (0.540, -0.785, -1.199)), + ("CD1", 5, (0.619, 1.391, 0.000)), ], "LEU": [ - ["N", 0, (-0.520, 1.363, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.525, -0.000, -0.000)], - ["CB", 0, (-0.522, -0.773, -1.214)], - ["O", 3, (0.625, 1.063, -0.000)], - ["CG", 4, (0.678, 1.371, 0.000)], - ["CD1", 5, (0.530, 1.430, -0.000)], - ["CD2", 5, (0.535, -0.774, 1.200)], + ("N", 0, (-0.520, 1.363, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, -0.000, -0.000)), + ("CB", 0, (-0.522, -0.773, -1.214)), + ("O", 3, (0.625, 1.063, -0.000)), + ("CG", 4, (0.678, 1.371, 0.000)), + ("CD1", 5, (0.530, 1.430, -0.000)), + ("CD2", 5, (0.535, -0.774, 1.200)), ], "LYS": [ - ["N", 0, (-0.526, 1.362, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.526, 0.000, 0.000)], - ["CB", 0, (-0.524, -0.778, -1.208)], - ["O", 3, (0.626, 1.062, -0.000)], - ["CG", 4, (0.619, 1.390, 0.000)], - ["CD", 5, (0.559, 1.417, 0.000)], - ["CE", 6, (0.560, 1.416, 0.000)], - ["NZ", 7, (0.554, 1.387, 0.000)], + ("N", 0, (-0.526, 1.362, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, 0.000, 0.000)), + ("CB", 0, (-0.524, -0.778, -1.208)), + ("O", 3, (0.626, 1.062, -0.000)), + ("CG", 4, (0.619, 1.390, 0.000)), + ("CD", 5, (0.559, 1.417, 0.000)), + ("CE", 6, (0.560, 1.416, 0.000)), + ("NZ", 7, (0.554, 1.387, 0.000)), ], "MET": [ - ["N", 0, (-0.521, 1.364, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.525, 0.000, 0.000)], - ["CB", 0, (-0.523, -0.776, -1.210)], - ["O", 3, (0.625, 1.062, -0.000)], - ["CG", 4, (0.613, 1.391, -0.000)], - ["SD", 5, (0.703, 1.695, 0.000)], - ["CE", 6, (0.320, 1.786, -0.000)], + ("N", 0, (-0.521, 1.364, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, 0.000, 0.000)), + ("CB", 0, (-0.523, -0.776, -1.210)), + ("O", 3, (0.625, 1.062, -0.000)), + ("CG", 4, (0.613, 1.391, -0.000)), + ("SD", 5, (0.703, 1.695, 0.000)), + ("CE", 6, (0.320, 1.786, -0.000)), ], "PHE": [ - ["N", 0, (-0.518, 1.363, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.524, 0.000, -0.000)], - ["CB", 0, (-0.525, -0.776, -1.212)], - ["O", 3, (0.626, 1.062, -0.000)], - ["CG", 4, (0.607, 1.377, 0.000)], - ["CD1", 5, (0.709, 1.195, -0.000)], - ["CD2", 5, (0.706, -1.196, 0.000)], - ["CE1", 5, (2.102, 1.198, -0.000)], - ["CE2", 5, (2.098, -1.201, -0.000)], - ["CZ", 5, (2.794, -0.003, -0.001)], + ("N", 0, (-0.518, 1.363, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.524, 0.000, -0.000)), + ("CB", 0, (-0.525, -0.776, -1.212)), + ("O", 3, (0.626, 1.062, -0.000)), + ("CG", 4, (0.607, 1.377, 0.000)), + ("CD1", 5, (0.709, 1.195, -0.000)), + ("CD2", 5, (0.706, -1.196, 0.000)), + ("CE1", 5, (2.102, 1.198, -0.000)), + ("CE2", 5, (2.098, -1.201, -0.000)), + ("CZ", 5, (2.794, -0.003, -0.001)), ], "PRO": [ - ["N", 0, (-0.566, 1.351, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.527, -0.000, 0.000)], - ["CB", 0, (-0.546, -0.611, -1.293)], - ["O", 3, (0.621, 1.066, 0.000)], - ["CG", 4, (0.382, 1.445, 0.0)], - # ['CD', 5, (0.427, 1.440, 0.0)], - ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ("N", 0, (-0.566, 1.351, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.527, -0.000, 0.000)), + ("CB", 0, (-0.546, -0.611, -1.293)), + ("O", 3, (0.621, 1.066, 0.000)), + ("CG", 4, (0.382, 1.445, 0.0)), + # ('CD', 5, (0.427, 1.440, 0.0)), + ("CD", 5, (0.477, 1.424, 0.0)), # manually made angle 2 degrees larger ], "SER": [ - ["N", 0, (-0.529, 1.360, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.525, -0.000, -0.000)], - ["CB", 0, (-0.518, -0.777, -1.211)], - ["O", 3, (0.626, 1.062, -0.000)], - ["OG", 4, (0.503, 1.325, 0.000)], + ("N", 0, (-0.529, 1.360, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, -0.000, -0.000)), + ("CB", 0, (-0.518, -0.777, -1.211)), + ("O", 3, (0.626, 1.062, -0.000)), + ("OG", 4, (0.503, 1.325, 0.000)), ], "THR": [ - ["N", 0, (-0.517, 1.364, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.526, 0.000, -0.000)], - ["CB", 0, (-0.516, -0.793, -1.215)], - ["O", 3, (0.626, 1.062, 0.000)], - ["CG2", 4, (0.550, -0.718, -1.228)], - ["OG1", 4, (0.472, 1.353, 0.000)], + ("N", 0, (-0.517, 1.364, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, 0.000, -0.000)), + ("CB", 0, (-0.516, -0.793, -1.215)), + ("O", 3, (0.626, 1.062, 0.000)), + ("CG2", 4, (0.550, -0.718, -1.228)), + ("OG1", 4, (0.472, 1.353, 0.000)), ], "TRP": [ - ["N", 0, (-0.521, 1.363, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.525, -0.000, 0.000)], - ["CB", 0, (-0.523, -0.776, -1.212)], - ["O", 3, (0.627, 1.062, 0.000)], - ["CG", 4, (0.609, 1.370, -0.000)], - ["CD1", 5, (0.824, 1.091, 0.000)], - ["CD2", 5, (0.854, -1.148, -0.005)], - ["CE2", 5, (2.186, -0.678, -0.007)], - ["CE3", 5, (0.622, -2.530, -0.007)], - ["NE1", 5, (2.140, 0.690, -0.004)], - ["CH2", 5, (3.028, -2.890, -0.013)], - ["CZ2", 5, (3.283, -1.543, -0.011)], - ["CZ3", 5, (1.715, -3.389, -0.011)], + ("N", 0, (-0.521, 1.363, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, -0.000, 0.000)), + ("CB", 0, (-0.523, -0.776, -1.212)), + ("O", 3, (0.627, 1.062, 0.000)), + ("CG", 4, (0.609, 1.370, -0.000)), + ("CD1", 5, (0.824, 1.091, 0.000)), + ("CD2", 5, (0.854, -1.148, -0.005)), + ("CE2", 5, (2.186, -0.678, -0.007)), + ("CE3", 5, (0.622, -2.530, -0.007)), + ("NE1", 5, (2.140, 0.690, -0.004)), + ("CH2", 5, (3.028, -2.890, -0.013)), + ("CZ2", 5, (3.283, -1.543, -0.011)), + ("CZ3", 5, (1.715, -3.389, -0.011)), ], "TYR": [ - ["N", 0, (-0.522, 1.362, 0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.524, -0.000, -0.000)], - ["CB", 0, (-0.522, -0.776, -1.213)], - ["O", 3, (0.627, 1.062, -0.000)], - ["CG", 4, (0.607, 1.382, -0.000)], - ["CD1", 5, (0.716, 1.195, -0.000)], - ["CD2", 5, (0.713, -1.194, -0.001)], - ["CE1", 5, (2.107, 1.200, -0.002)], - ["CE2", 5, (2.104, -1.201, -0.003)], - ["OH", 5, (4.168, -0.002, -0.005)], - ["CZ", 5, (2.791, -0.001, -0.003)], + ("N", 0, (-0.522, 1.362, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.524, -0.000, -0.000)), + ("CB", 0, (-0.522, -0.776, -1.213)), + ("O", 3, (0.627, 1.062, -0.000)), + ("CG", 4, (0.607, 1.382, -0.000)), + ("CD1", 5, (0.716, 1.195, -0.000)), + ("CD2", 5, (0.713, -1.194, -0.001)), + ("CE1", 5, (2.107, 1.200, -0.002)), + ("CE2", 5, (2.104, -1.201, -0.003)), + ("OH", 5, (4.168, -0.002, -0.005)), + ("CZ", 5, (2.791, -0.001, -0.003)), ], "VAL": [ - ["N", 0, (-0.494, 1.373, -0.000)], - ["CA", 0, (0.000, 0.000, 0.000)], - ["C", 0, (1.527, -0.000, -0.000)], - ["CB", 0, (-0.533, -0.795, -1.213)], - ["O", 3, (0.627, 1.062, -0.000)], - ["CG1", 4, (0.540, 1.429, -0.000)], - ["CG2", 4, (0.533, -0.776, 1.203)], + ("N", 0, (-0.494, 1.373, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.527, -0.000, -0.000)), + ("CB", 0, (-0.533, -0.795, -1.213)), + ("O", 3, (0.627, 1.062, -0.000)), + ("CG1", 4, (0.540, 1.429, -0.000)), + ("CG2", 4, (0.533, -0.776, 1.203)), ], } # A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. -residue_atoms = { +residue_atoms: Dict[str, List[str]] = { "ALA": ["C", "CA", "CB", "N", "O"], "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"], "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], @@ -372,36 +350,8 @@ "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"], "SER": ["C", "CA", "CB", "N", "O", "OG"], "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"], - "TRP": [ - "C", - "CA", - "CB", - "CG", - "CD1", - "CD2", - "CE2", - "CE3", - "CZ2", - "CZ3", - "CH2", - "N", - "NE1", - "O", - ], - "TYR": [ - "C", - "CA", - "CB", - "CG", - "CD1", - "CD2", - "CE1", - "CE2", - "CZ", - "N", - "O", - "OH", - ], + "TRP": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "CZ2", "CZ3", "CH2", "N", "NE1", "O"], + "TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"], "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"], } @@ -412,7 +362,7 @@ # in LEU, VAL and ARG can be resolved by using the 3d constellations of # the 'ambiguous' atoms and their neighbours) # TODO: ^ interpret this -residue_atom_renaming_swaps = { +residue_atom_renaming_swaps: Dict[str, Dict[str, str]] = { "ASP": {"OD1": "OD2"}, "GLU": {"OE1": "OE2"}, "PHE": {"CD1": "CD2", "CE1": "CE2"}, @@ -420,7 +370,7 @@ } # Van der Waals radii [Angstroem] of the atoms (from Wikipedia) -van_der_waals_radius = { +van_der_waals_radius: Dict[str, float] = { "C": 1.7, "N": 1.55, "O": 1.52, @@ -434,7 +384,7 @@ ) -def map_structure_with_atom_order(in_list: List, first_call: bool = True): +def map_structure_with_atom_order(in_list: list, first_call: bool = True) -> list: # Maps strings in a nested list structure to their corresponding index in atom_order if first_call: in_list = copy.deepcopy(in_list) @@ -468,20 +418,20 @@ def load_stereo_chemical_props() -> Tuple[ lines_iter = iter(stereo_chemical_props.splitlines()) # Load bond lengths. - residue_bonds = {} + residue_bonds: Dict[str, List[Bond]] = {} next(lines_iter) # Skip header line. for line in lines_iter: if line.strip() == "-": break - bond, resname, length, stddev = line.split() + bond, resname, bond_length, stddev = line.split() atom1, atom2 = bond.split("-") if resname not in residue_bonds: residue_bonds[resname] = [] - residue_bonds[resname].append(Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds[resname].append(Bond(atom1, atom2, float(bond_length), float(stddev))) residue_bonds["UNK"] = [] # Load bond angles. - residue_bond_angles = {} + residue_bond_angles: Dict[str, List[BondAngle]] = {} next(lines_iter) # Skip empty line. next(lines_iter) # Skip header line. for line in lines_iter: @@ -502,15 +452,15 @@ def load_stereo_chemical_props() -> Tuple[ ) residue_bond_angles["UNK"] = [] - def make_bond_key(atom1_name, atom2_name): + def make_bond_key(atom1_name: str, atom2_name: str) -> str: """Unique key to lookup bonds.""" return "-".join(sorted([atom1_name, atom2_name])) # Translate bond angles into distances ("virtual bonds"). - residue_virtual_bonds = {} + residue_virtual_bonds: Dict[str, List[Bond]] = {} for resname, bond_angles in residue_bond_angles.items(): # Create a fast lookup dict for bond lengths. - bond_cache = {} + bond_cache: Dict[str, Bond] = {} for b in residue_bonds[resname]: bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b residue_virtual_bonds[resname] = [] @@ -538,16 +488,16 @@ def make_bond_key(atom1_name, atom2_name): # Between-residue bond lengths for general bonds (first element) and for Proline # (second element). -between_res_bond_length_c_n = [1.329, 1.341] -between_res_bond_length_stddev_c_n = [0.014, 0.016] +between_res_bond_length_c_n: Tuple[float, float] = (1.329, 1.341) +between_res_bond_length_stddev_c_n: Tuple[float, float] = (0.014, 0.016) # Between-residue cos_angles. -between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 -between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 +between_res_cos_angles_c_n_ca: Tuple[float, float] = (-0.5203, 0.0353) # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n: Tuple[float, float] = (-0.4473, 0.0311) # degrees: 116.568 +- 1.995 # This mapping is used when we need to store atom data in a format that requires # fixed atom data size for every residue (e.g. a numpy array). -atom_types = [ +atom_types: List[str] = [ "N", "CA", "C", @@ -586,258 +536,33 @@ def make_bond_key(atom1_name, atom2_name): "NZ", "OXT", ] -atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_order: Dict[str, int] = {atom_type: i for i, atom_type in enumerate(atom_types)} atom_type_num = len(atom_types) # := 37. # A compact atom encoding with 14 columns # pylint: disable=line-too-long # pylint: disable=bad-whitespace -restype_name_to_atom14_names = { +restype_name_to_atom14_names: Dict[str, List[str]] = { "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], - "ARG": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "CD", - "NE", - "CZ", - "NH1", - "NH2", - "", - "", - "", - ], - "ASN": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "OD1", - "ND2", - "", - "", - "", - "", - "", - "", - ], - "ASP": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "OD1", - "OD2", - "", - "", - "", - "", - "", - "", - ], + "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", ""], + "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""], + "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""], "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], - "GLN": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "CD", - "OE1", - "NE2", - "", - "", - "", - "", - "", - ], - "GLU": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "CD", - "OE1", - "OE2", - "", - "", - "", - "", - "", - ], + "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""], + "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""], "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], - "HIS": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "ND1", - "CD2", - "CE1", - "NE2", - "", - "", - "", - "", - ], - "ILE": [ - "N", - "CA", - "C", - "O", - "CB", - "CG1", - "CG2", - "CD1", - "", - "", - "", - "", - "", - "", - ], - "LEU": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "CD1", - "CD2", - "", - "", - "", - "", - "", - "", - ], - "LYS": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "CD", - "CE", - "NZ", - "", - "", - "", - "", - "", - ], - "MET": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "SD", - "CE", - "", - "", - "", - "", - "", - "", - ], - "PHE": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "CD1", - "CD2", - "CE1", - "CE2", - "CZ", - "", - "", - "", - ], + "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", ""], + "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""], + "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""], + "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""], + "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""], + "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", ""], "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], - "THR": [ - "N", - "CA", - "C", - "O", - "CB", - "OG1", - "CG2", - "", - "", - "", - "", - "", - "", - "", - ], - "TRP": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "CD1", - "CD2", - "NE1", - "CE2", - "CE3", - "CZ2", - "CZ3", - "CH2", - ], - "TYR": [ - "N", - "CA", - "C", - "O", - "CB", - "CG", - "CD1", - "CD2", - "CE1", - "CE2", - "CZ", - "OH", - "", - "", - ], - "VAL": [ - "N", - "CA", - "C", - "O", - "CB", - "CG1", - "CG2", - "", - "", - "", - "", - "", - "", - "", - ], + "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""], + "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], + "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", ""], + "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""], "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], } # pylint: enable=line-too-long @@ -846,7 +571,7 @@ def make_bond_key(atom1_name, atom2_name): # This is the standard residue order when coding AA type as a number. # Reproduce it by taking 3-letter AA codes and sorting them alphabetically. -restypes = [ +restypes: List[str] = [ "A", "R", "N", @@ -868,12 +593,12 @@ def make_bond_key(atom1_name, atom2_name): "Y", "V", ] -restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_order: Dict[str, int] = {restype: i for i, restype in enumerate(restypes)} restype_num = len(restypes) # := 20. unk_restype_index = restype_num # Catch-all index for unknown restypes. -restypes_with_x = restypes + ["X"] -restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} +restypes_with_x: List[str] = restypes + ["X"] +restype_order_with_x: Dict[str, int] = {restype: i for i, restype in enumerate(restypes_with_x)} def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False) -> np.ndarray: @@ -916,7 +641,7 @@ def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to return one_hot_arr -restype_1to3 = { +restype_1to3: Dict[str, str] = { "A": "ALA", "R": "ARG", "N": "ASN", @@ -944,13 +669,13 @@ def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to # 1-to-1 mapping of 3 letter names to one letter names. The latter contains # many more, and less common, three letter names as keys and maps many of these # to the same one letter name (including 'X' and 'U' which we don't use here). -restype_3to1 = {v: k for k, v in restype_1to3.items()} +restype_3to1: Dict[str, str] = {v: k for k, v in restype_1to3.items()} # Define a restype name for all unknown residues. unk_restype = "UNK" -resnames = [restype_1to3[r] for r in restypes] + [unk_restype] -resname_to_idx = {resname: i for i, resname in enumerate(resnames)} +resnames: List[str] = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx: Dict[str, int] = {resname: i for i, resname in enumerate(resnames)} # The mapping here uses hhblits convention, so that B is mapped to D, J and O @@ -960,7 +685,7 @@ def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to # "-" representing a missing amino acid in an alignment. The id for these # codes is put at the end (20 and 21) so that they can easily be ignored if # desired. -HHBLITS_AA_TO_ID = { +HHBLITS_AA_TO_ID: Dict[str, int] = { "A": 0, "B": 2, "C": 1, @@ -991,7 +716,7 @@ def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to } # Partial inversion of HHBLITS_AA_TO_ID. -ID_TO_HHBLITS_AA = { +ID_TO_HHBLITS_AA: Dict[int, str] = { 0: "A", 1: "C", # Also U. 2: "D", # Also B. @@ -1016,8 +741,8 @@ def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to 21: "-", } -restypes_with_x_and_gap = restypes + ["X", "-"] -MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( +restypes_with_x_and_gap: List[str] = restypes + ["X", "-"] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE: Tuple[int, ...] = tuple( restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) for i in range(len(restypes_with_x_and_gap)) ) @@ -1066,15 +791,15 @@ def chi_angle_atom(atom_index: int) -> np.ndarray: chi_atom_2_one_hot = chi_angle_atom(2) # An array like chi_angles_atoms but using indices rather than names. -chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] -chi_angles_atom_indices_ours = map_structure_with_atom_order(chi_angles_atom_indices) +chi_angles_atom_indices_list: List[List[List[str]]] = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices_ours: list = map_structure_with_atom_order(chi_angles_atom_indices_list) chi_angles_atom_indices = np.array( - [chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) for chi_atoms in chi_angles_atom_indices] + [chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) for chi_atoms in chi_angles_atom_indices_list] ) # Mapping from (res_name, atom_name) pairs to the atom's chi group index # and atom index within that group. -chi_groups_for_atom = collections.defaultdict(list) +chi_groups_for_atom: Dict[Tuple[str, str], List[Tuple[int, int]]] = collections.defaultdict(list) for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): for atom_i, atom in enumerate(chi_group): @@ -1082,7 +807,7 @@ def chi_angle_atom(atom_index: int) -> np.ndarray: chi_groups_for_atom = dict(chi_groups_for_atom) -def _make_rigid_transformation_4x4(ex, ey, translation): +def _make_rigid_transformation_4x4(ex: np.ndarray, ey: np.ndarray, translation: np.ndarray) -> np.ndarray: """Create a rigid 4x4 transformation matrix from two axes and transl.""" # Normalize ex. ex_normalized = ex / np.linalg.norm(ex) @@ -1111,7 +836,7 @@ def _make_rigid_transformation_4x4(ex, ey, translation): restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) -def _make_rigid_group_constants(): +def _make_rigid_group_constants() -> None: """Fill the arrays above.""" for restype, restype_letter in enumerate(restypes): resname = restype_1to3[restype_letter] @@ -1128,7 +853,9 @@ def _make_rigid_group_constants(): for restype, restype_letter in enumerate(restypes): resname = restype_1to3[restype_letter] - atom_positions = {name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]} + atom_positions: Dict[str, np.ndarray] = { + name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname] + } # backbone to backbone is the identity transform restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) @@ -1183,7 +910,10 @@ def _make_rigid_group_constants(): _make_rigid_group_constants() -def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15): +def make_atom14_dists_bounds( + overlap_tolerance: float = 1.5, + bond_length_tolerance_factor: int = 15, +) -> Dict[str, np.ndarray]: """compute upper and lower bounds for bonds to assess violations.""" restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) @@ -1229,10 +959,10 @@ def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32) -restype_atom14_ambiguous_atoms_swap_idx = np.tile(np.arange(14, dtype=int), (21, 1)) +restype_atom14_ambiguous_atoms_swap_idx: np.ndarray = np.tile(np.arange(14, dtype=int), (21, 1)) -def _make_atom14_ambiguity_feats(): +def _make_atom14_ambiguity_feats() -> None: for res, pairs in residue_atom_renaming_swaps.items(): res_idx = restype_order[restype_3to1[res]] for atom1, atom2 in pairs.items(): @@ -1247,5 +977,5 @@ def _make_atom14_ambiguity_feats(): _make_atom14_ambiguity_feats() -def aatype_to_str_sequence(aatype): +def aatype_to_str_sequence(aatype: Sequence[int]) -> str: return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))]) diff --git a/src/transformers/models/esm/openfold_utils/rigid_utils.py b/src/transformers/models/esm/openfold_utils/rigid_utils.py index c437cf7953f8..2bc2fe5f5c4e 100644 --- a/src/transformers/models/esm/openfold_utils/rigid_utils.py +++ b/src/transformers/models/esm/openfold_utils/rigid_utils.py @@ -16,7 +16,7 @@ from __future__ import annotations from functools import lru_cache -from typing import Any, Callable, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import numpy as np import torch @@ -33,7 +33,7 @@ def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: The product ab """ - def row_mul(i): + def row_mul(i: int) -> torch.Tensor: return torch.stack( [ a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0], @@ -76,7 +76,7 @@ def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: @lru_cache(maxsize=None) def identity_rot_mats( - batch_dims: Tuple[int], + batch_dims: Tuple[int, ...], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = True, @@ -91,7 +91,7 @@ def identity_rot_mats( @lru_cache(maxsize=None) def identity_trans( - batch_dims: Tuple[int], + batch_dims: Tuple[int, ...], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = True, @@ -102,7 +102,7 @@ def identity_trans( @lru_cache(maxsize=None) def identity_quats( - batch_dims: Tuple[int], + batch_dims: Tuple[int, ...], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = True, @@ -115,15 +115,14 @@ def identity_quats( return quat -_quat_elements = ["a", "b", "c", "d"] -_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] -_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} +_quat_elements: List[str] = ["a", "b", "c", "d"] +_qtr_keys: List[str] = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] +_qtr_ind_dict: Dict[str, int] = {key: ind for ind, key in enumerate(_qtr_keys)} -def _to_mat(pairs): +def _to_mat(pairs: List[Tuple[str, int]]) -> np.ndarray: mat = np.zeros((4, 4)) - for pair in pairs: - key, value = pair + for key, value in pairs: ind = _qtr_ind_dict[key] mat[ind // 4][ind % 4] = value @@ -165,14 +164,11 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor: return torch.sum(quat, dim=(-3, -4)) -def rot_to_quat( - rot: torch.Tensor, -): +def rot_to_quat(rot: torch.Tensor) -> torch.Tensor: if rot.shape[-2:] != (3, 3): raise ValueError("Input rotation is incorrectly shaped") - rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] - [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = [[rot[..., i, j] for j in range(3)] for i in range(3)] k = [ [ @@ -201,9 +197,7 @@ def rot_to_quat( ], ] - k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) - - _, vectors = torch.linalg.eigh(k) + _, vectors = torch.linalg.eigh((1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)) return vectors[..., -1] @@ -218,7 +212,7 @@ def rot_to_quat( _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] -_CACHED_QUATS = { +_CACHED_QUATS: Dict[str, np.ndarray] = { "_QTR_MAT": _QTR_MAT, "_QUAT_MULTIPLY": _QUAT_MULTIPLY, "_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC, @@ -226,29 +220,29 @@ def rot_to_quat( @lru_cache(maxsize=None) -def _get_quat(quat_key, dtype, device): +def _get_quat(quat_key: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor: return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device) -def quat_multiply(quat1, quat2): +def quat_multiply(quat1: torch.Tensor, quat2: torch.Tensor) -> torch.Tensor: """Multiply a quaternion by another quaternion.""" mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device) reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2)) -def quat_multiply_by_vec(quat, vec): +def quat_multiply_by_vec(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """Multiply a quaternion by a pure-vector quaternion.""" mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device) reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2)) -def invert_rot_mat(rot_mat: torch.Tensor): +def invert_rot_mat(rot_mat: torch.Tensor) -> torch.Tensor: return rot_mat.transpose(-1, -2) -def invert_quat(quat: torch.Tensor): +def invert_quat(quat: torch.Tensor) -> torch.Tensor: quat_prime = quat.clone() quat_prime[..., 1:] *= -1 inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True) @@ -361,10 +355,7 @@ def __getitem__(self, index: Any) -> Rotation: else: raise ValueError("Both rotations are None") - def __mul__( - self, - right: torch.Tensor, - ) -> Rotation: + def __mul__(self, right: torch.Tensor) -> Rotation: """ Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation. @@ -386,10 +377,7 @@ def __mul__( else: raise ValueError("Both rotations are None") - def __rmul__( - self, - left: torch.Tensor, - ) -> Rotation: + def __rmul__(self, left: torch.Tensor) -> Rotation: """ Reverse pointwise multiplication of the rotation with a tensor. @@ -413,13 +401,12 @@ def shape(self) -> torch.Size: Returns: The virtual shape of the rotation object """ - s = None - if self._quats is not None: - s = self._quats.shape[:-1] + if self._rot_mats is not None: + return self._rot_mats.shape[:-2] + elif self._quats is not None: + return self._quats.shape[:-1] else: - s = self._rot_mats.shape[:-2] - - return s + raise ValueError("Both rotations are None") @property def dtype(self) -> torch.dtype: @@ -473,14 +460,12 @@ def get_rot_mats(self) -> torch.Tensor: Returns: The rotation as a rotation matrix tensor """ - rot_mats = self._rot_mats - if rot_mats is None: - if self._quats is None: - raise ValueError("Both rotations are None") - else: - rot_mats = quat_to_rot(self._quats) - - return rot_mats + if self._rot_mats is not None: + return self._rot_mats + elif self._quats is not None: + return quat_to_rot(self._quats) + else: + raise ValueError("Both rotations are None") def get_quats(self) -> torch.Tensor: """ @@ -491,14 +476,12 @@ def get_quats(self) -> torch.Tensor: Returns: The rotation as a quaternion tensor. """ - quats = self._quats - if quats is None: - if self._rot_mats is None: - raise ValueError("Both rotations are None") - else: - quats = rot_to_quat(self._rot_mats) - - return quats + if self._rot_mats is not None: + return rot_to_quat(self._rot_mats) + elif self._quats is not None: + return self._quats + else: + raise ValueError("Both rotations are None") def get_cur_rot(self) -> torch.Tensor: """ @@ -618,10 +601,7 @@ def invert(self) -> Rotation: # "Tensor" stuff - def unsqueeze( - self, - dim: int, - ) -> Rigid: + def unsqueeze(self, dim: int) -> Rotation: """ Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object. @@ -643,10 +623,7 @@ def unsqueeze( raise ValueError("Both rotations are None") @staticmethod - def cat( - rs: Sequence[Rotation], - dim: int, - ) -> Rigid: + def cat(rs: Sequence[Rotation], dim: int) -> Rotation: """ Concatenates rotations along one of the batch dimensions. Analogous to torch.cat(). @@ -661,12 +638,14 @@ def cat( Returns: A concatenated Rotation object in rotation matrix format """ - rot_mats = [r.get_rot_mats() for r in rs] - rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + rot_mats = torch.cat( + [r.get_rot_mats() for r in rs], + dim=dim if dim >= 0 else dim - 2, + ) return Rotation(rot_mats=rot_mats, quats=None) - def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation: + def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation: """ Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can be used e.g. to sum out a one-hot batch dimension. @@ -754,11 +733,7 @@ class Rigid: dimensions of its component parts. """ - def __init__( - self, - rots: Optional[Rotation], - trans: Optional[torch.Tensor], - ): + def __init__(self, rots: Optional[Rotation], trans: Optional[torch.Tensor]): """ Args: rots: A [*, 3, 3] rotation tensor @@ -795,6 +770,9 @@ def __init__( requires_grad, ) + assert rots is not None + assert trans is not None + if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device): raise ValueError("Rots and trans incompatible") @@ -806,7 +784,7 @@ def __init__( @staticmethod def identity( - shape: Tuple[int], + shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = True, @@ -832,10 +810,7 @@ def identity( identity_trans(shape, dtype, device, requires_grad), ) - def __getitem__( - self, - index: Any, - ) -> Rigid: + def __getitem__(self, index: Any) -> Rigid: """ Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of both the rotation and the translation. @@ -860,10 +835,7 @@ def __getitem__( self._trans[index + (slice(None),)], ) - def __mul__( - self, - right: torch.Tensor, - ) -> Rigid: + def __mul__(self, right: torch.Tensor) -> Rigid: """ Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid. @@ -881,10 +853,7 @@ def __mul__( return Rigid(new_rots, new_trans) - def __rmul__( - self, - left: torch.Tensor, - ) -> Rigid: + def __rmul__(self, left: torch.Tensor) -> Rigid: """ Reverse pointwise multiplication of the transformation with a tensor. @@ -904,8 +873,7 @@ def shape(self) -> torch.Size: Returns: The shape of the transformation """ - s = self._trans.shape[:-1] - return s + return self._trans.shape[:-1] @property def device(self) -> torch.device: @@ -935,10 +903,7 @@ def get_trans(self) -> torch.Tensor: """ return self._trans - def compose_q_update_vec( - self, - q_update_vec: torch.Tensor, - ) -> Rigid: + def compose_q_update_vec(self, q_update_vec: torch.Tensor) -> Rigid: """ Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation. @@ -956,10 +921,7 @@ def compose_q_update_vec( return Rigid(new_rots, new_translation) - def compose( - self, - r: Rigid, - ) -> Rigid: + def compose(self, r: Rigid) -> Rigid: """ Composes the current rigid object with another. @@ -973,10 +935,7 @@ def compose( new_trans = self._rots.apply(r._trans) + self._trans return Rigid(new_rot, new_trans) - def apply( - self, - pts: torch.Tensor, - ) -> torch.Tensor: + def apply(self, pts: torch.Tensor) -> torch.Tensor: """ Applies the transformation to a coordinate tensor. @@ -1012,7 +971,7 @@ def invert(self) -> Rigid: return Rigid(rot_inv, -1 * trn_inv) - def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: + def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid: """ Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the translation/rotation dimensions respectively. @@ -1074,10 +1033,7 @@ def to_tensor_7(self) -> torch.Tensor: return tensor @staticmethod - def from_tensor_7( - t: torch.Tensor, - normalize_quats: bool = False, - ) -> Rigid: + def from_tensor_7(t: torch.Tensor, normalize_quats: bool = False) -> Rigid: if t.shape[-1] != 7: raise ValueError("Incorrectly shaped input tensor") @@ -1102,18 +1058,18 @@ def from_3_points( Returns: A transformation object of shape [*] """ - p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) - origin = torch.unbind(origin, dim=-1) - p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + p_neg_x_axis_unbound = torch.unbind(p_neg_x_axis, dim=-1) + origin_unbound = torch.unbind(origin, dim=-1) + p_xy_plane_unbound = torch.unbind(p_xy_plane, dim=-1) - e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] - e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + e0 = [c1 - c2 for c1, c2 in zip(origin_unbound, p_neg_x_axis_unbound)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane_unbound, origin_unbound)] - denom = torch.sqrt(sum((c * c for c in e0)) + eps) + denom = torch.sqrt(sum(c * c for c in e0) + eps * torch.ones_like(e0[0])) e0 = [c / denom for c in e0] dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] - denom = torch.sqrt(sum((c * c for c in e1)) + eps) + denom = torch.sqrt(sum((c * c for c in e1)) + eps * torch.ones_like(e1[0])) e1 = [c / denom for c in e1] e2 = [ e0[1] * e1[2] - e0[2] * e1[1], @@ -1126,12 +1082,9 @@ def from_3_points( rot_obj = Rotation(rot_mats=rots, quats=None) - return Rigid(rot_obj, torch.stack(origin, dim=-1)) + return Rigid(rot_obj, torch.stack(origin_unbound, dim=-1)) - def unsqueeze( - self, - dim: int, - ) -> Rigid: + def unsqueeze(self, dim: int) -> Rigid: """ Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation. @@ -1148,10 +1101,7 @@ def unsqueeze( return Rigid(rots, trans) @staticmethod - def cat( - ts: Sequence[Rigid], - dim: int, - ) -> Rigid: + def cat(ts: Sequence[Rigid], dim: int) -> Rigid: """ Concatenates transformations along a new dimension. @@ -1168,7 +1118,7 @@ def cat( return Rigid(rots, trans) - def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid: + def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Rigid: """ Applies a Rotation -> Rotation function to the stored rotation object. @@ -1179,7 +1129,7 @@ def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid: """ return Rigid(fn(self._rots), self._trans) - def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: + def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid: """ Applies a Tensor -> Tensor function to the stored translation. @@ -1213,7 +1163,9 @@ def stop_rot_gradient(self) -> Rigid: return self.apply_rot_fn(lambda r: r.detach()) @staticmethod - def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + def make_transform_from_reference( + n_xyz: torch.Tensor, ca_xyz: torch.Tensor, c_xyz: torch.Tensor, eps: float = 1e-20 + ) -> Rigid: """ Returns a transformation object from reference coordinates. diff --git a/src/transformers/models/esm/openfold_utils/tensor_utils.py b/src/transformers/models/esm/openfold_utils/tensor_utils.py index 60e8b3f21466..99dd6dbe47b6 100644 --- a/src/transformers/models/esm/openfold_utils/tensor_utils.py +++ b/src/transformers/models/esm/openfold_utils/tensor_utils.py @@ -14,13 +14,14 @@ # limitations under the License. from functools import partial -from typing import List +from typing import Any, Callable, Dict, List, Type, TypeVar, Union, overload import torch import torch.nn as nn +import torch.types -def add(m1, m2, inplace): +def add(m1: torch.Tensor, m2: torch.Tensor, inplace: bool) -> torch.Tensor: # The first operation in a checkpoint can't be in-place, but it's # nice to have in-place addition during inference. Thus... if not inplace: @@ -31,33 +32,35 @@ def add(m1, m2, inplace): return m1 -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): +def permute_final_dims(tensor: torch.Tensor, inds: List[int]) -> torch.Tensor: zero_index = -1 * len(inds) first_inds = list(range(len(tensor.shape[:zero_index]))) return tensor.permute(first_inds + [zero_index + i for i in inds]) -def flatten_final_dims(t: torch.Tensor, no_dims: int): +def flatten_final_dims(t: torch.Tensor, no_dims: int) -> torch.Tensor: return t.reshape(t.shape[:-no_dims] + (-1,)) -def masked_mean(mask, value, dim, eps=1e-4): +def masked_mean(mask: torch.Tensor, value: torch.Tensor, dim: int, eps: float = 1e-4) -> torch.Tensor: mask = mask.expand(*value.shape) return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) -def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): +def pts_to_distogram( + pts: torch.Tensor, min_bin: torch.types.Number = 2.3125, max_bin: torch.types.Number = 21.6875, no_bins: int = 64 +) -> torch.Tensor: boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device) dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)) return torch.bucketize(dists, boundaries) -def dict_multimap(fn, dicts): +def dict_multimap(fn: Callable[[list], Any], dicts: List[dict]) -> dict: first = dicts[0] new_dict = {} for k, v in first.items(): all_v = [d[k] for d in dicts] - if type(v) is dict: + if isinstance(v, dict): new_dict[k] = dict_multimap(fn, all_v) else: new_dict[k] = fn(all_v) @@ -65,21 +68,21 @@ def dict_multimap(fn, dicts): return new_dict -def one_hot(x, v_bins): +def one_hot(x: torch.Tensor, v_bins: torch.Tensor) -> torch.Tensor: reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) diffs = x[..., None] - reshaped_bins am = torch.argmin(torch.abs(diffs), dim=-1) return nn.functional.one_hot(am, num_classes=len(v_bins)).float() -def batched_gather(data, inds, dim=0, no_batch_dims=0): - ranges = [] +def batched_gather(data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0) -> torch.Tensor: + ranges: List[Union[slice, torch.Tensor]] = [] for i, s in enumerate(data.shape[:no_batch_dims]): r = torch.arange(s) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) ranges.append(r) - remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] + remaining_dims: List[Union[slice, torch.Tensor]] = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds ranges.extend(remaining_dims) # Matt note: Editing this to get around the behaviour of using a list as an array index changing @@ -87,11 +90,16 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0): return data[tuple(ranges)] +T = TypeVar("T") + + # With tree_map, a poor man's JAX tree_map -def dict_map(fn, dic, leaf_type): - new_dict = {} +def dict_map( + fn: Callable[[T], Any], dic: Dict[Any, Union[dict, list, tuple, T]], leaf_type: Type[T] +) -> Dict[Any, Union[dict, list, tuple, Any]]: + new_dict: Dict[Any, Union[dict, list, tuple, Any]] = {} for k, v in dic.items(): - if type(v) is dict: + if isinstance(v, dict): new_dict[k] = dict_map(fn, v, leaf_type) else: new_dict[k] = tree_map(fn, v, leaf_type) @@ -99,13 +107,33 @@ def dict_map(fn, dic, leaf_type): return new_dict +@overload +def tree_map(fn: Callable[[T], Any], tree: T, leaf_type: Type[T]) -> Any: + ... + + +@overload +def tree_map(fn: Callable[[T], Any], tree: dict, leaf_type: Type[T]) -> dict: + ... + + +@overload +def tree_map(fn: Callable[[T], Any], tree: list, leaf_type: Type[T]) -> list: + ... + + +@overload +def tree_map(fn: Callable[[T], Any], tree: tuple, leaf_type: Type[T]) -> tuple: + ... + + def tree_map(fn, tree, leaf_type): if isinstance(tree, dict): return dict_map(fn, tree, leaf_type) elif isinstance(tree, list): return [tree_map(fn, x, leaf_type) for x in tree] elif isinstance(tree, tuple): - return tuple([tree_map(fn, x, leaf_type) for x in tree]) + return tuple(tree_map(fn, x, leaf_type) for x in tree) elif isinstance(tree, leaf_type): return fn(tree) else: