Skip to content

Commit

Permalink
redo accuracy function in util. call from test
Browse files Browse the repository at this point in the history
  • Loading branch information
matpalm committed Feb 9, 2021
1 parent 1678cb9 commit e11c651
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 34 deletions.
33 changes: 8 additions & 25 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@

num_classes = 10

# TODO: move this into util so can run during training too
# ( as well as being run against validation and test )


# convert to a prediction function that ensembles over all models
@jit
def predict(params, imgs):
def predict_fn(imgs):
logits = all_models_apply(params, imgs)
batch_size = logits.shape[-2]
logits = logits.reshape((-1, batch_size, num_classes)) # (M, B, 10)
Expand All @@ -42,24 +39,10 @@ def predict(params, imgs):
return predictions


num_correct = 0
num_total = 0
dataset = data.validation_dataset(batch_size=64)
for imgs, labels in dataset:
predictions = predict(params, imgs)
num_correct += jnp.sum(predictions == labels)
num_total += len(labels)

accuracy = num_correct / num_total
print(num_correct, num_total)

# # restore from save
# objax.io.load_var_collection(opts.saved_model, net.vars())

# # check against validation set
# accuracy = util.accuracy(net, data.validation_dataset(batch_size=128))
# print("validation accuracy %0.3f" % accuracy)
# check against validation set
accuracy = util.accuracy(predict_fn, data.validation_dataset(batch_size=128))
print("validation accuracy %0.3f" % accuracy)

# # check against test set
# accuracy = util.accuracy(net, data.test_dataset(batch_size=128))
# print("test accuracy %0.3f" % accuracy)
# check against test set
accuracy = util.accuracy(predict_fn, data.test_dataset(batch_size=128))
print("test accuracy %0.3f" % accuracy)
17 changes: 8 additions & 9 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,12 @@ def stopped(self):
# return float(losses_total / num_losses)


def accuracy(net, dataset):
# TODO: could go to NonEnsembleNet/EnsembleNet base class
y_pred = []
y_true = []
def accuracy(predict_fn, dataset):
num_correct = 0
num_total = 0
for imgs, labels in dataset:
y_pred.extend(net.predict(imgs, single_result=True))
y_true.extend(labels)
num_correct = np.equal(y_pred, y_true).sum()
num_total = len(y_pred)
return float(num_correct / num_total)
predictions = predict_fn(imgs)
num_correct += jnp.sum(predictions == labels)
num_total += len(labels)
accuracy = num_correct / num_total
return accuracy

0 comments on commit e11c651

Please sign in to comment.