Skip to content

Commit

Permalink
Add ranking test.
Browse files Browse the repository at this point in the history
  • Loading branch information
jbenjoseph committed Sep 11, 2024
1 parent 15896ad commit d9a1c37
Showing 1 changed file with 179 additions and 23 deletions.
202 changes: 179 additions & 23 deletions notebooks/BindingAffinityML.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -31,9 +31,11 @@
"# Third-party imports for numerical operations and machine learning\n",
"import joblib\n",
"import numpy as np\n",
"import polars as pl\n",
"import torch\n",
"from scipy.stats import spearmanr, kendalltau\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import roc_auc_score\n",
"from sklearn.metrics import roc_auc_score, mean_squared_error \n",
"from sklearn.model_selection import train_test_split\n",
"from tqdm.auto import tqdm\n",
"from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType\n",
Expand All @@ -51,7 +53,7 @@
"protein_dim = 480\n",
"\n",
"# Only test on n examples, set to None to test on all examples\n",
"test_only_n_examples = 2000"
"test_only_n_examples = None"
]
},
{
Expand Down Expand Up @@ -141,12 +143,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cbf5232d90934fbea25012f32cf82f81",
"model_id": "9804ea1b3150445392c11600f5be871f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating protein embeddings: 0%| | 0/2000 [00:00<?, ? proteins/s]"
"Generating protein embeddings: 0%| | 0/591469 [00:00<?, ? proteins/s]"
]
},
"metadata": {},
Expand Down Expand Up @@ -217,12 +219,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dafd8e83a2964348987f6e66fc8590ce",
"model_id": "821fc6404aa841499edbdb9cef102f80",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating ligand embeddings: 0%| | 0/2000 [00:00<?, ? ligand/s]"
"Generating ligand embeddings: 0%| | 0/591469 [00:00<?, ? ligand/s]"
]
},
"metadata": {},
Expand Down Expand Up @@ -291,7 +293,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"ROC-AUC score: 0.8919\n"
"ROC-AUC score: 0.9863\n"
]
}
],
Expand Down Expand Up @@ -328,18 +330,116 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0ea6fc0c551f40168caf9839aa0c3b20",
"model_id": "c21c7e4b15b444e8b4ad25e63c176df6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating ligand embeddings: 0%| | 0/4 [00:00<?, ? ligand/s]"
"Generating ligand embeddings: 0%| | 0/26 [00:00<?, ? ligand/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "85d81a7037e44b3184902706ba08a607",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating ligand embeddings: 0%| | 0/87 [00:00<?, ? ligand/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d01d90ac44994ecb92c1592db3531cc4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating ligand embeddings: 0%| | 0/41 [00:00<?, ? ligand/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ca20906fca25450e8c144a38cc79ff90",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating ligand embeddings: 0%| | 0/51 [00:00<?, ? ligand/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "433e3baf84d24fc78ad0fb3eb7acd6e0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating ligand embeddings: 0%| | 0/33 [00:00<?, ? ligand/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "37ec3a8b64064dc9be57bcaf5801ee7b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating ligand embeddings: 0%| | 0/34 [00:00<?, ? ligand/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e0cc424a28fa4e1a91c99bac3807938e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating ligand embeddings: 0%| | 0/24 [00:00<?, ? ligand/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5183eba2236e4c8cb7b75f77b92e4241",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating ligand embeddings: 0%| | 0/53 [00:00<?, ? ligand/s]"
]
},
"metadata": {},
Expand All @@ -349,10 +449,22 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Ligand: CCCC, Score: 0.5100\n",
"Ligand: CCC, Score: 0.4900\n",
"Ligand: CCO, Score: 0.4800\n",
"Ligand: CCN, Score: 0.4400\n"
"shape: (8, 5)\n",
"┌──────────────────────────┬────────────────────────┬─────────────┬────────────────┬───────────────┐\n",
"│ Target ┆ Predicted ┆ MSE ┆ Spearman's Rho ┆ Kendall's Tau │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ str ┆ str ┆ f64 ┆ f64 ┆ f64 │\n",
"╞══════════════════════════╪════════════════════════╪═════════════╪════════════════╪═══════════════╡\n",
"│ shp2 ┆ 10, 17, 1, 7, 16, 12 ┆ 88.461538 ┆ 0.213675 ┆ 0.144615 │\n",
"│ pfkfb3_automap ┆ 20, 68, 67, 76, 75, 25 ┆ 1655.563218 ┆ -0.31255 ┆ -0.232291 │\n",
"│ cdk8_5cei_new_helix_loop ┆ 1, 2, 7, 8, 9, 10 ┆ 119.219512 ┆ 0.574216 ┆ 0.429268 │\n",
"│ _extra ┆ ┆ ┆ ┆ │\n",
"│ hif2a_automap ┆ 25, 24, 10, 5, 40, 39 ┆ 533.490196 ┆ -0.231131 ┆ -0.162353 │\n",
"│ tnks2_fullmap ┆ 8, 1, 15, 6, 24, 4 ┆ 135.333333 ┆ 0.253676 ┆ 0.147727 │\n",
"│ eg5_extraprotomers ┆ 33, 23, 22, 20, 30, 31 ┆ 306.470588 ┆ -0.592055 ┆ -0.411765 │\n",
"│ cmet ┆ 7, 15, 2, 1, 20, 18 ┆ 95.583333 ┆ 0.002609 ┆ -0.014493 │\n",
"│ syk_4puz_fullmap ┆ 51, 43, 11, 40, 27, 28 ┆ 683.396226 ┆ -0.460248 ┆ -0.300435 │\n",
"└──────────────────────────┴────────────────────────┴─────────────┴────────────────┴───────────────┘\n"
]
}
],
Expand All @@ -371,16 +483,60 @@
" ranked_ligands = np.array(list_of_ligands)[ranked_indices]\n",
" ranked_scores = y_pred[ranked_indices]\n",
"\n",
" return ranked_ligands, ranked_scores\n",
" return ranked_indices, ranked_ligands, ranked_scores\n",
"\n",
"def calculate_metrics(true_ranks, predicted_ranks):\n",
" mse = mean_squared_error(true_ranks, predicted_ranks)\n",
" spearman_corr, _ = spearmanr(true_ranks, predicted_ranks)\n",
" kendall_tau, _ = kendalltau(true_ranks, predicted_ranks)\n",
" \n",
" return {\"MSE\": mse, \"Spearman's Rho\": spearman_corr, \"Kendall's Tau\": kendall_tau}\n",
"\n",
"def create_ranked_ligand_df(smiles_csv, rank_csv, fasta_file):\n",
" \"\"\"Create a DataFrame with ranked ligands and their scores.\"\"\"\n",
" smiles_df = pl.read_csv(smiles_csv, separator=\"\\t\") \n",
" rank_df = pl.read_csv(rank_csv, has_header=False, new_columns=[\"ligand\", \"rank\"])\n",
"\n",
" rank_df = rank_df.with_columns([\n",
" pl.col(\"rank\").str.strip_chars().cast(pl.Float64)\n",
" ])\n",
"\n",
" # Combine the two DataFrames on ligand\n",
" merged_df = smiles_df.join(rank_df, on=\"ligand\")\n",
"\n",
" # Load the target sequence from a FASTA file\n",
" with open(fasta_file) as f:\n",
" target_sequence = \"\".join(line.strip() for line in f.readlines()[1:])\n",
"\n",
" return merged_df, target_sequence\n",
"\n",
"\n",
"def walk_every_target_subdir_and_do_ranking(root_dir):\n",
" \"\"\"Walk through every subdirectory in the root directory and perform ranking.\"\"\"\n",
" for subdir in os.listdir(root_dir):\n",
" subdir_path = root_dir / subdir\n",
" if not subdir_path.is_dir():\n",
" continue\n",
"\n",
" smiles_csv = subdir_path / f\"{subdir}_smiles.csv\"\n",
" rank_csv = subdir_path / f\"{subdir}_rank.csv\"\n",
" fasta_file = subdir_path / f\"{subdir}_fasta\"\n",
"\n",
" test_df, target_sequence = create_ranked_ligand_df(smiles_csv, rank_csv, fasta_file)\n",
" ligands = test_df[\" Canonical SMILES \"]\n",
"\n",
" y_pred, ranked_ligands, ranked_scores = rank_binding_affinity(ligands, target_sequence)\n",
"\n",
" # y_true is merely the indices of the df\n",
" y_true = np.arange(len(ligands))\n",
"\n",
" result = {\"Target\": subdir, \"Predicted\": \", \".join([str(p) for p in y_pred[:6]])}\n",
" result.update(calculate_metrics(y_true, y_pred))\n",
"\n",
"# Test the ranking function\n",
"ligands = [\"CCO\", \"CCN\", \"CCC\", \"CCCC\"]\n",
"target_sequence = \"MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKR\"\n",
" yield result\n",
"\n",
"ranked_ligands, ranked_scores = rank_binding_affinity(ligands, target_sequence)\n",
"\n",
"for ligand, score in zip(ranked_ligands, ranked_scores):\n",
" print(f\"Ligand: {ligand}, Score: {score:.4f}\")"
"print(pl.DataFrame(walk_every_target_subdir_and_do_ranking(aiondata_path / \"schrodinger-fepp\")))"
]
},
{
Expand All @@ -392,7 +548,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit d9a1c37

Please sign in to comment.