-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switching training/test phase of DropoutLayer #26
Comments
I think one way to unify the API is add a new class Dropout(Layer):
def __init__(self,
layer = None,
keep = 0.5,
is_train = True,
name = 'dropout_layer'):
Layer.__init__(self, name=name)
self.inputs = layer.outputs
print(" tensorlayer:Instantiate Dropout %s: keep: %f" % (self.name, keep))
set_keep[name] = tf.constant(keep, dtype=tf.float32)
if is_train:
self.outputs = tf.nn.dropout(self.inputs, set_keep[name], name=name) # 1.2
else:
self.outputs = self.inputs
self.all_layers = list(layer.all_layers)
self.all_params = list(layer.all_params)
self.all_drop = dict(layer.all_drop)
self.all_drop.update( {set_keep[name]: keep} )
self.all_layers.extend( [self.outputs] ) |
[NEW] FYI, the lastest version of Previous answer: This may be better? class Dropout(Layer):
def __init__(self,
layer = None,
keep = 0.5,
is_fix = False,
name = 'dropout_layer'):
Layer.__init__(self, name=name)
self.inputs = layer.outputs
print(" tensorlayer:Instantiate Dropout %s: keep: %f" % (self.name, keep))
if is_fix:
self.outputs = tf.nn.dropout(self.inputs, keep, name=name)
else:
set_keep[name] = tf.placeholder(tf.float32)
self.outputs = tf.nn.dropout(self.inputs, set_keep[name], name=name)
self.all_layers = list(layer.all_layers)
self.all_params = list(layer.all_params)
self.all_drop = dict(layer.all_drop)
if not is_fix:
self.all_drop.update( {set_keep[name]: keep} )
self.all_layers.extend( [self.outputs] ) |
@lucidfrontier45 Is @zsdonghao 's code work for you? if yes, you can make a push request. |
Before talking about my code or @zsdonghao 's I want to make clear how to use batch normalization. class tensorlayer.layers.BatchNormLayer(
layer = None,
decay = 0.999,
epsilon = 0.00001,
act = tf.identity,
is_train = None,
beta_init = tf.zeros_initializer,
gamma_init = tf.ones_initializer,
name ='batchnorm_layer')
If so, I think it's confusing that One way is my |
@lucidfrontier45 Hi, your suggestion is good. Now, if is_train:
network = DropoutLayer(network, 0.8, name='xxx') instead of put the Please let me know, if you have any suggestion. |
FYI, the lastest version of |
if is_train:
network = DropoutLayer(network, 0.8, name='xxx') This looks fine to me. Thank you. |
IMPORTANT@lucidfrontier45 @wagamamaz the latest version of TL has an args of if is_train:
network = DropoutLayer(network, 0.8, is_fix=True, name='xxx') |
how to get the variable of network? the tensorboard shows nothing. |
Dropout and Batch Normalization are two major structures that should behave differently in both training and test phases.
According to the API reference (http://tensorlayer.readthedocs.io/en/latest/modules/layers.html#dropout-layer) , switching train/test for the
DropoutLayer
is handled by set different data to feed_dict.e.g.
I couldn't find how to switch for
BatchNormLayer
in the tutorial or API reference but according to the DCGAN example (https://github.com/zsdonghao/dcgan), it creates two different networks.For training phase,
is_train=True
is passed toBatchNormLayer
andis_train=False
for test phase.I think this is confusing that switching method is not unified. Or, is there any standard way for batch norm?
For example,
TFLearn
switches training/test bytflearn.is_training
method.The text was updated successfully, but these errors were encountered: