Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switching training/test phase of DropoutLayer #26

Closed
lucidfrontier45 opened this issue Nov 22, 2016 · 9 comments
Closed

Switching training/test phase of DropoutLayer #26

lucidfrontier45 opened this issue Nov 22, 2016 · 9 comments

Comments

@lucidfrontier45
Copy link

Dropout and Batch Normalization are two major structures that should behave differently in both training and test phases.

According to the API reference (http://tensorlayer.readthedocs.io/en/latest/modules/layers.html#dropout-layer) , switching train/test for the DropoutLayer is handled by set different data to feed_dict.
e.g.

# for training
feed_dict = {x: X_train_a, y_: y_train_a}
feed_dict.update( network.all_drop )     # enable noise layers

# for testing
dp_dict = tl.utils.dict_to_one( network.all_drop ) # disable noise layers
feed_dict = {x: X_val_a, y_: y_val_a}
feed_dict.update(dp_dict)

I couldn't find how to switch for BatchNormLayer in the tutorial or API reference but according to the DCGAN example (https://github.com/zsdonghao/dcgan), it creates two different networks.
For training phase, is_train=True is passed to BatchNormLayer and is_train=False for test phase.

I think this is confusing that switching method is not unified. Or, is there any standard way for batch norm?
For example, TFLearn switches training/test by tflearn.is_training method.

@lucidfrontier45
Copy link
Author

I think one way to unify the API is add a new Dropout layer that receives is_train argument.
See by test implementation below.

class Dropout(Layer):
    def __init__(self,
                layer = None,
                keep = 0.5,
                is_train = True,
                name = 'dropout_layer'):
        
        Layer.__init__(self, name=name)
        self.inputs = layer.outputs
        print("  tensorlayer:Instantiate Dropout %s: keep: %f" % (self.name, keep))

        set_keep[name] = tf.constant(keep, dtype=tf.float32)
        if is_train:
            self.outputs = tf.nn.dropout(self.inputs, set_keep[name], name=name) # 1.2
        else:
            self.outputs = self.inputs
            
        self.all_layers = list(layer.all_layers)
        self.all_params = list(layer.all_params)
        self.all_drop = dict(layer.all_drop)
        self.all_drop.update( {set_keep[name]: keep} )
        self.all_layers.extend( [self.outputs] )

@zsdonghao
Copy link
Member

zsdonghao commented Nov 22, 2016

[NEW] FYI, the lastest version of DropoutLayer have a is_fix setting, you can fix the keeping probability by setting it to True.


Previous answer:

This may be better?

class Dropout(Layer):
    def __init__(self,
                layer = None,
                keep = 0.5,
                is_fix = False,
                name = 'dropout_layer'):
        
        Layer.__init__(self, name=name)
        self.inputs = layer.outputs
        print("  tensorlayer:Instantiate Dropout %s: keep: %f" % (self.name, keep))

        if is_fix:
            self.outputs = tf.nn.dropout(self.inputs, keep, name=name) 
        else:
           set_keep[name] = tf.placeholder(tf.float32)
           self.outputs = tf.nn.dropout(self.inputs, set_keep[name], name=name)
            
        self.all_layers = list(layer.all_layers)
        self.all_params = list(layer.all_params)
        self.all_drop = dict(layer.all_drop)
        if not is_fix:
              self.all_drop.update( {set_keep[name]: keep} )
        self.all_layers.extend( [self.outputs] )

@wagamamaz
Copy link
Collaborator

@lucidfrontier45 Is @zsdonghao 's code work for you? if yes, you can make a push request.

@lucidfrontier45
Copy link
Author

@wagamamaz

Before talking about my code or @zsdonghao 's I want to make clear how to use batch normalization.

class tensorlayer.layers.BatchNormLayer(
        layer = None,
        decay = 0.999,
        epsilon = 0.00001,
        act = tf.identity,
        is_train = None,
        beta_init = tf.zeros_initializer,
        gamma_init = tf.ones_initializer,
        name ='batchnorm_layer')

BatchNormLayer accepts is_train as arg for constructor. It's compile time but not run time.
I couldn't find any example of batch normalization except in DCGAN example. It makes two net, one with is is_train=True passed and is_train=False the other. Is this the intended usage of BatchNormLayer ?

If so, I think it's confusing that DropoutLayer and BatchNormLayer has different API for switching training/test phase and should make it unify.

One way is my Dropout implementation that accepts is_train argument.
Is @zsdonghao 's code for switching training/test phase?

@zsdonghao
Copy link
Member

@lucidfrontier45 Hi, your suggestion is good.

Now, BatchNormLayer have is_train, but DropoutLayer doesn't. However, if a model contails BatchNormLayer, to build inferences for training and testing, we need to use the way in PTB example. In that case, we can use

if is_train:
    network = DropoutLayer(network, 0.8, name='xxx')

instead of put the is_train inside the DropoutLayer, or we can also enable/disable dropout layer by setting feed_dict see mnist cnn.

Please let me know, if you have any suggestion.

@zsdonghao
Copy link
Member

FYI, the lastest version of DropoutLayer have a is_fix setting, you can fix the keeping probability by setting it to True.

@lucidfrontier45
Copy link
Author

@zsdonghao

if is_train:
    network = DropoutLayer(network, 0.8, name='xxx')

This looks fine to me. Thank you.

@zsdonghao
Copy link
Member

zsdonghao commented Dec 27, 2016

IMPORTANT

@lucidfrontier45 @wagamamaz the latest version of TL has an args of is_fix, so you can do as follow:

if is_train:
    network = DropoutLayer(network, 0.8, is_fix=True, name='xxx')

@zsdonghao zsdonghao changed the title Switching training/test phase Switching training/test phase of DropoutLayer Dec 27, 2016
@quelle1
Copy link

quelle1 commented Dec 15, 2017

network = Conv2d(net_in, df_dim, (k, k), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), padding='SAME', W_init=w_init, name='h0/conv2d')
tf.summary.histogram('h0/conv2d',tf.get_collection(tf.GraphKeys.VARIABLES, 'h0/conv2d'))

how to get the variable of network? the tensorboard shows nothing.

zsdonghao pushed a commit that referenced this issue May 4, 2019
add cloudpickle to requirement.txt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants