diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index a59d508fd..6d0af1e6a 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -3724,7 +3724,7 @@ def double_discriminator(x, filters1=128, filters2=None, tf.reshape(net, [batch_size, -1]) net = tf.nn.relu(net) net = layers().Conv2D( - filters2, kernel_size, strides=strides, padding="SAME", name="conv2")(x) + filters2, kernel_size, strides=strides, padding="SAME", name="conv2")(net) if pure_mean: net2 = tf.reduce_mean(net, [1, 2]) else: