|
| 1 | +from argparse import Namespace |
| 2 | +import csv |
| 3 | +from typing import List, Optional |
| 4 | +#import pandas as pd |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from tqdm import tqdm |
| 8 | + |
| 9 | +from .predict import predict |
| 10 | +from chemprop.data import MoleculeDataset |
| 11 | +from chemprop.data.utils import get_data, get_data_from_smiles |
| 12 | +from chemprop.utils import load_args, load_checkpoint, load_scalers |
| 13 | + |
| 14 | + |
| 15 | +def make_predictions(args: Namespace, smiles: List[str] = None) -> List[Optional[List[float]]]: |
| 16 | + """ |
| 17 | + Makes predictions. If smiles is provided, makes predictions on smiles. Otherwise makes predictions on args.test_data. |
| 18 | +
|
| 19 | + :param args: Arguments. |
| 20 | + :param smiles: Smiles to make predictions on. |
| 21 | + :return: A list of lists of target predictions. |
| 22 | + """ |
| 23 | + if args.gpu is not None: |
| 24 | + torch.cuda.set_device(args.gpu) |
| 25 | + |
| 26 | + print('Loading training args') |
| 27 | + scaler, features_scaler = load_scalers(args.checkpoint_paths[0]) |
| 28 | + train_args = load_args(args.checkpoint_paths[0]) |
| 29 | + |
| 30 | + # Update args with training arguments |
| 31 | + for key, value in vars(train_args).items(): |
| 32 | + if not hasattr(args, key): |
| 33 | + setattr(args, key, value) |
| 34 | + |
| 35 | + print('Loading data') |
| 36 | + if smiles is not None: |
| 37 | + test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False) |
| 38 | + else: |
| 39 | + test_data = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False) |
| 40 | + |
| 41 | + print('Validating SMILES') |
| 42 | + valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None] |
| 43 | + full_data = test_data |
| 44 | + test_data = MoleculeDataset([test_data[i] for i in valid_indices]) |
| 45 | + |
| 46 | + # Edge case if empty list of smiles is provided |
| 47 | + if len(test_data) == 0: |
| 48 | + return [None] * len(full_data) |
| 49 | + |
| 50 | + if args.use_compound_names: |
| 51 | + compound_names = test_data.compound_names() |
| 52 | + print(f'Test size = {len(test_data):,}') |
| 53 | + |
| 54 | + # Normalize features |
| 55 | + if train_args.features_scaling: |
| 56 | + test_data.normalize_features(features_scaler) |
| 57 | + |
| 58 | + # Predict with each model individually and sum predictions |
| 59 | + if args.dataset_type == 'multiclass': |
| 60 | + sum_preds = np.zeros((len(test_data), args.num_tasks, args.multiclass_num_classes)) |
| 61 | + else: |
| 62 | + sum_preds = np.zeros((len(test_data), args.num_tasks)) |
| 63 | +# all_preds = np.empty((len(test_data), args.num_tasks, 1)) |
| 64 | + print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models') |
| 65 | +# |
| 66 | +# Modifications are made by RL to calculate standard deviations of all individual model predictions of an ensemble of models |
| 67 | +# |
| 68 | + iliu = 0 |
| 69 | + for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)): |
| 70 | + # Load model |
| 71 | + iliu +=1 |
| 72 | + model = load_checkpoint(checkpoint_path, cuda=args.cuda) |
| 73 | + model_preds = predict( |
| 74 | + model=model, |
| 75 | + data=test_data, |
| 76 | + batch_size=args.batch_size, |
| 77 | + scaler=scaler |
| 78 | + ) |
| 79 | + sum_preds += np.array(model_preds) |
| 80 | + m_preds =np.array(model_preds) |
| 81 | + mm_preds = m_preds[:, :, np.newaxis] |
| 82 | +# df = pd.DataFrame({'smiles':test_data.smiles()}) |
| 83 | +# for i in range(len(m_preds[0])): |
| 84 | +# df[f'pred_{i}'] = [item[i] for item in m_preds] |
| 85 | +# df.to_csv(f'./pred_out_{iliu}.csv', index=False) |
| 86 | + |
| 87 | + if iliu==1: |
| 88 | + all_preds=mm_preds |
| 89 | + else: |
| 90 | + all_preds=np.concatenate((all_preds, mm_preds), axis=2) |
| 91 | + |
| 92 | + # Ensemble predictions |
| 93 | + avg_preds = sum_preds / len(args.checkpoint_paths) |
| 94 | + avg_preds = avg_preds.tolist() |
| 95 | + std_preds = np.std(all_preds, axis=2) |
| 96 | + std_preds =std_preds.tolist() |
| 97 | + return avg_preds, test_data.smiles(), std_preds |
0 commit comments