diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 5b0780aeccee..f1cdae26a235 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -380,23 +380,27 @@ def update(self, labels, preds): Parameters ---------- labels : list of `NDArray` - The labels of the data. + The labels of the data with class indices as values, one per sample. preds : list of `NDArray` - Predicted values. + Prediction values for samples. Each prediction value can either be the class index, + or a vector of likelihoods for all classes. """ check_label_shapes(labels, preds) for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: pred_label = ndarray.argmax(pred_label, axis=self.axis) - pred_label = pred_label.asnumpy().astype('int32') - label = label.asnumpy().astype('int32') + pred_label = pred_label.astype('int32') + label = label.astype('int32') check_label_shapes(label, pred_label) - self.sum_metric += (pred_label.flat == label.flat).sum() - self.num_inst += len(pred_label.flat) + if pred_label.context != label.context: + pred_label = pred_label.as_in_context(label.context) + + self.sum_metric += (pred_label.flatten() == label.flatten()).sum().asscalar() + self.num_inst += numpy.prod(pred_label.shape) @register