Skip to content

Commit

Permalink
add --drop-logit and, more importantly, change validation and test da…
Browse files Browse the repository at this point in the history
…ta loading to not load entire dataset. this may have been the OOM root cause
  • Loading branch information
matpalm committed Sep 13, 2020
1 parent 76b2128 commit 1e04b18
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 48 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ __pycache__/
logs/
wandb/
saved_models/
collages/
ax_client_snapshot.json
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# training ensemble nets

ensemble nets are a single pass way of representing M models in a neural
net ensemble. they make heavy use of jax's vmap to provide super execution.
net ensemble. see this [blog post](http://matpalm.com/blog/ensemble_nets)

to reproduce section N

```
python3 tune_with_ax.py
```

etc
26 changes: 17 additions & 9 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def _convert_dtype(x):
return tf.cast(x, tf.float32) / 255


@tf.autograph.experimental.do_not_convert
def _augment_and_convert_dtype(x, y):
# rotate 0, 90, 180 or 270 deg
k = tf.random.uniform([], 0, 3, dtype=tf.int32)
Expand All @@ -26,23 +27,30 @@ def _augment_and_convert_dtype(x, y):
return x, y


@lru_cache()
def _entire_split(ds_split):
x, y = tfds.load('eurosat/rgb', split=ds_split, shuffle_files=False,
batch_size=-1, as_supervised=True)
return np.array(_convert_dtype(x)), np.array(y)
def _non_training_dataset(batch_size, ds_split):

@tf.autograph.experimental.do_not_convert
def _convert_image_dtype(x, y):
return _convert_dtype(x), y

def validation_dataset():
dataset = (tfds.load('eurosat/rgb', split=ds_split,
as_supervised=True)
.map(_convert_image_dtype, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.shuffle(1024))
return tfds.as_numpy(dataset)


def validation_dataset(batch_size):
# 2700 records
# [293, 307, 335, 258, 253, 194, 239, 284, 243, 294]
return _entire_split('train[80%:90%]')
return _non_training_dataset(batch_size, 'train[80%:90%]')


def test_dataset():
def test_dataset(batch_size):
# 2700 records
# [307, 300, 296, 221, 262, 216, 251, 296, 250, 301]
return _entire_split('train[90%:]')
return _non_training_dataset(batch_size, 'train[90%:]')


def training_dataset(batch_size, num_inputs=1):
Expand Down
32 changes: 26 additions & 6 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, num_models, num_classes, max_conv_size=64,
self.num_models = num_models

key = random.PRNGKey(seed)
subkeys = random.split(key, 6)
subkeys = random.split(key, 7)

# conv stack kernels and biases
self.conv_kernels = objax.ModuleList()
Expand All @@ -139,7 +139,9 @@ def __init__(self, num_models, num_classes, max_conv_size=64,
subkeys[5], (num_models, dense_kernel_size, num_classes)))
self.logits_bias = TrainVar(jnp.zeros((num_models, num_classes)))

def logits(self, inp, single_result):
self.dropout_key = subkeys[6]

def logits(self, inp, single_result, logits_dropout=False):
"""return logits over inputs.
Args:
inp: input images. either (B, HW, HW, 3) in which case all models
Expand Down Expand Up @@ -194,9 +196,27 @@ def logits(self, inp, single_result):
logits = vmap(partial(_dense_layer, None))(
y, self.logits_kernel.value, self.logits_bias.value)

# if dropout case randomly drop 50% of logits
if logits_dropout:
# extract size of logits for making mask
num_models = logits.shape[0]
batch_size = logits.shape[1]
num_classes = logits.shape[2]
# make a new (M, B) drop out mask of 0s & 1s
self.dropout_key, key = random.split(self.dropout_key)
mask = jax.random.randint(key, (num_models, batch_size),
minval=0, maxval=2)
# broadcast it along the logit axis to make (M, B, C)
# TODO: should be doable in jnp.tile (?)
mask = mask.reshape((num_models, batch_size, 1))
mask = jnp.broadcast_to(mask,
(num_models, batch_size, num_classes))
# apply mask
logits *= mask

# if single result sum logits over models to represent single
# ensemble result (B, num_classes)
if single_result:
# sum logits over models to represent single ensemble result
# (B, num_classes)
logits = jnp.sum(logits, axis=0)

return logits
Expand All @@ -215,7 +235,7 @@ def predict_proba(self, inp, single_result):
mode.
"""

return jax.nn.softmax(self.logits(inp, single_result), axis=-1)
return jax.nn.softmax(self.logits(inp, single_result, False), axis=-1)

def predict(self, inp, single_result):
"""return class predictions. i.e. argmax over logits.
Expand All @@ -231,4 +251,4 @@ def predict(self, inp, single_result):
mode.
"""

return jnp.argmax(self.logits(inp, single_result), axis=-1)
return jnp.argmax(self.logits(inp, single_result, False), axis=-1)
6 changes: 2 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@
objax.io.load_var_collection(opts.saved_model, net.vars())

# check against validation set
imgs, labels = data.validation_dataset()
accuracy = util.accuracy(net.predict(imgs, single_result=True), labels)
accuracy = util.accuracy(net, data.validation_dataset(batch_size=128))
print("validation accuracy %0.3f" % accuracy)

# check against test set
imgs, labels = data.test_dataset()
accuracy = util.accuracy(net.predict(imgs, single_result=True), labels)
accuracy = util.accuracy(net, data.test_dataset(batch_size=128))
print("test accuracy %0.3f" % accuracy)
26 changes: 8 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,17 @@ def train(opts):
# to one output). this is also the form of the loss called during validation
# loss calculation where the imgs, labels is the entire split.
def cross_entropy(imgs, labels):
logits = net.logits(imgs, single_result=True)
logits = net.logits(imgs, single_result=True,
logits_dropout=opts.logits_dropout)
return jnp.mean(cross_entropy_logits_sparse(logits, labels))

# in multiple input mode we get an output per model; so the logits are
# (num_models, batch_size, num_classes) i.e. (M, B, C) with labels
# (M, B). in this case we flatten the logits to (M*B, C) and the labels
# to (M*B,) for the cross entropy calculation.
def nested_cross_entropy(imgs, labels):
logits = net.logits(imgs, single_result=False)
logits = net.logits(imgs, single_result=False,
logits_dropout=opts.logits_dropout)
m, b, c = logits.shape
logits = logits.reshape((m*b, c))
labels = labels.reshape((m*b,))
Expand All @@ -109,12 +111,6 @@ def train_step(imgs, labels):
train_step = objax.Jit(train_step,
gradient_loss.vars() + optimiser.vars())

# create jitted call for validation loss
calculate_validation_loss = objax.Jit(cross_entropy, net.vars())

# read entire validation set
validation_imgs, validation_labels = data.validation_dataset()

# set up checkpointing; just need more ckpts than early stopping
# patience
ckpt_dir = "saved_models/"
Expand Down Expand Up @@ -149,8 +145,8 @@ def train_step(imgs, labels):
ckpt.save(net.vars(), idx=epoch)

# check validation loss
validation_loss = float(calculate_validation_loss(validation_imgs,
validation_labels))
validation_dataset = data.validation_dataset(opts.batch_size)
validation_loss = util.mean_loss(net, validation_dataset)
print("epoch", epoch, "validation_loss", validation_loss)
sys.stdout.flush()
if wandb_enabled:
Expand All @@ -160,22 +156,15 @@ def train_step(imgs, labels):
if early_stopping.should_stop(validation_loss):
break

# final validation metrics
validation_predictions = net.predict(validation_imgs, single_result=True)
validation_accuracy = util.accuracy(validation_predictions,
validation_labels)

# close out wandb run
if wandb_enabled:
wandb.config.early_stopped = early_stopping.stopped()
wandb.log({'final_validation_loss': validation_loss,
'final_validation_accuracy': validation_accuracy},
wandb.log({'final_validation_loss': validation_loss},
step=opts.epochs)
wandb.join()
else:
print("early_stopping.stopped()", early_stopping.stopped())
print("final validation_loss", validation_loss)
print("final validation accuracy", validation_accuracy)

# return validation loss to ax
return validation_loss
Expand Down Expand Up @@ -233,6 +222,7 @@ def _callback(q, opts):
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--learning-rate', type=float, default=1e-3)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--logits-dropout', action='store_true')
opts = parser.parse_args()
print(opts, file=sys.stderr)

Expand Down
12 changes: 5 additions & 7 deletions tune_with_ax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--group', type=str, default=None,
help='wandb group. if none, no logging')
# parser.add_argument('--prior-run-logs', type=str, default=None,
# help='a comma seperated prior runs to prime ax client with')
cmd_line_opts = parser.parse_args()
print(cmd_line_opts, file=sys.stderr)

