diff --git a/.gitignore b/.gitignore index fcc5213..cd52057 100644 --- a/.gitignore +++ b/.gitignore @@ -137,4 +137,9 @@ dmypy.json .pyre/ -/wandb/ \ No newline at end of file +/wandb/ +scripts/rgn2_models/ +scripts/sidechainnet_data +scripts/wandb + +.idea/ diff --git a/rgn2_play.ipynb b/rgn2_play.ipynb new file mode 100644 index 0000000..d29fe67 --- /dev/null +++ b/rgn2_play.ipynb @@ -0,0 +1,602 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "rgn2_play.ipynb", + "provenance": [], + "collapsed_sections": [], + "machine_shape": "hm", + "authorship_tag": "ABX9TyMCCozHZ4MhRpJY40idz1IE", + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9Y-jmJk5S5sG" + }, + "source": [ + "* How would coevolution implicitly affect the language model training?\n", + "* Would kNN-style unsupervised learning useful for RGN2?" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "id": "bisOhBV5FBxp", + "outputId": "9a869ace-a3ff-4b40-9d8d-46b80077c364" + }, + "source": [ + "import IPython\n", + "from google.colab import output\n", + "\n", + "display(IPython.display.Javascript('''\n", + " function ClickConnect(){\n", + " btn = document.querySelector(\"colab-connect-button\")\n", + " if (btn != null){\n", + " console.log(\"Click colab-connect-button\"); \n", + " btn.click() \n", + " }\n", + " \n", + " btn = document.getElementById('ok')\n", + " if (btn != null){\n", + " console.log(\"Click reconnect\"); \n", + " btn.click() \n", + " }\n", + " }\n", + " \n", + "setInterval(ClickConnect,60000)\n", + "'''))\n", + "\n", + "print(\"Done.\")" + ], + "execution_count": 35, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "\n", + " function ClickConnect(){\n", + " btn = document.querySelector(\"colab-connect-button\")\n", + " if (btn != null){\n", + " console.log(\"Click colab-connect-button\"); \n", + " btn.click() \n", + " }\n", + " \n", + " btn = document.getElementById('ok')\n", + " if (btn != null){\n", + " console.log(\"Click reconnect\"); \n", + " btn.click() \n", + " }\n", + " }\n", + " \n", + "setInterval(ClickConnect,60000)\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Done.\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "SO2-ZpV7SuFi" + }, + "source": [ + "!git clone https://github.com/hushuangwei/rgn2-replica.git" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "qS7DR2x_9kV_" + }, + "source": [ + "# https://blog.csdn.net/NEUdeep/article/details/115724826\n", + "!export PYTHONWARNINGS='ignore:semaphore_tracker:UserWarning'" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "r4w9hFGdUZmc" + }, + "source": [ + "!pip install wandb sidechainnet einops proDy tqdm datasets transformers x-transformers pytorch-lightning fair-esm En-transformer pytorch3d invariant_point_attention" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CYVrT9vEBthE", + "outputId": "80e40f7b-ad1c-44a5-eab9-95abb1049203" + }, + "source": [ + "%cd rgn2-replica/scripts/" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/rgn2-replica/scripts\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h6kBnPjeMxMw" + }, + "source": [ + "One can skip this google drive setting and download sidechainnet data directly but more slowly" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rUVr-8p-at6T", + "outputId": "d11f57da-bae4-44a9-a6e8-6c0050e2a9f6" + }, + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/drive\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "iKeXyuo6Z4-o" + }, + "source": [ + "!mkdir -p /content/rgn2-replica/scripts/sidechainnet_data/\n", + "!cp /content/drive/MyDrive/protein/sidechainnet_casp12_90.pkl /content/rgn2-replica/scripts/sidechainnet_data/" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "t0JBGC4vHYe-", + "outputId": "eb9e3478-925d-4402-fa82-4a136d5da534" + }, + "source": [ + "!wandb login" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter: \n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Y8z-qNqGTlVh" + }, + "source": [ + "!nohup python train_rgn2.py --device cuda:0 --wb_entity hushuangwei \\\n", + " --wb_proj rgn2_replica --run_name RGN2_ipa_1e-4 \\\n", + " --min_len_valid 50 --casp_version 12 --scn_thinning 90 \\\n", + " --min_len 0 --max_len 384 --input_dropout 0.1 --num_layers 6 \\\n", + " --bidirectional 1 --layer_type LSTM --act aconc --num_recycles_train 8 \\\n", + " --refiner_args \"{\\\"refiner_type\\\": \\\"IPA\\\"}\" \\\n", + " > RGN2X_vanillaLSTM_full_run_logs.txt 2>&1 &" + ], + "execution_count": 33, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Vh2vtuZ-Uc72", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b9cb8a21-b6ca-47fb-c657-dd7b676d5242" + }, + "source": [ + "!tail -f RGN2X_vanillaLSTM_full_run_logs.txt" + ], + "execution_count": 32, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "wandb: loss 0.08731\n", + "wandb: rmsd 9.9361\n", + "wandb: torsion_loss 1.76916\n", + "wandb: viol_loss 0.00342\n", + "wandb: \n", + "wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n", + "wandb: Synced RGN2_ipa_1e-4: https://wandb.ai/hushuangwei/rgn2_replica/runs/372ib897\n", + "wandb: Find logs at: ./wandb/run-20211010_110523-372ib897/logs/debug.log\n", + "wandb: \n", + "\n", + "^C\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_tNr33PMAsiy" + }, + "source": [ + "!pkill -f train_rgn2.py" + ], + "execution_count": 27, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KGrjhdZB0crD", + "outputId": "5088c897-2ef1-48d8-cb01-87926a6a9d8d" + }, + "source": [ + "!ps aux | grep train_rgn2.py" + ], + "execution_count": 61, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "root 2285 0.0 0.0 39200 6252 ? S 12:47 0:00 /bin/bash -c ps aux | grep train_rgn2.py\n", + "root 2287 0.0 0.0 38572 5180 ? S 12:47 0:00 grep train_rgn2.py\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Qa2fwnskPqRD", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "71e4390e-08bd-4348-d82c-cb45cba91408" + }, + "source": [ + "!nvidia-smi" + ], + "execution_count": 53, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Tue Oct 5 08:13:04 2021 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 470.74 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 53C P0 103W / 300W | 11293MiB / 16160MiB | 59% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jlgXayRhMa0v" + }, + "source": [ + "# below is unfinished" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "UWV02qQUNvFO" + }, + "source": [ + "import sidechainnet as scn\n", + "import random \n", + "\n", + "import sys\n", + "sys.path.append(\"..\")\n", + "\n", + "import torch\n", + "import py3Dmol\n", + "import esm\n", + "\n", + "from rgn2_replica import *\n", + "from rgn2_replica.rgn2 import *\n", + "from rgn2_replica.rgn2_utils import *\n", + "from rgn2_replica.rgn2_trainers import *" + ], + "execution_count": 56, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "o04uJC6CVZFi" + }, + "source": [ + "from sidechainnet.utils.sequence import ProteinVocabulary as VOCAB\n", + "VOCAB = VOCAB()" + ], + "execution_count": 78, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "q-y5nsoxwZSM", + "outputId": "67a7536a-a23d-4f37-9492-43d1c6f7d7ae" + }, + "source": [ + "set_seed(42)\n", + "dataloaders = scn.load(casp_version=12, thinning=90, with_pytorch=\"dataloaders\", \n", + " batch_size=1, dynamic_batching=False)" + ], + "execution_count": 81, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp12_90.pkl.\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "OfTlIEmgK1C4" + }, + "source": [ + "# if torch.cuda.is_available():\n", + "# device = torch.device(\"cuda\")\n", + "# else:\n", + "# device = torch.device(\"cpu\")\n", + "device = \"cpu\"\n", + "model = RGN2_IPA(embedding_dim=1284).to(device)" + ], + "execution_count": 82, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YkgieANeMKe_", + "outputId": "42073236-ac77-4284-9fcc-6d00d21fc002" + }, + "source": [ + "save_path = \"/content/rgn2-replica/scripts/rgn2_models/RGN2_ipa_1e-4@_32K.pt\"\n", + "model.load_state_dict(torch.load(save_path))\n", + "sum([p.numel() for p in model.parameters()])" + ], + "execution_count": 83, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "34216661" + ] + }, + "metadata": {}, + "execution_count": 83 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "HgJDm6ohv0FO" + }, + "source": [ + "dataloaders[\"train\"].dataset\n", + "MIN_LEN_TEST = 70\n", + "MIN_LEN = 0\n", + "MAX_LEN = 512" + ], + "execution_count": 84, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "SjWn3Q3vOrtt" + }, + "source": [ + "embedder, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()\n", + "batch_converter = alphabet.get_batch_converter()" + ], + "execution_count": 85, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "wD28yPldv-xi" + }, + "source": [ + "embedder = embedder.to(device)" + ], + "execution_count": 86, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 398 + }, + "id": "0fICN-VLUwEq", + "outputId": "d91ecc41-92e2-4b99-d3b6-04aca1bb2b40" + }, + "source": [ + "### TEST\n", + "tic = time.time()\n", + "get_prot_test_ = mp_nerf.utils.get_prot( \n", + " dataloader_=dataloaders, \n", + " vocab_=VOCAB, # mp_nerf.utils.\n", + " min_len=MIN_LEN, max_len=MAX_LEN, \n", + " verbose=False, subset=\"test\"\n", + ")\n", + "# get num of unique, full-masked proteins\n", + "seqs = []\n", + "for i, prot_args in enumerate(dataloaders[\"test\"].dataset):\n", + " # (id, int_seq, mask, ... , str_seq)\n", + " length = len(prot_args[-1]) \n", + " if 0 < length < MAX_LEN and sum( prot_args[2] ) == length:\n", + " seqs.append( prot_args[-1] )\n", + "\n", + "metrics_stuff_test = predict(\n", + " get_prot_= get_prot_test_, \n", + " steps = len(set(seqs)), # 24 for MIN_LEN=70\n", + " model = model,\n", + " embedder = embedder, \n", + " return_preds = True,\n", + " log_every = 4,\n", + " accumulate_every = len(set(seqs)),\n", + " seed = 42, # 42\n", + " mode = \"fast_test\", # \"test\" # \"test\" is for AR, \"fast_test\" is for iterative\n", + " recycle_func = lambda x: 1, # 5 # 3 # 2 \n", + " wandbai = False,\n", + ")\n", + "preds_list_test, metrics_list_test, metrics_stats_test = metrics_stuff_test\n", + "print(\"\\n\", \"Test Results:\", sep=\"\")\n", + "for k,v in metrics_stats_test.items():\n", + " offset = \" \" * ( max(len(ki) for ki in metrics_stats_test.keys()) - len(k) )\n", + " print(k + offset, \":\", v)\n", + "print(\"\\n\")\n", + "print(\"Time taken: \", time.time()-tic, \"\\n\")" + ], + "execution_count": 90, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([3, 246, 2])\n" + ] + }, + { + "output_type": "error", + "ename": "TypeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"fast_test\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# \"test\" # \"test\" is for AR, \"fast_test\" is for iterative\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mrecycle_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# 5 # 3 # 2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0mwandbai\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m )\n\u001b[1;32m 31\u001b[0m \u001b[0mpreds_list_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics_list_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics_stats_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmetrics_stuff_test\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/content/rgn2-replica/rgn2_replica/rgn2_trainers.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(get_prot_, steps, model, embedder, return_preds, accumulate_every, log_every, seed, wandbai, recycle_func, mode)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprots\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membedder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0membedder\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 335\u001b[0;31m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecycle_func\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrecycle_func\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 336\u001b[0m )\n\u001b[1;32m 337\u001b[0m \u001b[0;31m# calculate metrics || calc loss terms || baselines for next-term: torsion=2, fape=0.95\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/content/rgn2-replica/rgn2_replica/rgn2_trainers.py\u001b[0m in \u001b[0;36mbatched_inference\u001b[0;34m(model, embedder, mode, device, recycle_func, config, *args)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# don't pass angles info - just 0 at start (sin=0, cos=1)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mangles_input\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mangles_input\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m ], dim=-1)\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: expected Tensor as element 0 in argument 0, but got dict" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qovzXcFDPnvC" + }, + "source": [ + "" + ] + } + ] +} \ No newline at end of file diff --git a/rgn2_replica/embedders.py b/rgn2_replica/embedders.py index c4443f6..295d0a6 100644 --- a/rgn2_replica/embedders.py +++ b/rgn2_replica/embedders.py @@ -31,8 +31,8 @@ def forward(self, aa_seq): torch.Tensor (B, L) according to MP-NeRF encoding """ # format - if isinstance(aa_seqs, torch.Tensor): - aa_seq = ids_to_embed_input(to_cpu(aa_seqs).tolist()) + if isinstance(aa_seq, torch.Tensor): + aa_seq = ids_to_embed_input(to_cpu(aa_seq).tolist()) with torch.no_grad(): tokenized_seq = self.tokenizer(aa_seq, context_length=len(aa_seq), return_mask=False) diff --git a/rgn2_replica/mp_nerf/LICENSE b/rgn2_replica/mp_nerf/LICENSE new file mode 100644 index 0000000..1132c9a --- /dev/null +++ b/rgn2_replica/mp_nerf/LICENSE @@ -0,0 +1,421 @@ + +Copyright (c) 2021, Eric Alcaide +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following license. + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + license in the documentation and or other materials provided + with the distribution. + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + +Attribution-NonCommercial-NoDerivatives 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 +International Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-NoDerivatives 4.0 International Public +License ("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + c. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + d. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + e. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + f. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + g. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + h. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + i. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + j. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + k. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce and reproduce, but not Share, Adapted Material + for NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material, You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + For the avoidance of doubt, You do not have permission under + this Public License to Share Adapted Material. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only and provided You do not Share Adapted Material; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/rgn2_replica/mp_nerf/__init__.py b/rgn2_replica/mp_nerf/__init__.py new file mode 100644 index 0000000..2fd09c0 --- /dev/null +++ b/rgn2_replica/mp_nerf/__init__.py @@ -0,0 +1,5 @@ +# from rgn2_replica.mp_nerf import * +# from rgn2_replica.mp_nerf import * +import rgn2_replica.mp_nerf.utils +import rgn2_replica.mp_nerf.proteins +import rgn2_replica.mp_nerf.ml_utils \ No newline at end of file diff --git a/rgn2_replica/mp_nerf/kb_proteins.py b/rgn2_replica/mp_nerf/kb_proteins.py new file mode 100644 index 0000000..6d5e17f --- /dev/null +++ b/rgn2_replica/mp_nerf/kb_proteins.py @@ -0,0 +1,846 @@ +# Author: Eric Alcaide + +# A substantial part has been borrowed from +# https://github.com/jonathanking/sidechainnet +# +# Here's the License for it: +# +# Copyright 2020 Jonathan King +# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +# following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +# products derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +# THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np + +######################### +### FROM SIDECHAINNET ### +######################### + +# modified by considering rigid bodies in sidechains (remove extra torsions) + +SC_BUILD_INFO = { + 'A': { + 'angles-names': ['N-CA-CB'], + 'angles-types': ['N -CX-CT'], + 'angles-vals': [1.9146261894377796], + 'atom-names': ['CB'], + 'bonds-names': ['CA-CB'], + 'bonds-types': ['CX-CT'], + 'bonds-vals': [1.526], + 'torsion-names': ['C-N-CA-CB'], + 'torsion-types': ['C -N -CX-CT'], + 'torsion-vals': ['p'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4]], + }, + + 'R': { + 'angles-names': [ + 'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-NE', 'CD-NE-CZ', 'NE-CZ-NH1', + 'NE-CZ-NH2' + ], + 'angles-types': [ + 'N -CX-C8', 'CX-C8-C8', 'C8-C8-C8', 'C8-C8-N2', 'C8-N2-CA', 'N2-CA-N2', + 'N2-CA-N2' + ], + 'angles-vals': [ + 1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.9408061282176945, + 2.150245638457014, 2.0943951023931953, 2.0943951023931953 + ], + 'atom-names': ['CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-NE', 'NE-CZ', 'CZ-NH1', 'CZ-NH2'], + 'bonds-types': ['CX-C8', 'C8-C8', 'C8-C8', 'C8-N2', 'N2-CA', 'CA-N2', 'CA-N2'], + 'bonds-vals': [1.526, 1.526, 1.526, 1.463, 1.34, 1.34, 1.34], + 'torsion-names': [ + 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-NE', 'CG-CD-NE-CZ', + 'CD-NE-CZ-NH1', 'CD-NE-CZ-NH2' + ], + 'torsion-types': [ + 'C -N -CX-C8', 'N -CX-C8-C8', 'CX-C8-C8-C8', 'C8-C8-C8-N2', 'C8-C8-N2-CA', + 'C8-N2-CA-N2', 'C8-N2-CA-N2' + ], + 'torsion-vals': ['p', 'p', 'p', 'p', 'p', 0., 3.141592], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7], [6,7,8]], + }, + + 'N': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-OD1', 'CB-CG-ND2'], + 'angles-types': ['N -CX-2C', 'CX-2C-C ', '2C-C -O ', '2C-C -N '], + 'angles-vals': [ + 1.9146261894377796, 1.9390607989657, 2.101376419401173, 2.035053907825388 + ], + 'atom-names': ['CB', 'CG', 'OD1', 'ND2'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-OD1', 'CG-ND2'], + 'bonds-types': ['CX-2C', '2C-C ', 'C -O ', 'C -N '], + 'bonds-vals': [1.526, 1.522, 1.229, 1.335], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-OD1', 'CA-CB-CG-ND2'], + 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-C ', 'CX-2C-C -O ', 'CX-2C-C -N '], + 'torsion-vals': ['p', 'p', 'p', 'i'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]], + }, + + 'D': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-OD1', 'CB-CG-OD2'], + 'angles-types': ['N -CX-2C', 'CX-2C-CO', '2C-CO-O2', '2C-CO-O2'], + 'angles-vals': [ + 1.9146261894377796, 1.9390607989657, 2.0420352248333655, 2.0420352248333655 + ], + 'atom-names': ['CB', 'CG', 'OD1', 'OD2'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-OD1', 'CG-OD2'], + 'bonds-types': ['CX-2C', '2C-CO', 'CO-O2', 'CO-O2'], + 'bonds-vals': [1.526, 1.522, 1.25, 1.25], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-OD1', 'CA-CB-CG-OD2'], + 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-CO', 'CX-2C-CO-O2', 'CX-2C-CO-O2'], + 'torsion-vals': ['p', 'p', 'p', 'i'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]], + }, + + 'C': { + 'angles-names': ['N-CA-CB', 'CA-CB-SG'], + 'angles-types': ['N -CX-2C', 'CX-2C-SH'], + 'angles-vals': [1.9146261894377796, 1.8954275676658419], + 'atom-names': ['CB', 'SG'], + 'bonds-names': ['CA-CB', 'CB-SG'], + 'bonds-types': ['CX-2C', '2C-SH'], + 'bonds-vals': [1.526, 1.81], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-SG'], + 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-SH'], + 'torsion-vals': ['p', 'p'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5]], + }, + + 'Q': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-OE1', 'CG-CD-NE2'], + 'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-C ', '2C-C -O ', '2C-C -N '], + 'angles-vals': [ + 1.9146261894377796, 1.911135530933791, 1.9390607989657, 2.101376419401173, + 2.035053907825388 + ], + 'atom-names': ['CB', 'CG', 'CD', 'OE1', 'NE2'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-OE1', 'CD-NE2'], + 'bonds-types': ['CX-2C', '2C-2C', '2C-C ', 'C -O ', 'C -N '], + 'bonds-vals': [1.526, 1.526, 1.522, 1.229, 1.335], + 'torsion-names': [ + 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-OE1', 'CB-CG-CD-NE2' + ], + 'torsion-types': [ + 'C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-C ', '2C-2C-C -O ', '2C-2C-C -N ' + ], + 'torsion-vals': ['p', 'p', 'p', 'p', 'i'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7]], + }, + + 'E': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-OE1', 'CG-CD-OE2'], + 'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-CO', '2C-CO-O2', '2C-CO-O2'], + 'angles-vals': [ + 1.9146261894377796, 1.911135530933791, 1.9390607989657, 2.0420352248333655, + 2.0420352248333655 + ], + 'atom-names': ['CB', 'CG', 'CD', 'OE1', 'OE2'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-OE1', 'CD-OE2'], + 'bonds-types': ['CX-2C', '2C-2C', '2C-CO', 'CO-O2', 'CO-O2'], + 'bonds-vals': [1.526, 1.526, 1.522, 1.25, 1.25], + 'torsion-names': [ + 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-OE1', 'CB-CG-CD-OE2' + ], + 'torsion-types': [ + 'C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-CO', '2C-2C-CO-O2', '2C-2C-CO-O2' + ], + 'torsion-vals': ['p', 'p', 'p', 'p', 'i'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7]], + }, + + 'G': { + 'angles-names': [], + 'angles-types': [], + 'angles-vals': [], + 'atom-names': [], + 'bonds-names': [], + 'bonds-types': [], + 'bonds-vals': [], + 'torsion-names': [], + 'torsion-types': [], + 'torsion-vals': [], + 'rigid-frames-idxs': [[0,1,2]], + }, + + 'H': { + 'angles-names': [ + 'N-CA-CB', 'CA-CB-CG', 'CB-CG-ND1', 'CG-ND1-CE1', 'ND1-CE1-NE2', 'CE1-NE2-CD2' + ], + 'angles-types': [ + 'N -CX-CT', 'CX-CT-CC', 'CT-CC-NA', 'CC-NA-CR', 'NA-CR-NB', 'CR-NB-CV' + ], + 'angles-vals': [ + 1.9146261894377796, 1.9739673840055867, 2.0943951023931953, + 1.8849555921538759, 1.8849555921538759, 1.8849555921538759 + ], + 'atom-names': ['CB', 'CG', 'ND1', 'CE1', 'NE2', 'CD2'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-ND1', 'ND1-CE1', 'CE1-NE2', 'NE2-CD2'], + 'bonds-types': ['CX-CT', 'CT-CC', 'CC-NA', 'NA-CR', 'CR-NB', 'NB-CV'], + 'bonds-vals': [1.526, 1.504, 1.385, 1.343, 1.335, 1.394], + 'torsion-names': [ + 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-ND1', 'CB-CG-ND1-CE1', 'CG-ND1-CE1-NE2', + 'ND1-CE1-NE2-CD2' + ], + 'torsion-types': [ + 'C -N -CX-CT', 'N -CX-CT-CC', 'CX-CT-CC-NA', 'CT-CC-NA-CR', 'CC-NA-CR-NB', + 'NA-CR-NB-CV' + ], + 'torsion-vals': ['p', 'p', 'p', 3.141592653589793, 0.0, 0.0], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]], + }, + + 'I': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG1', 'CB-CG1-CD1', 'CA-CB-CG2'], + 'angles-types': ['N -CX-3C', 'CX-3C-2C', '3C-2C-CT', 'CX-3C-CT'], + 'angles-vals': [ + 1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791 + ], + 'atom-names': ['CB', 'CG1', 'CD1', 'CG2'], + 'bonds-names': ['CA-CB', 'CB-CG1', 'CG1-CD1', 'CB-CG2'], + 'bonds-types': ['CX-3C', '3C-2C', '2C-CT', '3C-CT'], + 'bonds-vals': [1.526, 1.526, 1.526, 1.526], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG1', 'CA-CB-CG1-CD1', 'N-CA-CB-CG2'], + 'torsion-types': ['C -N -CX-3C', 'N -CX-3C-2C', 'CX-3C-2C-CT', 'N -CX-3C-CT'], + 'torsion-vals': ['p', 'p', 'p', -2.1315], # last one was 'p' in the original - but cg1-cg2 = "2.133" + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,7]], + }, + + 'L': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CB-CG-CD2'], + 'angles-types': ['N -CX-2C', 'CX-2C-3C', '2C-3C-CT', '2C-3C-CT'], + 'angles-vals': [ + 1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791 + ], + 'atom-names': ['CB', 'CG', 'CD1', 'CD2'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD1', 'CG-CD2'], + 'bonds-types': ['CX-2C', '2C-3C', '3C-CT', '3C-CT'], + 'bonds-vals': [1.526, 1.526, 1.526, 1.526], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CA-CB-CG-CD2'], + 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-3C', 'CX-2C-3C-CT', 'CX-2C-3C-CT'], + # extra torsion is in negative bc in mask construction, previous angle is summed. + 'torsion-vals': ['p', 'p', 'p', 2.1315], # last one was 'p' in the original - but cd1-cd2 = "-2.130" + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]], + }, + + 'K': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD', 'CG-CD-CE', 'CD-CE-NZ'], + 'angles-types': ['N -CX-C8', 'CX-C8-C8', 'C8-C8-C8', 'C8-C8-C8', 'C8-C8-N3'], + 'angles-vals': [ + 1.9146261894377796, 1.911135530933791, 1.911135530933791, 1.911135530933791, + 1.9408061282176945 + ], + 'atom-names': ['CB', 'CG', 'CD', 'CE', 'NZ'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD', 'CD-CE', 'CE-NZ'], + 'bonds-types': ['CX-C8', 'C8-C8', 'C8-C8', 'C8-C8', 'C8-N3'], + 'bonds-vals': [1.526, 1.526, 1.526, 1.526, 1.471], + 'torsion-names': [ + 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD', 'CB-CG-CD-CE', 'CG-CD-CE-NZ' + ], + 'torsion-types': [ + 'C -N -CX-C8', 'N -CX-C8-C8', 'CX-C8-C8-C8', 'C8-C8-C8-C8', 'C8-C8-C8-N3' + ], + 'torsion-vals': ['p', 'p', 'p', 'p', 'p'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7], [6,7,8]], + }, + + 'M': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-SD', 'CG-SD-CE'], + 'angles-types': ['N -CX-2C', 'CX-2C-2C', '2C-2C-S ', '2C-S -CT'], + 'angles-vals': [ + 1.9146261894377796, 1.911135530933791, 2.0018926520374962, 1.726130630222392 + ], + 'atom-names': ['CB', 'CG', 'SD', 'CE'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-SD', 'SD-CE'], + 'bonds-types': ['CX-2C', '2C-2C', '2C-S ', 'S -CT'], + 'bonds-vals': [1.526, 1.526, 1.81, 1.81], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-SD', 'CB-CG-SD-CE'], + 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-2C', 'CX-2C-2C-S ', '2C-2C-S -CT'], + 'torsion-vals': ['p', 'p', 'p', 'p'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6], [5,6,7]], + }, + + 'F': { + 'angles-names': [ + 'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-CE1', 'CD1-CE1-CZ', 'CE1-CZ-CE2', + 'CZ-CE2-CD2' + ], + 'angles-types': [ + 'N -CX-CT', 'CX-CT-CA', 'CT-CA-CA', 'CA-CA-CA', 'CA-CA-CA', 'CA-CA-CA', + 'CA-CA-CA' + ], + 'angles-vals': [ + 1.9146261894377796, 1.9896753472735358, 2.0943951023931953, + 2.0943951023931953, 2.0943951023931953, 2.0943951023931953, 2.0943951023931953 + ], + 'atom-names': ['CB', 'CG', 'CD1', 'CE1', 'CZ', 'CE2', 'CD2'], + 'bonds-names': [ + 'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-CE1', 'CE1-CZ', 'CZ-CE2', 'CE2-CD2' + ], + 'bonds-types': ['CX-CT', 'CT-CA', 'CA-CA', 'CA-CA', 'CA-CA', 'CA-CA', 'CA-CA'], + 'bonds-vals': [1.526, 1.51, 1.4, 1.4, 1.4, 1.4, 1.4], + 'torsion-names': [ + 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-CE1', 'CG-CD1-CE1-CZ', + 'CD1-CE1-CZ-CE2', 'CE1-CZ-CE2-CD2' + ], + 'torsion-types': [ + 'C -N -CX-CT', 'N -CX-CT-CA', 'CX-CT-CA-CA', 'CT-CA-CA-CA', 'CA-CA-CA-CA', + 'CA-CA-CA-CA', 'CA-CA-CA-CA' + ], + 'torsion-vals': ['p', 'p', 'p', 3.141592653589793, 0.0, 0.0, 0.0], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]], + }, + + 'P': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG', 'CB-CG-CD'], + 'angles-types': ['N -CX-CT', 'CX-CT-CT', 'CT-CT-CT'], + 'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791], + 'atom-names': ['CB', 'CG', 'CD'], + 'bonds-names': ['CA-CB', 'CB-CG', 'CG-CD'], + 'bonds-types': ['CX-CT', 'CT-CT', 'CT-CT'], + 'bonds-vals': [1.526, 1.526, 1.526], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD'], + 'torsion-types': ['C -N -CX-CT', 'N -CX-CT-CT', 'CX-CT-CT-CT'], + 'torsion-vals': ['p', 'p', 'p'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]], + }, + + 'S': { + 'angles-names': ['N-CA-CB', 'CA-CB-OG'], + 'angles-types': ['N -CX-2C', 'CX-2C-OH'], + 'angles-vals': [1.9146261894377796, 1.911135530933791], + 'atom-names': ['CB', 'OG'], + 'bonds-names': ['CA-CB', 'CB-OG'], + 'bonds-types': ['CX-2C', '2C-OH'], + 'bonds-vals': [1.526, 1.41], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-OG'], + 'torsion-types': ['C -N -CX-2C', 'N -CX-2C-OH'], + 'torsion-vals': ['p', 'p'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5]], + }, + + 'T': { + 'angles-names': ['N-CA-CB', 'CA-CB-OG1', 'CA-CB-CG2'], + 'angles-types': ['N -CX-3C', 'CX-3C-OH', 'CX-3C-CT'], + 'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791], + 'atom-names': ['CB', 'OG1', 'CG2'], + 'bonds-names': ['CA-CB', 'CB-OG1', 'CB-CG2'], + 'bonds-types': ['CX-3C', '3C-OH', '3C-CT'], + 'bonds-vals': [1.526, 1.41, 1.526], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-OG1', 'N-CA-CB-CG2'], + 'torsion-types': ['C -N -CX-3C', 'N -CX-3C-OH', 'N -CX-3C-CT'], + 'torsion-vals': ['p', 'p', 'p'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5]], + }, + + 'W': { + 'angles-names': [ + 'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-NE1', 'CD1-NE1-CE2', + 'NE1-CE2-CZ2', 'CE2-CZ2-CH2', 'CZ2-CH2-CZ3', 'CH2-CZ3-CE3', 'CZ3-CE3-CD2' + ], + 'angles-types': [ + 'N -CX-CT', 'CX-CT-C*', 'CT-C*-CW', 'C*-CW-NA', 'CW-NA-CN', 'NA-CN-CA', + 'CN-CA-CA', 'CA-CA-CA', 'CA-CA-CA', 'CA-CA-CB' + ], + 'angles-vals': [ + 1.9146261894377796, 2.0176006153054447, 2.181661564992912, 1.8971728969178363, + 1.9477874452256716, 2.3177972466484698, 2.0943951023931953, + 2.0943951023931953, 2.0943951023931953, 2.0943951023931953 + ], + 'atom-names': [ + 'CB', 'CG', 'CD1', 'NE1', 'CE2', 'CZ2', 'CH2', 'CZ3', 'CE3', 'CD2' + ], + 'bonds-names': [ + 'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-NE1', 'NE1-CE2', 'CE2-CZ2', 'CZ2-CH2', + 'CH2-CZ3', 'CZ3-CE3', 'CE3-CD2' + ], + 'bonds-types': [ + 'CX-CT', 'CT-C*', 'C*-CW', 'CW-NA', 'NA-CN', 'CN-CA', 'CA-CA', 'CA-CA', + 'CA-CA', 'CA-CB' + ], + 'bonds-vals': [1.526, 1.495, 1.352, 1.381, 1.38, 1.4, 1.4, 1.4, 1.4, 1.404], + 'torsion-names': [ + 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-NE1', 'CG-CD1-NE1-CE2', + 'CD1-NE1-CE2-CZ2', 'NE1-CE2-CZ2-CH2', 'CE2-CZ2-CH2-CZ3', 'CZ2-CH2-CZ3-CE3', + 'CH2-CZ3-CE3-CD2' + ], + 'torsion-types': [ + 'C -N -CX-CT', 'N -CX-CT-C*', 'CX-CT-C*-CW', 'CT-C*-CW-NA', 'C*-CW-NA-CN', + 'CW-NA-CN-CA', 'NA-CN-CA-CA', 'CN-CA-CA-CA', 'CA-CA-CA-CA', 'CA-CA-CA-CB' + ], + 'torsion-vals': [ + 'p', 'p', 'p', 3.141592653589793, 0.0, 3.141592653589793, 3.141592653589793, + 0.0, 0.0, 0.0 + ], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]] + }, + + 'Y': { + 'angles-names': [ + 'N-CA-CB', 'CA-CB-CG', 'CB-CG-CD1', 'CG-CD1-CE1', 'CD1-CE1-CZ', 'CE1-CZ-OH', + 'CE1-CZ-CE2', 'CZ-CE2-CD2' + ], + 'angles-types': [ + 'N -CX-CT', 'CX-CT-CA', 'CT-CA-CA', 'CA-CA-CA', 'CA-CA-C ', 'CA-C -OH', + 'CA-C -CA', 'C -CA-CA' + ], + 'angles-vals': [ + 1.9146261894377796, 1.9896753472735358, 2.0943951023931953, + 2.0943951023931953, 2.0943951023931953, 2.0943951023931953, + 2.0943951023931953, 2.0943951023931953 + ], + 'atom-names': ['CB', 'CG', 'CD1', 'CE1', 'CZ', 'OH', 'CE2', 'CD2'], + 'bonds-names': [ + 'CA-CB', 'CB-CG', 'CG-CD1', 'CD1-CE1', 'CE1-CZ', 'CZ-OH', 'CZ-CE2', 'CE2-CD2' + ], + 'bonds-types': [ + 'CX-CT', 'CT-CA', 'CA-CA', 'CA-CA', 'CA-C ', 'C -OH', 'C -CA', 'CA-CA' + ], + 'bonds-vals': [1.526, 1.51, 1.4, 1.4, 1.409, 1.364, 1.409, 1.4], + 'torsion-names': [ + 'C-N-CA-CB', 'N-CA-CB-CG', 'CA-CB-CG-CD1', 'CB-CG-CD1-CE1', 'CG-CD1-CE1-CZ', + 'CD1-CE1-CZ-OH', 'CD1-CE1-CZ-CE2', 'CE1-CZ-CE2-CD2' + ], + 'torsion-types': [ + 'C -N -CX-CT', 'N -CX-CT-CA', 'CX-CT-CA-CA', 'CT-CA-CA-CA', 'CA-CA-CA-C ', + 'CA-CA-C -OH', 'CA-CA-C -CA', 'CA-C -CA-CA' + ], + 'torsion-vals': [ + 'p', 'p', 'p', 3.141592653589793, 0.0, 3.141592653589793, 0.0, 0.0 + ], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5], [4,5,6]], + }, + + 'V': { + 'angles-names': ['N-CA-CB', 'CA-CB-CG1', 'CA-CB-CG2'], + 'angles-types': ['N -CX-3C', 'CX-3C-CT', 'CX-3C-CT'], + 'angles-vals': [1.9146261894377796, 1.911135530933791, 1.911135530933791], + 'atom-names': ['CB', 'CG1', 'CG2'], + 'bonds-names': ['CA-CB', 'CB-CG1', 'CB-CG2'], + 'bonds-types': ['CX-3C', '3C-CT', '3C-CT'], + 'bonds-vals': [1.526, 1.526, 1.526], + 'torsion-names': ['C-N-CA-CB', 'N-CA-CB-CG1', 'N-CA-CB-CG2'], + 'torsion-types': ['C -N -CX-3C', 'N -CX-3C-CT', 'N -CX-3C-CT'], + 'torsion-vals': ['p', 'p', 'p'], + 'rigid-frames-idxs': [[0,1,2], [0,1,4], [1,4,5]] + }, + + '_': { + 'angles-names': [], + 'angles-types': [], + 'angles-vals': [], + 'atom-names': [], + 'bonds-names': [], + 'bonds-types': [], + 'bonds-vals': [], + 'torsion-names': [], + 'torsion-types': [], + 'torsion-vals': [], + 'rigid-frames-idxs': [[]], + } +} + +BB_BUILD_INFO = { + "BONDLENS": { + # the updated is according to crystal data from 1DPE_1_A and validated with other structures + # the commented is the sidechainnet one + 'n-ca': 1.4664931, # 1.442, + 'ca-c': 1.524119, # 1.498, + 'c-n': 1.3289373, # 1.379, + 'c-o': 1.229, # From parm10.dat || huge variability according to structures + # we get 1.3389416 from 1DPE_1_A but also 1.2289 from 2F2H_d2f2hf1 + 'c-oh': 1.364 + }, + # From parm10.dat, for OXT + # For placing oxygens + "BONDANGS": { + 'ca-c-o': 2.0944, # Approximated to be 2pi / 3; parm10.dat says 2.0350539 + 'ca-c-oh': 2.0944, + 'ca-c-n': 2.03, + 'n-ca-c': 1.94, + 'c-n-ca': 2.08, + }, + # Equal to 'ca-c-o', for OXT + "BONDTORSIONS": { + 'n-ca-c-n': -0.785398163, # psi (-44 deg, bimodal distro, pick one) + 'c-n-ca-c': -1.3962634015954636, # phi (-80 deg, bimodal distro, pick one) + 'ca-n-c-ca': 3.141592, # omega (180 deg - https://doi.org/10.1016/j.jmb.2005.01.065) + 'n-ca-c-o': -2.406 # oxygen + } # A simple approximation, not meant to be exact. +} + + +# numbers follow the same order as sidechainnet atoms +SCN_CONNECT = { + 'A': { + 'bonds': [[0,1], [1,2], [2,3], [1,4]] + }, + 'R': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [6,7], [7,8], [8,9], [8,10]] + }, + 'N': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [5,7]] + }, + 'D': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [5,7]] + }, + 'C': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5]] + }, + 'Q': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [6,7], [6,8]] + }, + 'E': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [6,7], [6,8]] + }, + 'G': { + 'bonds': [[0,1], [1,2], [2,3]] + }, + 'H': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [6,7], [7,8], [8,9], [5,9]] + }, + 'I': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [4,7]] + }, + 'L': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [5,7]] + }, + 'K': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [6,7], [7,8]] + }, + 'M': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [6,7]] + }, + 'F': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [6,7], [7,8], [8,9], [9,10], [5,10]] + }, + 'P': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [0,6]] + }, + 'S': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5]] + }, + 'T': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [4,6]] + }, + 'W': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [6,7], [7,8], [8,9], [9,10], [10,11], [11,12], + [12, 13], [5,13], [8,13]] + }, + 'Y': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [5,6], + [6,7], [7,8], [8,9], [8,10], [10,11], [5,11]] + }, + 'V': { + 'bonds': [[0,1], [1,2], [2,3], [1,4], [4,5], [4,6]] + }, + '_': { + 'bonds': [] + } + } + +# from: https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf +AMBIGUOUS = { + "D": {"names": [["OD1", "OD2"]], + "indexs": [[6, 7]], + }, + "E": {"names": [["OE1", "OE2"]], + "indexs": [[7, 8]], + }, + "F": {"names": [["CD1", "CD2"], ["CE1", "CE2"]], + "indexs": [[6, 10], [7, 9]], + }, + "Y": {"names": [["CD1", "CD2"], ["CE1", "CE2"]], + "indexs": [[6,10], [7,9]], + }, +} + + +# AA subst mat +BLOSUM = { + "A" : [4.0, -1.0, -2.0, -2.0, 0.0, -1.0, -1.0, 0.0, -2.0, -1.0, -1.0, -1.0, -1.0, -2.0, -1.0, 1.0, 0.0, -3.0, -2.0, 0.0, 0.0], + "C" : [-1.0, 5.0, 0.0, -2.0, -3.0, 1.0, 0.0, -2.0, 0.0, -3.0, -2.0, 2.0, -1.0, -3.0, -2.0, -1.0, -1.0, -3.0, -2.0, -3.0, 0.0], + "D" : [-2.0, 0.0, 6.0, 1.0, -3.0, 0.0, 0.0, 0.0, 1.0, -3.0, -3.0, 0.0, -2.0, -3.0, -2.0, 1.0, 0.0, -4.0, -2.0, -3.0, 0.0], + "E" : [-2.0, -2.0, 1.0, 6.0, -3.0, 0.0, 2.0, -1.0, -1.0, -3.0, -4.0, -1.0, -3.0, -3.0, -1.0, 0.0, -1.0, -4.0, -3.0, -3.0, 0.0], + "F" : [0.0, -3.0, -3.0, -3.0, 9.0, -3.0, -4.0, -3.0, -3.0, -1.0, -1.0, -3.0, -1.0, -2.0, -3.0, -1.0, -1.0, -2.0, -2.0, -1.0, 0.0], + "G" : [-1.0, 1.0, 0.0, 0.0, -3.0, 5.0, 2.0, -2.0, 0.0, -3.0, -2.0, 1.0, 0.0, -3.0, -1.0, 0.0, -1.0, -2.0, -1.0, -2.0, 0.0], + "H" : [-1.0, 0.0, 0.0, 2.0, -4.0, 2.0, 5.0, -2.0, 0.0, -3.0, -3.0, 1.0, -2.0, -3.0, -1.0, 0.0, -1.0, -3.0, -2.0, -2.0, 0.0], + "I" : [0.0, -2.0, 0.0, -1.0, -3.0, -2.0, -2.0, 6.0, -2.0, -4.0, -4.0, -2.0, -3.0, -3.0, -2.0, 0.0, -2.0, -2.0, -3.0, -3.0, 0.0], + "K" : [-2.0, 0.0, 1.0, -1.0, -3.0, 0.0, 0.0, -2.0, 8.0, -3.0, -3.0, -1.0, -2.0, -1.0, -2.0, -1.0, -2.0, -2.0, 2.0, -3.0, 0.0], + "L" : [-1.0, -3.0, -3.0, -3.0, -1.0, -3.0, -3.0, -4.0, -3.0, 4.0, 2.0, -3.0, 1.0, 0.0, -3.0, -2.0, -1.0, -3.0, -1.0, 3.0, 0.0], + "M" : [-1.0, -2.0, -3.0, -4.0, -1.0, -2.0, -3.0, -4.0, -3.0, 2.0, 4.0, -2.0, 2.0, 0.0, -3.0, -2.0, -1.0, -2.0, -1.0, 1.0, 0.0], + "N" : [-1.0, 2.0, 0.0, -1.0, -3.0, 1.0, 1.0, -2.0, -1.0, -3.0, -2.0, 5.0, -1.0, -3.0, -1.0, 0.0, -1.0, -3.0, -2.0, -2.0, 0.0], + "P" : [-1.0, -1.0, -2.0, -3.0, -1.0, 0.0, -2.0, -3.0, -2.0, 1.0, 2.0, -1.0, 5.0, 0.0, -2.0, -1.0, -1.0, -1.0, -1.0, 1.0, 0.0], + "Q" : [-2.0, -3.0, -3.0, -3.0, -2.0, -3.0, -3.0, -3.0, -1.0, 0.0, 0.0, -3.0, 0.0, 6.0, -4.0, -2.0, -2.0, 1.0, 3.0, -1.0, 0.0], + "R" : [-1.0, -2.0, -2.0, -1.0, -3.0, -1.0, -1.0, -2.0, -2.0, -3.0, -3.0, -1.0, -2.0, -4.0, 7.0, -1.0, -1.0, -4.0, -3.0, -2.0, 0.0], + "S" : [1.0, -1.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, -2.0, -2.0, 0.0, -1.0, -2.0, -1.0, 4.0, 1.0, -3.0, -2.0, -2.0, 0.0], + "T" : [0.0, -1.0, 0.0, -1.0, -1.0, -1.0, -1.0, -2.0, -2.0, -1.0, -1.0, -1.0, -1.0, -2.0, -1.0, 1.0, 5.0, -2.0, -2.0, 0.0, 0.0], + "V" : [-3.0, -3.0, -4.0, -4.0, -2.0, -2.0, -3.0, -2.0, -2.0, -3.0, -2.0, -3.0, -1.0, 1.0, -4.0, -3.0, -2.0, 11.0, 2.0, -3.0, 0.0], + "W" : [-2.0, -2.0, -2.0, -3.0, -2.0, -1.0, -2.0, -3.0, 2.0, -1.0, -1.0, -2.0, -1.0, 3.0, -3.0, -2.0, -2.0, 2.0, 7.0, -1.0, 0.0], + "Y" : [0.0, -3.0, -3.0, -3.0, -1.0, -2.0, -2.0, -3.0, -3.0, 3.0, 1.0, -2.0, 1.0, -1.0, -2.0, -2.0, 0.0, -3.0, -1.0, 4.0, 0.0], + "_" : [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], +} + + +# modified manually to match the mode +MP3SC_INFO = { + 'A': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146265, 'bond_dihedral': 2.848366} + }, + 'R': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146265, 'bond_dihedral': 2.6976738}, + 'CG': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.2}, + 'CD': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -3.141592}, + 'NE': {'bond_lens': 1.463, 'bond_angs': 1.9408059, 'bond_dihedral': -3.141592}, + 'CZ': {'bond_lens': 1.34, 'bond_angs': 2.1502457, 'bond_dihedral': -3.141592}, + 'NH1': {'bond_lens': 1.34, 'bond_angs': 2.094395, 'bond_dihedral': 0.}, + 'NH2': {'bond_lens': 1.34, 'bond_angs': 2.094395, 'bond_dihedral': -3.141592} + }, + 'N': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146265, 'bond_dihedral': 2.8416245}, + 'CG': {'bond_lens': 1.5219998, 'bond_angs': 1.9390607, 'bond_dihedral': -1.15}, + 'OD1': {'bond_lens': 1.229, 'bond_angs': 2.101376, 'bond_dihedral': -1.}, # spread out w/ mean at -1 + 'ND2': {'bond_lens': 1.3349999, 'bond_angs': 2.0350537, 'bond_dihedral': 2.14} # spread out with mean at -4 + }, + 'D': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146265, 'bond_dihedral': 2.7741134}, + 'CG': {'bond_lens': 1.522, 'bond_angs': 1.9390608, 'bond_dihedral': -1.07}, + 'OD1': {'bond_lens': 1.25, 'bond_angs': 2.0420356, 'bond_dihedral': -0.2678593}, + 'OD2': {'bond_lens': 1.25, 'bond_angs': 2.0420356, 'bond_dihedral': 2.95} + }, + 'C': {'CB': {'bond_lens': 1.5259998, 'bond_angs': 1.9146262, 'bond_dihedral': 2.553627}, + 'SG': {'bond_lens': 1.8099997, 'bond_angs': 1.8954275, 'bond_dihedral': -1.07} + }, + 'Q': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146266, 'bond_dihedral': 2.7262106}, + 'CG': {'bond_lens': 1.5260003, 'bond_angs': 1.9111353, 'bond_dihedral': -1.075}, + 'CD': {'bond_lens': 1.5219998, 'bond_angs': 1.9390606, 'bond_dihedral': -3.141592}, + 'OE1': {'bond_lens': 1.229, 'bond_angs': 2.101376, 'bond_dihedral': -1}, # bimodal at -1, +1 + 'NE2': {'bond_lens': 1.3349998, 'bond_angs': 2.0350537, 'bond_dihedral': 2.14} # bimodal at -2, -4 + }, + 'E': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146267, 'bond_dihedral': 2.7813723}, + 'CG': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.07}, # bimodal at -1.07, 3.14 + 'CD': {'bond_lens': 1.5219998, 'bond_angs': 1.9390606, 'bond_dihedral': -3.0907722155200403}, + 'OE1': {'bond_lens': 1.25, 'bond_angs': 2.0420356, 'bond_dihedral': 0.003740118}, # spread out btween -1,1 + 'OE2': {'bond_lens': 1.25, 'bond_angs': 2.0420356, 'bond_dihedral': -3.1378527} # spread out btween -4.3, -2.14 + }, + 'G': {}, + 'H': {'CB': {'bond_lens': 1.5259998, 'bond_angs': 1.9146264, 'bond_dihedral': 2.614421}, + 'CG': {'bond_lens': 1.5039998, 'bond_angs': 1.9739674, 'bond_dihedral': -1.05}, + 'ND1': {'bond_lens': 1.3850001, 'bond_angs': 2.094395, 'bond_dihedral': -1.41}, # bimodal at -1.4, 1.4 + 'CE1': {'bond_lens': 1.3430002, 'bond_angs': 1.8849558, 'bond_dihedral': 3.14}, + 'NE2': {'bond_lens': 1.335, 'bond_angs': 1.8849558, 'bond_dihedral': 0.0}, + 'CD2': {'bond_lens': 1.3940002, 'bond_angs': 1.8849558, 'bond_dihedral': 0.0} + }, + 'I': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146265, 'bond_dihedral': 2.5604365}, + 'CG1': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -1.025}, + 'CD1': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -3.0667439142810267}, + 'CG2': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -3.1225884596454065} + }, + 'L': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146265, 'bond_dihedral': 2.711971}, + 'CG': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.15}, + 'CD1': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': 3.14}, + 'CD2': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.05} + }, + 'K': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146266, 'bond_dihedral': 2.7441595}, + 'CG': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -1.15}, + 'CD': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': -3.09}, + 'CE': {'bond_lens': 1.526, 'bond_angs': 1.9111353, 'bond_dihedral': 3.092959}, + 'NZ': {'bond_lens': 1.4710001, 'bond_angs': 1.940806, 'bond_dihedral': 3.0515378} + }, + 'M': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146264, 'bond_dihedral': 2.7051392}, + 'CG': {'bond_lens': 1.526, 'bond_angs': 1.9111354, 'bond_dihedral': -1.1}, + 'SD': {'bond_lens': 1.8099998, 'bond_angs': 2.001892, 'bond_dihedral': 3.1411812}, # bimodal at 0, 3.14 + 'CE': {'bond_lens': 1.8099998, 'bond_angs': 1.7261307, 'bond_dihedral': -0.048235133} # trimodal at -1.41, 0, 1.41 + }, + 'F': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146266, 'bond_dihedral': 2.545154}, + 'CG': {'bond_lens': 1.5100001, 'bond_angs': 1.9896755, 'bond_dihedral': -1.2}, # bimodal at -1, 3.14 + 'CD1': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 1.41}, # bimodal -1.41, 1.41 + 'CE1': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 3.141592}, + 'CZ': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 0.0}, + 'CE2': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 0.0}, + 'CD2': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 0.0} + }, + 'P': {'CB': {'bond_lens': 1.5260001, 'bond_angs': 1.9146266, 'bond_dihedral': 3.141592}, + 'CG': {'bond_lens': 1.5260001, 'bond_angs': 1.9111352, 'bond_dihedral': -0.707}, # bimodal at -0.7, 0.7 + 'CD': {'bond_lens': 1.5260001, 'bond_angs': 1.9111352, 'bond_dihedral': 0.85} # bimodal at -0.85, 0.85 + }, + 'S': {'CB': {'bond_lens': 1.5260001, 'bond_angs': 1.9146266, 'bond_dihedral': 2.6017702}, + 'OG': {'bond_lens': 1.41, 'bond_angs': 1.9111352, 'bond_dihedral': 1.1} + }, + 'T': {'CB': {'bond_lens': 1.5260001, 'bond_angs': 1.9146265, 'bond_dihedral': 2.55}, + 'OG1': {'bond_lens': 1.4099998, 'bond_angs': 1.9111353, 'bond_dihedral': -1.07}, # bimodal at -1 and +1 + 'CG2': {'bond_lens': 1.5260001, 'bond_angs': 1.9111353, 'bond_dihedral': -3.05} # bimodal at -1 and -3 + }, + 'W': {'CB': {'bond_lens': 1.526, 'bond_angs': 1.9146266, 'bond_dihedral': 3.141592}, + 'CG': {'bond_lens': 1.4950002, 'bond_angs': 2.0176008, 'bond_dihedral': -1.2}, + 'CD1': {'bond_lens': 1.3520001, 'bond_angs': 2.1816616, 'bond_dihedral': 1.53}, + 'NE1': {'bond_lens': 1.3810003, 'bond_angs': 1.8971729, 'bond_dihedral': 3.141592}, + 'CE2': {'bond_lens': 1.3799998, 'bond_angs': 1.9477878, 'bond_dihedral': 0.0}, + 'CZ2': {'bond_lens': 1.3999999, 'bond_angs': 2.317797, 'bond_dihedral': 3.141592}, + 'CH2': {'bond_lens': 1.3999999, 'bond_angs': 2.094395, 'bond_dihedral': 3.141592}, + 'CZ3': {'bond_lens': 1.3999999, 'bond_angs': 2.094395, 'bond_dihedral': 0.0}, + 'CE3': {'bond_lens': 1.3999999, 'bond_angs': 2.094395, 'bond_dihedral': 0.0}, + 'CD2': {'bond_lens': 1.404, 'bond_angs': 2.094395, 'bond_dihedral': 0.0} + }, + 'Y': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146266, 'bond_dihedral': 3.1}, + 'CG': {'bond_lens': 1.5100001, 'bond_angs': 1.9896754, 'bond_dihedral': -1.1}, + 'CD1': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 1.36}, + 'CE1': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 3.141592}, + 'CZ': {'bond_lens': 1.4090003, 'bond_angs': 2.094395, 'bond_dihedral': 0.0}, + 'OH': {'bond_lens': 1.3640002, 'bond_angs': 2.094395, 'bond_dihedral': 3.141592}, + 'CE2': {'bond_lens': 1.4090003, 'bond_angs': 2.094395, 'bond_dihedral': 0.0}, + 'CD2': {'bond_lens': 1.3999997, 'bond_angs': 2.094395, 'bond_dihedral': 0.0} + }, + 'V': {'CB': {'bond_lens': 1.5260003, 'bond_angs': 1.9146266, 'bond_dihedral': 2.55}, + 'CG1': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': 3.141592}, + 'CG2': {'bond_lens': 1.5260003, 'bond_angs': 1.9111352, 'bond_dihedral': -1.1} + }, + + '_': {} +} + +# experimentally checked distances +FF = {"MIN_DISTS": {1: 1.180, # shortest =N or =O bond + 2: 2.138, # N-N in histidine group + 3: 2.380}, # N-N in backbone (N-CA-C-N) + "MAX_DISTS": {i: 1.840*i for i in range(1, 5+1)} # 1.84 is longest -S bond found, + } + +ATOM_TOKEN_IDS = set(["", "N", "CA", "C", "O"]) +ATOM_TOKEN_IDS = {k: i for i,k in enumerate(sorted( + ATOM_TOKEN_IDS.union( set( + [name for k,v in SC_BUILD_INFO.items() for name in v["atom-names"]] + ) ) + ))} + +################# +##### DOERS ##### +################# + +def make_cloud_mask(aa): + """ relevent points will be 1. paddings will be 0. """ + mask = np.zeros(14) + if aa != "_": + n_atoms = 4+len( SC_BUILD_INFO[aa]["atom-names"] ) + mask[:n_atoms] = True + return mask + +def make_bond_mask(aa): + """ Gives the length of the bond originating each atom. """ + mask = np.zeros(14) + # backbone + if aa != "_": + mask[0] = BB_BUILD_INFO["BONDLENS"]['c-n'] + mask[1] = BB_BUILD_INFO["BONDLENS"]['n-ca'] + mask[2] = BB_BUILD_INFO["BONDLENS"]['ca-c'] + mask[3] = BB_BUILD_INFO["BONDLENS"]['c-o'] + # sidechain - except padding token + if aa in SC_BUILD_INFO.keys(): + for i,bond in enumerate(SC_BUILD_INFO[aa]['bonds-vals']): + mask[4+i] = bond + return mask + +def make_theta_mask(aa): + """ Gives the theta of the bond originating each atom. """ + mask = np.zeros(14) + # backbone + if aa != "_": + mask[0] = BB_BUILD_INFO["BONDANGS"]['ca-c-n'] # nitrogen + mask[1] = BB_BUILD_INFO["BONDANGS"]['c-n-ca'] # c_alpha + mask[2] = BB_BUILD_INFO["BONDANGS"]['n-ca-c'] # carbon + mask[3] = BB_BUILD_INFO["BONDANGS"]['ca-c-o'] # oxygen + # sidechain + for i,theta in enumerate(SC_BUILD_INFO[aa]['angles-vals']): + mask[4+i] = theta + return mask + +def make_torsion_mask(aa, fill=False): + """ Gives the dihedral of the bond originating each atom. """ + mask = np.zeros(14) + if aa != "_": + # backbone + mask[0] = BB_BUILD_INFO["BONDTORSIONS"]['n-ca-c-n'] # psi + mask[1] = BB_BUILD_INFO["BONDTORSIONS"]['ca-n-c-ca'] # omega + mask[2] = BB_BUILD_INFO["BONDTORSIONS"]['c-n-ca-c'] # psi + mask[3] = BB_BUILD_INFO["BONDTORSIONS"]['n-ca-c-o'] # oxygen + # sidechain + for i, torsion in enumerate(SC_BUILD_INFO[aa]['torsion-vals']): + if fill: + mask[4+i] = MP3SC_INFO[aa][ SC_BUILD_INFO[aa]["atom-names"][i] ]["bond_dihedral"] + else: + # https://github.com/jonathanking/sidechainnet/blob/master/sidechainnet/structure/StructureBuilder.py#L372 + # 999 is an anotation -- change later || same for 555 + mask[4+i] = np.nan if torsion == 'p' else 999 if torsion == "i" else torsion + return mask + +def make_idx_mask(aa): + """ Gives the idxs of the 3 previous points. """ + mask = np.zeros((11, 3)) + if aa != "_": + # backbone + mask[0, :] = np.arange(3) + # sidechain + mapper = {"N": 0, "CA": 1, "C":2, "CB": 4} + for i, torsion in enumerate(SC_BUILD_INFO[aa]['torsion-names']): + # get all the atoms forming the dihedral + torsions = [x.rstrip(" ") for x in torsion.split("-")] + # for each atom + for n, torsion in enumerate(torsions[:-1]): + # get the index of the atom in the coords array + loc = mapper[torsion] if torsion in mapper.keys() else 4 + SC_BUILD_INFO[aa]['atom-names'].index(torsion) + # set position to index + mask[i+1][n] = loc + return mask + +def make_atom_token_mask(aa): + """ Return the tokens for each atom in the aa. """ + mask = np.zeros(14) + # get atom id + if aa != "_": + atom_list = ["N", "CA", "C", "O"] + SC_BUILD_INFO[ aa ]["atom-names"] + for i,atom in enumerate(atom_list): + mask[i] = ATOM_TOKEN_IDS[atom] + return mask + + +################### +##### GETTERS ##### +################### +INDEX2AAS = "ACDEFGHIKLMNPQRSTVWY_" +AAS2INDEX = {aa:i for i,aa in enumerate(INDEX2AAS)} +SUPREME_INFO = {k: {"cloud_mask": make_cloud_mask(k), + "bond_mask": make_bond_mask(k), + "theta_mask": make_theta_mask(k), + "torsion_mask": make_torsion_mask(k), + "torsion_mask_filled": make_torsion_mask(k, fill=True), + "idx_mask": make_idx_mask(k), + "atom_token_mask": make_atom_token_mask(k), + "rigid_idx_mask": SC_BUILD_INFO[k]['rigid-frames-idxs'], + } + for k in INDEX2AAS} + diff --git a/rgn2_replica/mp_nerf/massive_pnerf.py b/rgn2_replica/mp_nerf/massive_pnerf.py new file mode 100644 index 0000000..cf0d43d --- /dev/null +++ b/rgn2_replica/mp_nerf/massive_pnerf.py @@ -0,0 +1,67 @@ +import time +import numpy as np +# diff ml +import torch +from einops import repeat + + +def get_axis_matrix(a, b, c, norm=True): + """ Gets an orthonomal basis as a matrix of [e1, e2, e3]. + Useful for constructing rotation matrices between planes + according to the first answer here: + https://math.stackexchange.com/questions/1876615/rotation-matrix-from-plane-a-to-b + Inputs: + * a: (batch, 3) or (3, ). point(s) of the plane + * b: (batch, 3) or (3, ). point(s) of the plane + * c: (batch, 3) or (3, ). point(s) of the plane + Outputs: orthonomal basis as a matrix of [e1, e2, e3]. calculated as: + * e1_ = (c-b) + * e2_proto = (b-a) + * e3_ = e1_ ^ e2_proto + * e2_ = e3_ ^ e1_ + * basis = normalize_by_vectors( [e1_, e2_, e3_] ) + Note: Could be done more by Grahm-Schmidt and extend to N-dimensions + but this is faster and more intuitive for 3D. + """ + v1_ = c - b + v2_ = b - a + v3_ = torch.cross(v1_, v2_, dim=-1) + v2_ready = torch.cross(v3_, v1_, dim=-1) + basis = torch.stack([v1_, v2_ready, v3_], dim=-2) + # normalize if needed + if norm: + return basis / torch.norm(basis, dim=-1, keepdim=True) + return basis + + + +def mp_nerf_torch(a, b, c, l, theta, chi): + """ Custom Natural extension of Reference Frame. + Inputs: + * a: (batch, 3) or (3,). point(s) of the plane, not connected to d + * b: (batch, 3) or (3,). point(s) of the plane, not connected to d + * c: (batch, 3) or (3,). point(s) of the plane, connected to d + * theta: (batch,) or (float). angle(s) between b-c-d + * chi: (batch,) or float. dihedral angle(s) between the a-b-c and b-c-d planes + Outputs: d (batch, 3) or (float). the next point in the sequence, linked to c + """ + # safety check + if not ( (-np.pi <= theta) * (theta <= np.pi) ).all().item(): + raise ValueError(f"theta(s) must be in radians and in [-pi, pi]. theta(s) = {theta}") + # calc vecs + ba = b-a + cb = c-b + # calc rotation matrix. based on plane normals and normalized + n_plane = torch.cross(ba, cb, dim=-1) + n_plane_ = torch.cross(n_plane, cb, dim=-1) + rotate = torch.stack([cb, n_plane_, n_plane], dim=-1) + rotate /= torch.norm(rotate, dim=-2, keepdim=True) + # calc proto point, rotate. add (-1 for sidechainnet convention) + # https://github.com/jonathanking/sidechainnet/issues/14 + d = torch.stack([-torch.cos(theta), + torch.sin(theta) * torch.cos(chi), + torch.sin(theta) * torch.sin(chi)], dim=-1).unsqueeze(-1) + # extend base point, set length + return c + l.unsqueeze(-1) * torch.matmul(rotate, d).squeeze() + + diff --git a/rgn2_replica/mp_nerf/ml_utils.py b/rgn2_replica/mp_nerf/ml_utils.py new file mode 100644 index 0000000..fb5cf10 --- /dev/null +++ b/rgn2_replica/mp_nerf/ml_utils.py @@ -0,0 +1,435 @@ +# Author: Eric Alcaide + +# module +import torch +# from rgn2_replica.mp_nerf.utils import * +from rgn2_replica.mp_nerf.massive_pnerf import * +from rgn2_replica.mp_nerf.kb_proteins import * +from rgn2_replica.mp_nerf.proteins import * +from einops import rearrange, repeat + +def scn_atom_embedd(seq_list): + """ Returns the token for each atom in the aa seq. + Inputs: + * seq_list: list of FASTA sequences. same length + """ + batch_tokens = [] + # do loop in cpu + for i,seq in enumerate(seq_list): + batch_tokens.append( torch.tensor([SUPREME_INFO[aa]["atom_token_mask"] \ + for aa in seq]) ) + batch_tokens = torch.stack(batch_tokens, dim=0).long() + return batch_tokens + + +def chain2atoms(x, mask=None, c=3): + """ Expand from (L, other) to (L, C, other). """ + wrap = repeat( x, 'l ... -> l c ...', c=c ) + if mask is not None: + return wrap[mask] + return wrap + + +###################### +# from: https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf + +def rename_symmetric_atoms(pred_coors, true_coors, seq_list, cloud_mask, pred_feats=None): + """ Corrects ambiguous atoms (due to 180 torsions - ambiguous sidechains). + Inputs: + * pred_coors: (batch, L, 14, 3) float. sidechainnet format (see mp_nerf.kb_proteins) + * true_coors: (batch, L, 14, 3) float. sidechainnet format (see mp_nerf.kb_proteins) + * seq_list: list of FASTA sequences + * cloud_mask: (batch, L, 14) bool. mask for present atoms + * pred_feats: (batch, L, 14, D) optional. atom-wise predicted features + + Warning! A coordinate might be missing. TODO: + Outputs: pred_coors, pred_feats + """ + aux_cloud_mask = cloud_mask.clone() # will be manipulated + + for i,seq in enumerate(seq_list): + for aa, pairs in AMBIGUOUS.items(): + # indexes of aas in chain - check coords are given for aa + amb_idxs = np.array(pairs["indexs"]).flatten().tolist() + idxs = torch.tensor([ + k for k,s in enumerate(seq) if s==aa and \ + k in set( torch.nonzero(aux_cloud_mask[i, :, amb_idxs].sum(dim=-1)).tolist()[0] ) + ]).long() + # check if any AAs matching + if idxs.shape[0] == 0: + continue + # get indexes of non-ambiguous + aux_cloud_mask[i, idxs, amb_idxs] = False + non_amb_idx = torch.nonzero(aux_cloud_mask[i, idxs[0]]).tolist() + for a, pair in enumerate(pairs["indexs"]): + # calc distances + d_ij_pred = torch.cdist(pred_coors[ i, idxs, pair ], pred_coors[i, idxs, non_amb_idx], p=2) # 2, N + d_ij_true = torch.cdist(true_coors[ i, idxs, pair+pair[::-1] ], true_coors[i, idxs, non_amb_idx], p=2) # 2, 2N + # see if alternative is better (less distance) + idxs_to_change = ( (d_ij_pred - d_ij_true[2:]).sum(dim=-1) < (d_ij_pred - d_ij_true[:2]).sum(dim=-1) ).nonzero() + # change those + pred_coors[i, idxs[idxs_to_change], pair] = pred_coors[i, idxs[idxs_to_change], pair[::-1]] + if pred_feats is not None: + pred_feats[i, idxs[idxs_to_change], pair] = pred_feats[i, idxs[idxs_to_change], pair[::-1]] + + return pred_coors, pred_feats + + +def angle_to_point_in_circum(angles): + """ Converts an angle to a point in the unit circumference. + Inputs: + * angles: tensor of (any) shape. + Outputs: (any, 2) + """ + # ensure no last dummy dim + if len(angles.shape) == 0: + angles = angles.unsqueeze(0) + elif angles.shape[-1] == 1 and len(angles.shape) > 1 : + angles = angles[..., 0] + + return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1) + +def point_in_circum_to_angle(points): + """ Converts a point in the circumference to an angle + Inputs: + * poits: (any, 2) + Outputs: (any) + """ + # ensure first dim + if len(points.shape) == 1: + points = points.unsqueeze(0) + + return torch.atan2(points[..., points.shape[-1] // 2:], + points[..., :points.shape[-1] // 2] ) + + +def torsion_angle_loss(pred_torsions=None, true_torsions=None, + pred_points=None, true_points=None, + alt_true_points=None, alt_true_torsions=None, + coeff=2., norm_coeff=1e-2, angle_mask=None): + """ Computes a loss on the angles as the cosine of the difference. + Equivalent to an L2 on the unit circle. + Due to angle periodicity, for angle inputs, calculate the + disparity on both sides. + Alternative truths should only be passed if not previous renaming. + Inputs: + * pred_torsions: ( (B), L, X ) float. Predicted torsion angles.(-pi, pi) + Same format as sidechainnet. + * true_torsions: ( (B), L, X ) true torsion angles. (-pi, pi) + * pred_points: ( (B), L, X, 2) float. Predicted points in circum. + * true_points: ( (B), L, X, 2) float. true points in circum. + * alt_true_torsions: ( (B), L, X ) alt true torsion angles. (-pi, pi) + * alt_true_points: ( (B), L, X, 2) float. alt true points in circum. + * coeff: float. weight coefficient + * norm_coeff: float. coefficient for norm term. avoids big outputs. + * angle_mask: ((B), L, (X)) bool. Masks the non-existing angles. + Outputs: ( (B), L*X_masked ) 2*cosine difference + 0.02*norm + """ + # convert to sin·cos rep if not available + if pred_torsions is not None and pred_points is None: + pred_points = angle_to_point_in_circum(pred_torsions) + if true_torsions is not None and true_points is None: + true_points = angle_to_point_in_circum(true_torsions) + if alt_true_torsions is not None and alt_true_points is None: + alt_true_points = angle_to_point_in_circum(alt_true_torsions) + + # calc norm of angles + norm = torch.norm(pred_points, dim=-1) + angle_norm_loss = norm_coeff * (1-norm).abs() + + # do L2 on unit circle + pred_points = pred_points / norm.unsqueeze(-1) + torsion_loss = torch.pow(pred_points - true_points, 2).sum(dim=-1) + + if alt_true_points is not None: + torsion_loss = torch.minimum( + torsion_loss, + torch.pow(pred_points - alt_true_points, 2).sum(dim=-1) + ) + if coeff != 2.: + torsion_loss *= coeff/2 + + if angle_mask is None: + angle_mask = torch.ones(*pred_points.shape[:-1], dtype=torch.bool) + + return (torsion_loss + angle_norm_loss)[angle_mask] + + +def fape_torch(pred_coords, true_coords, max_val=10., d_clamp=10., l_func=None, + partial=None, seq_list=None, rot_mats_g=None, max_points=10000): + """ Computes the Frame-Aligned Point Error. Scaled 0 <= FAPE <= 1 + Even if computed only on C-alphas, all backbone atoms (N-CA-C) + must be passed to build the frames. + Inputs: + * pred_coords: (B, L, C, 3) or (B, (l c), 3) predicted coordinates. + * true_coords: (B, L, C, 3) or (B, (l c), 3) ground truth coordinates. + * max_val: float. number to divide by - the final loss + * d_clamp: float. the radius due to L1 usage + * l_func: function. allow for options other than l1 (consider dRMSD maybe) + * partial: str or None. one of ["c_alpha"]. + * seq_list: list of strs (FASTA sequences). to calculate rigid bodies' indexs. + Defaults to C-alpha if not passed. + * rot_mats_g: optional. List of n_seqs x (N_frames, 3, 3) rotation matrices. + * max_points: int. maximum points to rotate at once. + the higher, the more batching allowed. + Outputs: (B, N_atoms) + """ + fape_store = [] + if l_func is None: + l_func = lambda x,y,eps=1e-7,sup=d_clamp: (((x-y)**2).sum(dim=-1) + \ + eps).sqrt().clamp(0, sup) + # for chain + for s in range(pred_coords.shape[0]): + fape_store.append(0) + cloud_mask = (torch.abs(true_coords[s]).sum(dim=-1) != 0) + # center both structures + pred_center = pred_coords[s] - pred_coords[s, cloud_mask].mean(dim=0, keepdim=True) + true_center = true_coords[s] - true_coords[s, cloud_mask].mean(dim=0, keepdim=True) + # convert to (B, L*C, 3) + pred_center = rearrange(pred_center, 'l c d -> (l c) d') + true_center = rearrange(true_center, 'l c d -> (l c) d') + mask_center = rearrange(cloud_mask, 'l c -> (l c)') + # get frames and conversions - same scheme as in mp_nerf proteins' concat of monomers + if rot_mats_g is None: + rigid_idxs = scn_rigid_index_mask(seq_list[s], c_alpha=partial=="c_alpha") + true_frames = get_axis_matrix(*true_center[rigid_idxs], norm=True) + pred_frames = get_axis_matrix(*pred_center[rigid_idxs], norm=True) + rot_mats = torch.matmul(torch.transpose(pred_frames, -1, -2), true_frames).detach() + else: + rot_mats = rot_mats_g[s] + + # calculate loss only on c_alphas + if partial is not None: + mask_center = torch.zeros_like(mask_center, dtype=torch.bool) + if partial == "c_alpha": # only keep c-alphas + mask_center[np.arange(0, pred_coords.shape[1]) * 14 + 1] = \ + mask_center[np.arange(0, pred_coords.shape[1]) * 14 + 1] + True + else: # only keep backbone(+cb) frames' atoms + mask_center[rigid_idxs] = mask_center[rigid_idxs] + True + + pred_center = pred_center[mask_center] + true_center = true_center[mask_center] + + # return pred_center, true_center, mask_center, rot_mats + # measure errors - for residue + num = 0 + batch_size = max(1, int( max_points // pred_center.shape[0] ) ) + + while num <= rot_mats.shape[0]: + fape_store[s] = fape_store[s] + l_func( + pred_center @ rot_mats[num:num+batch_size], # (L_, D) + true_center # (L_, D) + ).sum(dim=0) + + num += batch_size + + fape_store[s] /= rot_mats.shape[0] # take mean + + # stack and average + return (1/max_val) * torch.stack(fape_store, dim=0) + + +# custom + +def atom_selector(scn_seq, x, option=None, discard_absent=True): + """ Returns a selection of the atoms in a protein. + Inputs: + * scn_seq: (batch, len) sidechainnet format or list of strings + * x: (batch, (len * n_aa), dims) sidechainnet format + * option: one of [torch.tensor, 'backbone-only', 'backbone-with-cbeta', + 'all', 'backbone-with-oxygen', 'backbone-with-cbeta-and-oxygen'] + * discard_absent: bool. Whether to discard the points for which + there are no labels (bad recordings) + """ + + + # get mask + present = [] + for i,seq in enumerate(scn_seq): + pass_x = x[i] if discard_absent else None + if pass_x is None and isinstance(seq, torch.Tensor): + seq = "".join([INDEX2AAS[x] for x in seq.cpu().detach().tolist()]) + + present.append( scn_cloud_mask(seq, coords=pass_x) ) + + present = torch.stack(present, dim=0).bool() + + + # atom mask + if isinstance(option, str): + atom_mask = torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + if "backbone" in option: + atom_mask[[0, 2]] = 1 + + if option == "backbone": + pass + elif option == 'backbone-with-oxygen': + atom_mask[3] = 1 + elif option == 'backbone-with-cbeta': + atom_mask[5] = 1 + elif option == 'backbone-with-cbeta-and-oxygen': + atom_mask[3] = 1 + atom_mask[5] = 1 + elif option == 'all': + atom_mask[:] = 1 + else: + print("Your string doesn't match any option.") + + elif isinstance(option, torch.Tensor): + atom_mask = option + else: + raise ValueError('option needs to be a valid string or a mask tensor of shape (14,) ') + + mask = rearrange(present * atom_mask.unsqueeze(0).unsqueeze(0).bool(), 'b l c -> b (l c)') + return x[mask], mask + + +def noise_internals(seq, angles=None, coords=None, noise_scale=0.5, theta_scale=0.5, verbose=0): + """ Noises the internal coordinates -> dihedral and bond angles. + Inputs: + * seq: string. Sequence in FASTA format + * angles: (l, 11) sidechainnet angles tensor + * coords: (l, 14, 13) + * noise_scale: float. std of noise gaussian. + * theta_scale: float. multiplier for bond angles + Outputs: + * chain (l, c, d) + * cloud_mask (l, c) + """ + assert angles is not None or coords is not None, \ + "You must pass either angles or coordinates" + # get scaffolds + if angles is None: + angles = torch.randn(coords.shape[0], 12).to(coords.device) + + scaffolds = build_scaffolds_from_scn_angles(seq, angles.clone()) + + if coords is not None: + scaffolds = modify_scaffolds_with_coords(scaffolds, coords) + + # noise bond angles and dihedrals (dihedrals of everyone, angles only of BB) + if noise_scale > 0.: + if verbose: + print("noising", noise_scale) + # thetas (half of noise of dihedrals. only for BB) + noised_bb = scaffolds["angles_mask"][0, :, :3].clone() + noised_bb += theta_scale*noise_scale * torch.randn_like(noised_bb) + # get noised values between [-pi, pi] + off_bounds = (noised_bb > 2*np.pi) + (noised_bb < -2*np.pi) + if off_bounds.sum().item() > 0: + noised_bb[off_bounds] = noised_bb[off_bounds] % (2*np.pi) + + upper, lower = noised_bb > np.pi, noised_bb < -np.pi + if upper.sum().item() > 0: + noised_bb[upper] = - ( 2*np.pi - noised_bb[upper] ).clone() + if lower.sum().item() > 0: + noised_bb[lower] = 2*np.pi + noised_bb[lower].clone() + scaffolds["angles_mask"][0, :, :3] = noised_bb + + # dihedrals + noised_dihedrals = scaffolds["angles_mask"][1].clone() + noised_dihedrals += noise_scale * torch.randn_like(noised_dihedrals) + # get noised values between [-pi, pi] + off_bounds = (noised_dihedrals > 2*np.pi) + (noised_dihedrals < -2*np.pi) + if off_bounds.sum().item() > 0: + noised_dihedrals[off_bounds] = noised_dihedrals[off_bounds] % (2*np.pi) + + upper, lower = noised_dihedrals > np.pi, noised_dihedrals < -np.pi + if upper.sum().item() > 0: + noised_dihedrals[upper] = - ( 2*np.pi - noised_dihedrals[upper] ).clone() + if lower.sum().item() > 0: + noised_dihedrals[lower] = 2*np.pi + noised_dihedrals[lower].clone() + scaffolds["angles_mask"][1] = noised_dihedrals + + # reconstruct + return protein_fold(**scaffolds) + + +def combine_noise(true_coords, seq=None, int_seq=None, angles=None, + NOISE_INTERNALS=1e-2, INTERNALS_SCN_SCALE=5., + SIDECHAIN_RECONSTRUCT=True): + """ Combines noises. For internal noise, no points can be missing. + Inputs: + * true_coords: ((B), N, D) + * int_seq: (N,) torch long tensor of sidechainnet AA tokens + * seq: str of length N. FASTA AAs. + * angles: (N_aa, D_). optional. used for internal noising + * NOISE_INTERNALS: float. amount of noise for internal coordinates. + * SIDECHAIN_RECONSTRUCT: bool. whether to discard the sidechain and + rebuild by sampling from plausible distro. + Outputs: (B, N, D) coords and (B, N) boolean mask + """ + # get seqs right + assert int_seq is not None or seq is not None, "Either int_seq or seq must be passed" + if int_seq is not None and seq is None: + seq = "".join([INDEX2AAS[x] for x in int_seq.cpu().detach().tolist()]) + elif int_seq is None and seq is not None: + int_seq = torch.tensor([AAS2INDEX[x] for x in seq.upper()], device=true_coords.device) + + cloud_mask_flat = (true_coords == 0.).sum(dim=-1) != true_coords.shape[-1] + naive_cloud_mask = scn_cloud_mask(seq).bool() + + if NOISE_INTERNALS: + assert cloud_mask_flat.sum().item() == naive_cloud_mask.sum().item(), \ + "atoms missing: {0}".format( naive_cloud_mask.sum().item() - \ + cloud_mask_flat.sum().item() ) + # expand to batch dim if needed + if len(true_coords.shape) < 3: + true_coords = true_coords.unsqueeze(0) + noised_coords = true_coords.clone() + coords_scn = rearrange(true_coords, 'b (l c) d -> b l c d', c=14) + + ###### SETP 1: internals ######### + if NOISE_INTERNALS: + # create noised and masked noised coords + noised_coords, cloud_mask = noise_internals(seq, angles = angles, + coords = coords_scn.squeeze(), + noise_scale = NOISE_INTERNALS, + theta_scale = INTERNALS_SCN_SCALE, + verbose = False) + masked_noised = noised_coords[naive_cloud_mask] + noised_coords = rearrange(noised_coords, 'l c d -> () (l c) d') + + ###### SETP 2: build from backbone ######### + if SIDECHAIN_RECONSTRUCT: + bb, mask = atom_selector(int_seq.unsqueeze(0), noised_coords, option="backbone", discard_absent=False) + scaffolds = build_scaffolds_from_scn_angles(seq, angles=None, device="cpu") + noised_coords[~mask] = 0. + noised_coords = rearrange(noised_coords, '() (l c) d -> l c d', c=14) + noised_coords, _ = sidechain_fold(wrapper = noised_coords.cpu(), **scaffolds, c_beta = False) + noised_coords = rearrange(noised_coords, 'l c d -> () (l c) d').to(true_coords.device) + + + return noised_coords, cloud_mask_flat + + + +if __name__ == "__main__": + import joblib + # imports of data (from mp_nerf.utils.get_prot) + prots = joblib.load("some_route_to_local_serialized_file_with_prots") + + # set params + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # unpack and test + seq, int_seq, true_coords, angles, padding_seq, mask, pid = prots[-1] + + true_coords = true_coords.unsqueeze(0) + + # check noised internals + coords_scn = rearrange(true_coords, 'b (l c) d -> b l c d', c=14) + cloud, cloud_mask = noise_internals(seq, angles=angles, coords=coords_scn[0], noise_scale=1.) + print("cloud.shape", cloud.shape) + + # check integral + integral, mask = combine_noise(true_coords, seq=seq, int_seq = None, angles=None, + NOISE_INTERNALS=1e-2, SIDECHAIN_RECONSTRUCT=True) + print("integral.shape", integral.shape) + + integral, mask = combine_noise(true_coords, seq=None, int_seq = int_seq, angles=None, + NOISE_INTERNALS=1e-2, SIDECHAIN_RECONSTRUCT=True) + print("integral.shape2", integral.shape) + + + diff --git a/rgn2_replica/mp_nerf/proteins.py b/rgn2_replica/mp_nerf/proteins.py new file mode 100644 index 0000000..f5433de --- /dev/null +++ b/rgn2_replica/mp_nerf/proteins.py @@ -0,0 +1,536 @@ +# science +# diff / ml +# module +import torch.nn.functional as F +from rgn2_replica.mp_nerf.utils import * +from rgn2_replica.mp_nerf.ml_utils import * +from rgn2_replica.mp_nerf.massive_pnerf import * +from rgn2_replica.mp_nerf.kb_proteins import * +from einops import rearrange, repeat + + +def scn_cloud_mask(seq, coords=None, strict=False): + """ Gets the boolean mask atom positions (not all aas have same atoms). + Inputs: + * seqs: (length) iterable of 1-letter aa codes of a protein + * coords: optional .(batch, lc, 3). sidechainnet coords. + returns the true mask (solves potential atoms that might not be provided) + * strict: bool. whther to discard the next points after a missing one + Outputs: (length, 14) boolean mask + """ + if coords is not None: + start = (( rearrange(coords, 'b (l c) d -> b l c d', c=14) != 0 ).sum(dim=-1) != 0).float() + # if a point is 0, the following are 0s as well + if strict: + for b in range(start.shape[0]): + for pos in range(start.shape[1]): + for chain in range(start.shape[2]): + if start[b, pos, chain].item() == 0: + start[b, pos, chain:] *= 0 + return start + return torch.tensor([SUPREME_INFO[aa]['cloud_mask'] for aa in seq]) + + +def scn_bond_mask(seq): + """ Inputs: + * seqs: (length). iterable of 1-letter aa codes of a protein + Outputs: (L, 14) maps point to bond length + """ + return torch.tensor([SUPREME_INFO[aa]['bond_mask'] for aa in seq]) + + +def scn_angle_mask(seq, angles=None, device=None): + """ Inputs: + * seq: (length). iterable of 1-letter aa codes of a protein + * angles: (length, 12). [phi, psi, omega, b_angle(n_ca_c), b_angle(ca_c_n), b_angle(c_n_ca), 6_scn_torsions] + Outputs: (L, 14) maps point to theta and dihedral. + first angle is theta, second is dihedral + """ + device = angles.device if angles is not None else torch.device("cpu") + precise = angles.dtype if angles is not None else torch.get_default_dtype() + torsion_mask_use = "torsion_mask" if angles is not None else "torsion_mask_filled" + # get masks + theta_mask = torch.tensor([SUPREME_INFO[aa]['theta_mask'] for aa in seq], dtype=precise).to(device) + torsion_mask = torch.tensor([SUPREME_INFO[aa][torsion_mask_use] for aa in seq], dtype=precise).to(device) + + # adapt general to specific angles if passed + if angles is not None: + # fill masks with angle values + theta_mask[:, 0] = angles[:, 4] # ca_c_n + theta_mask[1:, 1] = angles[:-1, 5] # c_n_ca + theta_mask[:, 2] = angles[:, 3] # n_ca_c + # backbone_torsions + torsion_mask[:, 0] = angles[:, 1] # n determined by psi of previous + torsion_mask[1:, 1] = angles[:-1, 2] # ca determined by omega of previous + torsion_mask[:, 2] = angles[:, 0] # c determined by phi + # https://github.com/jonathanking/sidechainnet/blob/master/sidechainnet/structure/StructureBuilder.py#L313 + torsion_mask[:, 3] = angles[:, 1] - np.pi + + # add torsions to sidechains - no need to modify indexes due to torsion modification + # since extra rigid modies are in terminal positions in sidechain + to_fill = torsion_mask != torsion_mask # "p" fill with passed values + to_pick = torsion_mask == 999 # "i" infer from previous one + for i,aa in enumerate(seq): + # check if any is nan -> fill the holes + number = to_fill[i].long().sum() + torsion_mask[i, to_fill[i]] = angles[i, 6:6+number] + + # pick previous value for inferred torsions + for j, val in enumerate(to_pick[i]): + if val: + torsion_mask[i, j] = torsion_mask[i, j-1] - np.pi # pick values from last one. + + # special rigid bodies anomalies: + if aa == "I": # scn_torsion(CG1) - scn_torsion(CG2) = 2.13 (see KB) + torsion_mask[i, 7] += torsion_mask[i, 5] + elif aa == "L": + torsion_mask[i, 7] += torsion_mask[i, 6] + + + torsion_mask[-1, 3] += np.pi + return torch.stack([theta_mask, torsion_mask], dim=0) + + +def scn_index_mask(seq): + """ Inputs: + * seq: (length). iterable of 1-letter aa codes of a protein + Outputs: (L, 11, 3) maps point to theta and dihedral. + first angle is theta, second is dihedral + """ + idxs = torch.tensor([SUPREME_INFO[aa]['idx_mask'] for aa in seq]) + return rearrange(idxs, 'l s d -> d l s') + + +def scn_rigid_index_mask(seq, c_alpha=None): + """ Inputs: + * seq: (length). iterable of 1-letter aa codes of a protein + * c_alpha: part of the chain to compute frames on. + Outputs: (3, Length * Groups). indexes for 1st, 2nd and 3rd point + to construct frames for each group. + """ + maxi = 1 if c_alpha else None + + return torch.cat([torch.tensor(SUPREME_INFO[aa]['rigid_idx_mask'])[:maxi] if i==0 else \ + torch.tensor(SUPREME_INFO[aa]['rigid_idx_mask'])[:maxi] + 14*i \ + for i,aa in enumerate(seq)], dim=0).t() + + +def build_scaffolds_from_scn_angles(seq, angles=None, coords=None, device="auto"): + """ Builds scaffolds for fast access to data + Inputs: + * seq: string of aas (1 letter code) + * angles: (L, 12) tensor containing the internal angles. + Distributed as follows (following sidechainnet convention): + * (L, 3) for torsion angles + * (L, 3) bond angles + * (L, 6) sidechain angles + * coords: (L, 3) sidechainnet coords. builds the mask with those instead + (better accuracy if modified residues present). + Outputs: + * cloud_mask: (L, 14 ) mask of points that should be converted to coords + * point_ref_mask: (3, L, 11) maps point (except n-ca-c) to idxs of + previous 3 points in the coords array + * angles_mask: (2, L, 14) maps point to theta and dihedral + * bond_mask: (L, 14) gives the length of the bond originating that atom + """ + # auto infer device and precision + precise = angles.dtype if angles is not None else torch.get_default_dtype() + if device == "auto": + device = angles.device if angles is not None else device + + if coords is not None: + cloud_mask = scn_cloud_mask(seq, coords=coords) + else: + cloud_mask = scn_cloud_mask(seq) + + cloud_mask = cloud_mask.bool().to(device) + + point_ref_mask = scn_index_mask(seq).long().to(device) + + angles_mask = scn_angle_mask(seq, angles).to(device, precise) + + bond_mask = scn_bond_mask(seq).to(device, precise) + # return all in a dict + return {"cloud_mask": cloud_mask, + "point_ref_mask": point_ref_mask, + "angles_mask": angles_mask, + "bond_mask": bond_mask } + + +############################# +####### ENCODERS ############ +############################# + + +def modify_angles_mask_with_torsions(seq, angles_mask, torsions): + """ Modifies a torsion mask to include variable torsions. + Inputs: + * seq: (L,) str. FASTA sequence + * angles_mask: (2, L, 14) float tensor of (angles, torsions) + * torsions: (L, 4) float tensor (or (L, 5) if it includes torsion for cb) + Outputs: (2, L, 14) a new angles mask + """ + c_beta = torsions.shape[-1] == 5 # whether c_beta torsion is passed as well + start = 4 if c_beta else 5 + # get mask of to-fill values + torsion_mask = torch.tensor([SUPREME_INFO[aa]["torsion_mask"] for aa in seq]).to(torsions.device) # (L, 14) + torsion_mask = torsion_mask != torsion_mask # values that are nan need replace + # undesired outside of margins + torsion_mask[:, :start] = torsion_mask[:, start+torsions.shape[-1]:] = False + + angles_mask[1, torsion_mask] = torsions[ torsion_mask[:, start:start+torsions.shape[-1]] ] + return angles_mask + + +def modify_scaffolds_with_coords(scaffolds, coords): + """ Gets scaffolds and fills in the right data. + Inputs: + * scaffolds: dict. as returned by `build_scaffolds_from_scn_angles` + * coords: (L, 14, 3). sidechainnet tensor. same device as scaffolds + Outputs: corrected scaffolds + """ + + + # calculate distances and update: + # N, CA, C + scaffolds["bond_mask"][1:, 0] = torch.norm(coords[1:, 0] - coords[:-1, 2], dim=-1) # N + scaffolds["bond_mask"][ :, 1] = torch.norm(coords[ :, 1] - coords[: , 0], dim=-1) # CA + scaffolds["bond_mask"][ :, 2] = torch.norm(coords[ :, 2] - coords[: , 1], dim=-1) # C + # O, CB, side chain + selector = np.arange(len(coords)) + for i in range(3, 14): + # get indexes + idx_a, idx_b, idx_c = scaffolds["point_ref_mask"][:, :, i-3] # (3, L, 11) -> 3 * (L, 11) + # correct distances + scaffolds["bond_mask"][:, i] = torch.norm(coords[:, i] - coords[selector, idx_c], dim=-1) + # get angles + scaffolds["angles_mask"][0, :, i] = get_angle(coords[selector, idx_b], + coords[selector, idx_c], + coords[:, i]) + # handle C-beta, where the C requested is from the previous aa + if i == 4: + # for 1st residue, use position of the second residue's N + first_next_n = coords[1, :1] # 1, 3 + # the c requested is from the previous residue + main_c_prev_idxs = coords[selector[:-1], idx_a[1:]]# (L-1), 3 + # concat + coords_a = torch.cat([first_next_n, main_c_prev_idxs]) + else: + coords_a = coords[selector, idx_a] + # get dihedrals + scaffolds["angles_mask"][1, :, i] = get_dihedral(coords_a, + coords[selector, idx_b], + coords[selector, idx_c], + coords[:, i]) + # correct angles and dihedrals for backbone + scaffolds["angles_mask"][0, :-1, 0] = get_angle(coords[:-1, 1], coords[:-1, 2], coords[1: , 0]) # ca_c_n + scaffolds["angles_mask"][0, 1:, 1] = get_angle(coords[:-1, 2], coords[1:, 0], coords[1: , 1]) # c_n_ca + scaffolds["angles_mask"][0, :, 2] = get_angle(coords[:, 0], coords[ :, 1], coords[ : , 2]) # n_ca_c + + # N determined by previous psi = f(n, ca, c, n+1) + scaffolds["angles_mask"][1, :-1, 0] = get_dihedral(coords[:-1, 0], coords[:-1, 1], coords[:-1, 2], coords[1:, 0]) + # CA determined by omega = f(ca, c, n+1, ca+1) + scaffolds["angles_mask"][1, 1:, 1] = get_dihedral(coords[:-1, 1], coords[:-1, 2], coords[1:, 0], coords[1:, 1]) + # C determined by phi = f(c-1, n, ca, c) + scaffolds["angles_mask"][1, 1:, 2] = get_dihedral(coords[:-1, 2], coords[1:, 0], coords[1:, 1], coords[1:, 2]) + + return scaffolds + + +################################## +####### MAIN FUNCTION ############ +################################## + + +def protein_fold(cloud_mask, point_ref_mask, angles_mask, bond_mask, + device=torch.device("cpu"), hybrid=False): + """ Calcs coords of a protein given it's + sequence and internal angles. + Inputs: + * cloud_mask: (L, 14) mask of points that should be converted to coords + * point_ref_mask: (3, L, 11) maps point (except n-ca-c) to idxs of + previous 3 points in the coords array + * angles_mask: (2, 14, L) maps point to theta and dihedral + * bond_mask: (L, 14) gives the length of the bond originating that atom + + Output: (L, 14, 3) and (L, 14) coordinates and cloud_mask + """ + # automatic type (float, mixed, double) and size detection + precise = bond_mask.dtype + length = cloud_mask.shape[0] + # create coord wrapper + coords = torch.zeros(length, 14, 3, device=device, dtype=precise) + + # do first AA + coords[0, 1] = coords[0, 0] + torch.tensor([1, 0, 0], device=device, dtype=precise) * BB_BUILD_INFO["BONDLENS"]["n-ca"] + coords[0, 2] = coords[0, 1] + torch.tensor([torch.cos(np.pi - angles_mask[0, 0, 2]), + torch.sin(np.pi - angles_mask[0, 0, 2]), + 0.], device=device, dtype=precise) * BB_BUILD_INFO["BONDLENS"]["ca-c"] + + # starting positions (in the x,y plane) and normal vector [0,0,1] + init_a = repeat(torch.tensor([1., 0., 0.], device=device, dtype=precise), 'd -> l d', l=length) + init_b = repeat(torch.tensor([1., 1., 0.], device=device, dtype=precise), 'd -> l d', l=length) + # do N -> CA. don't do 1st since its done already + thetas, dihedrals = angles_mask[:, :, 1] + coords[1:, 1] = mp_nerf_torch(init_a, + init_b, + coords[:, 0], + bond_mask[:, 1], + thetas, dihedrals)[1:] + # do CA -> C. don't do 1st since its done already + thetas, dihedrals = angles_mask[:, :, 2] + coords[1:, 2] = mp_nerf_torch(init_b, + coords[:, 0], + coords[:, 1], + bond_mask[:, 2], + thetas, dihedrals)[1:] + # do C -> N + thetas, dihedrals = angles_mask[:, :, 0] + coords[:, 3] = mp_nerf_torch(coords[:, 0], + coords[:, 1], + coords[:, 2], + bond_mask[:, 0], + thetas, dihedrals) + + ######### + # sequential pass to join fragments + ######### + # part of rotation mat corresponding to origin - 3 orthogonals + mat_origin = get_axis_matrix(init_a[0], init_b[0], coords[0, 0], norm=False) + # part of rotation mat corresponding to destins || a, b, c = CA, C, N+1 + # (L-1) since the first is in the origin already + mat_destins = get_axis_matrix(coords[:-1, 1], coords[:-1, 2], coords[:-1, 3]) + + # get rotation matrices from origins + # https://math.stackexchange.com/questions/1876615/rotation-matrix-from-plane-a-to-b + rotations = torch.matmul(mat_origin.t(), mat_destins) + rotations /= torch.norm(rotations, dim=-1, keepdim=True) + + # do rotation concatenation - do for loop in cpu always - faster + rotations = rotations.cpu() if coords.is_cuda and hybrid else rotations + for i in range(1, length-1): + rotations[i] = torch.matmul(rotations[i], rotations[i-1]) + rotations = rotations.to(device) if coords.is_cuda and hybrid else rotations + # rotate all + coords[1:, :4] = torch.matmul(coords[1:, :4], rotations) + # offset each position by cumulative sum at that position + coords[1:, :4] += torch.cumsum(coords[:-1, 3], dim=0).unsqueeze(-2) + + + ######### + # parallel sidechain - do the oxygen, c-beta and side chain + ######### + for i in range(3,14): + level_mask = cloud_mask[:, i] + thetas, dihedrals = angles_mask[:, level_mask, i] + idx_a, idx_b, idx_c = point_ref_mask[:, level_mask, i-3] + + # to place C-beta, we need the carbons from prev res - not available for the 1st res + if i == 4: + # the c requested is from the previous residue - offset boolean mask by one + # can't be done with slicing bc glycines are inside chain (dont have cb) + coords_a = coords[(level_mask.nonzero().view(-1) - 1), idx_a] # (L-1), 3 + # if first residue is not glycine, + # for 1st residue, use position of the second residue's N (1,3) + if level_mask[0].item(): + coords_a[0] = coords[1, 1] + else: + coords_a = coords[level_mask, idx_a] + + coords[level_mask, i] = mp_nerf_torch(coords_a, + coords[level_mask, idx_b], + coords[level_mask, idx_c], + bond_mask[level_mask, i], + thetas, dihedrals) + + return coords, cloud_mask + + +def sidechain_fold(wrapper, cloud_mask, point_ref_mask, angles_mask, bond_mask, + device=torch.device("cpu"), c_beta=False): + """ Calcs coords of a protein given it's sequence and internal angles. + Inputs: + * wrapper: (L, 14, 3). coords container with backbone ([:, :3]) and optionally + c_beta ([:, 4]) + * cloud_mask: (L, 14) mask of points that should be converted to coords + * point_ref_mask: (3, L, 11) maps point (except n-ca-c) to idxs of + previous 3 points in the coords array + * angles_mask: (2, 14, L) maps point to theta and dihedral + * bond_mask: (L, 14) gives the length of the bond originating that atom + * c_beta: whether to place cbeta + + Output: (L, 14, 3) and (L, 14) coordinates and cloud_mask + """ + precise = wrapper.dtype + + # parallel sidechain - do the oxygen, c-beta and side chain + for i in range(3,14): + # skip cbeta if arg is set + if i == 4 and not isinstance(c_beta, str): + continue + # prepare inputs + level_mask = cloud_mask[:, i] + thetas, dihedrals = angles_mask[:, level_mask, i] + idx_a, idx_b, idx_c = point_ref_mask[:, level_mask, i-3] + + # to place C-beta, we need the carbons from prev res - not available for the 1st res + if i == 4: + # the c requested is from the previous residue - offset boolean mask by one + # can't be done with slicing bc glycines are inside chain (dont have cb) + coords_a = wrapper[(level_mask.nonzero().view(-1) - 1), idx_a] # (L-1), 3 + # if first residue is not glycine, + # for 1st residue, use position of the second residue's N (1,3) + if level_mask[0].item(): + coords_a[0] = wrapper[1, 1] + else: + coords_a = wrapper[level_mask, idx_a] + + wrapper[level_mask, i] = mp_nerf_torch(coords_a, + wrapper[level_mask, idx_b], + wrapper[level_mask, idx_c], + bond_mask[level_mask, i], + thetas, dihedrals) + + return wrapper, cloud_mask + + +############################## +####### XTENSION ############ +############################## + + +# inspired by: https://www.biorxiv.org/content/10.1101/2021.08.02.454840v1 +def ca_from_angles(angles, bond_len=3.80): + """ Builds a C-alpha trace from a set of 2 angles (theta, chi). + Inputs: + * angles: (B, L, 4): float tensor. (cos, sin) · (theta, chi) + angles in point-in-unit-circumference format. + Outputs: (B, L, 3) coords for c-alpha trace + """ + device = angles.device + length = angles.shape[-2] + frames = [ torch.repeat_interleave( + torch.eye(3, device=device, dtype=torch.float).unsqueeze(0), + angles.shape[0], + dim=0 + )] + + rot_mats = torch.stack([ + torch.stack([ angles[...,0] * angles[...,2], angles[...,0] * angles[...,3], -angles[...,1] ], dim=-1), + torch.stack([ -angles[...,3] , angles[...,2] , angles[...,0]*0. ], dim=-1), + torch.stack([ angles[...,1] * angles[...,2], angles[...,1] * angles[...,3], angles[...,0] ], dim=-1), + ], dim=-2) # (B, L, 3, 3) + + # iterative update of frames - skip last frame. + for i in range(length-1): + frames.append( rot_mats[:, i] @ frames[i] ) # could do frames[-1] as well + frames = torch.stack(frames, dim=1) # (B, L, 3, 3) + + ca_trace = bond_len * frames[..., -1, :].cumsum(dim=-2) # (B, L, 3) + + return ca_trace, frames + + +# inspired by: https://github.com/psipred/DMPfold2/blob/master/dmpfold/network.py#L139 +def ca_bb_fold(ca_trace): + """ Calcs a backbone given the coordinate trace of the CAs. + Inputs: + * ca_trace: (B, L, 3) float tensor with CA coordinates. + Outputs: (B, L, 14, 3) (-N-CA(-CB-...)-C(=O)-) + """ + wrapper = torch.zeros(ca_trace.shape[0], ca_trace.shape[1]+2, 14, 3, device=ca_trace.device) + wrapper[:, 1:-1, 1] = ca_trace + # Place dummy extra Cα atoms on extremenes to get the required vectors + vecs = ca_trace[ :, [0, 2, -1, -3] ] - ca_trace[ :, [1, 1, -2, -2] ] # (B, 4, 3) + wrapper[:, 0, 1] = ca_trace[:, 0] + 3.80 * F.normalize(torch.cross(vecs[:, 0], vecs[:, 1]), dim=-1) + wrapper[:, -1, 1] = ca_trace[:, -1] + 3.80 * F.normalize(torch.cross(vecs[:, 2], vecs[:, 3]), dim=-1) + + # place N and C term + vec_ca_can = wrapper[:, :-2, 1] - wrapper[:, 1:-1, 1] + vec_ca_cac = wrapper[:, 2: , 1] - wrapper[:, 1:-1, 1] + mid_ca_can = (wrapper[:, 1:, 1] + wrapper[:, :-1, 1]) / 2 + cross_vcan_vcac = F.normalize(torch.cross(vec_ca_can, vec_ca_cac, dim=-1), dim=-1) + wrapper[:, 1:-1, 0] = mid_ca_can[:, :-1] - vec_ca_can / 7.5 + cross_vcan_vcac / 3.33 + # placve all C but last, which is special + wrapper[:, 1:-2, 2] = (mid_ca_can[:, :-1] + vec_ca_can / 8 - cross_vcan_vcac / 2.5)[:, 1:] + wrapper[:, -2, 2] = mid_ca_can[:, -1, :] - vec_ca_cac[:, -1, :] / 8 + cross_vcan_vcac[:, -1, :] / 2.5 + + return wrapper[:, 1:-1] + + + +############################ +####### METRICS ############ +############################ + + +def get_protein_metrics( + true_coords, + pred_coords, + cloud_mask = None, + return_aligned = True, + detach = None + ): + """ Calculates many metrics for protein structure quality. + Aligns coordinates. + Inputs: + * true_coords: (B, L, 14, 3) unaligned coords (B = 1) + * pred_coords: (B, L, 14, 3) unaligned coords (B = 1) + * cloud_mask: (B, L, 14) bool. gotten from pred_coords if not passed + * return_aligned: bool. whether to return aligned structs. + * detach: bool. whether to detach inputs before compute. saves mem + Outputs: dict (k,v) + """ + metric_dict = { + "rmsd": rmsd_torch, + "drmsd": drmsd_torch, + # not implemented yet + # "gdt_ts": partial(GDT, mode="TS"), + # "gdt_ha": partial(GDT, mode="HA"), + # "tmscore": tmscore_torch, + # "lddt": lddt_torch, + } + + if detach: + true_coords = true_coords.detach() + pred_coords = pred_coords.detach() + + # clone so originals are not modified + true_coords = true_coords.clone() + pred_coords = pred_coords.clone() + cloud_mask = pred_coords.abs().sum(dim=-1).bool() * \ + true_coords.abs().sum(dim=-1).bool() # 1, L, 14 + chain_mask = cloud_mask.sum(dim=-1).bool() # 1, L + + true_aligned, pred_aligned = kabsch_torch( + pred_coords[cloud_mask].t(), true_coords[cloud_mask].t() + ) + # no need to rebuild true coords since unaffected by kabsch + true_coords[cloud_mask] = true_aligned.t() + pred_coords[cloud_mask] = pred_aligned.t() + + # compute metrics + outputs = {} + for k,f in metric_dict.items(): + # special. works only on ca trace + if k == "tmscore": + ca_trace = true_coords[:, :, 1].transpose(-1, -2) + ca_pred_trace = pred_coords[:, :, 1].transpose(-1, -2) + outputs[k] = f(ca_trace, ca_pred_trace) + # special. works on full prot + elif k == "lddt": + outputs[k] = f(true_coords[:, chain_mask[0]], pred_coords[:, chain_mask[0]], cloud_mask=cloud_mask) + # special. needs batch dim + elif "gdt" in k: + outputs[k] = f(true_aligned[None, ...], pred_aligned[None, ...]) + else: + outputs[k] = f(true_aligned, pred_aligned) + + if return_aligned: + outputs.update({ + "pred_align_wrap": pred_coords, + "true_align_wrap": true_coords, + }) + + return outputs + diff --git a/rgn2_replica/mp_nerf/utils.py b/rgn2_replica/mp_nerf/utils.py new file mode 100644 index 0000000..7e26f02 --- /dev/null +++ b/rgn2_replica/mp_nerf/utils.py @@ -0,0 +1,224 @@ +# Author: Eric Alcaide + +import torch +import numpy as np + + +# random hacks + +# to_pi_minus_pi(4) = -2.28 # to_pi_minus_pi(-4) = 2.28 # rads to pi-(-pi) +to_zero_two_pi = lambda x: ( x + (2*np.pi) * ( 1 + torch.floor_divide(x.abs(), 2*np.pi) ) ) % (2*np.pi) +def to_pi_minus_pi(x): + zero_two_pi = to_zero_two_pi(x) + return torch.where( + zero_two_pi < np.pi, zero_two_pi, -(2*np.pi - zero_two_pi) + ) + +@torch.jit.script +def cdist(x,y): + """ robust cdist - drop-in for pytorch's. + Inputs: + * x, y: (B, N, D) + """ + return torch.pow( + x.unsqueeze(-3) - y.unsqueeze(-2), 2 + ).sum(dim=-1).clamp(min=1e-7).sqrt() + +# data utils +def get_prot(dataloader_=None, vocab_=None, min_len=80, max_len=150, + verbose=True, subset="train", xray_filter=False, full_mask=True): + """ Gets a protein from sidechainnet and returns + the right attrs for training. + Inputs: + * dataloader_: sidechainnet iterator over dataset + * vocab_: sidechainnet VOCAB class + * min_len: int. minimum sequence length + * max_len: int. maximum sequence length + * verbose: bool. verbosity level + * subset: str. which subset to load proteins from. + * xray_filter: bool. whether to return only xray structures. + * mask_tol: bool or int. bool: whether to return seqs with unknown coords. + int: number of minimum label positions + Outputs: (cleaned, without padding) + (seq_str, int_seq, coords, angles, padding_seq, mask, pid) + """ + if xray_filter: + raise NotImplementedError + + while True: + for b,batch in enumerate(dataloader_[subset]): + for i in range(batch.int_seqs.shape[0]): + # skip too short + if batch.int_seqs[i].shape[0] < min_len: + continue + + # strip padding - matching angles to string means + # only accepting prots with no missing residues (mask is 0) + padding_seq = (batch.int_seqs[i] == 20).sum().item() + padding_mask = -(batch.msks[i] - 1).sum().item() # find 0s + + if (full_mask and padding_seq == padding_mask) or \ + (full_mask is not True and batch.int_seqs[i].shape[0] - full_mask > 0): + # check for appropiate length + real_len = batch.int_seqs[i].shape[0] - padding_seq + if max_len >= real_len >= min_len: + # strip padding tokens + seq = batch.str_seqs[i] # seq is already unpadded - see README at scn repo + int_seq = batch.int_seqs[i][:-padding_seq or None] + angles = batch.angs[i][:-padding_seq or None] + mask = batch.msks[i][:-padding_seq or None] + coords = batch.crds[i][:-padding_seq*14 or None] + + if verbose: + print("stopping at sequence of length", real_len) + + yield seq, int_seq, coords, angles, padding_seq, mask, batch.pids[i] + else: + if verbose: + print("found a seq of length:", batch.int_seqs[i].shape, + "but oustide the threshold:", min_len, max_len) + else: + if verbose: + print("paddings not matching", padding_seq, padding_mask) + pass + return None + + +###################### +## structural utils ## +###################### + +def get_dihedral(c1, c2, c3, c4): + """ Returns the dihedral angle in radians. + Will use atan2 formula from: + https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics + Inputs: + * c1: (batch, 3) or (3,) + * c2: (batch, 3) or (3,) + * c3: (batch, 3) or (3,) + * c4: (batch, 3) or (3,) + """ + u1 = c2 - c1 + u2 = c3 - c2 + u3 = c4 - c3 + + return torch.atan2( ( (torch.norm(u2, dim=-1, keepdim=True) * u1) * torch.cross(u2,u3, dim=-1) ).sum(dim=-1) , + ( torch.cross(u1,u2, dim=-1) * torch.cross(u2, u3, dim=-1) ).sum(dim=-1) ) + + +def get_cosine_angle(c1, c2, c3, eps=1e-7): + """ Returns the angle in radians. Uses cosine formula + Not all angles are possible all the time. + Inputs: + * c1: (batch, 3) or (3,) + * c2: (batch, 3) or (3,) + * c3: (batch, 3) or (3,) + """ + u1 = c2 - c1 + u2 = c3 - c2 + + return torch.acos( (u1*u2).sum(dim=-1) / (u1.norm(dim=-1)*u2.norm(dim=-1) + eps)) + + +def get_angle(c1, c2, c3): + """ Returns the angle in radians. + Inputs: + * c1: (batch, 3) or (3,) + * c2: (batch, 3) or (3,) + * c3: (batch, 3) or (3,) + """ + u1 = c2 - c1 + u2 = c3 - c2 + + # dont use acos since norms involved. + # better use atan2 formula: atan2(cross, dot) from here: + # https://johnblackburne.blogspot.com/2012/05/angle-between-two-3d-vectors.html + + # add a minus since we want the angle in reversed order - sidechainnet issues + return torch.atan2( torch.norm(torch.cross(u1,u2, dim=-1), dim=-1), + -(u1*u2).sum(dim=-1) ) + + +def kabsch_torch(X, Y): + """ Kabsch alignment of X into Y. + Assumes X,Y are both (D, N) - usually (3, N) + """ + # center X and Y to the origin + X_ = X - X.mean(dim=-1, keepdim=True) + Y_ = Y - Y.mean(dim=-1, keepdim=True) + # calculate convariance matrix (for each prot in the batch) + C = torch.matmul(X_, Y_.t()) + # Optimal rotation matrix via SVD - warning! W must be transposed + if int(torch.__version__.split(".")[1]) < 8: + V, S, W = torch.svd(C.detach()) + W = W.t() + else: + V, S, W = torch.linalg.svd(C.detach()) + # determinant sign for direction correction + d = (torch.det(V) * torch.det(W)) < 0.0 + if d: + S[-1] = S[-1] * (-1) + V[:, -1] = V[:, -1] * (-1) + # Create Rotation matrix U + U = torch.matmul(V, W) + # calculate rotations + X_ = torch.matmul(X_.t(), U).t() + # return centered and aligned + return X_, Y_ + + +def rmsd_torch(X, Y): + """ Assumes x,y are both (batch, d, n) - usually (batch, 3, N). """ + return torch.sqrt( torch.mean((X - Y)**2, axis=(-1, -2)) ) + + +def drmsd_torch(X, Y): + """ Assumes x,y are both (B x D x N). See below for wrapper. """ + X_ = X.transpose(-1, -2) + Y_ = Y.transpose(-1, -2) + x_dist = cdist(X_, X_) # (B, N, N) + y_dist = cdist(Y_, Y_) # (B, N, N) + + return torch.sqrt( torch.pow(x_dist-y_dist, 2).mean(dim=(-1, -2)).clamp(min=1e-7) ) + + +def ensure_chirality(coords_wrapper, use_backbone=True): + """ Ensures protein agrees with natural distribution + of chiral bonds (ramachandran plots). + Reflects ( (-1)*Z ) the ones that do not. + Inputs: + * coords_wrapper: (B, L, C, 3) float tensor. First 3 atoms + in C should be N-CA-C + * use_backbone: bool. whether to use the backbone (better, more robust) + if provided, or just use c-alphas. + Ouputs: (B, L, C, 3) + """ + + # detach gradients for angle calculation - mirror selection + coords_wrapper_ = coords_wrapper.detach() + mask = coords_wrapper_.abs().sum(dim=(-1, -2)) != 0. + + # if BB present: use bb dihedrals + if coords_wrapper[:, :, 0].abs().sum() != 0. and use_backbone: + # compute phis for every protein in the batch + phis = get_dihedral( + coords_wrapper_[:, :-1, 2], # C_{i-1} + coords_wrapper_[:, 1: , 0], # N_{i} + coords_wrapper_[:, 1: , 1], # CA_{i} + coords_wrapper_[:, 1: , 2], # C_{i} + ) + + # get proportion of negatives + props = [(phis[i, mask[i, :-1]] > 0).float().mean() for i in range(mask.shape[0])] + + # fix mirrors by (-1)*Z if more (+) than (-) phi angles + corrector = torch.tensor([ [1, 1, -1 if p > 0.5 else 1] # (B, 3) + for p in props ], dtype=coords_wrapper.dtype) + + return coords_wrapper * corrector.to(coords_wrapper.device)[:, None, None, :] + else: + return coords_wrapper + + + + diff --git a/rgn2_replica/rgn2.py b/rgn2_replica/rgn2.py index e313092..19aa472 100644 --- a/rgn2_replica/rgn2.py +++ b/rgn2_replica/rgn2.py @@ -1,4 +1,4 @@ -# Author: Eric Alcaide ( @hypnopump ) +# Author: Eric Alcaide ( @hypnopump ) import os import sys from typing import Optional, Tuple, List @@ -11,10 +11,11 @@ from x_transformers import XTransformer, Encoder from einops import rearrange, repeat # custom -import mp_nerf +from rgn2_replica import mp_nerf from rgn2_replica.utils import * # refiners import en_transformer +from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix import invariant_point_attention @@ -22,6 +23,15 @@ #### USEFUL PIECES #### ####################### +def exists(val): + return val is not None + + +def init_zero_(layer): + torch.nn.init.constant_(layer.weight, 0.) + if exists(layer.bias): + torch.nn.init.constant_(layer.bias, 0.) + @torch.jit.script def prediction_wrapper(x: torch.Tensor, pred: torch.Tensor): """ Facilitates recycling. Inputs the original input + prediction @@ -122,6 +132,112 @@ def pred_post_process(points_preds: torch.Tensor, return points_preds, ca_trace_pred, frames_preds, wrapper_pred +def rotations2angles(rotations: torch.Tensor): + # ref to eq.20 in paper: https://arxiv.org/pdf/1102.5658.pdf + # input: (B, L, 3, 3) = (B, L, [t, n, b]) + length = rotations.shape[1] + points_preds = torch.zeros(*rotations.shape[:-2], 2, 2, device=rotations.device) + points_preds[:, 0, [0, 1], 0] = 1. + points_preds[:, 0, [0, 1], 1] = 0. + points_preds[:, 1, 1, 0] = 1. + points_preds[:, 1, 1, 1] = 0. + + for i in range(length - 1): + # cos(theta) = t_{i+1} * t_i + points_preds[:, i+1, 0, 0] = \ + torch.einsum('b i, b i -> b', rotations[:, i, :, 0], rotations[:, i+1, :, 0]) + # sin(theta) = -n_{i+1} * t_i + points_preds[:, i+1, 0, 1] = \ + -torch.einsum('b i, b i -> b', rotations[:, i, :, 0], rotations[:, i+1, :, 1]) + if i > 0: + # cos(chi) = b_{i+1} * b_i + points_preds[:, i+1, 1, 0] = \ + torch.einsum('b i, b i -> b', rotations[:, i, :, 2], rotations[:, i+1, :, 2]) + # sin(chi) = -b_{i+1} * n_i + points_preds[:, i+1, 1, 1] = \ + -torch.einsum('b i, b i -> b', rotations[:, i, :, 1], rotations[:, i+1, :, 2]) + + return points_preds + + +def pred_post_process_ipa(coords_preds: torch.Tensor, + rotations: torch.Tensor, + seq_list: Optional[List] = None, + mask: Optional[torch.Tensor] = None, + model = None, + refine_args = {}): + """ Converts an angle-based output to structures. + Inputs: + * coords_preds: (B, L, 3) + * seq_list: (B,) list of str. FASTA sequences. Optional. build scns + * mask: (B, L) bool tensor. + * model: subclass of torch.nn.Module. prediction model w/ potential refiner + * model_args: dict. arguments to pass to model for refinement + Outputs: + * points_preds: (B, L, 2, 2) + * ca_trace_pred: (B, L, 14, 3) + * frames_preds: (B, L, 3, 3) + * wrapper_pred: (B, L, 14, 3) + """ + device = coords_preds.device + if mask is None: + mask = torch.ones(coords_preds.shape[:-2], dtype=torch.bool) + lengths = mask.sum(dim=-1).cpu().detach().tolist() + + frames_preds = rotations # simply forward it + + # restate first values to known ones (1st angle, 1s + 2nd dihedral) + points_preds = rotations2angles(rotations) + + # rebuild ca trace with angles - norm vectors to ensure mod=1. - (B, L, 14, 3) + ca_trace_pred = torch.zeros(*coords_preds.shape[:2], 14, 3, device=device) + ca_trace_pred[:, :, 1] = coords_preds + # delete extra part and chirally reflect + ca_trace_pred_aux = torch.zeros_like(ca_trace_pred) + for i in range(coords_preds.shape[0]): + ca_trace_pred_aux[i, :lengths[i]] = ca_trace_pred_aux[i, :lengths[i]] + \ + mp_nerf.utils.ensure_chirality(ca_trace_pred[i:i+1, :lengths[i]]) + ca_trace_pred = ca_trace_pred_aux + + # use model's refiner if available + if model is not None: + if model.refiner is not None: + for i in range(mask.shape[0]): + adj_mat = torch.from_numpy( + np.eye(lengths[i], k=1) + np.eye(lengths[i], k=1).T + ).bool().to(device).unsqueeze(0) + + coors = ca_trace_pred[i:i+1, :mask[i].shape[-1], 1].clone() + coors = coors.detach() if model.refiner.refiner_detach else coors + feats, coors, r_iters = model.refiner( + feats=refine_args[model.refiner.feats_inputs][i:i+1, :lengths[i]], # embeddings + coors=coors, + adj_mat=adj_mat, + recycle=refine_args["recycle"], + inter_recycle=refine_args["inter_recycle"], + ) + ca_trace_pred[i:i+1, :lengths[i], 1] = coors + + # calc BB - can't do batched bc relies on extremes. + wrapper_pred = torch.zeros_like(ca_trace_pred) + for i in range(coords_preds.shape[0]): + wrapper_pred[i, :lengths[i]] = mp_nerf.proteins.ca_bb_fold( + ca_trace_pred[i:i+1, :lengths[i], 1] + ) + if seq_list is not None: + # solve backbone steric clashes + wrapper_pred[i, :lengths[i]] = mp_nerf.ml_utils.backbone_forcefield( + coords=wrapper_pred[i, :lengths[i]], coeffs=[3, 5, 3, 1], lr=1e-2 + ) + # build sidechains + scaffolds = mp_nerf.proteins.build_scaffolds_from_scn_angles(seq=seq_list[i], device=device) + wrapper_pred[i, :lengths[i]], _ = mp_nerf.proteins.sidechain_fold( + wrapper_pred[i, :lengths[i]], **scaffolds, c_beta="backbone" + ) + + return points_preds, ca_trace_pred, frames_preds, wrapper_pred + + class SqReLU(torch.jit.ScriptModule): r""" Squared ReLU activation from https://arxiv.org/abs/2109.08668v1. """ @@ -441,6 +557,8 @@ def __init__(self, embedding_dim=1280, hidden=[512], mlp_hidden=[128, 4], torch.nn.Linear(self.mlp_hidden[0], self.mlp_hidden[-1]) ) + self.refiner = None # to be implemented + def forward(self, x, mask : Optional[torch.Tensor] = None, recycle:int = 1, inter_recycle:bool = False): @@ -858,3 +976,149 @@ def forward(self, **data_dict): +class RGN2_IPA(torch.nn.Module): + def __init__(self, embedding_dim=1280, hidden=[512], mlp_hidden=[128, 4], + act="silu", structure_module_depth=8, predict_points=False, x_transformer_config={ + "depth": 8, + "heads": 4, + "attn_dim_head": 64, + # "attn_num_mem_kv": 16, # 16 memory key / values + "use_scalenorm": True, # set to true to use for all layers + "ff_glu": True, # set to true to use for all feedforwards + "attn_collab_heads": True, + "attn_collab_compression": .3, + "cross_attend": False, + "gate_values": True, # gate aggregated values with the input" + # "sandwich_coef": 6, # interleave attention and feedforwards with sandwich coefficient of 6 + "rotary_pos_emb": True # turns on rotary positional embeddings" + } + ): + """ Transformer drop-in for RGN2-LSTM. + Inputs: + * layers: int. number of rnn layers + * mlp_hidden: list of ints. + """ + super(RGN2_IPA, self).__init__() + act_types = { + "relu": torch.nn.ReLU, + "silu": torch.nn.SiLU, + } + # store params + self.embedding_dim = embedding_dim + self.hidden = hidden + self.mlp_hidden = mlp_hidden + self.structure_module_depth = structure_module_depth + self.predict_points = predict_points + + # declare layers + """ Declares an XTransformer model. + * No decoder, just predict embeddings + * project with a lst_mlp + + """ + self.to_latent = torch.nn.Linear(self.embedding_dim, self.hidden[0]) + self.transformer = Encoder( + dim= self.hidden[-1], + + **x_transformer_config + ) + self.last_mlp = torch.nn.Sequential( + torch.nn.Linear(self.hidden[-1], self.mlp_hidden[0]), + act_types[act](), + torch.nn.Linear(self.mlp_hidden[0], self.mlp_hidden[-1]) + ) + + """ + IPA stuff + """ + with torch_default_dtype(torch.float32): + self.ipa_block = invariant_point_attention.IPABlock( + dim=self.embedding_dim, + heads=4, #structure_module_heads, + require_pairwise_repr=False + ) + + self.to_quaternion_update = torch.nn.Linear(self.embedding_dim, 6) + + init_zero_(self.ipa_block.attn.to_out) + + self.to_points = torch.nn.Linear(self.embedding_dim, 3) + + + self.refiner = None # to be implemented + + + def forward(self, x, mask : Optional[torch.Tensor] = None, + recycle:int = 1, inter_recycle:bool = False): + """ Inputs: + * x (B, L, Emb_dim) + Outputs: (B, L, 4). + + """ + # same input for both rgn2-lstm and transformer, so mask angles + r_iters = [] # todo: implement this + x_buffer = x.clone() if recycle > 1 else x # buffer for recycling + x[..., -4:] = 0. + + b, n, device = *x.shape[:2], x.device + + with torch_default_dtype(torch.float32): + quaternions = torch.tensor([1., 0., 0., 0.], device=device) + quaternions = repeat(quaternions, 'd -> b n d', b=b, n=n) + translations = torch.zeros((b, n, 3), device=device) + + # go through the layers and apply invariant point attention and feedforward + + for i in range(self.structure_module_depth): + is_last = i == (self.structure_module_depth - 1) + + # the detach comes from + # https://github.com/deepmind/alphafold/blob/0bab1bf84d9d887aba5cfb6d09af1e8c3ecbc408/alphafold/model/folding.py#L383 + rotations = quaternion_to_matrix(quaternions) + + if not is_last: + rotations = rotations.detach() + + x = self.ipa_block( + x, + mask=mask, + # pairwise_repr=pairwise_repr, + rotations=rotations, + translations=translations + ) + + # update quaternion and translation + + quaternion_update, translation_update = self.to_quaternion_update(x).chunk(2, dim=-1) + quaternion_update = F.pad(quaternion_update, (1, 0), value=1.) + + quaternions = quaternion_multiply(quaternions, quaternion_update) + translations = translations + torch.einsum('b n c, b n c r -> b n r', translation_update, rotations) + + points_local = self.to_points(x) + rotations = quaternion_to_matrix(quaternions) + x_pred = torch.einsum('b n c, b n c d -> b n d', points_local, rotations) + translations + + x_pred = x_pred.type(x.dtype).to(x.device) + # todo: support the inter_recycle option + r_iters = \ + torch.empty(x.shape[0], recycle - 1, device=x.device) # (B, recycle-1, L, 4) + + if not self.predict_points: + # todo: + return x_pred, r_iters, rotations, translations + + return x_pred, r_iters + + + def predict_fold(self, x, mask : Optional[torch.Tensor] = None, + recycle:int = 1, inter_recycle:bool = False): + """ Predicts all angles at once so no need for AR prediction. + Same inputs / outputs than + """ + with torch.no_grad(): + return self.forward( + x=x, mask=mask, + recycle=recycle, inter_recycle=inter_recycle + ) + diff --git a/rgn2_replica/rgn2_trainers.py b/rgn2_replica/rgn2_trainers.py index 2838ab3..eea8177 100644 --- a/rgn2_replica/rgn2_trainers.py +++ b/rgn2_replica/rgn2_trainers.py @@ -2,14 +2,7 @@ import time import gc -import random -import numpy as np -import torch -from einops import rearrange, repeat -from functools import partial - -import mp_nerf from rgn2_replica.rgn2 import * from rgn2_replica.utils import * from rgn2_replica.rgn2_utils import * @@ -47,7 +40,7 @@ def batched_inference(*args, model, embedder, # create scaffolds int_seq = torch.ones(batch_dim, max_seq_len, dtype=torch.long) * 20 # padding tok # mask is true mask. long mask is for lstm - mask, long_mask = torch.zeros(2, *int_seq.shape, dtype=torch.bool) + mask, long_mask = torch.zeros(2, *int_seq.shape, dtype=torch.bool, device=device) true_coords = torch.zeros(int_seq.shape[0], int_seq.shape[1]*14, 3, device=device) # fill scaffolds for i,arg in enumerate(args): @@ -59,14 +52,14 @@ def batched_inference(*args, model, embedder, mask = mask.bool().to(device) coords = rearrange(true_coords, 'b (l c) d -> b l c d', c=14) ca_trace = coords[..., 1, :] - coords_rebuilt = mp_nerf.proteins.ca_bb_fold( ca_trace ) # beware extremes + # coords_rebuilt = mp_nerf.proteins.ca_bb_fold( ca_trace ) # beware extremes # calc angle labels angles_label_ = torch.zeros(*ca_trace.shape[:-1], 2, dtype=torch.float, device=device) angles_mask_ = torch.zeros_like(angles_label_).bool() # propagate mask to angles w/ missing points for i, arg in enumerate(args): length = arg[1].shape[-1] - angles_label_[i, 1:length-1, 0] = mp_nerf.utils.get_cosine_angle( + angles_label_[i, 1:length-1, 0] = mp_nerf.utils.get_cosine_angle( ca_trace[i, :length-2 , :], ca_trace[i, 1:length-1, :], ca_trace[i, 2:length , :], @@ -87,6 +80,7 @@ def batched_inference(*args, model, embedder, # later don't count them # angles_label_[~angles_mask_] = 0. angles_label_[angles_label_ != angles_label_] = 0. + print(angles_label_.shape) points_label = mp_nerf.ml_utils.angle_to_point_in_circum(angles_label_) # (B, L, 2, 2) # include angles of previous AA as input @@ -116,21 +110,39 @@ def batched_inference(*args, model, embedder, # PREDICT if mode in ["train", "test", "fast_test"]: # get angles - preds, r_iters = model.forward(embedds, mask=long_mask, - recycle=recycle_func(None)) # (B, L, 4) - points_preds = rearrange(preds, '... (a d) -> ... a d', a=2) # (B, L, 2, 2) - - # POST-PROCESS - points_preds, ca_trace_pred, frames_preds, wrapper_pred = pred_post_process( - points_preds, mask=long_mask, # long_mask == True for all seq_len - # seq_list = None, # don't fold sidechain - model=model, refine_args={ - "embedds": embedds, - "int_seq": int_seq.to(device), - "recycle": recycle_func(None), - "inter_recycle": False, - } - ) + refiner_type = config.refiner_args["refiner_type"] + if refiner_type == "En": + preds, r_iters = model.forward(embedds, mask=long_mask, + recycle=recycle_func(None)) # (B, L, 4) + + points_preds = rearrange(preds, '... (a d) -> ... a d', a=2) # (B, L, 2, 2) + + # POST-PROCESS + points_preds, ca_trace_pred, frames_preds, wrapper_pred = pred_post_process( + points_preds, mask=long_mask, # long_mask == True for all seq_len + # seq_list = None, # don't fold sidechain + model=model, refine_args={ + "embedds": embedds, + "int_seq": int_seq.to(device), + "recycle": recycle_func(None), + "inter_recycle": False, + } + ) + elif refiner_type == "IPA": # IPA returns coords + preds, r_iters, rotations, translations = model.forward(embedds, mask=long_mask, + recycle=recycle_func(None)) #  (B, L, 4) + points_preds, ca_trace_pred, frames_preds, wrapper_pred = pred_post_process_ipa( + preds, rotations, mask=long_mask, # long_mask == True for all seq_len + # seq_list = None, # don't fold sidechain + model=model, refine_args={ + "embedds": embedds, + "int_seq": int_seq.to(device), + "recycle": recycle_func(None), + "inter_recycle": False, + } + ) + else: + raise NotImplementedError("refiner types besides En/IPA are not supported.") # get frames (for labels) for for later fape bb_ca_trace_rebuilt, frames_labels = mp_nerf.proteins.ca_from_angles( @@ -188,7 +200,7 @@ def inference(*args, model, embedder, long_mask = torch.ones_like(mask) coords = rearrange(true_coords, '(l c) d -> () l c d', c=14).to(device) ca_trace = coords[..., 1, :] - coords_rebuilt = mp_nerf.proteins.ca_bb_fold( ca_trace ) + coords_rebuilt = mp_nerf.proteins.ca_bb_fold(ca_trace) # mask for thetas and chis angles_label_ = torch.zeros(*ca_trace.shape[:-1], 2, dtype=torch.float, device=device) angles_mask_ = torch.zeros_like(angles_label_).bool() @@ -203,12 +215,12 @@ def inference(*args, model, embedder, ca_trace[..., 2:-1, :], ca_trace[..., 3: , :], ) - angles_mask_[..., 1:-1, 0] = ( - mask[i, :-2] * mask[i, 1:-1] * mask[i, 2:] - ) - angles_mask_[i, 2:-1, 0] = ( - mask[i, :-3] * mask[i, 1:-2] * mask[i, 2:-1], mask[i, 3:] - ) + # angles_mask_[..., 1:-1, 0] = ( + # mask[i, :-2] * mask[i, 1:-1] * mask[i, 2:] + # ) + # angles_mask_[i, 2:-1, 0] = ( + # mask[i, :-3] * mask[i, 1:-2] * mask[i, 2:-1], mask[i, 3:] + # ) # replace nan and (angles whose coords are not fully known) by 0. # angles_label_[~angles_mask_] = 0. angles_label_[angles_label_ != angles_label_] = 0. @@ -249,7 +261,7 @@ def inference(*args, model, embedder, ) # get frames for for later fape - bb_ca_trace_rebuilt, frames_labels = mp_nerf.proteins.ca_from_angles( + bb_ca_trace_rebuilt, frames_labels = mp_nerf.proteins.ca_from_angles( points_label.reshape(1, -1, 4) # (B, L, 2, 2) -> (B, L, 4) ) @@ -333,7 +345,7 @@ def predict(get_prot_, steps, model, embedder, return_preds=True, # violation loss btween calphas - L1 dist_mat = mp_nerf.utils.cdist(infer["wrapper_pred"][:, :, 1], - infer["wrapper_pred"][:, :, 1],) # B, L, L + infer["wrapper_pred"][:, :, 1], ) # B, L, L dist_mat[:, np.arange(dist_mat.shape[-1]), np.arange(dist_mat.shape[-1])] = \ dist_mat[:, np.arange(dist_mat.shape[-1]), np.arange(dist_mat.shape[-1])] + 5. viol_loss = -(dist_mat - 3.78).clamp(min=-np.inf, max=0.) @@ -443,7 +455,7 @@ def train(get_prot_, steps, model, embedder, optim, loss_f=None, # violation loss btween calphas - L1 dist_mat = mp_nerf.utils.cdist(infer["wrapper_pred"][:, :, 1], - infer["wrapper_pred"][:, :, 1],) # B, L, L + infer["wrapper_pred"][:, :, 1], ) # B, L, L dist_mat = dist_mat + torch.eye(dist_mat.shape[-1]).unsqueeze(0).to(dist_mat)*5. viol_loss = -(dist_mat - 3.78).clamp(min=-np.inf, max=0.).contiguous() diff --git a/rgn2_replica/utils.py b/rgn2_replica/utils.py index f96978b..fbb657f 100644 --- a/rgn2_replica/utils.py +++ b/rgn2_replica/utils.py @@ -3,6 +3,7 @@ import math import torch import numpy as np +import contextlib # random hacks - device utils for pyTorch - saves transfers @@ -46,7 +47,12 @@ def set_seed(seed, verbose=False): print("Seet seed to {0}".format(seed)) - +@contextlib.contextmanager +def torch_default_dtype(dtype): + prev_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(prev_dtype) diff --git a/scripts/rgn2_predict_fold.py b/scripts/rgn2_predict_fold.py index 9d5567f..e4e83d7 100644 --- a/scripts/rgn2_predict_fold.py +++ b/scripts/rgn2_predict_fold.py @@ -1,17 +1,16 @@ # Author: Eirc Alcaide (@hypnopump) +<<<<<<< HEAD +======= import os import re import json import numpy as np import torch +>>>>>>> 9c60d1ddc49a5b9dd73937ad6f0c9e21f8bf8867 # process import argparse -import joblib -from tqdm import tqdm # custom -import esm import sidechainnet -import mp_nerf from rgn2_replica import * from rgn2_replica.embedders import * from rgn2_replica.rgn2_refine import * @@ -118,7 +117,6 @@ # refine structs if args.rosetta_refine: from typing import Optional - import pyrosetta for i, seq in enumerate(seq_list): # only refine diff --git a/scripts/train_rgn2.py b/scripts/train_rgn2.py index fbdce59..1ca6db3 100644 --- a/scripts/train_rgn2.py +++ b/scripts/train_rgn2.py @@ -1,24 +1,19 @@ import os import json import argparse -import random -import numpy as np -import wandb -import torch -import esm import sidechainnet from sidechainnet.utils.sequence import ProteinVocabulary as VOCAB +import sys +sys.path.append("..") + # IMPORTED ALSO IN LATER MODULES VOCAB = VOCAB() -import mp_nerf from rgn2_replica.rgn2_trainers import * from rgn2_replica.embedders import * -from rgn2_replica import set_seed, RGN2_Naive - - +from rgn2_replica import set_seed, RGN2_Naive, mp_nerf def parse_arguments(): @@ -34,9 +29,9 @@ def parse_arguments(): # data params parser.add_argument("--min_len", help="Min seq len, for train", type=int, default=0) parser.add_argument("--min_len_valid", help="Min seq len, for valid", type=int, default=0) - parser.add_argument("--max_len", help="Max seq len", type=int, default=512) - parser.add_argument("--casp_version", help="SCN dataset version", type=int, default=12) - parser.add_argument("--scn_thinning", help="SCN dataset thinning", type=int, default=90) + parser.add_argument("--max_len", help="Max seq len", type=int, default=128)#512) + parser.add_argument("--casp_version", help="SCN dataset version", type=int, default=7) + parser.add_argument("--scn_thinning", help="SCN dataset thinning", type=int, default=30) parser.add_argument("--xray", help="only use xray structures", type=bool, default=0) parser.add_argument("--frac_true_torsions", help="Provide right torsions for some prots", type=bool, default=0) parser.add_argument("--full_mask", help="require full mask in proteins", type=bool, default=1) @@ -53,7 +48,8 @@ def parse_arguments(): parser.add_argument("--num_recycles_train", type=int, default=3, help="number of recycling iters. set to 1 to speed training.",) # refiner params - parser.add_argument("--refiner_args", help="args for refiner module", type=json.loads, default={}) + parser.add_argument("--refiner_args", help="args for refiner module", type=json.loads, + default={"refiner_type": "En"}) parser.add_argument("--seed", help="Random seed", default=101) return parser.parse_args() @@ -145,17 +141,20 @@ def run_train_schedule(dataloaders, embedder, config, args): embedder = embedder.to(device) set_seed(config.seed) - model = RGN2_Naive(layers=config.num_layers, - emb_dim=config.emb_dim+4, - hidden=config.hidden, - bidirectional=config.bidirectional, - mlp_hidden=config.mlp_hidden, - act=config.act, - layer_type=config.layer_type, - input_dropout=config.input_dropout, - angularize=config.angularize, - refiner_args=config.refiner_args, - ).to(device) + # model = RGN2_Naive(layers=config.num_layers, + # emb_dim=config.emb_dim+4, + # hidden=config.hidden, + # bidirectional=config.bidirectional, + # mlp_hidden=config.mlp_hidden, + # act=config.act, + # layer_type=config.layer_type, + # input_dropout=config.input_dropout, + # angularize=config.angularize, + # refiner_args=config.refiner_args, + # ).to(device) + model = RGN2_IPA( + embedding_dim=config.emb_dim+4, + ).to(device) if args.resume_name is not None: model.load_my_state_dict(torch.load(args.resume_name, map_location=device)) @@ -326,9 +325,10 @@ def get_training_schedule(args): loss_f = " metrics['drmsd'].mean() / len(infer['seq']) " # steps, ckpt, lr , bs , max_len, clip, loss_f - return [[32000, 135 , 1e-3, 16 , args.max_len, None, loss_f, 42 , ], - [64000, 135 , 1e-3, 32 , args.max_len, None, loss_f, 42 , ], + return [[32000, 135 , 1e-4, 16 , args.max_len, None, loss_f, 42 , ], + [64000, 135 , 1e-4, 32 , args.max_len, None, loss_f, 42 , ], [32000, 135 , 1e-4, 32 , args.max_len, None, loss_f, 42 , ],] + # return [[32, 2, 1e-3, 16, args.max_len, None, loss_f, 42, ]] if __name__ == '__main__': diff --git a/setup.py b/setup.py index f39dab7..9e5a113 100644 --- a/setup.py +++ b/setup.py @@ -22,14 +22,16 @@ 'sidechainnet', 'proDy', 'tqdm', - 'mp-nerf', + # 'mp-nerf', 'en-transformer>=0.5.0', 'datasets>=1.10', 'transformers>=4.2', 'x-transformers>=0.16.1', 'pytorch-lightning>=1.4', 'wandb', - 'fair-esm>=0.4.0' + 'fair-esm>=0.4.0', + 'pytorch3d', + 'invariant_point_attention' ], setup_requires=[ 'pytest-runner',