-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
Copy pathtest_rec_nme.py
71 lines (67 loc) · 2.27 KB
/
test_rec_nme.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import cv2
import sys
import numpy as np
import os
import mxnet as mx
import datetime
import img_helper
from config import config
from data import FaceSegIter
from metric import LossValueMetric, NMEMetric
parser = argparse.ArgumentParser(description='test nme on rec data')
# general
parser.add_argument('--rec',
default='./data_2d/ibug.rec',
help='rec data path')
parser.add_argument('--prefix', default='', help='model prefix')
parser.add_argument('--epoch', type=int, default=1, help='model epoch')
parser.add_argument('--gpu', type=int, default=0, help='')
parser.add_argument('--landmark-type', default='2d', help='')
parser.add_argument('--image-size', type=int, default=128, help='')
args = parser.parse_args()
rec_path = args.rec
ctx_id = args.gpu
prefix = args.prefix
epoch = args.epoch
image_size = (args.image_size, args.image_size)
config.landmark_type = args.landmark_type
config.input_img_size = image_size[0]
if ctx_id >= 0:
ctx = mx.gpu(ctx_id)
else:
ctx = mx.cpu()
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers['heatmap_output']
#model = mx.mod.Module(symbol=sym, context=ctx, data_names=['data'], label_names=['softmax_label'])
model = mx.mod.Module(symbol=sym,
context=ctx,
data_names=['data'],
label_names=None)
#model = mx.mod.Module(symbol=sym, context=ctx)
model.bind(for_training=False,
data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)
val_iter = FaceSegIter(
path_imgrec=rec_path,
batch_size=1,
aug_level=0,
)
_metric = NMEMetric()
#val_metric = mx.metric.create(_metric)
#val_metric.reset()
#val_iter.reset()
nme = []
for i, eval_batch in enumerate(val_iter):
if i % 10 == 0:
print('processing', i)
#print(eval_batch.data[0].shape, eval_batch.label[0].shape)
batch_data = mx.io.DataBatch(eval_batch.data)
model.forward(batch_data, is_train=False)
#model.update_metric(val_metric, eval_batch.label, True)
pred_label = model.get_outputs()[-1].asnumpy()
label = eval_batch.label[0].asnumpy()
_nme = _metric.cal_nme(label, pred_label)
nme.append(_nme)
print(np.mean(nme))