Skip to content

Commit

Permalink
use nd for accuracy calculation (apache#9583)
Browse files Browse the repository at this point in the history
* use nd for accuracy calculation

* check for context
  • Loading branch information
szha authored and zhreshold committed Jan 27, 2018
1 parent 6a4671b commit f5f1b91
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,23 +380,27 @@ def update(self, labels, preds):
Parameters
----------
labels : list of `NDArray`
The labels of the data.
The labels of the data with class indices as values, one per sample.
preds : list of `NDArray`
Predicted values.
Prediction values for samples. Each prediction value can either be the class index,
or a vector of likelihoods for all classes.
"""
check_label_shapes(labels, preds)

for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32')
label = label.asnumpy().astype('int32')
pred_label = pred_label.astype('int32')
label = label.astype('int32')

check_label_shapes(label, pred_label)

self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
if pred_label.context != label.context:
pred_label = pred_label.as_in_context(label.context)

self.sum_metric += (pred_label.flatten() == label.flatten()).sum().asscalar()
self.num_inst += numpy.prod(pred_label.shape)


@register
Expand Down

0 comments on commit f5f1b91

Please sign in to comment.