Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Adds a multiclass-MCC metric derived from Pearson
Browse files Browse the repository at this point in the history
  • Loading branch information
tlby committed Mar 21, 2019
1 parent d671528 commit 3bb93ba
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 1 deletion.
132 changes: 131 additions & 1 deletion python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ class MCC(EvalMetric):
.. note::
This version of MCC only supports binary classification.
This version of MCC only supports binary classification. See PCC.
Parameters
----------
Expand Down Expand Up @@ -1476,6 +1476,136 @@ def update(self, labels, preds):
self.global_num_inst += 1


@register
class PCC(EvalMetric):
"""PCC is a multiclass equivalent for the Matthews correlation coefficient derived
from a discrete solution to the Pearson correlation coefficient.
.. math::
\\text{PCC} = \\frac {\\sum _{k}\\sum _{l}\\sum _{m}C_{kk}C_{lm}-C_{kl}C_{mk}}
{{\\sqrt {\\sum _{k}(\\sum _{l}C_{kl})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{k'l'})}}
{\\sqrt {\\sum _{k}(\\sum _{l}C_{lk})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{l'k'})}}}
defined in terms of a K x K confusion matrix C.
When there are more than two labels the PCC will no longer range between -1 and +1.
Instead the minimum value will be between -1 and 0 depending on the true distribution.
The maximum value is always +1.
Parameters
----------
name : str
Name of this metric instance for display.
output_names : list of str, or None
Name of predictions that should be used when updating with update_dict.
By default include all predictions.
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.
Examples
--------
>>> # In this example the network almost always predicts positive
>>> false_positives = 1000
>>> false_negatives = 1
>>> true_positives = 10000
>>> true_negatives = 1
>>> predicts = [mx.nd.array(
[[.3, .7]]*false_positives +
[[.7, .3]]*true_negatives +
[[.7, .3]]*false_negatives +
[[.3, .7]]*true_positives
)]
>>> labels = [mx.nd.array(
[0]*(false_positives + true_negatives) +
[1]*(false_negatives + true_positives)
)]
>>> f1 = mx.metric.F1()
>>> f1.update(preds = predicts, labels = labels)
>>> pcc = mx.metric.PCC()
>>> pcc.update(preds = predicts, labels = labels)
>>> print f1.get()
('f1', 0.95233560306652054)
>>> print pcc.get()
('pcc', 0.01917751877733392)
"""
def __init__(self, name='pcc',
output_names=None, label_names=None,
has_global_stats=True):
self.k = 2
super(PCC, self).__init__(
name=name, output_names=output_names, label_names=label_names,
has_global_stats=has_global_stats)

def _grow(self, inc):
self.lcm = numpy.pad(
self.lcm, ((0, inc), (0, inc)), 'constant', constant_values=(0))
self.gcm = numpy.pad(
self.gcm, ((0, inc), (0, inc)), 'constant', constant_values=(0))
self.k += inc

def _calc_mcc(self, cmat):
n = cmat.sum()
x = cmat.sum(axis=1)
y = cmat.sum(axis=0)
cov_xx = numpy.sum(x * (n - x))
cov_yy = numpy.sum(y * (n - y))
if cov_xx == 0 or cov_yy == 0:
return float('nan')
i = cmat.diagonal()
cov_xy = numpy.sum(i * n - x * y)
return cov_xy / (cov_xx * cov_yy) ** 0.5

def update(self, labels, preds):
"""Updates the internal evaluation result.
Parameters
----------
labels : list of `NDArray`
The labels of the data.
preds : list of `NDArray`
Predicted values.
"""
labels, preds = check_label_shapes(labels, preds, True)

# update the confusion matrix
for label, pred in zip(labels, preds):
label = label.astype('int32').asnumpy()
pred = pred.argmax(axis=1).astype('int32').asnumpy()
n = max(pred.max(), label.max())
if n >= self.k:
self._grow(n + 1 - self.k)
bcm = numpy.zeros((self.k, self.k))
for i, j in zip(pred, label):
bcm[i, j] += 1
self.lcm += bcm
self.gcm += bcm

