From c0b08f2e8804c79ca65ca0ac76794438fb9c7a3c Mon Sep 17 00:00:00 2001 From: Robert Stone Date: Fri, 3 May 2019 15:23:14 -0700 Subject: [PATCH] multiclass-mcc metric enhancements * Rename metric from "PCC" to "mMCC" because though the math is derived from Pearson CC, it's utility is as a multiclass extension of Mathews CC. * Harden mx.metric.mMCC.update to more variations of input format, similar to mx.metric.Accuracy.update. --- python/mxnet/metric.py | 28 +++++++++++++---------- tests/python/unittest/test_metric.py | 34 ++++++++++++++-------------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 4bfd1b36bf89..fd85797e3ad9 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -860,7 +860,7 @@ class MCC(EvalMetric): .. note:: - This version of MCC only supports binary classification. See PCC. + This version of MCC only supports binary classification. See mMCC. Parameters ---------- @@ -1477,18 +1477,18 @@ def update(self, labels, preds): @register -class PCC(EvalMetric): - """PCC is a multiclass equivalent for the Matthews correlation coefficient derived +class mMCC(EvalMetric): + """mMCC is a multiclass equivalent for the Matthews correlation coefficient derived from a discrete solution to the Pearson correlation coefficient. .. math:: - \\text{PCC} = \\frac {\\sum _{k}\\sum _{l}\\sum _{m}C_{kk}C_{lm}-C_{kl}C_{mk}} + \\text{mMCC} = \\frac {\\sum _{k}\\sum _{l}\\sum _{m}C_{kk}C_{lm}-C_{kl}C_{mk}} {{\\sqrt {\\sum _{k}(\\sum _{l}C_{kl})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{k'l'})}} {\\sqrt {\\sum _{k}(\\sum _{l}C_{lk})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{l'k'})}}} defined in terms of a K x K confusion matrix C. - When there are more than two labels the PCC will no longer range between -1 and +1. + When there are more than two labels the mMCC will no longer range between -1 and +1. Instead the minimum value will be between -1 and 0 depending on the true distribution. The maximum value is always +1. @@ -1522,18 +1522,18 @@ class PCC(EvalMetric): )] >>> f1 = mx.metric.F1() >>> f1.update(preds = predicts, labels = labels) - >>> pcc = mx.metric.PCC() - >>> pcc.update(preds = predicts, labels = labels) + >>> mmcc = mx.metric.mMCC() + >>> mmcc.update(preds = predicts, labels = labels) >>> print f1.get() ('f1', 0.95233560306652054) - >>> print pcc.get() - ('pcc', 0.01917751877733392) + >>> print mmcc.get() + ('mmcc', 0.01917751877733392) """ - def __init__(self, name='pcc', + def __init__(self, name='mmcc', output_names=None, label_names=None, has_global_stats=True): self.k = 2 - super(PCC, self).__init__( + super(mMCC, self).__init__( name=name, output_names=output_names, label_names=label_names, has_global_stats=has_global_stats) @@ -1572,7 +1572,11 @@ def update(self, labels, preds): # update the confusion matrix for label, pred in zip(labels, preds): label = label.astype('int32', copy=False).asnumpy() - pred = pred.asnumpy().argmax(axis=1) + pred = pred.asnumpy() + if pred.shape != label.shape: + pred = pred.argmax(axis=1) + else: + pred = pred.astype('int32', copy=False) n = max(pred.max(), label.max()) if n >= self.k: self._grow(n + 1 - self.k) diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index d8dca753bda4..ed1a922ae60f 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -34,7 +34,7 @@ def test_metrics(): check_metric('mcc') check_metric('perplexity', -1) check_metric('pearsonr') - check_metric('pcc') + check_metric('mmcc') check_metric('nll_loss') check_metric('loss') composite = mx.metric.create(['acc', 'f1']) @@ -90,7 +90,7 @@ def test_global_metric(): _check_global_metric('mcc', shape=(10,2), average='micro') _check_global_metric('perplexity', -1) _check_global_metric('pearsonr', use_same_shape=True) - _check_global_metric('pcc', shape=(10,2)) + _check_global_metric('mmcc', shape=(10,2)) _check_global_metric('nll_loss') _check_global_metric('loss') _check_global_metric('ce') @@ -267,26 +267,26 @@ def cm_batch(cm): preds += [ ident[j] ] * cm[i][j] return ([ mx.nd.array(labels, dtype='int32') ], [ mx.nd.array(preds) ]) -def test_pcc(): +def test_mmcc(): labels, preds = cm_batch([ [ 7, 3 ], [ 2, 5 ], ]) - met_pcc = mx.metric.create('pcc') - met_pcc.update(labels, preds) - _, pcc = met_pcc.get() + met_mmcc = mx.metric.create('mmcc') + met_mmcc.update(labels, preds) + _, mmcc = met_mmcc.get() - # pcc should agree with mcc for binary classification + # mmcc should agree with mcc for binary classification met_mcc = mx.metric.create('mcc') met_mcc.update(labels, preds) _, mcc = met_mcc.get() - np.testing.assert_almost_equal(pcc, mcc) + np.testing.assert_almost_equal(mmcc, mcc) - # pcc should agree with Pearson for binary classification + # mmcc should agree with Pearson for binary classification met_pear = mx.metric.create('pearsonr') met_pear.update(labels, [p.argmax(axis=1) for p in preds]) _, pear = met_pear.get() - np.testing.assert_almost_equal(pcc, pear) + np.testing.assert_almost_equal(mmcc, pear) # check multiclass case against reference implementation CM = [ @@ -316,10 +316,10 @@ def test_pcc(): for k in range(K) )) ** 0.5 labels, preds = cm_batch(CM) - met_pcc.reset() - met_pcc.update(labels, preds) - _, pcc = met_pcc.get() - np.testing.assert_almost_equal(pcc, ref) + met_mmcc.reset() + met_mmcc.update(labels, preds) + _, mmcc = met_mmcc.get() + np.testing.assert_almost_equal(mmcc, ref) # things that should not change metric score: # * order @@ -330,10 +330,10 @@ def test_pcc(): preds = [ [ i.reshape((1, -1)) ] for i in preds[0] ] preds.reverse() - met_pcc.reset() + met_mmcc.reset() for l, p in zip(labels, preds): - met_pcc.update(l, p) - assert pcc == met_pcc.get()[1] + met_mmcc.update(l, p) + assert mmcc == met_mmcc.get()[1] def test_single_array_input(): pred = mx.nd.array([[1,2,3,4]])