Skip to content

Commit

Permalink
final, for now, pod slice version
Browse files Browse the repository at this point in the history
  • Loading branch information
matpalm committed Feb 9, 2021
1 parent 64fc08d commit 87ad647
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 266 deletions.
249 changes: 0 additions & 249 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
# import jax
# import jax.numpy as jnp
# from jax import random, lax, vmap
# from jax.nn.initializers import glorot_normal, he_normal
# from jax.nn.functions import gelu
# from functools import partial
# import objax
# from objax.variable import TrainVar, StateVar

import jax.numpy as jnp
from jax.nn import gelu
import haiku as hk
Expand Down Expand Up @@ -34,243 +25,3 @@ def build_model(opts):
m = partial(model, max_conv_size=opts.max_conv_size,
dense_kernel_size=opts.dense_kernel_size)
return hk.without_apply_rng(hk.transform(m))

# def _conv_layer(stride, activation, inp, kernel, bias):
# no_dilation = (1, 1)
# some_height_width = 10 # values don't matter; just shape of input
# input_shape = (1, some_height_width, some_height_width, 3)
# kernel_shape = (3, 3, 1, 1)
# input_kernel_output = ('NHWC', 'HWIO', 'NHWC')
# conv_dimension_numbers = lax.conv_dimension_numbers(input_shape,
# kernel_shape,
# input_kernel_output)
# block = lax.conv_general_dilated(inp, kernel, (stride, stride),
# 'VALID', no_dilation, no_dilation,
# conv_dimension_numbers)
# if bias is not None:
# block += bias
# if activation:
# block = activation(block)
# return block


# def _dense_layer(activation, inp, kernel, bias):
# block = jnp.dot(inp, kernel) + bias
# if activation:
# block = activation(block)
# return block

# # TODO: introduce basecase for Ensemble & NonEnsemble to avoid
# # clumsy single_result & model_dropout on NonEnsemble


# class NonEnsembleNet(objax.Module):

# def __init__(self, num_classes, max_conv_size=64, dense_kernel_size=32,
# seed=0):
# key = random.PRNGKey(seed)
# subkeys = random.split(key, 6)

# # conv stack kernels and biases
# self.conv_kernels = objax.ModuleList()
# self.conv_biases = objax.ModuleList()
# input_channels = 3
# for i, output_channels in enumerate([32, 64, 128, 256]):
# output_channels = min(output_channels, max_conv_size)
# self.conv_kernels.append(TrainVar(he_normal()(
# subkeys[i], (3, 3, input_channels, output_channels))))
# self.conv_biases.append(TrainVar(jnp.zeros((output_channels))))
# input_channels = output_channels

# # dense layer kernel and bias
# self.dense_kernel = TrainVar(he_normal()(
# subkeys[4], (output_channels, dense_kernel_size)))
# self.dense_bias = TrainVar(jnp.zeros((dense_kernel_size)))

# # classifier layer kernel and bias
# self.logits_kernel = TrainVar(glorot_normal()(
# subkeys[5], (dense_kernel_size, num_classes)))
# self.logits_bias = TrainVar(jnp.zeros((num_classes)))

# def logits(self, inp, single_result=None, model_dropout=None):
# """return logits over inputs
# Args:
# inp: input images (B, HW, HW, 3)
# single_result: clumsily ignore for NonEnsembleNet :/
# model_dropout: clumsily ignore for NonEnsembleNet :/
# Returns:
# logit values for input images (B, C)
# """

# # conv stack -> (B, 3, 3, MCS) MaxConvSize
# y = inp
# for kernel, bias in zip(self.conv_kernels, self.conv_biases):
# y = _conv_layer(2, gelu, y, kernel.value, bias.value)

# # global spatial pooling -> (B, MCS)
# y = jnp.mean(y, axis=(1, 2))

# # dense layer with non linearity -> (B, DKS) DenseKernelSize
# y = _dense_layer(gelu, y, self.dense_kernel.value,
# self.dense_bias.value)

# # dense layer with no activation to number classes -> (B, C)
# logits = _dense_layer(
# None, y, self.logits_kernel.value, self.logits_bias.value)

