@@ -72,43 +72,26 @@ def inject_latent(self, layer, inputs, target):
7272 filters = hparams .hidden_size
7373 kernel = (4 , 4 )
7474 layer_shape = common_layers .shape_list (layer )
75- batch_size = layer_shape [0 ]
76- state_size = hparams .latent_predictor_state_size
77- lstm_cell = tf .contrib .rnn .LSTMCell (state_size )
78- discrete_predict = tfl .Dense (256 , name = "discrete_predict" )
79- discrete_embed = tfl .Dense (state_size , name = "discrete_embed" )
80-
81- def add_d (layer , d ):
82- z_mul = tfl .dense (d , final_filters , name = "unbottleneck_mul" )
75+
76+ def add_bits (layer , bits ):
77+ z_mul = tfl .dense (bits , final_filters , name = "unbottleneck_mul" )
8378 if not hparams .complex_addn :
8479 return layer + z_mul
8580 layer *= tf .nn .sigmoid (z_mul )
86- z_add = tfl .dense (d , final_filters , name = "unbottleneck_add" )
81+ z_add = tfl .dense (bits , final_filters , name = "unbottleneck_add" )
8782 layer += z_add
8883 return layer
8984
9085 if self .is_predicting :
9186 if hparams .full_latent_tower :
9287 rand = tf .random_uniform (layer_shape [:- 1 ] + [hparams .bottleneck_bits ])
88+ bits = 2.0 * tf .to_float (tf .less (0.5 , rand )) - 1.0
9389 else :
94- layer_pred = tfl .flatten (layer )
95- prediction = tfl .dense (layer_pred , state_size , name = "istate" )
96- c_state = tfl .dense (layer_pred , state_size , name = "cstate" )
97- m_state = tfl .dense (layer_pred , state_size , name = "mstate" )
98- state = (c_state , m_state )
99- outputs = []
100- for i in range (hparams .bottleneck_bits // 8 ):
101- output , state = lstm_cell (prediction , state )
102- discrete_logits = discrete_predict (output )
103- discrete_samples = common_layers .sample_with_temperature (
104- discrete_logits , hparams .latent_predictor_temperature )
105- outputs .append (tf .expand_dims (discrete_samples , axis = 1 ))
106- prediction = discrete_embed (tf .one_hot (discrete_samples , 256 ))
107- outputs = tf .concat (outputs , axis = 1 )
108- outputs = discretization .int_to_bit (outputs , 8 )
109- rand = tf .reshape (outputs , [batch_size , 1 , 1 , hparams .bottleneck_bits ])
110- d = 2.0 * tf .to_float (tf .less (0.5 , rand )) - 1.0
111- return add_d (layer , d ), 0.0
90+ bits , _ = discretization .predict_bits_with_lstm (
91+ layer , hparams .latent_predictor_state_size , hparams .bottleneck_bits ,
92+ temperature = hparams .latent_predictor_temperature )
93+ bits = tf .expand_dims (tf .expand_dims (bits , axis = 1 ), axis = 2 )
94+ return add_bits (layer , bits ), 0.0
11295
11396 # Embed.
11497 frames = tf .concat (inputs + [target ], axis = - 1 )
@@ -131,43 +114,16 @@ def add_d(layer, d):
131114 else :
132115 x = common_layers .double_discriminator (x )
133116 x = tf .expand_dims (tf .expand_dims (x , axis = 1 ), axis = 1 )
134- x = tfl . dense ( x , hparams . bottleneck_bits , name = "bottleneck" )
135- x0 = tf . tanh ( x )
136- d = x0 + tf . stop_gradient ( 2.0 * tf . to_float ( tf . less ( 0.0 , x0 )) - 1.0 - x0 )
137- pred_loss = 0.0
117+
118+ bits , bits_clean = discretization . tanh_discrete_bottleneck (
119+ x , hparams . bottleneck_bits , hparams . bottleneck_noise ,
120+ hparams . discretize_warmup_steps , hparams . mode )
138121 if not hparams .full_latent_tower :
139- d_pred = tf .reshape (tf .maximum (tf .stop_gradient (d ), 0 ), [
140- batch_size , hparams .bottleneck_bits // 8 , 8 ])
141- d_int = discretization .bit_to_int (d_pred , 8 )
142- tf .summary .histogram ("d_int" , tf .reshape (d_int , [- 1 ]))
143- d_hot = tf .one_hot (d_int , 256 , axis = - 1 )
144- d_pred = discrete_embed (d_hot )
145- layer_pred = tfl .flatten (layer )
146- prediction0 = tfl .dense (layer_pred , state_size , name = "istate" )
147- c_state = tfl .dense (layer_pred , state_size , name = "cstate" )
148- m_state = tfl .dense (layer_pred , state_size , name = "mstate" )
149- pred = tf .concat ([tf .expand_dims (prediction0 , axis = 1 ), d_pred ], axis = 1 )
150- state = (c_state , m_state )
151- outputs = []
152- for i in range (hparams .bottleneck_bits // 8 ):
153- output , state = lstm_cell (pred [:, i , :], state )
154- outputs .append (tf .expand_dims (output , axis = 1 ))
155- outputs = tf .concat (outputs , axis = 1 )
156- d_int_pred = discrete_predict (outputs )
157- pred_loss = tf .losses .sparse_softmax_cross_entropy (
158- logits = d_int_pred , labels = d_int )
159- pred_loss = tf .reduce_mean (pred_loss )
160- if hparams .mode == tf .estimator .ModeKeys .TRAIN :
161- x += tf .truncated_normal (
162- common_layers .shape_list (x ), mean = 0.0 , stddev = 0.2 )
163- x = tf .tanh (x )
164- noise = tf .random_uniform (common_layers .shape_list (x ))
165- noise = 2.0 * tf .to_float (tf .less (hparams .bottleneck_noise , noise )) - 1.0
166- x *= noise
167- d = x + tf .stop_gradient (2.0 * tf .to_float (tf .less (0.0 , x )) - 1.0 - x )
168- p = common_layers .inverse_lin_decay (hparams .discrete_warmup_steps )
169- d = tf .where (tf .less (tf .random_uniform ([batch_size ]), p ), d , x )
170- return add_d (layer , d ), pred_loss
122+ _ , pred_loss = discretization .predict_bits_with_lstm (
123+ layer , hparams .latent_predictor_state_size , hparams .bottleneck_bits ,
124+ target_bits = bits_clean )
125+
126+ return add_bits (layer , bits ), pred_loss
171127
172128
173129@registry .register_hparams
@@ -224,7 +180,7 @@ def next_frame_basic_stochastic_discrete():
224180 hparams .learning_rate_schedule = "linear_warmup * constant"
225181 hparams .add_hparam ("bottleneck_bits" , 64 )
226182 hparams .add_hparam ("bottleneck_noise" , 0.02 )
227- hparams .add_hparam ("discrete_warmup_steps " , 40000 )
183+ hparams .add_hparam ("discretize_warmup_steps " , 40000 )
228184 hparams .add_hparam ("full_latent_tower" , False )
229185 hparams .add_hparam ("latent_predictor_state_size" , 128 )
230186 hparams .add_hparam ("latent_predictor_temperature" , 0.5 )
0 commit comments