Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MX-9588] Add micro averaging strategy for F1 metric (#9777)
Browse files Browse the repository at this point in the history
* add macro/micro f1 and test and binary abstraction

* make average an option

* use metric.create

* add decimal for float division

* add default in docstring, reference generic base class in error msg

* expand on docstring

* use scikit in test

* Revert "use scikit in test"

This reverts commit 797c01c.

* use composition

* minibatches
  • Loading branch information
sethah authored and szha committed Feb 15, 2018
1 parent 7d4b4c0 commit d03182f
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 52 deletions.
143 changes: 104 additions & 39 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,85 @@ def update(self, labels, preds):
self.num_inst += num_samples


class _BinaryClassificationMetrics(object):
"""
Private container class for classification metric statistics. True/false positive and
true/false negative counts are sufficient statistics for various classification metrics.
This class provides the machinery to track those statistics across mini-batches of
(label, prediction) pairs.
"""

def __init__(self):
self.true_positives = 0
self.false_negatives = 0
self.false_positives = 0
self.true_negatives = 0

def update_binary_stats(self, label, pred):
"""
Update various binary classification counts for a single (label, pred)
pair.
Parameters
----------
label : `NDArray`
The labels of the data.
pred : `NDArray`
Predicted values.
"""
pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
pred_label = numpy.argmax(pred, axis=1)

check_label_shapes(label, pred)
if len(numpy.unique(label)) > 2:
raise ValueError("%s currently only supports binary classification."
% self.__class__.__name__)

for y_pred, y_true in zip(pred_label, label):
if y_pred == 1 and y_true == 1:
self.true_positives += 1.
elif y_pred == 1 and y_true == 0:
self.false_positives += 1.
elif y_pred == 0 and y_true == 1:
self.false_negatives += 1.
else:
self.true_negatives += 1.

@property
def precision(self):
if self.true_positives + self.false_positives > 0:
return self.true_positives / (self.true_positives + self.false_positives)
else:
return 0.

@property
def recall(self):
if self.true_positives + self.false_negatives > 0:
return self.true_positives / (self.true_positives + self.false_negatives)
else:
return 0.

@property
def fscore(self):
if self.precision + self.recall > 0:
return 2 * self.precision * self.recall / (self.precision + self.recall)
else:
return 0.

@property
def total_examples(self):
return self.false_negatives + self.false_positives + \
self.true_negatives + self.true_positives

def reset_stats(self):
self.false_positives = 0
self.false_negatives = 0
self.true_positives = 0
self.true_negatives = 0


@register
class F1(EvalMetric):
"""Computes the F1 score of a binary classification problem.
Expand Down Expand Up @@ -503,21 +582,27 @@ class F1(EvalMetric):
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.
average : str, default 'macro'
Strategy to be used for aggregating across mini-batches.
"macro": average the F1 scores for each batch.
"micro": compute a single F1 score across all batches.
Examples
--------
>>> predicts = [mx.nd.array([[0.3, 0.7], [0., 1.], [0.4, 0.6]])]
>>> labels = [mx.nd.array([0., 1., 1.])]
>>> acc = mx.metric.F1()
>>> acc.update(preds = predicts, labels = labels)
>>> print acc.get()
>>> f1 = mx.metric.F1()
>>> f1.update(preds = predicts, labels = labels)
>>> print f1.get()
('f1', 0.8)
"""

def __init__(self, name='f1',
output_names=None, label_names=None):
super(F1, self).__init__(
name, output_names=output_names, label_names=label_names)
output_names=None, label_names=None, average="macro"):
self.average = average
self.metrics = _BinaryClassificationMetrics()
EvalMetric.__init__(self, name=name,
output_names=output_names, label_names=label_names)

