-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
metric.py
107 lines (99 loc) · 3.96 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import mxnet as mx
import numpy as np
import math
import cv2
from config import config
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric, self).__init__('lossvalue',
axis=self.axis,
output_names=None,
label_names=None)
self.losses = []
def update(self, labels, preds):
loss = preds[0].asnumpy()[0]
self.sum_metric += loss
self.num_inst += 1.0
class NMEMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(NMEMetric, self).__init__('NME',
axis=self.axis,
output_names=None,
label_names=None)
#self.losses = []
self.count = 0
def cal_nme(self, label, pred_label):
nme = []
for b in range(pred_label.shape[0]):
record = [None] * 6
item = []
if label.ndim == 4:
_heatmap = label[b][36]
if np.count_nonzero(_heatmap) == 0:
continue
else: #ndim==3
#print(label[b])
if np.count_nonzero(label[b]) == 0:
continue
for p in range(pred_label.shape[1]):
if label.ndim == 4:
heatmap_gt = label[b][p]
ind_gt = np.unravel_index(np.argmax(heatmap_gt, axis=None),
heatmap_gt.shape)
ind_gt = np.array(ind_gt)
else:
ind_gt = label[b][p]
#ind_gt = ind_gt.astype(np.int)
#print(ind_gt)
heatmap_pred = pred_label[b][p]
heatmap_pred = cv2.resize(
heatmap_pred,
(config.input_img_size, config.input_img_size))
ind_pred = np.unravel_index(np.argmax(heatmap_pred, axis=None),
heatmap_pred.shape)
ind_pred = np.array(ind_pred)
#print(ind_gt.shape)
#print(ind_pred)
if p == 36:
#print('b', b, p, ind_gt, np.count_nonzero(heatmap_gt))
record[0] = ind_gt
elif p == 39:
record[1] = ind_gt
elif p == 42:
record[2] = ind_gt
elif p == 45:
record[3] = ind_gt
if record[4] is None or record[5] is None:
record[4] = ind_gt
record[5] = ind_gt
else:
record[4] = np.minimum(record[4], ind_gt)
record[5] = np.maximum(record[5], ind_gt)
#print(ind_gt.shape, ind_pred.shape)
value = np.sqrt(np.sum(np.square(ind_gt - ind_pred)))
item.append(value)
_nme = np.mean(item)
if config.landmark_type == '2d':
left_eye = (record[0] + record[1]) / 2
right_eye = (record[2] + record[3]) / 2
_dist = np.sqrt(np.sum(np.square(left_eye - right_eye)))
#print('eye dist', _dist, left_eye, right_eye)
_nme /= _dist
else:
#_dist = np.sqrt(float(label.shape[2]*label.shape[3]))
_dist = np.sqrt(np.sum(np.square(record[5] - record[4])))
#print(_dist)
_nme /= _dist
nme.append(_nme)
return np.mean(nme)
def update(self, labels, preds):
self.count += 1
label = labels[0].asnumpy()
pred_label = preds[-1].asnumpy()
nme = self.cal_nme(label, pred_label)
#print('nme', nme)
#nme = np.mean(nme)
self.sum_metric += np.mean(nme)
self.num_inst += 1.0