diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index f1cdae26a235..fc2b9014e8cc 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -399,7 +399,7 @@ def update(self, labels, preds): if pred_label.context != label.context: pred_label = pred_label.as_in_context(label.context) - self.sum_metric += (pred_label.flatten() == label.flatten()).sum().asscalar() + self.sum_metric += (pred_label.reshape((-1,)) == label.reshape((-1,))).sum().asscalar() self.num_inst += numpy.prod(pred_label.shape)