def update(self, labels, preds):
"""Updates the internal evaluation result.
Expand All @@ -533,41 +618,21 @@ def update(self, labels, preds):
check_label_shapes(labels, preds)

for label, pred in zip(labels, preds):
pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
pred_label = numpy.argmax(pred, axis=1)

check_label_shapes(label, pred)
if len(numpy.unique(label)) > 2:
raise ValueError("F1 currently only supports binary classification.")

true_positives, false_positives, false_negatives = 0., 0., 0.

for y_pred, y_true in zip(pred_label, label):
if y_pred == 1 and y_true == 1:
true_positives += 1.
elif y_pred == 1 and y_true == 0:
false_positives += 1.
elif y_pred == 0 and y_true == 1:
false_negatives += 1.
self.metrics.update_binary_stats(label, pred)

if true_positives + false_positives > 0:
precision = true_positives / (true_positives + false_positives)
else:
precision = 0.

if true_positives + false_negatives > 0:
recall = true_positives / (true_positives + false_negatives)
else:
recall = 0.

if precision + recall > 0:
f1_score = 2 * precision * recall / (precision + recall)
else:
f1_score = 0.

self.sum_metric += f1_score
if self.average == "macro":
self.sum_metric += self.metrics.fscore
self.num_inst += 1
self.metrics.reset_stats()
else:
self.sum_metric = self.metrics.fscore * self.metrics.total_examples
self.num_inst = self.metrics.total_examples

def reset(self):
"""Resets the internal evaluation result to initial state."""
self.sum_metric = 0.
self.num_inst = 0.
self.metrics.reset_stats()


@register
Expand Down
58 changes: 45 additions & 13 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def check_metric(metric, *args, **kwargs):

assert metric.get_config() == metric2.get_config()


def test_metrics():
check_metric('acc', axis=0)
check_metric('f1')
Expand Down Expand Up @@ -56,18 +55,51 @@ def test_acc():
assert acc == expected_acc

def test_f1():
pred = mx.nd.array([[0.3, 0.7], [1., 0], [0.4, 0.6], [0.6, 0.4], [0.9, 0.1]])
label = mx.nd.array([0, 1, 1, 1, 1])
positives = np.argmax(pred, axis=1).sum().asscalar()
true_positives = (np.argmax(pred, axis=1) == label).sum().asscalar()
precision = true_positives / positives
overall_positives = label.sum().asscalar()
recall = true_positives / overall_positives
f1_expected = 2 * (precision * recall) / (precision + recall)
metric = mx.metric.create('f1')
metric.update([label], [pred])
_, f1 = metric.get()
assert f1 == f1_expected
microF1 = mx.metric.create("f1", average="micro")
macroF1 = mx.metric.F1(average="macro")

assert np.isnan(macroF1.get()[1])
assert np.isnan(microF1.get()[1])

# check divide by zero
pred = mx.nd.array([[0.9, 0.1],
[0.8, 0.2]])
label = mx.nd.array([0, 0])
macroF1.update([label], [pred])
microF1.update([label], [pred])
assert macroF1.get()[1] == 0.0
assert microF1.get()[1] == 0.0
macroF1.reset()
microF1.reset()

pred11 = mx.nd.array([[0.1, 0.9],
[0.5, 0.5]])
label11 = mx.nd.array([1, 0])
pred12 = mx.nd.array([[0.85, 0.15],
[1.0, 0.0]])
label12 = mx.nd.array([1, 0])
pred21 = mx.nd.array([[0.6, 0.4]])
label21 = mx.nd.array([0])
pred22 = mx.nd.array([[0.2, 0.8]])
label22 = mx.nd.array([1])

microF1.update([label11, label12], [pred11, pred12])
macroF1.update([label11, label12], [pred11, pred12])
assert microF1.num_inst == 4
assert macroF1.num_inst == 1
# f1 = 2 * tp / (2 * tp + fp + fn)
fscore1 = 2. * (1) / (2 * 1 + 1 + 0)
np.testing.assert_almost_equal(microF1.get()[1], fscore1)
np.testing.assert_almost_equal(macroF1.get()[1], fscore1)

microF1.update([label21, label22], [pred21, pred22])
macroF1.update([label21, label22], [pred21, pred22])
assert microF1.num_inst == 6
assert macroF1.num_inst == 2
fscore2 = 2. * (1) / (2 * 1 + 0 + 0)
fscore_total = 2. * (1 + 1) / (2 * (1 + 1) + (1 + 0) + (0 + 0))
np.testing.assert_almost_equal(microF1.get()[1], fscore_total)
np.testing.assert_almost_equal(macroF1.get()[1], (fscore1 + fscore2) / 2.)

def test_perplexity():
pred = mx.nd.array([[0.8, 0.2], [0.2, 0.8], [0, 1.]])
Expand Down

0 comments on commit d03182f

Please sign in to comment.