Skip to content

Commit

Permalink
switch from logit-dropout to model-dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
matpalm committed Sep 22, 2020
1 parent 36ab1c4 commit e825e16
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 683 deletions.
638 changes: 0 additions & 638 deletions blog/ensemble_quality_on_test.ipynb

This file was deleted.

83 changes: 83 additions & 0 deletions generate_blog_images_for_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import argparse
import models
import numpy as np
import jax.numpy as jnp
import objax
import jax
import data
import util
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import sys


def cm_plot(cm):
labels = ['Annual Crop', 'Forest', 'Herbaceous Vegetation', 'Highway',
'Industrial Buildings', 'Pasture', 'Permanent Crop',
'Residential Buildings', 'River', 'Sea & Lake']
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=labels)
return disp.plot(include_values=True,
cmap='viridis', ax=None, xticks_rotation='vertical',
values_format=None)


def save_plot(y_true, y_pred, title, fname):
plot = cm_plot(confusion_matrix(y_true, y_pred))
plot.figure_.suptitle(title)
plot.figure_.savefig(fname, bbox_inches='tight', transparent=True)


def save_sub_model_plots(y_true, logits, num_models, title_template,
fname_template):
for m in range(num_models):
y_pred = jnp.argmax(logits[m], axis=-1)
num_correct = np.equal(y_pred, y_true).sum()
num_total = len(y_true)
print("model %d accuracy %0.3f" % (m, float(num_correct / num_total)))
save_plot(y_true, y_pred, title_template % m, fname_template % m)


def print_validation_test_accuracy(net):
print("validation %0.3f" % util.accuracy(
net, data.validation_dataset(batch_size=100)))
print("test %0.3f" % util.accuracy(net, data.test_dataset(batch_size=100)))


def logits_and_y_true_for_test_set(net, num_models):
logits = []
y_true = []
for imgs, labels in data.test_dataset(batch_size=100):
logits.append(net.logits(imgs, single_result=False,
model_dropout=False))
y_true.extend(labels)
logits = jnp.stack(logits) # (27, M, 100, 10)
logits = logits.transpose((1, 0, 2, 3)) # (M, 27, 100, 10)
logits = logits.reshape((num_models, 2700, 10)) # (M, 2700, 10)
return logits, y_true


parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-models', type=int, default=1)
parser.add_argument('--max-conv-size', type=int)
parser.add_argument('--dense-kernel-size', type=int)
parser.add_argument('--ckpt-file', type=str)
opts = parser.parse_args()
print(opts, file=sys.stderr)

net = models.EnsembleNet(num_models=opts.num_models,
num_classes=10,
max_conv_size=opts.max_conv_size,
dense_kernel_size=opts.dense_kernel_size,
seed=0)
objax.io.load_var_collection(opts.ckpt_file, net.vars())

logits, y_true = logits_and_y_true_for_test_set(net, opts.num_models)

y_pred = jnp.argmax(logits.sum(axis=0), axis=-1)

print_validation_test_accuracy(net)

save_plot(y_true, y_pred, "ensemble", "cm.ensemble.png")
save_sub_model_plots(y_true, logits, opts.num_models,
"sub model %d", "cm.model_%d.png")
80 changes: 45 additions & 35 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax.nn.functions import gelu
from functools import partial
import objax
from objax.variable import TrainVar
from objax.variable import TrainVar, StateVar


def _conv_layer(stride, activation, inp, kernel, bias):
Expand Down Expand Up @@ -34,7 +34,7 @@ def _dense_layer(activation, inp, kernel, bias):
return block

# TODO: introduce basecase for Ensemble & NonEnsemble to avoid
# clumsy single_result & logits_dropout on NonEnsemble
# clumsy single_result & model_dropout on NonEnsemble


class NonEnsembleNet(objax.Module):
Expand Down Expand Up @@ -65,12 +65,12 @@ def __init__(self, num_classes, max_conv_size=64, dense_kernel_size=32,
subkeys[5], (dense_kernel_size, num_classes)))
self.logits_bias = TrainVar(jnp.zeros((num_classes)))

def logits(self, inp, single_result, logits_dropout=False):
def logits(self, inp, single_result=None, model_dropout=None):
"""return logits over inputs
Args:
inp: input images (B, HW, HW, 3)
single_result: clumsily ignore for NonEnsembleNet :/
logits_dropout: clumsily ignore for NonEnsembleNet :/
model_dropout: clumsily ignore for NonEnsembleNet :/
Returns:
logit values for input images (B, C)
"""
Expand All @@ -93,7 +93,7 @@ def logits(self, inp, single_result, logits_dropout=False):

