Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ca628e4

Browse files
Lukasz KaiserCopybara-Service
authored andcommitted
Clean up stochastic discrete video model by moving code and using discretization layers.
PiperOrigin-RevId: 217386874
1 parent be78912 commit ca628e4

File tree

4 files changed

+104
-69
lines changed

4 files changed

+104
-69
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3019,8 +3019,10 @@ def get_res():
30193019
if is_xla_compiled():
30203020
return get_res()
30213021
else:
3022-
return tf.cond(
3023-
tf.less(tf.train.get_global_step(), steps), get_res, lambda: x1)
3022+
cur_step = tf.train.get_global_step()
3023+
if cur_step is None:
3024+
return x1 # Step not available, probably eval mode, don't mix.
3025+
return tf.cond(tf.less(cur_step, steps), get_res, lambda: x1)
30243026

30253027

30263028
def brelu(x):

tensor2tensor/layers/discretization.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,81 @@ def discrete_bottleneck(inputs,
783783
return outputs_dense, outputs_discrete, extra_loss, embed_fn, neg_q_entropy
784784

785785

786+
def predict_bits_with_lstm(prediction_source, state_size, total_num_bits,
787+
target_bits=None, bits_at_once=8, temperature=1.0):
788+
"""Predict a sequence of bits (a latent) with LSTM, both training and infer.
789+
790+
Given a tensor on which the predictions are based (prediction_source), we use
791+
a single-layer LSTM with state of size state_size to predict total_num_bits,
792+
which we predict in groups of size bits_at_once. During training, we use
793+
target_bits as input to the LSTM (teacher forcing) and return the target_bits
794+
together with the prediction loss. During inference, we sample with the given
795+
temperature and return the predicted sequence and loss 0.
796+
797+
Args:
798+
prediction_source: a Tensor of shape [batch_size, ...] used to create
799+
the initial state and the first input to the LSTM.
800+
state_size: python integer, the size of the LSTM state.
801+
total_num_bits: python integer, how many bits in total to predict.
802+
target_bits: a tensor of shape [batch_size, total_num_bits] used during
803+
training as the target to predict; each element should be -1 or 1.
804+
bits_at_once: pytho integer, how many bits to predict at once.
805+
temperature: python float, temperature used for sampling during inference.
806+
807+
Returns:
808+
a pair (bits, loss) with the predicted bit sequence, which is a Tensor of
809+
shape [batch_size, total_num_bits] with elements either -1 or 1, and a loss
810+
used to train the predictions against the provided target_bits.
811+
"""
812+
813+
with tf.variable_scope("predict_bits_with_lstm"):
814+
# Layers and cell state creation.
815+
lstm_cell = tf.contrib.rnn.LSTMCell(state_size)
816+
discrete_predict = tf.layers.Dense(2**bits_at_once, name="discrete_predict")
817+
discrete_embed = tf.layers.Dense(state_size, name="discrete_embed")
818+
batch_size = common_layers.shape_list(prediction_source)[0]
819+
layer_pred = tf.layers.flatten(prediction_source)
820+
prediction = tf.layers.dense(layer_pred, state_size, name="istate")
821+
c_state = tf.layers.dense(layer_pred, state_size, name="cstate")
822+
m_state = tf.layers.dense(layer_pred, state_size, name="mstate")
823+
state = (c_state, m_state)
824+
825+
# Prediction mode if no targets are given.
826+
if target_bits is None:
827+
outputs = []
828+
for i in range(total_num_bits // bits_at_once):
829+
output, state = lstm_cell(prediction, state)
830+
discrete_logits = discrete_predict(output)
831+
discrete_samples = common_layers.sample_with_temperature(
832+
discrete_logits, temperature)
833+
outputs.append(tf.expand_dims(discrete_samples, axis=1))
834+
prediction = discrete_embed(tf.one_hot(discrete_samples, 256))
835+
outputs = tf.concat(outputs, axis=1)
836+
outputs = int_to_bit(outputs, bits_at_once)
837+
outputs = tf.reshape(outputs, [batch_size, total_num_bits])
838+
return 2 * outputs - 1, 0.0
839+
840+
# Training mode, calculating loss.
841+
assert total_num_bits % bits_at_once == 0
842+
d_pred = tf.reshape(tf.maximum(tf.stop_gradient(target_bits), 0), [
843+
batch_size, total_num_bits // bits_at_once, bits_at_once])
844+
d_int = bit_to_int(d_pred, bits_at_once)
845+
tf.summary.histogram("target_integers", tf.reshape(d_int, [-1]))
846+
d_hot = tf.one_hot(d_int, 2**bits_at_once, axis=-1)
847+
d_pred = discrete_embed(d_hot)
848+
pred = tf.concat([tf.expand_dims(prediction, axis=1), d_pred], axis=1)
849+
outputs = []
850+
for i in range(total_num_bits // bits_at_once):
851+
output, state = lstm_cell(pred[:, i, :], state)
852+
outputs.append(tf.expand_dims(output, axis=1))
853+
outputs = tf.concat(outputs, axis=1)
854+
d_int_pred = discrete_predict(outputs)
855+
pred_loss = tf.losses.sparse_softmax_cross_entropy(
856+
logits=d_int_pred, labels=d_int)
857+
pred_loss = tf.reduce_mean(pred_loss)
858+
return target_bits, pred_loss
859+
860+
786861
# New API for discretization bottlenecks:
787862
# * Each method is separate and provides 2 functions:
788863
# * The [method]_bottleneck function returns discretized state.
@@ -1281,6 +1356,7 @@ def tanh_discrete_bottleneck(x, bottleneck_bits, bottleneck_noise,
12811356
discretize_warmup_steps, mode):
12821357
"""Simple discretization through tanh, flip bottleneck_noise many bits."""
12831358
x = tf.layers.dense(x, bottleneck_bits, name="tanh_discrete_bottleneck")
1359+
d0 = tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x))) - 1.0
12841360
if mode == tf.estimator.ModeKeys.TRAIN:
12851361
x += tf.truncated_normal(
12861362
common_layers.shape_list(x), mean=0.0, stddev=0.2)
@@ -1292,7 +1368,7 @@ def tanh_discrete_bottleneck(x, bottleneck_bits, bottleneck_noise,
12921368
d *= noise
12931369
d = common_layers.mix(d, x, discretize_warmup_steps,
12941370
mode == tf.estimator.ModeKeys.TRAIN)
1295-
return d, 0.0
1371+
return d, d0
12961372

12971373

12981374
def tanh_discrete_unbottleneck(x, hidden_size):
@@ -1345,9 +1421,10 @@ def isemhash_unbottleneck(x, hidden_size, isemhash_filter_size_multiplier=1.0):
13451421
def parametrized_bottleneck(x, hparams):
13461422
"""Meta-function calling all the above bottlenecks with hparams."""
13471423
if hparams.bottleneck_kind == "tanh_discrete":
1348-
return tanh_discrete_bottleneck(
1424+
d, _ = tanh_discrete_bottleneck(
13491425
x, hparams.bottleneck_bits, hparams.bottleneck_noise * 0.5,
13501426
hparams.discretize_warmup_steps, hparams.mode)
1427+
return d, 0.0
13511428
if hparams.bottleneck_kind == "isemhash":
13521429
return isemhash_bottleneck(
13531430
x, hparams.bottleneck_bits, hparams.bottleneck_noise * 0.5,

tensor2tensor/models/video/basic_stochastic.py

Lines changed: 20 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -72,43 +72,26 @@ def inject_latent(self, layer, inputs, target):
7272
filters = hparams.hidden_size
7373
kernel = (4, 4)
7474
layer_shape = common_layers.shape_list(layer)
75-
batch_size = layer_shape[0]
76-
state_size = hparams.latent_predictor_state_size
77-
lstm_cell = tf.contrib.rnn.LSTMCell(state_size)
78-
discrete_predict = tfl.Dense(256, name="discrete_predict")
79-
discrete_embed = tfl.Dense(state_size, name="discrete_embed")
80-
81-
def add_d(layer, d):
82-
z_mul = tfl.dense(d, final_filters, name="unbottleneck_mul")
75+
76+
def add_bits(layer, bits):
77+
z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
8378
if not hparams.complex_addn:
8479
return layer + z_mul
8580
layer *= tf.nn.sigmoid(z_mul)
86-
z_add = tfl.dense(d, final_filters, name="unbottleneck_add")
81+
z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
8782
layer += z_add
8883
return layer
8984

9085
if self.is_predicting:
9186
if hparams.full_latent_tower:
9287
rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits])
88+
bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
9389
else:
94-
layer_pred = tfl.flatten(layer)
95-
prediction = tfl.dense(layer_pred, state_size, name="istate")
96-
c_state = tfl.dense(layer_pred, state_size, name="cstate")
97-
m_state = tfl.dense(layer_pred, state_size, name="mstate")
98-
state = (c_state, m_state)
99-
outputs = []
100-
for i in range(hparams.bottleneck_bits // 8):
101-
output, state = lstm_cell(prediction, state)
102-
discrete_logits = discrete_predict(output)
103-
discrete_samples = common_layers.sample_with_temperature(
104-
discrete_logits, hparams.latent_predictor_temperature)
105-
outputs.append(tf.expand_dims(discrete_samples, axis=1))
106-
prediction = discrete_embed(tf.one_hot(discrete_samples, 256))
107-
outputs = tf.concat(outputs, axis=1)
108-
outputs = discretization.int_to_bit(outputs, 8)
109-
rand = tf.reshape(outputs, [batch_size, 1, 1, hparams.bottleneck_bits])
110-
d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
111-
return add_d(layer, d), 0.0
90+
bits, _ = discretization.predict_bits_with_lstm(
91+
layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
92+
temperature=hparams.latent_predictor_temperature)
93+
bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
94+
return add_bits(layer, bits), 0.0
11295

11396
# Embed.
11497
frames = tf.concat(inputs + [target], axis=-1)
@@ -131,43 +114,16 @@ def add_d(layer, d):
131114
else:
132115
x = common_layers.double_discriminator(x)
133116
x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
134-
x = tfl.dense(x, hparams.bottleneck_bits, name="bottleneck")
135-
x0 = tf.tanh(x)
136-
d = x0 + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x0)) - 1.0 - x0)
137-
pred_loss = 0.0
117+
118+
bits, bits_clean = discretization.tanh_discrete_bottleneck(
119+
x, hparams.bottleneck_bits, hparams.bottleneck_noise,
120+
hparams.discretize_warmup_steps, hparams.mode)
138121
if not hparams.full_latent_tower:
139-
d_pred = tf.reshape(tf.maximum(tf.stop_gradient(d), 0), [
140-
batch_size, hparams.bottleneck_bits // 8, 8])
141-
d_int = discretization.bit_to_int(d_pred, 8)
142-
tf.summary.histogram("d_int", tf.reshape(d_int, [-1]))
143-
d_hot = tf.one_hot(d_int, 256, axis=-1)
144-
d_pred = discrete_embed(d_hot)
145-
layer_pred = tfl.flatten(layer)
146-
prediction0 = tfl.dense(layer_pred, state_size, name="istate")
147-
c_state = tfl.dense(layer_pred, state_size, name="cstate")
148-
m_state = tfl.dense(layer_pred, state_size, name="mstate")
149-
pred = tf.concat([tf.expand_dims(prediction0, axis=1), d_pred], axis=1)
150-
state = (c_state, m_state)
151-
outputs = []
152-
for i in range(hparams.bottleneck_bits // 8):
153-
output, state = lstm_cell(pred[:, i, :], state)
154-
outputs.append(tf.expand_dims(output, axis=1))
155-
outputs = tf.concat(outputs, axis=1)
156-
d_int_pred = discrete_predict(outputs)
157-
pred_loss = tf.losses.sparse_softmax_cross_entropy(
158-
logits=d_int_pred, labels=d_int)
159-
pred_loss = tf.reduce_mean(pred_loss)
160-
if hparams.mode == tf.estimator.ModeKeys.TRAIN:
161-
x += tf.truncated_normal(
162-
common_layers.shape_list(x), mean=0.0, stddev=0.2)
163-
x = tf.tanh(x)
164-
noise = tf.random_uniform(common_layers.shape_list(x))
165-
noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0
166-
x *= noise
167-
d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
168-
p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps)
169-
d = tf.where(tf.less(tf.random_uniform([batch_size]), p), d, x)
170-
return add_d(layer, d), pred_loss
122+
_, pred_loss = discretization.predict_bits_with_lstm(
123+
layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
124+
target_bits=bits_clean)
125+
126+
return add_bits(layer, bits), pred_loss
171127

172128

173129
@registry.register_hparams
@@ -224,7 +180,7 @@ def next_frame_basic_stochastic_discrete():
224180
hparams.learning_rate_schedule = "linear_warmup * constant"
225181
hparams.add_hparam("bottleneck_bits", 64)
226182
hparams.add_hparam("bottleneck_noise", 0.02)
227-
hparams.add_hparam("discrete_warmup_steps", 40000)
183+
hparams.add_hparam("discretize_warmup_steps", 40000)
228184
hparams.add_hparam("full_latent_tower", False)
229185
hparams.add_hparam("latent_predictor_state_size", 128)
230186
hparams.add_hparam("latent_predictor_temperature", 0.5)

tensor2tensor/rl/trainer_model_based_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def rlmb_base_stochastic_discrete():
203203
"""Base setting with stochastic discrete model."""
204204
hparams = rlmb_base()
205205
hparams.learning_rate_bump = 1.0
206-
hparams.grayscale = True
206+
hparams.grayscale = False
207207
hparams.generative_model = "next_frame_basic_stochastic_discrete"
208208
hparams.generative_model_params = "next_frame_basic_stochastic_discrete"
209209
return hparams

0 commit comments

Comments
 (0)