# return logits

# def predict(self, inp, single_result=None):
# """return class predictions. i.e. argmax over logits.
# Args:
# inp: input images (B, HW, HW, 3)
# single_result: clumsily ignore for NonEnsembleNet :/
# Returns:
# prediction classes for input images (B,)
# """
# return jnp.argmax(self.logits(inp, single_result), axis=-1)


# class EnsembleNet(objax.Module):

# def __init__(self, num_models, num_classes, max_conv_size=64,
# dense_kernel_size=32, seed=0):
# self.num_models = num_models

# key = random.PRNGKey(seed)
# subkeys = random.split(key, 7)

# # conv stack kernels and biases
# self.conv_kernels = objax.ModuleList()
# self.conv_biases = objax.ModuleList()
# input_channels = 3
# for i, output_channels in enumerate([32, 64, 128, 256]):
# output_channels = min(output_channels, max_conv_size)
# self.conv_kernels.append(TrainVar(he_normal()(
# subkeys[i], (num_models, 3, 3, input_channels,
# output_channels))))
# self.conv_biases.append(
# TrainVar(jnp.zeros((num_models, output_channels))))
# input_channels = output_channels

# # dense layer kernel and bias
# self.dense_kernel = TrainVar(he_normal()(
# subkeys[4], (num_models, output_channels, dense_kernel_size)))
# self.dense_bias = TrainVar(jnp.zeros((num_models, dense_kernel_size)))

# # classifier layer kernel and bias
# self.logits_kernel = TrainVar(glorot_normal()(
# subkeys[5], (num_models, dense_kernel_size, num_classes)))
# self.logits_bias = TrainVar(jnp.zeros((num_models, num_classes)))

# # create a rng key generator for doing dropout each call to logits
# # (as required)
# self.dropout_key = objax.random.Generator(seed)

# def logits(self, inp, single_result, model_dropout):
# """return logits over inputs.
# Args:
# inp: input images. either (B, HW, HW, 3) in which case all models
# will get the same images or (M, B, HW, HW, 3) in which case each
# model will get a different image.
# single_result: if true return single logits value for ensemble.
# otherwise return logits for each sub model.
# model_dropout: if true then only run 50% of the models during
# training.
# Returns:
# logit values for input images. either (B, C) if in single_result mode
# or (M, B, C) otherwise.
# Raises:
# Exception: if input images are (M, B, HW, HW, 3) and in single_result
# mode.
# """

# if not model_dropout:
# # if we are not doing model dropout then just take all the model
# # variables as they are. these will be of shape (M, ...)
# conv_kernels = [k.value for k in self.conv_kernels]
# conv_biases = [b.value for b in self.conv_biases]
# dense_kernel = self.dense_kernel.value
# dense_bias = self.dense_bias.value
# logits_kernel = self.logits_kernel.value
# logits_bias = self.logits_bias.value
# else:
# # but if we are doing model dropout then we take half the models
# # by first 1) picking idxs that represents a random half of the
# # models ...
# idxs = jnp.arange(self.num_models)
# idxs = jax.random.permutation(self.dropout_key(), idxs)
# idxs = idxs[:(self.num_models//2)]
# # ... and then 2) slicing the variables out. all these will be of
# # shape (M/2, ...)
# conv_kernels = [k.value[idxs] for k in self.conv_kernels]
# conv_biases = [b.value[idxs] for b in self.conv_biases]
# dense_kernel = self.dense_kernel.value[idxs]
# dense_bias = self.dense_bias.value[idxs]
# logits_kernel = self.logits_kernel.value[idxs]
# logits_bias = self.logits_bias.value[idxs]

# if len(inp.shape) == 4:
# # single_input mode; inp (B, HW, HW, 3)
# # apply first convolution as vmap against just inp
# y = vmap(partial(_conv_layer, 2, gelu, inp))(conv_kernels[0],
# conv_biases[0])
# elif len(inp.shape) == 5:
# # multi_input mode; inp (M, B, HW, HW, 3)
# if single_result:
# raise Exception("self.single_result=True not valid when passed"
# " an image per model")
# if inp.shape[0] != self.num_models:
# raise Exception("when passing (M, B, HW, HW, 3) the leading"
# " dimension needs to match the number of models"
# " in the ensemble.")
# # apply all convolutions, including first, as vmap against both y
# # and kernel, bias
# y = vmap(partial(_conv_layer, 2, gelu))(inp, conv_kernels[0],
# conv_biases[0])
# else:
# raise Exception("unexpected input shape")

