-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
68 lines (50 loc) · 2.04 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from argparse import ArgumentParser, BooleanOptionalAction
import torch
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
from utils.model import UNet
from utils import config
def log(message, dots=True):
message = f'[INFO] {message}'
if dots:
message += '...'
print(message)
def predict(model: UNet, img: np.ndarray):
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
img = cv.resize(img, (128, 128))
img = img.astype(np.float32) / 255.
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, 0)
img = torch.from_numpy(img).to(config.DEVICE)
predicted_mask = model(img).squeeze()
predicted_mask = torch.sigmoid(predicted_mask)
predicted_mask = predicted_mask.cpu().detach().numpy()
# Filter out the weak predictions and convert them to integers
predicted_mask = (predicted_mask > config.THRESHOLD) * 255
predicted_mask = predicted_mask.astype(np.uint8)
return predicted_mask
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--img', type=str, required=True,
help='Your image path to make prediction on it')
parser.add_argument('--model', type=str, default='UNet_tgs_salt.pth',
help='Your model path (default: UNet_tgs_salt.pth)')
parser.add_argument('--show', type=bool, action=BooleanOptionalAction,
default=True, help='Show the predicted mask')
parser.add_argument('--save', type=bool, action=BooleanOptionalAction,
default=True, help='Save the predicted mask')
opt = parser.parse_args()
log('Loading up the UNet model')
unet: UNet = torch.load(opt.model).to(config.DEVICE)
img = cv.imread(opt.img)
log('Make predictions')
prediction = predict(unet, img)
if opt.show:
cv.imshow('predictions', prediction)
cv.waitKey(0)
if opt.save:
os.makedirs('results', exist_ok=True)
cv.imwrite('results/result.png', prediction)
log('Predicted mask saved.', dots=False)