-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
64 lines (45 loc) · 1.61 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 7 08:40:11 2023
@author: Ariane Djeupang
"""
import argparse
from predict_helper import load_checkpoint
from predict_helper import predict
import json
import torch
## Define our parameters
parser = argparse.ArgumentParser(description='Predicting a flower name from an image')
parser.add_argument('path_to_image', action="store")
parser.add_argument('checkpoint', action="store")
parser.add_argument('--top_k', action="store",type=int, dest="top_k", default=5)
parser.add_argument('--category_names', action="store", dest="category_names", default="")
parser.add_argument('--gpu', action="store_true", dest="gpu", default=False)
results = parser.parse_args()
## the prediction results
path_to_image = results.path_to_image
checkpoint = results.checkpoint
top_k = results.top_k
category_names = results.category_names
gpu = results.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
## Load the checkpoint
model = load_checkpoint(checkpoint)
model.to(device);
## Make predictions
class_to_idx = model.class_to_idx
probs, classes = predict(path_to_image, model, top_k)
if not category_names:
print(classes)
print(probs)
else:
name_classes = []
with open(category_names, 'r') as f:
cat_to_name = json.load(f)
for i in classes:
for clas,index in class_to_idx.items():
if index == i:
flower_key = clas
name_classes.append(cat_to_name[flower_key])
print(name_classes)
print(probs)