Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 5ea09f3

Browse files
author
Giovanni Campagna
committed
transformer_fast_decode, beam search: take an optional cache and return it
Some models, eg. semantic parsing models with copying mechanisms, want to use the output of Transformer for multiple predictions. One way to do so is to modify the symbols_to_logits_fn to generate the additional predictions and save it in the cache dictionary. To do so, though, fast_decode() must allow an externally supplied cache, and must return it to the caller after the loop.
1 parent 36e1446 commit 5ea09f3

File tree

7 files changed

+58
-20
lines changed

7 files changed

+58
-20
lines changed

tensor2tensor/layers/latent_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def symbols_to_logits_fn(ids):
134134

135135
initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
136136
length = tf.shape(latents_dense_in)[1]
137-
ids, _ = beam_search.beam_search(
137+
ids, _, _ = beam_search.beam_search(
138138
symbols_to_logits_fn,
139139
initial_ids,
140140
1,

tensor2tensor/models/research/transformer_nat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def symbols_to_logits_fn(ids):
227227

228228
initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
229229
length = tf.shape(latents_dense_in)[1]
230-
ids, _ = beam_search.beam_search(
230+
ids, _, _ = beam_search.beam_search(
231231
symbols_to_logits_fn,
232232
initial_ids,
233233
beam_size=1,

tensor2tensor/models/research/transformer_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def symbols_to_logits_fn(ids):
286286

287287
initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
288288
length = tf.shape(latents_dense_in)[1]
289-
ids, _ = beam_search.beam_search(
289+
ids, _, _ = beam_search.beam_search(
290290
symbols_to_logits_fn, initial_ids, beam_size, length,
291291
vocab_size, alpha=0.0, eos_id=-1, stop_early=False)
292292

tensor2tensor/models/transformer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,8 @@ def fast_decode(encoder_output,
812812
eos_id=beam_search.EOS_ID,
813813
batch_size=None,
814814
force_decode_length=False,
815-
scope_prefix="body/"):
815+
scope_prefix="body/",
816+
cache=None):
816817
"""Given encoder output and a symbols to logits function, does fast decoding.
817818
818819
Implements both greedy and beam search decoding, uses beam search iff
@@ -859,7 +860,9 @@ def fast_decode(encoder_output,
859860
vars_3d_num_heads = (
860861
hparams.num_heads if hparams.get("attention_variables_3d") else 0)
861862

862-
cache = {
863+
if cache is None:
864+
cache = dict()
865+
cache.update({
863866
"layer_%d" % layer: {
864867
"k":
865868
common_attention.split_heads(
@@ -870,7 +873,7 @@ def fast_decode(encoder_output,
870873
"f":
871874
tf.zeros([batch_size, 0, hparams.hidden_size]),
872875
} for layer in range(num_layers)
873-
}
876+
})
874877

875878
if encoder_output is not None:
876879
for layer in range(num_layers):
@@ -894,7 +897,7 @@ def fast_decode(encoder_output,
894897

895898
if beam_size > 1: # Beam Search
896899
initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32)
897-
decoded_ids, scores = beam_search.beam_search(
900+
decoded_ids, scores, cache = beam_search.beam_search(
898901
symbols_to_logits_fn,
899902
initial_ids,
900903
beam_size,
@@ -940,7 +943,7 @@ def is_not_finished(i, hit_eos, *_):
940943
hit_eos = tf.fill([batch_size], False)
941944
next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64)
942945
initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
943-
_, _, _, decoded_ids, _, log_prob = tf.while_loop(
946+
_, _, _, decoded_ids, cache, log_prob = tf.while_loop(
944947
is_not_finished,
945948
inner_loop, [
946949
tf.constant(0), hit_eos, next_id, decoded_ids, cache,
@@ -956,7 +959,7 @@ def is_not_finished(i, hit_eos, *_):
956959
])
957960
scores = log_prob
958961

959-
return {"outputs": decoded_ids, "scores": scores}
962+
return {"outputs": decoded_ids, "scores": scores, "cache": cache}
960963

961964

962965
@registry.register_model

tensor2tensor/utils/beam_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
505505
tf.less(i, decode_length), tf.logical_not(bound_is_met))
506506

507507
(_, alive_seq, alive_log_probs, finished_seq, finished_scores,
508-
finished_flags, _) = tf.while_loop(
508+
finished_flags, states) = tf.while_loop(
509509
_is_finished,
510510
inner_loop, [
511511
tf.constant(0), alive_seq, alive_log_probs, finished_seq,
@@ -535,4 +535,4 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
535535
tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
536536
finished_scores = tf.where(
537537
tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
538-
return finished_seq, finished_scores
538+
return finished_seq, finished_scores, states

tensor2tensor/utils/beam_search_test.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def symbols_to_logits(_):
3737
# Just return random logits
3838
return tf.random_uniform((batch_size * beam_size, vocab_size))
3939

40-
final_ids, final_probs = beam_search.beam_search(
40+
final_ids, final_probs, _ = beam_search.beam_search(
4141
symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size,
4242
0.)
4343

@@ -113,7 +113,7 @@ def symbols_to_logits(ids):
113113
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
114114
return logits
115115

116-
final_ids, final_probs = beam_search.beam_search(
116+
final_ids, final_probs, _ = beam_search.beam_search(
117117
symbols_to_logits,
118118
initial_ids,
119119
beam_size,
@@ -144,7 +144,7 @@ def symbols_to_logits(ids):
144144
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
145145
return logits
146146

147-
final_ids, final_probs = beam_search.beam_search(
147+
final_ids, final_probs, _ = beam_search.beam_search(
148148
symbols_to_logits,
149149
initial_ids,
150150
beam_size,
@@ -173,7 +173,7 @@ def symbols_to_logits(ids):
173173
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
174174
return logits
175175

176-
final_ids, final_probs = beam_search.beam_search(
176+
final_ids, final_probs, _ = beam_search.beam_search(
177177
symbols_to_logits,
178178
initial_ids,
179179
beam_size,
@@ -213,7 +213,7 @@ def symbols_to_logits(ids):
213213
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
214214
return logits
215215

216-
final_ids, final_scores = beam_search.beam_search(
216+
final_ids, final_scores, _ = beam_search.beam_search(
217217
symbols_to_logits,
218218
initial_ids,
219219
beam_size,
@@ -256,7 +256,7 @@ def symbols_to_logits(ids):
256256
return logits
257257

258258
# Disable early stopping
259-
final_ids, final_scores = beam_search.beam_search(
259+
final_ids, final_scores, _ = beam_search.beam_search(
260260
symbols_to_logits,
261261
initial_ids,
262262
beam_size,
@@ -302,7 +302,7 @@ def symbols_to_logits(ids, _, states):
302302
states["state"] = tf.placeholder_with_default(
303303
states["state"], shape=(None, 1))
304304

305-
final_ids, _ = beam_search.beam_search(
305+
final_ids, _, _ = beam_search.beam_search(
306306
symbols_to_logits,
307307
initial_ids,
308308
beam_size,
@@ -319,6 +319,41 @@ def symbols_to_logits(ids, _, states):
319319
except tf.errors.InvalidArgumentError as e:
320320
raise AssertionError(e.message)
321321

322+
def testStatesAfterLoop(self):
323+
batch_size = 1
324+
beam_size = 1
325+
vocab_size = 2
326+
decode_length = 3
327+
328+
initial_ids = tf.constant([0] * batch_size) # GO
329+
probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]])
330+
331+
def symbols_to_logits(ids, _, states):
332+
pos = tf.shape(ids)[1] - 1
333+
logits = tf.to_float(tf.log(probabilities[pos, :]))
334+
states["state"] += 1
335+
return logits, states
336+
337+
states = {
338+
"state": tf.zeros((batch_size, 1)),
339+
}
340+
states["state"] = tf.placeholder_with_default(
341+
states["state"], shape=(None, 1))
342+
343+
_, _, final_states = beam_search.beam_search(
344+
symbols_to_logits,
345+
initial_ids,
346+
beam_size,
347+
decode_length,
348+
vocab_size,
349+
0.0,
350+
eos_id=1,
351+
states=states)
352+
353+
with self.test_session() as sess:
354+
final_states = sess.run(final_states)
355+
self.assertAllEqual([[1]], final_states["state"])
356+
322357
def testStateBeamTwo(self):
323358
batch_size = 1
324359
beam_size = 2
@@ -352,7 +387,7 @@ def symbols_to_logits(ids, _, states):
352387
states["state"] = tf.placeholder_with_default(
353388
states["state"], shape=(None, 1))
354389

355-
final_ids, _ = beam_search.beam_search(
390+
final_ids, _, _ = beam_search.beam_search(
356391
symbols_to_logits,
357392
initial_ids,
358393
beam_size,

tensor2tensor/utils/t2t_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ def symbols_to_logits_fn(ids):
698698
inputs = features["inputs"]
699699
decode_length = (common_layers.shape_list(inputs)[1] +
700700
features.get("decode_length", decode_length))
701-
ids, scores = beam_search.beam_search(
701+
ids, scores, _ = beam_search.beam_search(
702702
symbols_to_logits_fn,
703703
initial_ids,
704704
beam_size,

0 commit comments

Comments
 (0)