diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index fc2b9014e8cc..e91fd3b13ee6 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -28,6 +28,7 @@ from .base import numeric_types, string_types from . import ndarray from . import registry +from .context import cpu def check_label_shapes(labels, preds, shape=0): @@ -388,6 +389,7 @@ def update(self, labels, preds): """ check_label_shapes(labels, preds) + results = [] for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: pred_label = ndarray.argmax(pred_label, axis=self.axis) @@ -399,8 +401,10 @@ def update(self, labels, preds): if pred_label.context != label.context: pred_label = pred_label.as_in_context(label.context) - self.sum_metric += (pred_label.reshape((-1,)) == label.reshape((-1,))).sum().asscalar() - self.num_inst += numpy.prod(pred_label.shape) + self.num_inst += pred_label.size + results.append((pred_label.reshape((-1,)) == label.reshape((-1,))) + .sum().as_in_context(cpu())) + self.sum_metric += ndarray.add_n(*results).asscalar() @register