Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
multiclass-mcc metric enhancements
Browse files Browse the repository at this point in the history
 * Rename metric from "PCC" to "mMCC" because though the math is derived
   from Pearson CC, it's utility is as a multiclass extension of Mathews CC.
 * Harden mx.metric.mMCC.update to more variations of input format, similar to
   mx.metric.Accuracy.update.
  • Loading branch information
tlby committed May 3, 2019
1 parent d09f68a commit c0b08f2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
28 changes: 16 additions & 12 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ class MCC(EvalMetric):
.. note::
This version of MCC only supports binary classification. See PCC.
This version of MCC only supports binary classification. See mMCC.
Parameters
----------
Expand Down Expand Up @@ -1477,18 +1477,18 @@ def update(self, labels, preds):


@register
class PCC(EvalMetric):
"""PCC is a multiclass equivalent for the Matthews correlation coefficient derived
class mMCC(EvalMetric):
"""mMCC is a multiclass equivalent for the Matthews correlation coefficient derived
from a discrete solution to the Pearson correlation coefficient.
.. math::
\\text{PCC} = \\frac {\\sum _{k}\\sum _{l}\\sum _{m}C_{kk}C_{lm}-C_{kl}C_{mk}}
\\text{mMCC} = \\frac {\\sum _{k}\\sum _{l}\\sum _{m}C_{kk}C_{lm}-C_{kl}C_{mk}}
{{\\sqrt {\\sum _{k}(\\sum _{l}C_{kl})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{k'l'})}}
{\\sqrt {\\sum _{k}(\\sum _{l}C_{lk})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{l'k'})}}}
defined in terms of a K x K confusion matrix C.
When there are more than two labels the PCC will no longer range between -1 and +1.
When there are more than two labels the mMCC will no longer range between -1 and +1.
Instead the minimum value will be between -1 and 0 depending on the true distribution.
The maximum value is always +1.
Expand Down Expand Up @@ -1522,18 +1522,18 @@ class PCC(EvalMetric):
)]
>>> f1 = mx.metric.F1()
>>> f1.update(preds = predicts, labels = labels)
>>> pcc = mx.metric.PCC()
>>> pcc.update(preds = predicts, labels = labels)
>>> mmcc = mx.metric.mMCC()
>>> mmcc.update(preds = predicts, labels = labels)
>>> print f1.get()
('f1', 0.95233560306652054)
>>> print pcc.get()
('pcc', 0.01917751877733392)
>>> print mmcc.get()
('mmcc', 0.01917751877733392)
"""
def __init__(self, name='pcc',
def __init__(self, name='mmcc',
output_names=None, label_names=None,
has_global_stats=True):
self.k = 2
super(PCC, self).__init__(
super(mMCC, self).__init__(
name=name, output_names=output_names, label_names=label_names,
has_global_stats=has_global_stats)

Expand Down Expand Up @@ -1572,7 +1572,11 @@ def update(self, labels, preds):
# update the confusion matrix
for label, pred in zip(labels, preds):
label = label.astype('int32', copy=False).asnumpy()
pred = pred.asnumpy().argmax(axis=1)
pred = pred.asnumpy()
if pred.shape != label.shape:
pred = pred.argmax(axis=1)
else:
pred = pred.astype('int32', copy=False)
n = max(pred.max(), label.max())
if n >= self.k:
self._grow(n + 1 - self.k)
Expand Down
34 changes: 17 additions & 17 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_metrics():
check_metric('mcc')
check_metric('perplexity', -1)
check_metric('pearsonr')
check_metric('pcc')
check_metric('mmcc')
check_metric('nll_loss')
check_metric('loss')
composite = mx.metric.create(['acc', 'f1'])
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_global_metric():
_check_global_metric('mcc', shape=(10,2), average='micro')
_check_global_metric('perplexity', -1)
_check_global_metric('pearsonr', use_same_shape=True)
_check_global_metric('pcc', shape=(10,2))
_check_global_metric('mmcc', shape=(10,2))
_check_global_metric('nll_loss')
_check_global_metric('loss')
_check_global_metric('ce')
Expand Down Expand Up @@ -267,26 +267,26 @@ def cm_batch(cm):
preds += [ ident[j] ] * cm[i][j]
return ([ mx.nd.array(labels, dtype='int32') ], [ mx.nd.array(preds) ])

def test_pcc():
def test_mmcc():
labels, preds = cm_batch([
[ 7, 3 ],
[ 2, 5 ],
])
met_pcc = mx.metric.create('pcc')
met_pcc.update(labels, preds)
_, pcc = met_pcc.get()
met_mmcc = mx.metric.create('mmcc')
met_mmcc.update(labels, preds)
_, mmcc = met_mmcc.get()

# pcc should agree with mcc for binary classification
# mmcc should agree with mcc for binary classification
met_mcc = mx.metric.create('mcc')
met_mcc.update(labels, preds)
_, mcc = met_mcc.get()
np.testing.assert_almost_equal(pcc, mcc)
np.testing.assert_almost_equal(mmcc, mcc)

# pcc should agree with Pearson for binary classification
# mmcc should agree with Pearson for binary classification
met_pear = mx.metric.create('pearsonr')
met_pear.update(labels, [p.argmax(axis=1) for p in preds])
_, pear = met_pear.get()
np.testing.assert_almost_equal(pcc, pear)
np.testing.assert_almost_equal(mmcc, pear)

# check multiclass case against reference implementation
CM = [
Expand Down Expand Up @@ -316,10 +316,10 @@ def test_pcc():
for k in range(K)
)) ** 0.5
labels, preds = cm_batch(CM)
met_pcc.reset()
met_pcc.update(labels, preds)
_, pcc = met_pcc.get()
np.testing.assert_almost_equal(pcc, ref)
met_mmcc.reset()
met_mmcc.update(labels, preds)
_, mmcc = met_mmcc.get()
np.testing.assert_almost_equal(mmcc, ref)

# things that should not change metric score:
# * order
Expand All @@ -330,10 +330,10 @@ def test_pcc():
preds = [ [ i.reshape((1, -1)) ] for i in preds[0] ]
preds.reverse()

met_pcc.reset()
met_mmcc.reset()
for l, p in zip(labels, preds):
met_pcc.update(l, p)
assert pcc == met_pcc.get()[1]
met_mmcc.update(l, p)
assert mmcc == met_mmcc.get()[1]

def test_single_array_input():
pred = mx.nd.array([[1,2,3,4]])
Expand Down

0 comments on commit c0b08f2

Please sign in to comment.