return logits

def predict(self, inp, single_result):
def predict(self, inp, single_result=None):
"""return class predictions. i.e. argmax over logits.
Args:
inp: input images (B, HW, HW, 3)
Expand Down Expand Up @@ -136,17 +136,18 @@ def __init__(self, num_models, num_classes, max_conv_size=64,
subkeys[5], (num_models, dense_kernel_size, num_classes)))
self.logits_bias = TrainVar(jnp.zeros((num_models, num_classes)))

self.dropout_key = subkeys[6]
self.dropout_key = StateVar(subkeys[6])

def logits(self, inp, single_result, logits_dropout=False):
def logits(self, inp, single_result, model_dropout):
"""return logits over inputs.
Args:
inp: input images. either (B, HW, HW, 3) in which case all models
will get the same images or (M, B, HW, HW, 3) in which case each
model will get a different image.
single_result: if true return single logits value for ensemble.
otherwise return logits for each sub model.
logits_dropout: if true then apply 50% dropout to logits
model_dropout: if true then only run 50% of the models during
training.
Returns:
logit values for input images. either (B, C) if in single_result mode
or (M, B, C) otherwise.
Expand All @@ -155,11 +156,38 @@ def logits(self, inp, single_result, logits_dropout=False):
mode.
"""

if not model_dropout:
# if we are not doing model dropout then just take all the model
# variables as they are. these will be of shape (M, ...)
conv_kernels = [k.value for k in self.conv_kernels]
conv_biases = [b.value for b in self.conv_biases]
dense_kernel = self.dense_kernel.value
dense_bias = self.dense_bias.value
logits_kernel = self.logits_kernel.value
logits_bias = self.logits_bias.value
else:
# but if we are doing model dropout then we take half the models
# by first 1) picking idxs that represents a random half of the
# models ...
new_dropout_key, permute_key = random.split(self.dropout_key.value)
self.dropout_key.assign(new_dropout_key)
idxs = jnp.arange(self.num_models)
idxs = jax.random.permutation(permute_key, idxs)
idxs = idxs[:(self.num_models//2)]
# ... and then 2) slicing the variables out. all these will be of
# shape (M/2, ...)
conv_kernels = [k.value[idxs] for k in self.conv_kernels]
conv_biases = [b.value[idxs] for b in self.conv_biases]
dense_kernel = self.dense_kernel.value[idxs]
dense_bias = self.dense_bias.value[idxs]
logits_kernel = self.logits_kernel.value[idxs]
logits_bias = self.logits_bias.value[idxs]

if len(inp.shape) == 4:
# single_input mode; inp (B, HW, HW, 3)
# apply first convolution as vmap against just inp
y = vmap(partial(_conv_layer, 2, gelu, inp))(
self.conv_kernels[0].value, self.conv_biases[0].value)
y = vmap(partial(_conv_layer, 2, gelu, inp))(conv_kernels[0],
conv_biases[0])
elif len(inp.shape) == 5:
# multi_input mode; inp (M, B, HW, HW, 3)
if single_result:
Expand All @@ -171,8 +199,8 @@ def logits(self, inp, single_result, logits_dropout=False):
" in the ensemble.")
# apply all convolutions, including first, as vmap against both y
# and kernel, bias
y = vmap(partial(_conv_layer, 2, gelu))(
inp, self.conv_kernels[0].value, self.conv_biases[0].value)
y = vmap(partial(_conv_layer, 2, gelu))(inp, conv_kernels[0],
conv_biases[0])
else:
raise Exception("unexpected input shape")

Expand All @@ -181,37 +209,19 @@ def logits(self, inp, single_result, logits_dropout=False):
# rest of the convolution stack can be applied as vmap against both y
# and conv kernels, biases.
# final result is (M, B, 3, 3, 256|max_conv_size)
for kernel, bias in zip(self.conv_kernels[1:], self.conv_biases[1:]):
y = vmap(partial(_conv_layer, 2, gelu))(
y, kernel.value, bias.value)
for kernel, bias in zip(conv_kernels[1:], conv_biases[1:]):
y = vmap(partial(_conv_layer, 2, gelu))(y, kernel, bias)

# global spatial pooling. (M, B, dense_kernel_size)
y = jnp.mean(y, axis=(2, 3))

# dense layer with non linearity. (M, B, dense_kernel_size)
y = vmap(partial(_dense_layer, gelu))(
y, self.dense_kernel.value, self.dense_bias.value)
y = vmap(partial(_dense_layer, gelu))(y, dense_kernel, dense_bias)

# dense layer with no activation to number classes.
# (M, B, num_classes)
logits = vmap(partial(_dense_layer, None))(
y, self.logits_kernel.value, self.logits_bias.value)

# if dropout case randomly drop 50% of logits
if logits_dropout:
# extract size of logits for making mask
num_models = logits.shape[0]
batch_size = logits.shape[1]
num_classes = logits.shape[2]
# make a new (M, B) drop out mask of 50% 0s & 1s
self.dropout_key, key = random.split(self.dropout_key)
mask = jax.random.randint(key, (num_models, batch_size),
minval=0, maxval=2)
# tile it along the logit axis to make (M, B, C)
mask = mask.reshape((num_models, batch_size, 1))
mask = jnp.tile(mask, (1, 1, num_classes))
# apply mask
logits *= mask
logits = vmap(partial(_dense_layer, None))(y, logits_kernel,
logits_bias)

# if single result sum logits over models to represent single
# ensemble result (B, num_classes)
Expand Down
2 changes: 1 addition & 1 deletion smoke_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ python3 train.py \
--dense-kernel-size 32 \
--batch-size 32 \
--epochs 2 \
--logits-dropout
--model-dropout
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def train(opts):
wandb.config.seed = opts.seed
wandb.config.learning_rate = opts.learning_rate
wandb.config.batch_size = opts.batch_size
wandb.config.logits_dropout = opts.logits_dropout
wandb.config.model_dropout = opts.model_dropout
else:
print("not using wandb", file=sys.stderr)

Expand Down Expand Up @@ -84,7 +84,7 @@ def train(opts):
# loss calculation where the imgs, labels is the entire split.
def cross_entropy(imgs, labels):
logits = net.logits(imgs, single_result=True,
logits_dropout=opts.logits_dropout)
model_dropout=opts.model_dropout)
return jnp.mean(cross_entropy_logits_sparse(logits, labels))

# in multiple input mode we get an output per model; so the logits are
Expand All @@ -93,7 +93,7 @@ def cross_entropy(imgs, labels):
# to (M*B,) for the cross entropy calculation.
def nested_cross_entropy(imgs, labels):
logits = net.logits(imgs, single_result=False,
logits_dropout=opts.logits_dropout)
model_dropout=opts.model_dropout)
m, b, c = logits.shape
logits = logits.reshape((m*b, c))
labels = labels.reshape((m*b,))
Expand Down Expand Up @@ -224,7 +224,7 @@ def _callback(q, opts):
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--learning-rate', type=float, default=1e-3)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--logits-dropout', action='store_true')
parser.add_argument('--model-dropout', action='store_true')
opts = parser.parse_args()
print(opts, file=sys.stderr)

Expand Down
8 changes: 4 additions & 4 deletions tune_with_ax.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,19 @@ class Opts(object):
if cmd_line_opts.mode == 'siso':
opts.input_mode = 'single'
opts.num_models = 1
opts.logits_dropout = False # N/A for multi_input
opts.model_dropout = False # N/A for multi_input
elif cmd_line_opts.mode == 'simo':
opts.input_mode = 'single'
opts.num_models = cmd_line_opts.num_models
opts.logits_dropout = False # not yet under tuning
opts.model_dropout = False
elif cmd_line_opts.mode == 'simo_ld':
opts.input_mode = 'single'
opts.num_models = cmd_line_opts.num_models
opts.logits_dropout = True # not yet under tuning
opts.model_dropout = True
else: # mimo
opts.input_mode = 'multiple'
opts.num_models = cmd_line_opts.num_models
opts.logits_dropout = False # N/A for multi_input
opts.model_dropout = False # N/A for multi_input

opts.max_conv_size = parameters['max_conv_size']
opts.dense_kernel_size = parameters['dense_kernel_size']
Expand Down
2 changes: 1 addition & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def mean_loss(net, dataset):
losses_total = 0
num_losses = 0
for imgs, labels in dataset:
logits = net.logits(imgs, single_result=True, logits_dropout=False)
logits = net.logits(imgs, single_result=True, model_dropout=False)
losses = cross_entropy_logits_sparse(logits, labels)
losses_total += jnp.sum(losses)
num_losses += len(losses)
Expand Down

0 comments on commit e825e16

Please sign in to comment.