Skip to content

Commit

Permalink
working on pod slice
Browse files Browse the repository at this point in the history
  • Loading branch information
matpalm committed Feb 9, 2021
1 parent 6d9a587 commit 64fc08d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 26 deletions.
8 changes: 4 additions & 4 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def validation_dataset(batch_size, sample_data=False):
split = 'train[80%:82%]'
else:
split = 'train[80%:90%]'
logging.info("validation_dataset %s" % split)
logging.debug("validation_dataset %s" % split)
return _non_training_dataset(batch_size, split)


Expand All @@ -59,19 +59,19 @@ def test_dataset(batch_size, sample_data=False):
split = 'train[90%:92%]'
else:
split = 'train[90%:]'
logging.info("test_dataset %s" % split)
logging.debug("test_dataset %s" % split)
return _non_training_dataset(batch_size, split)


def training_dataset(batch_size, shuffle_seed, num_inputs=1, sample_data=False):
logging.info("training dataset shuffle_seed %d" % shuffle_seed)
logging.debug("training_dataset shuffle_seed %d" % shuffle_seed)

if sample_data:
logging.warn("using small sample_data for training")
split = 'train[:2%]'
else:
split = 'train[:80%]'
logging.info("training_dataset %s" % split)
logging.debug("training_dataset %s" % split)

dataset = (tfds.load('eurosat/rgb', split=split,
as_supervised=True)
Expand Down
55 changes: 34 additions & 21 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
import wandb
import logging

logging.basicConfig(format='%(asctime)s %(message)s')
logging.getLogger().setLevel(logging.INFO)


def softmax_cross_entropy(logits, labels):
one_hot = jax.nn.one_hot(
Expand Down Expand Up @@ -59,10 +56,14 @@ def train(opts):
num_devices = len(jax.local_devices())
num_models = num_devices * opts.models_per_device

# we make two rngs; one that is distinct per host and one
# that will be common across the pod
host_rng = jax.random.PRNGKey(opts.seed ^ jax.host_id())
pod_rng = jax.random.PRNGKey(opts.seed * 2) # o_O

logging.info("init models")
rng = jax.random.PRNGKey(opts.seed ^ jax.host_id())
keys = jax.random.split(rng, num_models)
logging.info("model keys %s" % list(keys))
keys = jax.random.split(host_rng, num_models)
logging.debug("model keys %s" % list(keys))
representative_input = jnp.zeros((1, 64, 64, 3))
params = vmap(lambda k: model.init(k, representative_input))(keys)

Expand Down Expand Up @@ -134,9 +135,6 @@ def total_ensemble_xent_loss(params, x, y_true):
# --------------------------------
# run training loop

# generate a shuffle key that will be the same for all hosts
shuffle_rng = jax.random.PRNGKey(opts.seed)

for epoch in range(opts.epochs):

# train for one epoch
Expand All @@ -145,8 +143,9 @@ def total_ensemble_xent_loss(params, x, y_true):
total_training_loss = 0
training_num_examples = 0

# split out a new shuffle seed for this epoch
shuffle_rng, shuffle_seed = jax.random.split(shuffle_rng)
# split out a new shuffle seed for this epoch common
# across pod
pod_rng, shuffle_seed = jax.random.split(pod_rng)

# create dataset
train_ds = data.training_dataset(batch_size=opts.batch_size,
Expand All @@ -156,11 +155,12 @@ def total_ensemble_xent_loss(params, x, y_true):

for imgs, labels in train_ds:

logging.info("labels %s" % labels)
logging.debug("labels %s" % labels)

# replicate batch across M devices
imgs = u.replicate(imgs) # (M, B, H, W, 3)
labels = u.replicate(labels) # (M, B)
# (M, B, H, W, 3)
imgs = u.replicate(imgs, replicas=num_devices)
labels = u.replicate(labels, replicas=num_devices) # (M, B)

# run across all the 4 rotations
# for k in range(4):
Expand All @@ -169,13 +169,16 @@ def total_ensemble_xent_loss(params, x, y_true):
# run some steps for this set, each with a different set of
# dropout idxs
for _ in range(opts.steps_per_batch):
rng, dropout_key = jax.random.split(rng)
host_rng, dropout_key = jax.random.split(host_rng)
logging.debug("dropout_key %s" % dropout_key[0])
sub_model_idxs = jax.random.randint(dropout_key, minval=0,
maxval=opts.models_per_device,
shape=(num_devices,))
logging.debug("sub_model_idxs %s" % sub_model_idxs)
params, opt_states, losses = p_update(params, opt_states,
sub_model_idxs,
imgs, labels)
logging.debug("losses %s" % losses)

total_training_loss += jnp.sum(losses)
training_num_examples += len(losses)
Expand All @@ -195,7 +198,8 @@ def total_ensemble_xent_loss(params, x, y_true):
total_validation_loss = 0
validation_num_examples = 0
validation_data = data.validation_dataset(
batch_size=opts.batch_size)
batch_size=opts.batch_size,
sample_data=opts.sample_data)
for imgs, labels in validation_data:
total_validation_loss += total_ensemble_xent_loss(params, imgs,
labels)
Expand All @@ -217,9 +221,10 @@ def total_ensemble_xent_loss(params, x, y_true):
else:
logging.info("finished %s final validation_loss %f" %
(run, mean_validation_loss))

# return validation loss to ax
return mean_validation_loss
# return validation loss to ax
return mean_validation_loss
else:
return None


if __name__ == '__main__':
Expand All @@ -244,18 +249,26 @@ def total_ensemble_xent_loss(params, x, y_true):
# parser.add_argument('--input-mode', type=str, default='single',
# help="whether inputs are across all models (single) or"
# " one input per model (multiple). inv")
parser.add_argument('--max-conv-size', type=int, default=64)
parser.add_argument('--dense-kernel-size', type=int, default=16)
parser.add_argument('--max-conv-size', type=int, default=256)
parser.add_argument('--dense-kernel-size', type=int, default=32)
parser.add_argument('--models-per-device', type=int, default=2)
parser.add_argument('--learning-rate', type=float, default=1e-3)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--steps-per-batch', type=int, default=4,
help='how many steps to run, each with new random'
' dropout, per batch that is loaded')
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--log-level', type=str, default='INFO')
parser.add_argument('--sample-data', action='store_true',
help='set for running test with small training data')
opts = parser.parse_args()
print(opts, file=sys.stderr)

# set logging level
numeric_level = getattr(logging, opts.log_level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError('Invalid log level: %s' % opts.log_level)
logging.basicConfig(format='%(asctime)s %(message)s')
logging.getLogger().setLevel(numeric_level) # logging.INFO)

train(opts)
2 changes: 1 addition & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import data
#from objax.functional.loss import cross_entropy_logits_sparse
import jax.numpy as jnp
from jax import pmap, host_id
from jax import pmap, host_id, jit
from jax.tree_util import tree_map


Expand Down

0 comments on commit 64fc08d

Please sign in to comment.