diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index ecb8e1c3bc22..2a33cf4d9d28 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -860,7 +860,7 @@ class MCC(EvalMetric): .. note:: - This version of MCC only supports binary classification. + This version of MCC only supports binary classification. See PCC. Parameters ---------- @@ -1476,6 +1476,136 @@ def update(self, labels, preds): self.global_num_inst += 1 +@register +class PCC(EvalMetric): + """PCC is a multiclass equivalent for the Matthews correlation coefficient derived + from a discrete solution to the Pearson correlation coefficient. + + .. math:: + \\text{PCC} = \\frac {\\sum _{k}\\sum _{l}\\sum _{m}C_{kk}C_{lm}-C_{kl}C_{mk}} + {{\\sqrt {\\sum _{k}(\\sum _{l}C_{kl})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{k'l'})}} + {\\sqrt {\\sum _{k}(\\sum _{l}C_{lk})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{l'k'})}}} + + defined in terms of a K x K confusion matrix C. + + When there are more than two labels the PCC will no longer range between -1 and +1. + Instead the minimum value will be between -1 and 0 depending on the true distribution. + The maximum value is always +1. + + Parameters + ---------- + name : str + Name of this metric instance for display. + output_names : list of str, or None + Name of predictions that should be used when updating with update_dict. + By default include all predictions. + label_names : list of str, or None + Name of labels that should be used when updating with update_dict. + By default include all labels. + + Examples + -------- + >>> # In this example the network almost always predicts positive + >>> false_positives = 1000 + >>> false_negatives = 1 + >>> true_positives = 10000 + >>> true_negatives = 1 + >>> predicts = [mx.nd.array( + [[.3, .7]]*false_positives + + [[.7, .3]]*true_negatives + + [[.7, .3]]*false_negatives + + [[.3, .7]]*true_positives + )] + >>> labels = [mx.nd.array( + [0]*(false_positives + true_negatives) + + [1]*(false_negatives + true_positives) + )] + >>> f1 = mx.metric.F1() + >>> f1.update(preds = predicts, labels = labels) + >>> pcc = mx.metric.PCC() + >>> pcc.update(preds = predicts, labels = labels) + >>> print f1.get() + ('f1', 0.95233560306652054) + >>> print pcc.get() + ('pcc', 0.01917751877733392) + """ + def __init__(self, name='pcc', + output_names=None, label_names=None, + has_global_stats=True): + self.k = 2 + super(PCC, self).__init__( + name=name, output_names=output_names, label_names=label_names, + has_global_stats=has_global_stats) + + def _grow(self, inc): + self.lcm = numpy.pad( + self.lcm, ((0, inc), (0, inc)), 'constant', constant_values=(0)) + self.gcm = numpy.pad( + self.gcm, ((0, inc), (0, inc)), 'constant', constant_values=(0)) + self.k += inc + + def _calc_mcc(self, cmat): + n = cmat.sum() + x = cmat.sum(axis=1) + y = cmat.sum(axis=0) + cov_xx = numpy.sum(x * (n - x)) + cov_yy = numpy.sum(y * (n - y)) + if cov_xx == 0 or cov_yy == 0: + return float('nan') + i = cmat.diagonal() + cov_xy = numpy.sum(i * n - x * y) + return cov_xy / (cov_xx * cov_yy) ** 0.5 + + def update(self, labels, preds): + """Updates the internal evaluation result. + + Parameters + ---------- + labels : list of `NDArray` + The labels of the data. + + preds : list of `NDArray` + Predicted values. + """ + labels, preds = check_label_shapes(labels, preds, True) + + # update the confusion matrix + for label, pred in zip(labels, preds): + label = label.astype('int32').asnumpy() + pred = pred.argmax(axis=1).astype('int32').asnumpy() + n = max(pred.max(), label.max()) + if n >= self.k: + self._grow(n + 1 - self.k) + bcm = numpy.zeros((self.k, self.k)) + for i, j in zip(pred, label): + bcm[i, j] += 1 + self.lcm += bcm + self.gcm += bcm + + self.num_inst += 1 + self.global_num_inst += 1 + + @property + def sum_metric(self): + return self._calc_mcc(self.lcm) * self.num_inst + + @property + def global_sum_metric(self): + return self._calc_mcc(self.gcm) * self.global_num_inst + + def reset(self): + """Resets the internal evaluation result to initial state.""" + self.global_num_inst = 0. + self.gcm = numpy.zeros((self.k, self.k)) + self.reset_local() + + def reset_local(self): + """Resets the local portion of the internal evaluation results + to initial state.""" + self.num_inst = 0. + self.lcm = numpy.zeros((self.k, self.k)) + + @register class Loss(EvalMetric): """Dummy metric for directly printing loss. diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index 2821c4bbae3c..4b5b2047d595 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -34,6 +34,7 @@ def test_metrics(): check_metric('mcc') check_metric('perplexity', -1) check_metric('pearsonr') + check_metric('pcc') check_metric('nll_loss') check_metric('loss') composite = mx.metric.create(['acc', 'f1']) @@ -89,6 +90,7 @@ def test_global_metric(): _check_global_metric('mcc', shape=(10,2), average='micro') _check_global_metric('perplexity', -1) _check_global_metric('pearsonr', use_same_shape=True) + _check_global_metric('pcc', shape=(10,2)) _check_global_metric('nll_loss') _check_global_metric('loss') _check_global_metric('ce') @@ -253,6 +255,76 @@ def test_pearsonr(): _, pearsonr = metric.get() assert pearsonr == pearsonr_expected +def cm_batch(cm): + # generate a batch yielding a given confusion matrix + n = len(cm) + # identity matrix + ident = [ + [ 1.0 if i == j else 0.0 for i in range(n) ] + for j in range(n) + ] + labels = [] + preds = [] + for i in range(n): + for j in range(n): + labels += [ i ] * cm[i][j] + preds += [ ident[j] ] * cm[i][j] + return ([ mx.nd.array(labels, dtype='int32') ], [ mx.nd.array(preds) ]) + +def test_pcc(): + labels, preds = cm_batch([ + [ 7, 3 ], + [ 2, 5 ], + ]) + met_pcc = mx.metric.create('pcc') + met_pcc.update(labels, preds) + _, pcc = met_pcc.get() + + # pcc should agree with mcc for binary classification + met_mcc = mx.metric.create('mcc') + met_mcc.update(labels, preds) + _, mcc = met_mcc.get() + np.testing.assert_almost_equal(pcc, mcc) + + # pcc should agree with Pearson for binary classification + met_pear = mx.metric.create('pearsonr') + met_pear.update(labels, [p.argmax(axis=1) for p in preds]) + _, pear = met_pear.get() + np.testing.assert_almost_equal(pcc, pear) + + # check multiclass case against reference implementation + met_pcc.reset() + CM = [ + [ 23, 13, 3 ], + [ 7, 19, 11 ], + [ 2, 5, 17 ], + ] + K = 3 + ref = sum( + CM[k][k] * CM[l][m] - CM[k][l] * CM[m][k] + for k in range(K) + for l in range(K) + for m in range(K) + ) / (sum( + sum(CM[k][l] for l in range(K)) * sum( + sum(CM[f][g] for g in range(K)) + for f in range(K) + if f != k + ) + for k in range(K) + ) * sum( + sum(CM[l][k] for l in range(K)) * sum( + sum(CM[f][g] for f in range(K)) + for g in range(K) + if g != k + ) + for k in range(K) + )) ** 0.5 + labels, preds = cm_batch(CM) + met_pcc.update(labels, preds) + _, pcc = met_pcc.get() + np.testing.assert_almost_equal(pcc, ref) + def test_single_array_input(): pred = mx.nd.array([[1,2,3,4]]) label = pred + 0.1