Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 36e9131

Browse files
mirkobronziafrozenator
authored andcommitted
changed stopping condition for the beam decoder (when returning all the beams) - fixed test accordingly (#965)
1 parent e5ecacf commit 36e9131

File tree

3 files changed

+62
-21
lines changed

3 files changed

+62
-21
lines changed

tensor2tensor/models/transformer_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,6 @@ def testBeamVsFast(self):
206206
beam_res = beam_result.eval()
207207
fast_res = fast_result.eval()
208208

209-
self.assertEqual(fast_res.shape,
210-
(BATCH_SIZE, INPUT_LENGTH + decode_length))
211209
self.assertAllClose(beam_res, fast_res)
212210

213211
def testTransformerWithoutProblem(self):

tensor2tensor/utils/beam_search.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores,
704704
finished_flags, states)
705705

706706
def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
707-
finished_scores, finished_in_finished, unused_states):
707+
finished_scores, unused_finished_in_finished, unused_states):
708708
"""Checking termination condition.
709709
710710
We terminate when we decoded up to decode_length or the lowest scoring item
@@ -716,30 +716,33 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
716716
alive_log_probs: probabilities of the beams. [batch_size, beam_size]
717717
finished_scores: scores for each of these sequences.
718718
[batch_size, beam_size]
719-
finished_in_finished: finished bools for each of these sequences.
720-
[batch_size, beam_size]
721719
722720
Returns:
723721
Bool.
724722
"""
725-
if not stop_early:
726-
return tf.less(i, decode_length)
727723
max_length_penalty = tf.pow(((5. + tf.to_float(decode_length)) / 6.), alpha)
728724
# The best possible score of the most likely alive sequence.
729725
lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty
730726

731-
# Now to compute the lowest score of a finished sequence in finished
732-
# If the sequence isn't finished, we multiply it's score by 0. since
733-
# scores are all -ve, taking the min will give us the score of the lowest
734-
# finished item.
735-
lowest_score_of_finished_in_finished = tf.reduce_min(
736-
finished_scores * tf.to_float(finished_in_finished), axis=1)
737-
# If none of the sequences have finished, then the min will be 0 and
738-
# we have to replace it by -ve INF if it is. The score of any seq in alive
739-
# will be much higher than -ve INF and the termination condition will not
740-
# be met.
741-
lowest_score_of_finished_in_finished += (
742-
(1. - tf.to_float(tf.reduce_any(finished_in_finished, 1))) * -INF)
727+
if not stop_early:
728+
# by considering the min score (in the top N beams) we ensure that
729+
# the decoder will keep decoding until there is at least one beam
730+
# (in the top N) that can be improved (w.r.t. the alive beams).
731+
# any unfinished beam will have score -INF - thus the min
732+
# will always be -INF if there is at least one unfinished beam -
733+
# which means the bound_is_met condition cannot be true in this case.
734+
lowest_score_of_finished_in_finished = tf.reduce_min(finished_scores)
735+
else:
736+
# by taking the max score we only care about the the first beam;
737+
# as soon as this first beam cannot be beaten from the alive beams
738+
# the beam decoder can stop.
739+
# similarly to the above, if the top beam is not completed, its
740+
# finished_score is -INF, thus it will not activate the
741+
# bound_is_met condition. (i.e., decoder will keep going on).
742+
# note we need to find the max for every sequence eparately - so, we need
743+
# to keep the batch dimension (see axis=1)
744+
lowest_score_of_finished_in_finished = tf.reduce_max(finished_scores,
745+
axis=1)
743746

744747
bound_is_met = tf.reduce_all(
745748
tf.greater(lowest_score_of_finished_in_finished,

tensor2tensor/utils/beam_search_test.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def symbols_to_logits(ids):
129129
self.assertAllEqual([[[0, 0, 1]]], ids)
130130
self.assertAllClose([[0.7 * 0.6]], np.exp(probs))
131131

132-
def testNotGreedyBeamTwo(self):
132+
def testNotGreedyBeamTwoWithStopEarly(self):
133133
batch_size = 1
134134
beam_size = 2
135135
vocab_size = 3
@@ -152,11 +152,51 @@ def symbols_to_logits(ids):
152152
decode_length,
153153
vocab_size,
154154
0.0,
155-
eos_id=1)
155+
eos_id=1,
156+
stop_early=True) # defaul value, but just to make this explicit
157+
158+
with self.test_session():
159+
ids = final_ids.eval()
160+
probs = final_probs.eval()
161+
# given stop_early = True, the only 'assurance' is w.r.t. the first beam
162+
# (i.e., other beams may not even be completed)
163+
# so, we check only the first beam
164+
first_beam = ids[:, 0]
165+
first_probs = probs[:, 0]
166+
self.assertAllEqual([[0, 2, 1]], first_beam)
167+
self.assertAllClose([0.8 * 0.5], np.exp(first_probs))
168+
169+
def testNotGreedyBeamTwoWithoutStopEarly(self):
170+
batch_size = 1
171+
beam_size = 2
172+
vocab_size = 3
173+
decode_length = 3
174+
175+
initial_ids = tf.constant([0] * batch_size) # GO
176+
probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
177+
[[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
178+
[[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])
179+
180+
def symbols_to_logits(ids):
181+
pos = tf.shape(ids)[1]
182+
logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
183+
return logits
184+
185+
final_ids, final_probs = beam_search.beam_search(
186+
symbols_to_logits,
187+
initial_ids,
188+
beam_size,
189+
decode_length,
190+
vocab_size,
191+
0.0,
192+
eos_id=1,
193+
stop_early=False)
156194

157195
with self.test_session():
158196
ids = final_ids.eval()
159197
probs = final_probs.eval()
198+
# given stop_early = False, the algorithm will return all the beams
199+
# so we can test all of them here
160200
self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids)
161201
self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))
162202

0 commit comments

Comments
 (0)