From e11c65171d27ebc4fa76ebd8291fdf64037cc6d3 Mon Sep 17 00:00:00 2001 From: Mat Kelcey Date: Tue, 9 Feb 2021 13:50:25 +1100 Subject: [PATCH] redo accuracy function in util. call from test --- test.py | 33 ++++++++------------------------- util.py | 17 ++++++++--------- 2 files changed, 16 insertions(+), 34 deletions(-) diff --git a/test.py b/test.py index e30c7e2..23bdfd4 100644 --- a/test.py +++ b/test.py @@ -28,12 +28,9 @@ num_classes = 10 -# TODO: move this into util so can run during training too -# ( as well as being run against validation and test ) - - +# convert to a prediction function that ensembles over all models @jit -def predict(params, imgs): +def predict_fn(imgs): logits = all_models_apply(params, imgs) batch_size = logits.shape[-2] logits = logits.reshape((-1, batch_size, num_classes)) # (M, B, 10) @@ -42,24 +39,10 @@ def predict(params, imgs): return predictions -num_correct = 0 -num_total = 0 -dataset = data.validation_dataset(batch_size=64) -for imgs, labels in dataset: - predictions = predict(params, imgs) - num_correct += jnp.sum(predictions == labels) - num_total += len(labels) - -accuracy = num_correct / num_total -print(num_correct, num_total) - -# # restore from save -# objax.io.load_var_collection(opts.saved_model, net.vars()) - -# # check against validation set -# accuracy = util.accuracy(net, data.validation_dataset(batch_size=128)) -# print("validation accuracy %0.3f" % accuracy) +# check against validation set +accuracy = util.accuracy(predict_fn, data.validation_dataset(batch_size=128)) +print("validation accuracy %0.3f" % accuracy) -# # check against test set -# accuracy = util.accuracy(net, data.test_dataset(batch_size=128)) -# print("test accuracy %0.3f" % accuracy) +# check against test set +accuracy = util.accuracy(predict_fn, data.test_dataset(batch_size=128)) +print("test accuracy %0.3f" % accuracy) diff --git a/util.py b/util.py index 2d6e6fa..e38cdf9 100644 --- a/util.py +++ b/util.py @@ -122,13 +122,12 @@ def stopped(self): # return float(losses_total / num_losses) -def accuracy(net, dataset): - # TODO: could go to NonEnsembleNet/EnsembleNet base class - y_pred = [] - y_true = [] +def accuracy(predict_fn, dataset): + num_correct = 0 + num_total = 0 for imgs, labels in dataset: - y_pred.extend(net.predict(imgs, single_result=True)) - y_true.extend(labels) - num_correct = np.equal(y_pred, y_true).sum() - num_total = len(y_pred) - return float(num_correct / num_total) + predictions = predict_fn(imgs) + num_correct += jnp.sum(predictions == labels) + num_total += len(labels) + accuracy = num_correct / num_total + return accuracy