diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 4bfd1b36bf89..07ec2ef4d61d 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -1572,7 +1572,11 @@ def update(self, labels, preds): # update the confusion matrix for label, pred in zip(labels, preds): label = label.astype('int32', copy=False).asnumpy() - pred = pred.asnumpy().argmax(axis=1) + pred = pred.asnumpy() + if pred.shape != label.shape: + pred = pred.argmax(axis=1) + else: + pred = pred.astype('int32', copy=False) n = max(pred.max(), label.max()) if n >= self.k: self._grow(n + 1 - self.k) diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index d8dca753bda4..266d7cc862df 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -288,6 +288,13 @@ def test_pcc(): _, pear = met_pear.get() np.testing.assert_almost_equal(pcc, pear) + # pcc should also accept pred as scalar rather than softmax vector + # like acc does + met_pcc.reset() + met_pcc.update(labels, [p.argmax(axis=1) for p in preds]) + _, chk = met_pcc.get() + np.testing.assert_almost_equal(pcc, chk) + # check multiclass case against reference implementation CM = [ [ 23, 13, 3 ],