Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorflow version error #59

Closed
xin-xinhanggao opened this issue Nov 30, 2016 · 5 comments · Fixed by #60
Closed

tensorflow version error #59

xin-xinhanggao opened this issue Nov 30, 2016 · 5 comments · Fixed by #60

Comments

@xin-xinhanggao
Copy link

When I upgrade tensorflow to the latest version, I type python main.py --dataset mnist --is_train True
Traceback (most recent call last):
File "main.py", line 60, in
tf.app.run()
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 43, in run
sys.exit(main(sys.argv[:1] + flags_passthrough))
File "main.py", line 38, in main
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir)
File "/Users/apple/Desktop/test/DCGAN-tensorflow/model.py", line 67, in init
self.build_model()
File "/Users/apple/Desktop/test/DCGAN-tensorflow/model.py", line 86, in build_model
self.sampler = self.sampler(self.z, self.y)
File "/Users/apple/Desktop/test/DCGAN-tensorflow/model.py", line 354, in sampler
h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
File "/Users/apple/Desktop/test/DCGAN-tensorflow/ops.py", line 34, in call
ema_apply_op = self.ema.apply([batch_mean, batch_var])
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/moving_averages.py", line 391, in apply
self._averages[var], var, decay, zero_debias=zero_debias))
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/moving_averages.py", line 70, in assign_moving_average
update_delta = _zero_debias(variable, value, decay)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/moving_averages.py", line 177, in _zero_debias
trainable=False)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 1024, in get_variable
custom_getter=custom_getter)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 850, in get_variable
custom_getter=custom_getter)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 346, in get_variable
validate_shape=validate_shape)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 331, in _true_getter
caching_device=caching_device, validate_shape=validate_shape)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 650, in _get_single_variable
"VarScope?" % name)
ValueError: Variable g_bn0/g_bn0_2/g_bn0_2/moments_1/moments_1/mean/ExponentialMovingAverage/biased does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

The previous tensorflow is OK. Would you please tell me which version you choose?

@ppwwyyxx
Copy link

I've seen the same problem. This is due to changes in TF0.11 where exponential moving average cannot be under a reuse=True scope. See tensorflow/tensorflow#2740.
You can modify the code to force reuse=False around exponential moving average, like what's been done here.

@AStrangeQuark
Copy link

@ppwwyyxx Could you explain how you would apply such a modification in this code? I don't know TF well enough to figure out where exactly you would force the EMA to be done outside of a scope (or... whatever is going on...).

@nemtiax
Copy link

nemtiax commented Dec 1, 2016

We're trying to update the code to use tf.contrib.layers.batch_norm to avoid this issue, but the results we see on mnist are noticeably worse. Are there any important differences between the batch_norm code used here in ops.py and the methods used in the contrib.layers version that might explain the discrepancy?

@ppwwyyxx
Copy link

ppwwyyxx commented Dec 2, 2016

You can add the line

            with tf.variable_scope(tf.get_variable_scope(), reuse=False):

before ema.apply in ops.py:35.
Note that this is a quick hack and the real fix should come from tensorflow.
Also a duplicate of #57.

@shimafoolad
Copy link

On tensorflow 1.12.0, I had the same problem and fixed it by adding the line:

        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):

before ema.apply

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants