From 196d1f4b034e00a49e0798bc64a3d680e8a50f6c Mon Sep 17 00:00:00 2001 From: Robert Stone Date: Thu, 29 Aug 2019 10:13:45 -0700 Subject: [PATCH] [MXNET-1399] multiclass-mcc metric enhancements (#14874) * multiclass-mcc metric enhancements * Rename metric from "PCC" to "mMCC" because though the math is derived from Pearson CC, it's utility is as a multiclass extension of Mathews CC. * Harden mx.metric.mMCC.update to more variations of input format, similar to mx.metric.Accuracy.update. * Harden mx.metric.PCC.update to more variations of input format, similar to mx.metric.Accuracy.update. * Enhance testcases for mx.metric.PCC. --- python/mxnet/metric.py | 6 +++++- tests/python/unittest/test_metric.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 4bfd1b36bf89..07ec2ef4d61d 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -1572,7 +1572,11 @@ def update(self, labels, preds): # update the confusion matrix for label, pred in zip(labels, preds): label = label.astype('int32', copy=False).asnumpy() - pred = pred.asnumpy().argmax(axis=1) + pred = pred.asnumpy() + if pred.shape != label.shape: + pred = pred.argmax(axis=1) + else: + pred = pred.astype('int32', copy=False) n = max(pred.max(), label.max()) if n >= self.k: self._grow(n + 1 - self.k) diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index a70ffdb88987..0ae8aeaa697f 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -304,6 +304,13 @@ def test_pcc(): _, pear = met_pear.get() np.testing.assert_almost_equal(pcc, pear) + # pcc should also accept pred as scalar rather than softmax vector + # like acc does + met_pcc.reset() + met_pcc.update(labels, [p.argmax(axis=1) for p in preds]) + _, chk = met_pcc.get() + np.testing.assert_almost_equal(pcc, chk) + # check multiclass case against reference implementation CM = [ [ 23, 13, 3 ],