diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py index e1c1714445df..bfec794f6220 100644 --- a/python/mxnet/callback.py +++ b/python/mxnet/callback.py @@ -113,7 +113,7 @@ def _callback(param): logging.info('Iter[%d] Batch[%d] Train-%s=%f', param.epoch, param.nbatch, name, value) if auto_reset: - param.eval_metric.reset() + param.eval_metric.reset_local() return _callback @@ -164,7 +164,7 @@ def __call__(self, param): if param.eval_metric is not None: name_value = param.eval_metric.get_name_value() if self.auto_reset: - param.eval_metric.reset() + param.eval_metric.reset_local() msg = 'Epoch[%d] Batch [%d-%d]\tSpeed: %.2f samples/sec' msg += '\t%s=%f'*len(name_value) logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ())) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 6d9972074b67..ecb8e1c3bc22 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -90,6 +90,7 @@ def __init__(self, name, output_names=None, self.name = str(name) self.output_names = output_names self.label_names = label_names + self._has_global_stats = kwargs.pop("has_global_stats", False) self._kwargs = kwargs self.reset() @@ -148,6 +149,14 @@ def reset(self): """Resets the internal evaluation result to initial state.""" self.num_inst = 0 self.sum_metric = 0.0 + self.global_num_inst = 0 + self.global_sum_metric = 0.0 + + def reset_local(self): + """Resets the local portion of the internal evaluation results + to initial state.""" + self.num_inst = 0 + self.sum_metric = 0.0 def get(self): """Gets the current evaluation result. @@ -164,6 +173,24 @@ def get(self): else: return (self.name, self.sum_metric / self.num_inst) + def get_global(self): + """Gets the current global evaluation result. + + Returns + ------- + names : list of str + Name of the metrics. + values : list of float + Value of the evaluations. + """ + if self._has_global_stats: + if self.global_num_inst == 0: + return (self.name, float('nan')) + else: + return (self.name, self.global_sum_metric / self.global_num_inst) + else: + return self.get() + def get_name_value(self): """Returns zipped name and value pairs. @@ -179,6 +206,24 @@ def get_name_value(self): value = [value] return list(zip(name, value)) + def get_global_name_value(self): + """Returns zipped name and value pairs for global results. + + Returns + ------- + list of tuples + A (name, value) tuple list. + """ + if self._has_global_stats: + name, value = self.get_global() + if not isinstance(name, list): + name = [name] + if not isinstance(value, list): + value = [value] + return list(zip(name, value)) + else: + return self.get_name_value() + # pylint: disable=invalid-name register = registry.get_register_func(EvalMetric, 'metric') alias = registry.get_alias_func(EvalMetric, 'metric') @@ -263,7 +308,8 @@ class CompositeEvalMetric(EvalMetric): def __init__(self, metrics=None, name='composite', output_names=None, label_names=None): super(CompositeEvalMetric, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) if metrics is None: metrics = [] self.metrics = [create(i) for i in metrics] @@ -325,6 +371,15 @@ def reset(self): except AttributeError: pass + def reset_local(self): + """Resets the local portion of the internal evaluation results + to initial state.""" + try: + for metric in self.metrics: + metric.reset_local() + except AttributeError: + pass + def get(self): """Returns the current evaluation result. @@ -347,6 +402,28 @@ def get(self): values.extend(value) return (names, values) + def get_global(self): + """Returns the current evaluation result. + + Returns + ------- + names : list of str + Name of the metrics. + values : list of float + Value of the evaluations. + """ + names = [] + values = [] + for metric in self.metrics: + name, value = metric.get_global() + if isinstance(name, string_types): + name = [name] + if isinstance(value, numeric_types): + value = [value] + names.extend(name) + values.extend(value) + return (names, values) + def get_config(self): config = super(CompositeEvalMetric, self).get_config() config.update({'metrics': [i.get_config() for i in self.metrics]}) @@ -395,7 +472,8 @@ def __init__(self, axis=1, name='accuracy', output_names=None, label_names=None): super(Accuracy, self).__init__( name, axis=axis, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.axis = axis def update(self, labels, preds): @@ -423,8 +501,11 @@ def update(self, labels, preds): check_label_shapes(label, pred_label) - self.sum_metric += (pred_label == label).sum() + num_correct = (pred_label == label).sum() + self.sum_metric += num_correct + self.global_sum_metric += num_correct self.num_inst += len(pred_label) + self.global_num_inst += len(pred_label) @register @@ -467,7 +548,8 @@ def __init__(self, top_k=1, name='top_k_accuracy', output_names=None, label_names=None): super(TopKAccuracy, self).__init__( name, top_k=top_k, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.top_k = top_k assert(self.top_k > 1), 'Please use Accuracy if top_k is no more than 1' self.name += '_%d' % self.top_k @@ -487,7 +569,11 @@ def update(self, labels, preds): for label, pred_label in zip(labels, preds): assert(len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims' - pred_label = numpy.argsort(pred_label.asnumpy().astype('float32'), axis=1) + # Using argpartition here instead of argsort is safe because + # we do not care about the order of top k elements. It is + # much faster, which is important since that computation is + # single-threaded due to Python GIL. + pred_label = numpy.argpartition(pred_label.asnumpy().astype('float32'), -self.top_k) label = label.asnumpy().astype('int32') check_label_shapes(label, pred_label) num_samples = pred_label.shape[0] @@ -498,8 +584,11 @@ def update(self, labels, preds): num_classes = pred_label.shape[1] top_k = min(num_classes, self.top_k) for j in range(top_k): - self.sum_metric += (pred_label[:, num_classes - 1 - j].flat == label.flat).sum() + num_correct = (pred_label[:, num_classes - 1 - j].flat == label.flat).sum() + self.sum_metric += num_correct + self.global_sum_metric += num_correct self.num_inst += num_samples + self.global_num_inst += num_samples class _BinaryClassificationMetrics(object): @@ -515,6 +604,10 @@ def __init__(self): self.false_negatives = 0 self.false_positives = 0 self.true_negatives = 0 + self.global_true_positives = 0 + self.global_false_negatives = 0 + self.global_false_positives = 0 + self.global_true_negatives = 0 def update_binary_stats(self, label, pred): """ @@ -542,10 +635,18 @@ def update_binary_stats(self, label, pred): label_true = (label == 1) label_false = 1 - label_true - self.true_positives += (pred_true * label_true).sum() - self.false_positives += (pred_true * label_false).sum() - self.false_negatives += (pred_false * label_true).sum() - self.true_negatives += (pred_false * label_false).sum() + true_pos = (pred_true * label_true).sum() + false_pos = (pred_true * label_false).sum() + false_neg = (pred_false * label_true).sum() + true_neg = (pred_false * label_false).sum() + self.true_positives += true_pos + self.global_true_positives += true_pos + self.false_positives += false_pos + self.global_false_positives += false_pos + self.false_negatives += false_neg + self.global_false_negatives += false_neg + self.true_negatives += true_neg + self.global_true_negatives += true_neg @property def precision(self): @@ -554,6 +655,13 @@ def precision(self): else: return 0. + @property + def global_precision(self): + if self.global_true_positives + self.global_false_positives > 0: + return float(self.global_true_positives) / (self.global_true_positives + self.global_false_positives) + else: + return 0. + @property def recall(self): if self.true_positives + self.false_negatives > 0: @@ -561,6 +669,13 @@ def recall(self): else: return 0. + @property + def global_recall(self): + if self.global_true_positives + self.global_false_negatives > 0: + return float(self.global_true_positives) / (self.global_true_positives + self.global_false_negatives) + else: + return 0. + @property def fscore(self): if self.precision + self.recall > 0: @@ -569,17 +684,33 @@ def fscore(self): return 0. @property - def matthewscc(self): + def global_fscore(self): + if self.global_precision + self.global_recall > 0: + return 2 * self.global_precision * self.global_recall / (self.global_precision + self.global_recall) + else: + return 0. + + def matthewscc(self, use_global=False): """ Calculate the Matthew's Correlation Coefficent """ - if not self.total_examples: - return 0. + if use_global: + if not self.global_total_examples: + return 0. + + true_pos = float(self.global_true_positives) + false_pos = float(self.global_false_positives) + false_neg = float(self.global_false_negatives) + true_neg = float(self.global_true_negatives) + else: + if not self.total_examples: + return 0. + + true_pos = float(self.true_positives) + false_pos = float(self.false_positives) + false_neg = float(self.false_negatives) + true_neg = float(self.true_negatives) - true_pos = float(self.true_positives) - false_pos = float(self.false_positives) - false_neg = float(self.false_negatives) - true_neg = float(self.true_negatives) terms = [(true_pos + false_pos), (true_pos + false_neg), (true_neg + false_pos), @@ -594,11 +725,26 @@ def total_examples(self): return self.false_negatives + self.false_positives + \ self.true_negatives + self.true_positives + @property + def global_total_examples(self): + return self.global_false_negatives + self.global_false_positives + \ + self.global_true_negatives + self.global_true_positives + + def local_reset_stats(self): + self.false_positives = 0 + self.false_negatives = 0 + self.true_positives = 0 + self.true_negatives = 0 + def reset_stats(self): self.false_positives = 0 self.false_negatives = 0 self.true_positives = 0 self.true_negatives = 0 + self.global_false_positives = 0 + self.global_false_negatives = 0 + self.global_true_positives = 0 + self.global_true_negatives = 0 @register @@ -649,7 +795,8 @@ def __init__(self, name='f1', self.average = average self.metrics = _BinaryClassificationMetrics() EvalMetric.__init__(self, name=name, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -669,18 +816,30 @@ def update(self, labels, preds): if self.average == "macro": self.sum_metric += self.metrics.fscore + self.global_sum_metric += self.metrics.global_fscore self.num_inst += 1 + self.global_num_inst += 1 self.metrics.reset_stats() else: self.sum_metric = self.metrics.fscore * self.metrics.total_examples + self.global_sum_metric = self.metrics.global_fscore * self.metrics.global_total_examples self.num_inst = self.metrics.total_examples + self.global_num_inst = self.metrics.global_total_examples def reset(self): """Resets the internal evaluation result to initial state.""" self.sum_metric = 0. - self.num_inst = 0. + self.num_inst = 0 + self.global_num_inst = 0 + self.global_sum_metric = 0.0 self.metrics.reset_stats() + def reset_local(self): + """Resets the internal evaluation result to initial state.""" + self.sum_metric = 0. + self.num_inst = 0 + self.metrics.local_reset_stats() + @register class MCC(EvalMetric): @@ -750,7 +909,8 @@ def __init__(self, name='mcc', self._average = average self._metrics = _BinaryClassificationMetrics() EvalMetric.__init__(self, name=name, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -769,19 +929,32 @@ def update(self, labels, preds): self._metrics.update_binary_stats(label, pred) if self._average == "macro": - self.sum_metric += self._metrics.matthewscc + self.sum_metric += self._metrics.matthewscc() + self.global_sum_metric += self._metrics.matthewscc(use_global=True) self.num_inst += 1 + self.global_num_inst += 1 self._metrics.reset_stats() else: - self.sum_metric = self._metrics.matthewscc * self._metrics.total_examples + self.sum_metric = self._metrics.matthewscc() * self._metrics.total_examples + self.global_sum_metric = self._metrics.matthewscc(use_global=True) * \ + self._metrics.global_total_examples self.num_inst = self._metrics.total_examples + self.global_num_inst = self._metrics.global_total_examples def reset(self): """Resets the internal evaluation result to initial state.""" self.sum_metric = 0. self.num_inst = 0. + self.global_sum_metric = 0. + self.global_num_inst = 0. self._metrics.reset_stats() + def reset_local(self): + """Resets the internal evaluation result to initial state.""" + self.sum_metric = 0. + self.num_inst = 0. + self._metrics.local_reset_stats() + @register class Perplexity(EvalMetric): @@ -841,7 +1014,8 @@ def __init__(self, ignore_label, axis=-1, name='perplexity', output_names=None, label_names=None): super(Perplexity, self).__init__( name, ignore_label=ignore_label, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.ignore_label = ignore_label self.axis = axis @@ -871,7 +1045,9 @@ def update(self, labels, preds): loss -= ndarray.sum(ndarray.log(ndarray.maximum(1e-10, pred))).asscalar() num += pred.size self.sum_metric += loss + self.global_sum_metric += loss self.num_inst += num + self.global_num_inst += num def get(self): """Returns the current evaluation result. @@ -881,7 +1057,23 @@ def get(self): Tuple of (str, float) Representing name of the metric and evaluation result. """ - return (self.name, math.exp(self.sum_metric/self.num_inst)) + if self.num_inst == 0: + return (self.name, float('nan')) + else: + return (self.name, math.exp(self.sum_metric/self.num_inst)) + + def get_global(self): + """Returns the current global evaluation result. + + Returns + ------- + Tuple of (str, float) + Representing name of the metric and evaluation result. + """ + if self.global_num_inst == 0: + return (self.name, float('nan')) + else: + return (self.name, math.exp(self.global_sum_metric/self.global_num_inst)) #################### # REGRESSION METRICS @@ -921,7 +1113,8 @@ class MAE(EvalMetric): def __init__(self, name='mae', output_names=None, label_names=None): super(MAE, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -945,8 +1138,11 @@ def update(self, labels, preds): if len(pred.shape) == 1: pred = pred.reshape(pred.shape[0], 1) - self.sum_metric += numpy.abs(label - pred).mean() + mae = numpy.abs(label - pred).mean() + self.sum_metric += mae + self.global_sum_metric += mae self.num_inst += 1 # numpy.prod(label.shape) + self.global_num_inst += 1 # numpy.prod(label.shape) @register @@ -981,7 +1177,8 @@ class MSE(EvalMetric): def __init__(self, name='mse', output_names=None, label_names=None): super(MSE, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -1005,8 +1202,11 @@ def update(self, labels, preds): if len(pred.shape) == 1: pred = pred.reshape(pred.shape[0], 1) - self.sum_metric += ((label - pred)**2.0).mean() + mse = ((label - pred)**2.0).mean() + self.sum_metric += mse + self.global_sum_metric += mse self.num_inst += 1 # numpy.prod(label.shape) + self.global_num_inst += 1 # numpy.prod(label.shape) @register @@ -1041,7 +1241,8 @@ class RMSE(EvalMetric): def __init__(self, name='rmse', output_names=None, label_names=None): super(RMSE, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -1065,8 +1266,11 @@ def update(self, labels, preds): if len(pred.shape) == 1: pred = pred.reshape(pred.shape[0], 1) - self.sum_metric += numpy.sqrt(((label - pred)**2.0).mean()) + rmse = numpy.sqrt(((label - pred)**2.0).mean()) + self.sum_metric += rmse + self.global_sum_metric += rmse self.num_inst += 1 + self.global_num_inst += 1 @register @@ -1110,7 +1314,8 @@ def __init__(self, eps=1e-12, name='cross-entropy', output_names=None, label_names=None): super(CrossEntropy, self).__init__( name, eps=eps, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.eps = eps def update(self, labels, preds): @@ -1134,8 +1339,11 @@ def update(self, labels, preds): assert label.shape[0] == pred.shape[0] prob = pred[numpy.arange(label.shape[0]), numpy.int64(label)] - self.sum_metric += (-numpy.log(prob + self.eps)).sum() + cross_entropy = (-numpy.log(prob + self.eps)).sum() + self.sum_metric += cross_entropy + self.global_sum_metric += cross_entropy self.num_inst += label.shape[0] + self.global_num_inst += label.shape[0] @register @alias('nll_loss') @@ -1178,7 +1386,8 @@ def __init__(self, eps=1e-12, name='nll-loss', output_names=None, label_names=None): super(NegativeLogLikelihood, self).__init__( name, eps=eps, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self.eps = eps def update(self, labels, preds): @@ -1202,8 +1411,11 @@ def update(self, labels, preds): num_examples = pred.shape[0] assert label.shape[0] == num_examples, (label.shape[0], num_examples) prob = pred[numpy.arange(num_examples, dtype=numpy.int64), numpy.int64(label)] - self.sum_metric += (-numpy.log(prob + self.eps)).sum() + nll = (-numpy.log(prob + self.eps)).sum() + self.sum_metric += nll + self.global_sum_metric += nll self.num_inst += num_examples + self.global_num_inst += num_examples @register @alias('pearsonr') @@ -1238,7 +1450,8 @@ class PearsonCorrelation(EvalMetric): def __init__(self, name='pearsonr', output_names=None, label_names=None): super(PearsonCorrelation, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, labels, preds): """Updates the internal evaluation result. @@ -1256,8 +1469,11 @@ def update(self, labels, preds): check_label_shapes(label, pred, False, True) label = label.asnumpy() pred = pred.asnumpy() - self.sum_metric += numpy.corrcoef(pred.ravel(), label.ravel())[0, 1] + pearson_corr = numpy.corrcoef(pred.ravel(), label.ravel())[0, 1] + self.sum_metric += pearson_corr + self.global_sum_metric += pearson_corr self.num_inst += 1 + self.global_num_inst += 1 @register @@ -1278,7 +1494,8 @@ class Loss(EvalMetric): def __init__(self, name='loss', output_names=None, label_names=None): super(Loss, self).__init__( - name, output_names=output_names, label_names=label_names) + name, output_names=output_names, label_names=label_names, + has_global_stats=True) def update(self, _, preds): @@ -1286,8 +1503,11 @@ def update(self, _, preds): preds = [preds] for pred in preds: - self.sum_metric += ndarray.sum(pred).asscalar() + loss = ndarray.sum(pred).asscalar() + self.sum_metric += loss + self.global_sum_metric += loss self.num_inst += pred.size + self.global_num_inst += pred.size @register @@ -1353,7 +1573,8 @@ def __init__(self, feval, name=None, allow_extra_outputs=False, super(CustomMetric, self).__init__( name, feval=feval, allow_extra_outputs=allow_extra_outputs, - output_names=output_names, label_names=label_names) + output_names=output_names, label_names=label_names, + has_global_stats=True) self._feval = feval self._allow_extra_outputs = allow_extra_outputs @@ -1379,10 +1600,14 @@ def update(self, labels, preds): if isinstance(reval, tuple): (sum_metric, num_inst) = reval self.sum_metric += sum_metric + self.global_sum_metric += sum_metric self.num_inst += num_inst + self.global_num_inst += num_inst else: self.sum_metric += reval + self.global_sum_metric += reval self.num_inst += 1 + self.global_num_inst += 1 def get_config(self): raise NotImplementedError("CustomMetric cannot be serialized") diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index babea53d6e40..ca8463153686 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -22,7 +22,6 @@ import time import logging import warnings -import copy import numpy as np from .. import metric @@ -508,7 +507,6 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) - epoch_eval_metric = copy.deepcopy(eval_metric) ################################################################################ # training loop @@ -516,7 +514,6 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() - epoch_eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False @@ -532,12 +529,8 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True) - self.update_metric(epoch_eval_metric, - [db.label for db in data_batch], - pre_sliced=True) else: self.update_metric(eval_metric, data_batch.label) - self.update_metric(epoch_eval_metric, data_batch.label) try: # pre fetch next batch @@ -550,7 +543,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', monitor.toc_print() if end_of_batch: - eval_name_vals = epoch_eval_metric.get_name_value() + eval_name_vals = eval_metric.get_global_name_value() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index 26277d2acff5..2821c4bbae3c 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -18,6 +18,8 @@ import mxnet as mx import numpy as np import json +from common import with_seed +from copy import deepcopy def check_metric(metric, *args, **kwargs): metric = mx.metric.create(metric, *args, **kwargs) @@ -37,6 +39,67 @@ def test_metrics(): composite = mx.metric.create(['acc', 'f1']) check_metric(composite) +def _check_global_metric(metric, *args, **kwargs): + def _create_pred_label(): + if use_same_shape: + pred = mx.nd.random.uniform(0, 1, shape=shape) + label = mx.nd.random.uniform(0, 1, shape=shape) + else: + # Make a random prediction + idx = np.random.rand(*shape).argsort(1) + pred = mx.nd.array(1 - 0.1 * idx) + # Label is half 1 and half 0 + # Setting all 0s or all 1s would make either + # MCC or F1 metrics always produce 0 + label = mx.nd.ones(shape[0]) + label[:shape[0] // 2] = 0 + return pred, label + + shape = kwargs.pop('shape', (10,10)) + use_same_shape = kwargs.pop('use_same_shape', False) + m1 = mx.metric.create(metric, *args, **kwargs) + m2 = deepcopy(m1) + # check that global stats are not reset when calling + # reset_local() + for i in range(10): + pred, label = _create_pred_label() + m1.update([label], [pred]) + m1.reset_local() + m2.update([label], [pred]) + assert m1.get_global() == m2.get() + + # check that reset_local() properly resets the local state + m1.reset_local() + m2.reset() + pred, label = _create_pred_label() + m1.update([label], [pred]) + m1.reset_local() + pred, label = _create_pred_label() + m1.update([label], [pred]) + m2.update([label], [pred]) + assert m1.get() == m2.get() + +@with_seed() +def test_global_metric(): + _check_global_metric('acc') + _check_global_metric('TopKAccuracy', top_k=3) + _check_global_metric('f1', shape=(10,2)) + _check_global_metric('f1', shape=(10,2), average='micro') + _check_global_metric('mcc', shape=(10,2)) + _check_global_metric('mcc', shape=(10,2), average='micro') + _check_global_metric('perplexity', -1) + _check_global_metric('pearsonr', use_same_shape=True) + _check_global_metric('nll_loss') + _check_global_metric('loss') + _check_global_metric('ce') + _check_global_metric('mae', use_same_shape=True) + _check_global_metric('mse', use_same_shape=True) + _check_global_metric('rmse', use_same_shape=True) + def custom_metric(label, pred): + return np.mean(np.abs(label-pred)) + _check_global_metric(custom_metric, use_same_shape=True) + _check_global_metric(['acc', 'f1'], shape=(10,2)) + def test_nll_loss(): metric = mx.metric.create('nll_loss') pred = mx.nd.array([[0.2, 0.3, 0.5], [0.6, 0.1, 0.3]])