Expand Down Expand Up @@ -79,10 +77,6 @@
minimize=True,
)

# if cmd_line_opts.prior_run_logs is not None:
# for log_tsv in cmd_line_opts.prior_run_logs.split(","):
# u.prime_ax_client_with_prior_run(ax, log_tsv)

u.ensure_dir_exists("logs/%s" % cmd_line_opts.group)
log = open("logs/%s/ax_trials.tsv" % cmd_line_opts.group, "w")
print("trial_index\tparameters\truntime\tfinal_loss", file=log)
Expand All @@ -104,7 +98,8 @@ class Opts(object):
opts.dense_kernel_size = parameters['dense_kernel_size']
opts.batch_size = 32 # parameters['batch_size']
opts.learning_rate = parameters['learning_rate']
opts.epochs = 50 # max to run, we also use early stopping
opts.epochs = 60 # max to run, we also use early stopping
opts.logits_dropout = False # not yet under tuning

# run
start_time = time.time()
Expand All @@ -131,3 +126,6 @@ class Opts(object):
print(log_msg, file=log)
print(log_msg)
log.flush()

# save ax state
ax.save_to_json_file()
27 changes: 24 additions & 3 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import numpy as np
import datetime
import time
import data
from objax.functional.loss import cross_entropy_logits_sparse
import jax.numpy as jnp


def DTS():
Expand Down Expand Up @@ -85,7 +88,25 @@ def stopped(self):
return self.decided_to_stop


def accuracy(predictions, labels):
num_correct = np.equal(predictions, labels).sum()
num_total = len(predictions)
def mean_loss(net, dataset):
# TODO: could go to NonEnsembleNet/EnsembleNet base class
losses_total = 0
num_losses = 0
for imgs, labels in dataset:
logits = net.logits(imgs, single_result=True, logits_dropout=False)
losses = cross_entropy_logits_sparse(logits, labels)
losses_total += jnp.sum(losses)
num_losses += len(losses)
return losses_total / num_losses


def accuracy(net, dataset):
# TODO: could go to NonEnsembleNet/EnsembleNet base class
y_pred = []
y_true = []
for imgs, labels in dataset:
y_pred.extend(net.predict(imgs, single_result=True))
y_true.extend(labels)
num_correct = np.equal(y_pred, y_true).sum()
num_total = len(y_pred)
return num_correct / num_total

0 comments on commit 1e04b18

Please sign in to comment.