From 6d4bd4362223040179b53d4bef2a3b05b15593d0 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 28 Nov 2018 09:22:42 -0800 Subject: [PATCH 1/8] Change argsort to argpartition --- python/mxnet/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 6d9972074b67..2efcb802ce06 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -487,7 +487,7 @@ def update(self, labels, preds): for label, pred_label in zip(labels, preds): assert(len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims' - pred_label = numpy.argsort(pred_label.asnumpy().astype('float32'), axis=1) + pred_label = numpy.argpartition(pred_label.asnumpy().astype('float32'), -self.top_k) label = label.asnumpy().astype('int32') check_label_shapes(label, pred_label) num_samples = pred_label.shape[0] From 49050aacc3fa72345c9f133a5bec38e8bdedd25e Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 28 Nov 2018 15:39:18 -0800 Subject: [PATCH 2/8] Global statistics in metrics --- python/mxnet/callback.py | 4 +- python/mxnet/metric.py | 207 +++++++++++++++++++++++++---- python/mxnet/module/base_module.py | 8 +- 3 files changed, 182 insertions(+), 37 deletions(-) diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py index e1c1714445df..bfec794f6220 100644 --- a/python/mxnet/callback.py +++ b/python/mxnet/callback.py @@ -113,7 +113,7 @@ def _callback(param): logging.info('Iter[%d] Batch[%d] Train-%s=%f', param.epoch, param.nbatch, name, value) if auto_reset: - param.eval_metric.reset() + param.eval_metric.reset_local() return _callback @@ -164,7 +164,7 @@ def __call__(self, param): if param.eval_metric is not None: name_value = param.eval_metric.get_name_value() if self.auto_reset: - param.eval_metric.reset() + param.eval_metric.reset_local() msg = 'Epoch[%d] Batch [%d-%d]\tSpeed: %.2f samples/sec' msg += '\t%s=%f'*len(name_value) logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ())) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 2efcb802ce06..8cc9eed4d9eb 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -91,6 +91,10 @@ def __init__(self, name, output_names=None, self.output_names = output_names self.label_names = label_names self._kwargs = kwargs + if "has_global_stats" in kwargs: + self._has_global_stats = kwargs["has_global_stats"] + else: + self._has_global_stats = False self.reset() def __str__(self): @@ -148,6 +152,14 @@ def reset(self): """Resets the internal evaluation result to initial state.""" self.num_inst = 0 self.sum_metric = 0.0 + self.global_num_inst = 0 + self.global_sum_metric = 0.0 + + def reset_local(self): + """Resets the local portion of the internal evaluation results + to initial state.""" + self.num_inst = 0 + self.sum_metric = 0.0 def get(self): """Gets the current evaluation result. @@ -164,6 +176,24 @@ def get(self): else: return (self.name, self.sum_metric / self.num_inst) + def get_global(self): + """Gets the current global evaluation result. + + Returns + ------- + names : list of str + Name of the metrics. + values : list of float + Value of the evaluations. + """ + if self._has_global_stats: + if self.global_num_inst == 0: + return (self.name, float('nan')) + else: + return (self.name, self.global_sum_metric / self.global_num_inst) + else: + return self.get() + def get_name_value(self): """Returns zipped name and value pairs. @@ -179,6 +209,24 @@ def get_name_value(self): value = [value] return list(zip(name, value)) + def get_global_name_value(self): + """Returns zipped name and value pairs for global results. + + Returns + ------- + list of tuples + A (name, value) tuple list. + """ + if self._has_global_stats: + name, value = self.get_global() + if not isinstance(name, list): + name = [name] + if not isinstance(value, list): + value = [value] + return list(zip(name, value)) + else: + return self.get_name_value() + # pylint: disable=invalid-name register = registry.get_register_func(EvalMetric, 'metric') alias = registry.get_alias_func(EvalMetric, 'metric') @@ -263,7 +311,8 @@ class CompositeEvalMetric(EvalMetric): def __init__(self, metrics=None, name='composite', output_names=None, label_names=None): super(CompositeEvalMetric, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) if metrics is None: metrics = [] self.metrics = [create(i) for i in metrics] @@ -325,6 +374,15 @@ def reset(self): except AttributeError: pass + def reset_local(self): + """Resets the local portion of the internal evaluation results + to initial state.""" + try: + for metric in self.metrics: + metric.reset_local() + except AttributeError: + pass + def get(self): """Returns the current evaluation result. @@ -347,6 +405,28 @@ def get(self): values.extend(value) return (names, values) + def get_global(self): + """Returns the current evaluation result. + + Returns + ------- + names : list of str + Name of the metrics. + values : list of float + Value of the evaluations. + """ + names = [] + values = [] + for metric in self.metrics: + name, value = metric.get_global() + if isinstance(name, string_types): + name = [name] + if isinstance(value, numeric_types): + value = [value] + names.extend(name) + values.extend(value) + return (names, values) + def get_config(self): config = super(CompositeEvalMetric, self).get_config() config.update({'metrics': [i.get_config() for i in self.metrics]}) @@ -395,7 +475,8 @@ def __init__(self, axis=1, name='accuracy', output_names=None, label_names=None): super(Accuracy, self).__init__( name, axis=axis, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.axis = axis def update(self, labels, preds): @@ -423,8 +504,11 @@ def update(self, labels, preds): check_label_shapes(label, pred_label) - self.sum_metric += (pred_label == label).sum() + num_correct = (pred_label == label).sum() + self.sum_metric += num_correct + self.global_sum_metric += num_correct self.num_inst += len(pred_label) + self.global_num_inst += len(pred_label) @register @@ -467,7 +551,8 @@ def __init__(self, top_k=1, name='top_k_accuracy', output_names=None, label_names=None): super(TopKAccuracy, self).__init__( name, top_k=top_k, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.top_k = top_k assert(self.top_k > 1), 'Please use Accuracy if top_k is no more than 1' self.name += '_%d' % self.top_k @@ -498,8 +583,11 @@ def update(self, labels, preds): num_classes = pred_label.shape[1] top_k = min(num_classes, self.top_k) for j in range(top_k): - self.sum_metric += (pred_label[:, num_classes - 1 - j].flat == label.flat).sum() + num_correct = (pred_label[:, num_classes - 1 - j].flat == label.flat).sum() + self.sum_metric += num_correct + self.global_sum_metric += num_correct self.num_inst += num_samples + self.global_num_inst += num_samples class _BinaryClassificationMetrics(object): @@ -649,7 +737,8 @@ def __init__(self, name='f1', self.average = average self.metrics = _BinaryClassificationMetrics() EvalMetric.__init__(self, name=name, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -667,18 +756,25 @@ def update(self, labels, preds): for label, pred in zip(labels, preds): self.metrics.update_binary_stats(label, pred) + fscore = self.metrics.fscore if self.average == "macro": - self.sum_metric += self.metrics.fscore + self.sum_metric += fscore + self.global_sum_metric += fscore self.num_inst += 1 + self.global_num_inst += 1 self.metrics.reset_stats() else: - self.sum_metric = self.metrics.fscore * self.metrics.total_examples + self.sum_metric = fscore * self.metrics.total_examples + self.global_sum_metric = fscore * self.metrics.total_examples self.num_inst = self.metrics.total_examples + self.global_num_inst = self.metrics.total_examples def reset(self): """Resets the internal evaluation result to initial state.""" self.sum_metric = 0. - self.num_inst = 0. + self.num_inst = 0 + self.global_num_inst = 0 + self.global_sum_metric = 0.0 self.metrics.reset_stats() @@ -750,7 +846,8 @@ def __init__(self, name='mcc', self._average = average self._metrics = _BinaryClassificationMetrics() EvalMetric.__init__(self, name=name, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -768,18 +865,25 @@ def update(self, labels, preds): for label, pred in zip(labels, preds): self._metrics.update_binary_stats(label, pred) + matthewscc = self._metrics.matthewscc if self._average == "macro": - self.sum_metric += self._metrics.matthewscc + self.sum_metric += matthewscc + self.global_sum_metric += matthewscc self.num_inst += 1 + self.global_num_inst += 1 self._metrics.reset_stats() else: - self.sum_metric = self._metrics.matthewscc * self._metrics.total_examples + self.sum_metric = matthewscc * self._metrics.total_examples + self.global_sum_metric = matthewscc * self._metrics.total_examples self.num_inst = self._metrics.total_examples + self.global_num_inst = self._metrics.total_examples def reset(self): """Resets the internal evaluation result to initial state.""" self.sum_metric = 0. self.num_inst = 0. + self.global_sum_metric = 0. + self.global_num_inst = 0. self._metrics.reset_stats() @@ -841,7 +945,8 @@ def __init__(self, ignore_label, axis=-1, name='perplexity', output_names=None, label_names=None): super(Perplexity, self).__init__( name, ignore_label=ignore_label, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.ignore_label = ignore_label self.axis = axis @@ -871,7 +976,9 @@ def update(self, labels, preds): loss -= ndarray.sum(ndarray.log(ndarray.maximum(1e-10, pred))).asscalar() num += pred.size self.sum_metric += loss + self.global_sum_metric += loss self.num_inst += num + self.global_num_inst += num def get(self): """Returns the current evaluation result. @@ -883,6 +990,17 @@ def get(self): """ return (self.name, math.exp(self.sum_metric/self.num_inst)) + def get_global(self): + """Returns the current global evaluation result. + + Returns + ------- + Tuple of (str, float) + Representing name of the metric and evaluation result. + """ + num = self.global_num_inst if self.global_num_inst > 0 else float('nan') + return (self.name, math.exp(self.global_sum_metric/num)) + #################### # REGRESSION METRICS #################### @@ -921,7 +1039,8 @@ class MAE(EvalMetric): def __init__(self, name='mae', output_names=None, label_names=None): super(MAE, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -945,8 +1064,11 @@ def update(self, labels, preds): if len(pred.shape) == 1: pred = pred.reshape(pred.shape[0], 1) - self.sum_metric += numpy.abs(label - pred).mean() + mae = numpy.abs(label - pred).mean() + self.sum_metric += mae + self.global_sum_metric += mae self.num_inst += 1 # numpy.prod(label.shape) + self.global_num_inst += 1 # numpy.prod(label.shape) @register @@ -981,7 +1103,8 @@ class MSE(EvalMetric): def __init__(self, name='mse', output_names=None, label_names=None): super(MSE, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -1005,8 +1128,11 @@ def update(self, labels, preds): if len(pred.shape) == 1: pred = pred.reshape(pred.shape[0], 1) - self.sum_metric += ((label - pred)**2.0).mean() + mse = ((label - pred)**2.0).mean() + self.sum_metric += mse + self.global_sum_metric += mse self.num_inst += 1 # numpy.prod(label.shape) + self.global_num_inst += 1 # numpy.prod(label.shape) @register @@ -1041,7 +1167,8 @@ class RMSE(EvalMetric): def __init__(self, name='rmse', output_names=None, label_names=None): super(RMSE, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -1065,8 +1192,11 @@ def update(self, labels, preds): if len(pred.shape) == 1: pred = pred.reshape(pred.shape[0], 1) - self.sum_metric += numpy.sqrt(((label - pred)**2.0).mean()) + rmse = numpy.sqrt(((label - pred)**2.0).mean()) + self.sum_metric += rmse + self.global_sum_metric += rmse self.num_inst += 1 + self.global_num_inst += 1 @register @@ -1110,7 +1240,8 @@ def __init__(self, eps=1e-12, name='cross-entropy', output_names=None, label_names=None): super(CrossEntropy, self).__init__( name, eps=eps, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.eps = eps def update(self, labels, preds): @@ -1134,8 +1265,11 @@ def update(self, labels, preds): assert label.shape[0] == pred.shape[0] prob = pred[numpy.arange(label.shape[0]), numpy.int64(label)] - self.sum_metric += (-numpy.log(prob + self.eps)).sum() + ce = (-numpy.log(prob + self.eps)).sum() + self.sum_metric += ce + self.global_sum_metric += ce self.num_inst += label.shape[0] + self.global_num_inst += label.shape[0] @register @alias('nll_loss') @@ -1178,7 +1312,8 @@ def __init__(self, eps=1e-12, name='nll-loss', output_names=None, label_names=None): super(NegativeLogLikelihood, self).__init__( name, eps=eps, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.eps = eps def update(self, labels, preds): @@ -1202,8 +1337,11 @@ def update(self, labels, preds): num_examples = pred.shape[0] assert label.shape[0] == num_examples, (label.shape[0], num_examples) prob = pred[numpy.arange(num_examples, dtype=numpy.int64), numpy.int64(label)] - self.sum_metric += (-numpy.log(prob + self.eps)).sum() + nll = (-numpy.log(prob + self.eps)).sum() + self.sum_metric += nll + self.global_sum_metric += nll self.num_inst += num_examples + self.global_num_inst += num_examples @register @alias('pearsonr') @@ -1238,7 +1376,8 @@ class PearsonCorrelation(EvalMetric): def __init__(self, name='pearsonr', output_names=None, label_names=None): super(PearsonCorrelation, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -1256,8 +1395,11 @@ def update(self, labels, preds): check_label_shapes(label, pred, False, True) label = label.asnumpy() pred = pred.asnumpy() - self.sum_metric += numpy.corrcoef(pred.ravel(), label.ravel())[0, 1] + pearson_corr = numpy.corrcoef(pred.ravel(), label.ravel())[0, 1] + self.sum_metric += pearson_corr + self.global_sum_metric += pearson_corr self.num_inst += 1 + self.global_num_inst += 1 @register @@ -1278,7 +1420,8 @@ class Loss(EvalMetric): def __init__(self, name='loss', output_names=None, label_names=None): super(Loss, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, _, preds): @@ -1286,8 +1429,11 @@ def update(self, _, preds): preds = [preds] for pred in preds: - self.sum_metric += ndarray.sum(pred).asscalar() + loss = ndarray.sum(pred).asscalar() + self.sum_metric += loss + self.global_sum_metric += loss self.num_inst += pred.size + self.global_num_inst += pred.size @register @@ -1353,7 +1499,8 @@ def __init__(self, feval, name=None, allow_extra_outputs=False, super(CustomMetric, self).__init__( name, feval=feval, allow_extra_outputs=allow_extra_outputs, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self._feval = feval self._allow_extra_outputs = allow_extra_outputs @@ -1379,10 +1526,14 @@ def update(self, labels, preds): if isinstance(reval, tuple): (sum_metric, num_inst) = reval self.sum_metric += sum_metric + self.global_sum_metric += sum_metric self.num_inst += num_inst + self.global_num_inst += num_inst else: self.sum_metric += reval + self.global_sum_metric += reval self.num_inst += 1 + self.global_num_inst += 1 def get_config(self): raise NotImplementedError("CustomMetric cannot be serialized") diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index babea53d6e40..b842dc98cd9f 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -508,7 +508,6 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) - epoch_eval_metric = copy.deepcopy(eval_metric) ################################################################################ # training loop @@ -516,7 +515,6 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() - epoch_eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False @@ -532,12 +530,8 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True) - self.update_metric(epoch_eval_metric, - [db.label for db in data_batch], - pre_sliced=True) else: self.update_metric(eval_metric, data_batch.label) - self.update_metric(epoch_eval_metric, data_batch.label) try: # pre fetch next batch @@ -550,7 +544,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', monitor.toc_print() if end_of_batch: - eval_name_vals = epoch_eval_metric.get_name_value() + eval_name_vals = eval_metric.get_global_name_value() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, From da6ec43d08a0dfe642e932461bb96baa69b21447 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 29 Nov 2018 16:22:23 -0800 Subject: [PATCH 3/8] Fix lint --- python/mxnet/metric.py | 6 +++--- python/mxnet/module/base_module.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 8cc9eed4d9eb..7451872f8fc5 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -1265,9 +1265,9 @@ def update(self, labels, preds): assert label.shape[0] == pred.shape[0] prob = pred[numpy.arange(label.shape[0]), numpy.int64(label)] - ce = (-numpy.log(prob + self.eps)).sum() - self.sum_metric += ce - self.global_sum_metric += ce + cross_entropy = (-numpy.log(prob + self.eps)).sum() + self.sum_metric += cross_entropy + self.global_sum_metric += cross_entropy self.num_inst += label.shape[0] self.global_num_inst += label.shape[0] diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index b842dc98cd9f..ca8463153686 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -22,7 +22,6 @@ import time import logging import warnings -import copy import numpy as np from .. import metric From 70455e6b3d81a9d1aeaa0e2a9c16ff52571d350e Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 Nov 2018 11:04:52 -0800 Subject: [PATCH 4/8] Fixes from review --- python/mxnet/metric.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 7451872f8fc5..9d23e817f20c 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -90,11 +90,8 @@ def __init__(self, name, output_names=None, self.name = str(name) self.output_names = output_names self.label_names = label_names + self._has_global_stats = kwargs.pop("has_global_stats", False) self._kwargs = kwargs - if "has_global_stats" in kwargs: - self._has_global_stats = kwargs["has_global_stats"] - else: - self._has_global_stats = False self.reset() def __str__(self): @@ -764,10 +761,8 @@ def update(self, labels, preds): self.global_num_inst += 1 self.metrics.reset_stats() else: - self.sum_metric = fscore * self.metrics.total_examples - self.global_sum_metric = fscore * self.metrics.total_examples - self.num_inst = self.metrics.total_examples - self.global_num_inst = self.metrics.total_examples + self.sum_metric = self.global_sum_metric = fscore * self.metrics.total_examples + self.num_inst = self.global_num_inst = self.metrics.total_examples def reset(self): """Resets the internal evaluation result to initial state.""" @@ -873,10 +868,8 @@ def update(self, labels, preds): self.global_num_inst += 1 self._metrics.reset_stats() else: - self.sum_metric = matthewscc * self._metrics.total_examples - self.global_sum_metric = matthewscc * self._metrics.total_examples - self.num_inst = self._metrics.total_examples - self.global_num_inst = self._metrics.total_examples + self.sum_metric = self.global_sum_metric = matthewscc * self._metrics.total_examples + self.num_inst = self.global_num_inst = self._metrics.total_examples def reset(self): """Resets the internal evaluation result to initial state.""" From d7e14cc890567f82dac044903702590609919d88 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 Nov 2018 20:47:51 -0800 Subject: [PATCH 5/8] Trigger From 3fb50c77d099a4f3b3f32bcd6543babf242b48b1 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 12 Dec 2018 11:41:44 -0800 Subject: [PATCH 6/8] Fixes from review, fix to F1, MCC and perplexity metrics, added test for global stats --- python/mxnet/metric.py | 129 ++++++++++++++++++++++----- tests/python/unittest/test_metric.py | 61 +++++++++++++ 2 files changed, 166 insertions(+), 24 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 9d23e817f20c..28fb44a9802f 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -569,6 +569,10 @@ def update(self, labels, preds): for label, pred_label in zip(labels, preds): assert(len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims' + # Using argpartition here instead of argsort is safe because + # we do not care about the order of top k elements. It is + # much faster, which is important since that computation is + # single-threaded due to Python GIL. pred_label = numpy.argpartition(pred_label.asnumpy().astype('float32'), -self.top_k) label = label.asnumpy().astype('int32') check_label_shapes(label, pred_label) @@ -600,6 +604,10 @@ def __init__(self): self.false_negatives = 0 self.false_positives = 0 self.true_negatives = 0 + self.global_true_positives = 0 + self.global_false_negatives = 0 + self.global_false_positives = 0 + self.global_true_negatives = 0 def update_binary_stats(self, label, pred): """ @@ -627,10 +635,18 @@ def update_binary_stats(self, label, pred): label_true = (label == 1) label_false = 1 - label_true - self.true_positives += (pred_true * label_true).sum() - self.false_positives += (pred_true * label_false).sum() - self.false_negatives += (pred_false * label_true).sum() - self.true_negatives += (pred_false * label_false).sum() + tp = (pred_true * label_true).sum() + fp = (pred_true * label_false).sum() + fn = (pred_false * label_true).sum() + tn = (pred_false * label_false).sum() + self.true_positives += tp + self.global_true_positives += tp + self.false_positives += fp + self.global_false_positives += fp + self.false_negatives += fn + self.global_false_negatives += fn + self.true_negatives += tn + self.global_true_negatives += tn @property def precision(self): @@ -639,6 +655,13 @@ def precision(self): else: return 0. + @property + def global_precision(self): + if self.global_true_positives + self.global_false_positives > 0: + return float(self.global_true_positives) / (self.global_true_positives + self.global_false_positives) + else: + return 0. + @property def recall(self): if self.true_positives + self.false_negatives > 0: @@ -646,6 +669,13 @@ def recall(self): else: return 0. + @property + def global_recall(self): + if self.global_true_positives + self.global_false_negatives > 0: + return float(self.global_true_positives) / (self.global_true_positives + self.global_false_negatives) + else: + return 0. + @property def fscore(self): if self.precision + self.recall > 0: @@ -654,17 +684,33 @@ def fscore(self): return 0. @property - def matthewscc(self): + def global_fscore(self): + if self.global_precision + self.global_recall > 0: + return 2 * self.global_precision * self.global_recall / (self.global_precision + self.global_recall) + else: + return 0. + + def matthewscc(self, use_global=False): """ Calculate the Matthew's Correlation Coefficent """ - if not self.total_examples: - return 0. + if use_global: + if not self.global_total_examples: + return 0. + + true_pos = float(self.global_true_positives) + false_pos = float(self.global_false_positives) + false_neg = float(self.global_false_negatives) + true_neg = float(self.global_true_negatives) + else: + if not self.total_examples: + return 0. + + true_pos = float(self.true_positives) + false_pos = float(self.false_positives) + false_neg = float(self.false_negatives) + true_neg = float(self.true_negatives) - true_pos = float(self.true_positives) - false_pos = float(self.false_positives) - false_neg = float(self.false_negatives) - true_neg = float(self.true_negatives) terms = [(true_pos + false_pos), (true_pos + false_neg), (true_neg + false_pos), @@ -679,11 +725,26 @@ def total_examples(self): return self.false_negatives + self.false_positives + \ self.true_negatives + self.true_positives + @property + def global_total_examples(self): + return self.global_false_negatives + self.global_false_positives + \ + self.global_true_negatives + self.global_true_positives + + def local_reset_stats(self): + self.false_positives = 0 + self.false_negatives = 0 + self.true_positives = 0 + self.true_negatives = 0 + def reset_stats(self): self.false_positives = 0 self.false_negatives = 0 self.true_positives = 0 self.true_negatives = 0 + self.global_false_positives = 0 + self.global_false_negatives = 0 + self.global_true_positives = 0 + self.global_true_negatives = 0 @register @@ -753,16 +814,17 @@ def update(self, labels, preds): for label, pred in zip(labels, preds): self.metrics.update_binary_stats(label, pred) - fscore = self.metrics.fscore if self.average == "macro": - self.sum_metric += fscore - self.global_sum_metric += fscore + self.sum_metric += self.metrics.fscore + self.global_sum_metric += self.metrics.global_fscore self.num_inst += 1 self.global_num_inst += 1 self.metrics.reset_stats() else: - self.sum_metric = self.global_sum_metric = fscore * self.metrics.total_examples - self.num_inst = self.global_num_inst = self.metrics.total_examples + self.sum_metric = self.metrics.fscore * self.metrics.total_examples + self.global_sum_metric = self.metrics.global_fscore * self.metrics.global_total_examples + self.num_inst = self.metrics.total_examples + self.global_num_inst = self.metrics.global_total_examples def reset(self): """Resets the internal evaluation result to initial state.""" @@ -772,6 +834,12 @@ def reset(self): self.global_sum_metric = 0.0 self.metrics.reset_stats() + def reset_local(self): + """Resets the internal evaluation result to initial state.""" + self.sum_metric = 0. + self.num_inst = 0 + self.metrics.local_reset_stats() + @register class MCC(EvalMetric): @@ -860,16 +928,18 @@ def update(self, labels, preds): for label, pred in zip(labels, preds): self._metrics.update_binary_stats(label, pred) - matthewscc = self._metrics.matthewscc if self._average == "macro": - self.sum_metric += matthewscc - self.global_sum_metric += matthewscc + self.sum_metric += self._metrics.matthewscc() + self.global_sum_metric += self._metrics.matthewscc(use_global=True) self.num_inst += 1 self.global_num_inst += 1 self._metrics.reset_stats() else: - self.sum_metric = self.global_sum_metric = matthewscc * self._metrics.total_examples - self.num_inst = self.global_num_inst = self._metrics.total_examples + self.sum_metric = self._metrics.matthewscc() * self._metrics.total_examples + self.global_sum_metric = self._metrics.matthewscc(use_global=True) * \ + self._metrics.global_total_examples + self.num_inst = self._metrics.total_examples + self.global_num_inst = self._metrics.global_total_examples def reset(self): """Resets the internal evaluation result to initial state.""" @@ -879,6 +949,12 @@ def reset(self): self.global_num_inst = 0. self._metrics.reset_stats() + def reset_local(self): + """Resets the internal evaluation result to initial state.""" + self.sum_metric = 0. + self.num_inst = 0. + self._metrics.local_reset_stats() + @register class Perplexity(EvalMetric): @@ -981,7 +1057,10 @@ def get(self): Tuple of (str, float) Representing name of the metric and evaluation result. """ - return (self.name, math.exp(self.sum_metric/self.num_inst)) + if self.num_inst == 0: + return (self.name, float('nan')) + else: + return (self.name, math.exp(self.sum_metric/self.num_inst)) def get_global(self): """Returns the current global evaluation result. @@ -991,8 +1070,10 @@ def get_global(self): Tuple of (str, float) Representing name of the metric and evaluation result. """ - num = self.global_num_inst if self.global_num_inst > 0 else float('nan') - return (self.name, math.exp(self.global_sum_metric/num)) + if self.global_num_inst == 0: + return (self.name, float('nan')) + else: + return (self.name, math.exp(self.global_sum_metric/self.global_num_inst)) #################### # REGRESSION METRICS diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index 26277d2acff5..1931beffa96e 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -18,6 +18,8 @@ import mxnet as mx import numpy as np import json +from common import with_seed +from copy import deepcopy def check_metric(metric, *args, **kwargs): metric = mx.metric.create(metric, *args, **kwargs) @@ -37,6 +39,65 @@ def test_metrics(): composite = mx.metric.create(['acc', 'f1']) check_metric(composite) +def _check_global_metric(metric, *args, shape=(10,10), use_same_shape=False, **kwargs): + def _create_pred_label(): + if use_same_shape: + pred = mx.nd.random.uniform(0, 1, shape=shape) + label = mx.nd.random.uniform(0, 1, shape=shape) + else: + # Make a random prediction + idx = np.random.rand(*shape).argsort(1) + pred = mx.nd.array(1 - 0.1 * idx) + # Label is half 1 and half 0 + # Setting all 0s or all 1s would make either + # MCC or F1 metrics always produce 0 + label = mx.nd.ones(shape[0]) + label[:shape[0] // 2] = 0 + return pred, label + + m1 = mx.metric.create(metric, *args, **kwargs) + m2 = deepcopy(m1) + # check that global stats are not reset when calling + # reset_local() + for i in range(10): + pred, label = _create_pred_label() + m1.update([label], [pred]) + m1.reset_local() + m2.update([label], [pred]) + assert m1.get_global() == m2.get() + + # check that reset_local() properly resets the local state + m1.reset_local() + m2.reset() + pred, label = _create_pred_label() + m1.update([label], [pred]) + m1.reset_local() + pred, label = _create_pred_label() + m1.update([label], [pred]) + m2.update([label], [pred]) + assert m1.get() == m2.get() + +@with_seed() +def test_global_metric(): + _check_global_metric('acc') + _check_global_metric('TopKAccuracy', top_k=3) + _check_global_metric('f1', shape=(10,2)) + _check_global_metric('f1', shape=(10,2), average='micro') + _check_global_metric('mcc', shape=(10,2)) + _check_global_metric('mcc', shape=(10,2), average='micro') + _check_global_metric('perplexity', -1) + _check_global_metric('pearsonr', use_same_shape=True) + _check_global_metric('nll_loss') + _check_global_metric('loss') + _check_global_metric('ce') + _check_global_metric('mae', use_same_shape=True) + _check_global_metric('mse', use_same_shape=True) + _check_global_metric('rmse', use_same_shape=True) + def custom_metric(label, pred): + return np.mean(np.abs(label-pred)) + _check_global_metric(custom_metric, use_same_shape=True) + _check_global_metric(['acc', 'f1'], shape=(10,2)) + def test_nll_loss(): metric = mx.metric.create('nll_loss') pred = mx.nd.array([[0.2, 0.3, 0.5], [0.6, 0.1, 0.3]]) From 18807be91af35a82fc008654259f682e180618e5 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 12 Dec 2018 11:49:51 -0800 Subject: [PATCH 7/8] Fix lint --- python/mxnet/metric.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 28fb44a9802f..ecb8e1c3bc22 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -635,18 +635,18 @@ def update_binary_stats(self, label, pred): label_true = (label == 1) label_false = 1 - label_true - tp = (pred_true * label_true).sum() - fp = (pred_true * label_false).sum() - fn = (pred_false * label_true).sum() - tn = (pred_false * label_false).sum() - self.true_positives += tp - self.global_true_positives += tp - self.false_positives += fp - self.global_false_positives += fp - self.false_negatives += fn - self.global_false_negatives += fn - self.true_negatives += tn - self.global_true_negatives += tn + true_pos = (pred_true * label_true).sum() + false_pos = (pred_true * label_false).sum() + false_neg = (pred_false * label_true).sum() + true_neg = (pred_false * label_false).sum() + self.true_positives += true_pos + self.global_true_positives += true_pos + self.false_positives += false_pos + self.global_false_positives += false_pos + self.false_negatives += false_neg + self.global_false_negatives += false_neg + self.true_negatives += true_neg + self.global_true_negatives += true_neg @property def precision(self): From 2ad4b1efeb1e5352fe232fb1dce9e49ee8adf593 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 12 Dec 2018 13:49:17 -0800 Subject: [PATCH 8/8] Fix compatibility with Python 2 --- tests/python/unittest/test_metric.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index 1931beffa96e..2821c4bbae3c 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -39,7 +39,7 @@ def test_metrics(): composite = mx.metric.create(['acc', 'f1']) check_metric(composite) -def _check_global_metric(metric, *args, shape=(10,10), use_same_shape=False, **kwargs): +def _check_global_metric(metric, *args, **kwargs): def _create_pred_label(): if use_same_shape: pred = mx.nd.random.uniform(0, 1, shape=shape) @@ -55,6 +55,8 @@ def _create_pred_label(): label[:shape[0] // 2] = 0 return pred, label + shape = kwargs.pop('shape', (10,10)) + use_same_shape = kwargs.pop('use_same_shape', False) m1 = mx.metric.create(metric, *args, **kwargs) m2 = deepcopy(m1) # check that global stats are not reset when calling