self.num_inst += 1
self.global_num_inst += 1

@property
def sum_metric(self):
return self._calc_mcc(self.lcm) * self.num_inst

@property
def global_sum_metric(self):
return self._calc_mcc(self.gcm) * self.global_num_inst

def reset(self):
"""Resets the internal evaluation result to initial state."""
self.global_num_inst = 0.
self.gcm = numpy.zeros((self.k, self.k))
self.reset_local()

def reset_local(self):
"""Resets the local portion of the internal evaluation results
to initial state."""
self.num_inst = 0.
self.lcm = numpy.zeros((self.k, self.k))


@register
class Loss(EvalMetric):
"""Dummy metric for directly printing loss.
Expand Down
82 changes: 82 additions & 0 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_metrics():
check_metric('mcc')
check_metric('perplexity', -1)
check_metric('pearsonr')
check_metric('pcc')
check_metric('nll_loss')
check_metric('loss')
composite = mx.metric.create(['acc', 'f1'])
Expand Down Expand Up @@ -89,6 +90,7 @@ def test_global_metric():
_check_global_metric('mcc', shape=(10,2), average='micro')
_check_global_metric('perplexity', -1)
_check_global_metric('pearsonr', use_same_shape=True)
_check_global_metric('pcc', shape=(10,2))
_check_global_metric('nll_loss')
_check_global_metric('loss')
_check_global_metric('ce')
Expand Down Expand Up @@ -253,6 +255,86 @@ def test_pearsonr():
_, pearsonr = metric.get()
assert pearsonr == pearsonr_expected

def cm_batch(cm):
# generate a batch yielding a given confusion matrix
n = len(cm)
ident = np.identity(n)
labels = []
preds = []
for i in range(n):
for j in range(n):
labels += [ i ] * cm[i][j]
preds += [ ident[j] ] * cm[i][j]
return ([ mx.nd.array(labels, dtype='int32') ], [ mx.nd.array(preds) ])

def test_pcc():
labels, preds = cm_batch([
[ 7, 3 ],
[ 2, 5 ],
])
met_pcc = mx.metric.create('pcc')
met_pcc.update(labels, preds)
_, pcc = met_pcc.get()

# pcc should agree with mcc for binary classification
met_mcc = mx.metric.create('mcc')
met_mcc.update(labels, preds)
_, mcc = met_mcc.get()
np.testing.assert_almost_equal(pcc, mcc)

# pcc should agree with Pearson for binary classification
met_pear = mx.metric.create('pearsonr')
met_pear.update(labels, [p.argmax(axis=1) for p in preds])
_, pear = met_pear.get()
np.testing.assert_almost_equal(pcc, pear)

# check multiclass case against reference implementation
CM = [
[ 23, 13, 3 ],
[ 7, 19, 11 ],
[ 2, 5, 17 ],
]
K = 3
ref = sum(
CM[k][k] * CM[l][m] - CM[k][l] * CM[m][k]
for k in range(K)
for l in range(K)
for m in range(K)
) / (sum(
sum(CM[k][l] for l in range(K)) * sum(
sum(CM[f][g] for g in range(K))
for f in range(K)
if f != k
)
for k in range(K)
) * sum(
sum(CM[l][k] for l in range(K)) * sum(
sum(CM[f][g] for f in range(K))
for g in range(K)
if g != k
)
for k in range(K)
)) ** 0.5
labels, preds = cm_batch(CM)
met_pcc.reset()
met_pcc.update(labels, preds)
_, pcc = met_pcc.get()
np.testing.assert_almost_equal(pcc, ref)

# things that should not change metric score:
# * order
# * batch size
# * update frequency
labels = [ [ i ] for i in labels[0] ]
labels.reverse()
preds = [ [ i.reshape((1, -1)) ] for i in preds[0] ]
preds.reverse()

met_pcc.reset()
for l, p in zip(labels, preds):
met_pcc.update(l, p)
assert pcc == met_pcc.get()[1]

def test_single_array_input():
pred = mx.nd.array([[1,2,3,4]])
label = pred + 0.1
Expand Down

0 comments on commit 3bb93ba

Please sign in to comment.