Skip to content

Commit e8dca7f

Browse files
authored
Add files via upload
This 'make_predictions.py' code is used to estimate the mean absolute prediction errors.
1 parent d02aa75 commit e8dca7f

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

make_predictions.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from argparse import Namespace
2+
import csv
3+
from typing import List, Optional
4+
#import pandas as pd
5+
import numpy as np
6+
import torch
7+
from tqdm import tqdm
8+
9+
from .predict import predict
10+
from chemprop.data import MoleculeDataset
11+
from chemprop.data.utils import get_data, get_data_from_smiles
12+
from chemprop.utils import load_args, load_checkpoint, load_scalers
13+
14+
15+
def make_predictions(args: Namespace, smiles: List[str] = None) -> List[Optional[List[float]]]:
16+
"""
17+
Makes predictions. If smiles is provided, makes predictions on smiles. Otherwise makes predictions on args.test_data.
18+
19+
:param args: Arguments.
20+
:param smiles: Smiles to make predictions on.
21+
:return: A list of lists of target predictions.
22+
"""
23+
if args.gpu is not None:
24+
torch.cuda.set_device(args.gpu)
25+
26+
print('Loading training args')
27+
scaler, features_scaler = load_scalers(args.checkpoint_paths[0])
28+
train_args = load_args(args.checkpoint_paths[0])
29+
30+
# Update args with training arguments
31+
for key, value in vars(train_args).items():
32+
if not hasattr(args, key):
33+
setattr(args, key, value)
34+
35+
print('Loading data')
36+
if smiles is not None:
37+
test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False)
38+
else:
39+
test_data = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False)
40+
41+
print('Validating SMILES')
42+
valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None]
43+
full_data = test_data
44+
test_data = MoleculeDataset([test_data[i] for i in valid_indices])
45+
46+
# Edge case if empty list of smiles is provided
47+
if len(test_data) == 0:
48+
return [None] * len(full_data)
49+
50+
if args.use_compound_names:
51+
compound_names = test_data.compound_names()
52+
print(f'Test size = {len(test_data):,}')
53+
54+
# Normalize features
55+
if train_args.features_scaling:
56+
test_data.normalize_features(features_scaler)
57+
58+
# Predict with each model individually and sum predictions
59+
if args.dataset_type == 'multiclass':
60+
sum_preds = np.zeros((len(test_data), args.num_tasks, args.multiclass_num_classes))
61+
else:
62+
sum_preds = np.zeros((len(test_data), args.num_tasks))
63+
# all_preds = np.empty((len(test_data), args.num_tasks, 1))
64+
print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models')
65+
#
66+
# Modifications are made by RL to calculate standard deviations of all individual model predictions of an ensemble of models
67+
#
68+
iliu = 0
69+
for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)):
70+
# Load model
71+
iliu +=1
72+
model = load_checkpoint(checkpoint_path, cuda=args.cuda)
73+
model_preds = predict(
74+
model=model,
75+
data=test_data,
76+
batch_size=args.batch_size,
77+
scaler=scaler
78+
)
79+
sum_preds += np.array(model_preds)
80+
m_preds =np.array(model_preds)
81+
mm_preds = m_preds[:, :, np.newaxis]
82+
# df = pd.DataFrame({'smiles':test_data.smiles()})
83+
# for i in range(len(m_preds[0])):
84+
# df[f'pred_{i}'] = [item[i] for item in m_preds]
85+
# df.to_csv(f'./pred_out_{iliu}.csv', index=False)
86+
87+
if iliu==1:
88+
all_preds=mm_preds
89+
else:
90+
all_preds=np.concatenate((all_preds, mm_preds), axis=2)
91+
92+
# Ensemble predictions
93+
avg_preds = sum_preds / len(args.checkpoint_paths)
94+
avg_preds = avg_preds.tolist()
95+
std_preds = np.std(all_preds, axis=2)
96+
std_preds =std_preds.tolist()
97+
return avg_preds, test_data.smiles(), std_preds

0 commit comments

Comments
 (0)