# # y now (M, B, HW/2 HW/2, 32)

# # rest of the convolution stack can be applied as vmap against both y
# # and conv kernels, biases.
# # final result is (M, B, 3, 3, 256|max_conv_size)
# for kernel, bias in zip(conv_kernels[1:], conv_biases[1:]):
# y = vmap(partial(_conv_layer, 2, gelu))(y, kernel, bias)

# # global spatial pooling. (M, B, dense_kernel_size)
# y = jnp.mean(y, axis=(2, 3))

# # dense layer with non linearity. (M, B, dense_kernel_size)
# y = vmap(partial(_dense_layer, gelu))(y, dense_kernel, dense_bias)

# # dense layer with no activation to number classes.
# # (M, B, num_classes)
# logits = vmap(partial(_dense_layer, None))(y, logits_kernel,
# logits_bias)

# # if single result sum logits over models to represent single
# # ensemble result (B, num_classes)
# if single_result:
# logits = jnp.sum(logits, axis=0)

# return logits

# def predict(self, inp, single_result):
# """return class predictions. i.e. argmax over logits.
# Args:
# inp: input images. either (B, HW, HW, 3) in which case all models
# will get the same images or (M, B, HW, HW, 3) in which case each
# model will get a different image.
# single_result: if true a single prediction for ensemble is returned.
# otherwise return predictions per sub model.
# Returns:
# prediction classes for input images. either (B,) if in single_result
# mode or (M, B) otherwise.
# Raises:
# Exception: if input images are (M, B, HW, HW, 3) and in single_result
# mode.
# """

# return jnp.argmax(self.logits(inp, single_result, False), axis=-1)
4 changes: 0 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,6 @@ def total_ensemble_xent_loss(params, x, y_true):
parser.add_argument('--group', type=str,
help='w&b init group', default=None)
parser.add_argument('--seed', type=int, default=0)
# parser.add_argument('--num-models', type=int, default=1)
# parser.add_argument('--input-mode', type=str, default='single',
# help="whether inputs are across all models (single) or"
# " one input per model (multiple). inv")
parser.add_argument('--max-conv-size', type=int, default=256)
parser.add_argument('--dense-kernel-size', type=int, default=32)
parser.add_argument('--models-per-device', type=int, default=2)
Expand Down
4 changes: 3 additions & 1 deletion tune_with_ax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import json
import tensorflow as tf

raise Exception("need to port to pod version")

# tf.config.experimental.set_visible_devices([], "GPU")

parser = argparse.ArgumentParser(
Expand All @@ -18,7 +20,7 @@
parser.add_argument('--mode', type=str, required=True,
help='mode; one of siso, simo, simo_ld or mimo')
parser.add_argument('--num-models', type=int, default=1)
parser.add_argument('--run-time-sec', type=int, default=60*10)
parser.add_argument('--run-time-sec', type=int, default=60 * 10)
parser.add_argument('--epochs', type=int, default=60)

cmd_line_opts = parser.parse_args()
Expand Down
12 changes: 0 additions & 12 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,6 @@ def stopped(self):
return self.decided_to_stop


# def mean_loss(net, dataset):
# # TODO: could go to NonEnsembleNet/EnsembleNet base class
# losses_total = 0
# num_losses = 0
# for imgs, labels in dataset:
# logits = net.logits(imgs, single_result=True, model_dropout=False)
# losses = cross_entropy_logits_sparse(logits, labels)
# losses_total += jnp.sum(losses)
# num_losses += len(losses)
# return float(losses_total / num_losses)


def accuracy(predict_fn, dataset):
num_correct = 0
num_total = 0
Expand Down

0 comments on commit 87ad647

Please sign in to comment.