-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
39 lines (33 loc) · 1.45 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
from utils.helper_functions import decode_pred
def test(model, test_dataset, thresh=0.6):
"""
This function is used for inference, it takes input, test_dataset and thresh as an input.
It output list of numpy arrays prediction as well as ground truth for the provided test_dataset
for each class i.e color and state.
This data is used in main to provide classification report i.e precision, recall and f1 score.
"""
prediction_color = []
prediction_state = []
gt_color = []
gt_state = []
for batch in test_dataset:
pred = model(batch["img"].unsqueeze(0))
pred_color = pred["color"].sigmoid().cpu().detach().numpy()
pred_color = np.where(pred_color <= thresh, 0.0, 1.0)
pred_state = pred["state"].sigmoid().cpu().detach().numpy()
pred_state = np.where(pred_state <= thresh, 0.0, 1.0)
prediction_color.append(pred_color[0])
gt_color.append(batch["labels"]["color_labels"].cpu().numpy())
prediction_state.append(pred_state[0])
gt_state.append(batch["labels"]["state_labels"].cpu().numpy())
return prediction_color, gt_color, prediction_state, gt_state
def infer_single_image(model, data_attrib, img):
"""
This function infers on a single image and returns decoded prediction i.e in the same form as
mentioned in json.
"""
img = img.unsqueeze(0)
model.eval()
pred = model(img)
print(decode_pred(pred, data_attrib))