Skip to content

Commit

Permalink
use objax rng generator for dropout key
Browse files Browse the repository at this point in the history
  • Loading branch information
matpalm committed Sep 22, 2020
1 parent e825e16 commit 973c42f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def __init__(self, num_models, num_classes, max_conv_size=64,
subkeys[5], (num_models, dense_kernel_size, num_classes)))
self.logits_bias = TrainVar(jnp.zeros((num_models, num_classes)))

self.dropout_key = StateVar(subkeys[6])
# create a rng key generator for doing dropout each call to logits
# (as required)
self.dropout_key = objax.random.Generator(seed)

def logits(self, inp, single_result, model_dropout):
"""return logits over inputs.
Expand Down Expand Up @@ -169,10 +171,8 @@ def logits(self, inp, single_result, model_dropout):
# but if we are doing model dropout then we take half the models
# by first 1) picking idxs that represents a random half of the
# models ...
new_dropout_key, permute_key = random.split(self.dropout_key.value)
self.dropout_key.assign(new_dropout_key)
idxs = jnp.arange(self.num_models)
idxs = jax.random.permutation(permute_key, idxs)
idxs = jax.random.permutation(self.dropout_key(), idxs)
idxs = idxs[:(self.num_models//2)]
# ... and then 2) slicing the variables out. all these will be of
# shape (M/2, ...)
Expand Down
2 changes: 1 addition & 1 deletion smoke_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ set -ex

python3 train.py \
--input-mode single \
--num-models 2 \
--num-models 4 \
--max-conv-size 32 \
--dense-kernel-size 32 \
--batch-size 32 \
Expand Down

0 comments on commit 973c42f

Please sign in to comment.