@@ -704,7 +704,7 @@ def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores,
704704 finished_flags , states )
705705
706706 def _is_finished (i , unused_alive_seq , alive_log_probs , unused_finished_seq ,
707- finished_scores , finished_in_finished , unused_states ):
707+ finished_scores , unused_finished_in_finished , unused_states ):
708708 """Checking termination condition.
709709
710710 We terminate when we decoded up to decode_length or the lowest scoring item
@@ -716,30 +716,33 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
716716 alive_log_probs: probabilities of the beams. [batch_size, beam_size]
717717 finished_scores: scores for each of these sequences.
718718 [batch_size, beam_size]
719- finished_in_finished: finished bools for each of these sequences.
720- [batch_size, beam_size]
721719
722720 Returns:
723721 Bool.
724722 """
725- if not stop_early :
726- return tf .less (i , decode_length )
727723 max_length_penalty = tf .pow (((5. + tf .to_float (decode_length )) / 6. ), alpha )
728724 # The best possible score of the most likely alive sequence.
729725 lower_bound_alive_scores = alive_log_probs [:, 0 ] / max_length_penalty
730726
731- # Now to compute the lowest score of a finished sequence in finished
732- # If the sequence isn't finished, we multiply it's score by 0. since
733- # scores are all -ve, taking the min will give us the score of the lowest
734- # finished item.
735- lowest_score_of_finished_in_finished = tf .reduce_min (
736- finished_scores * tf .to_float (finished_in_finished ), axis = 1 )
737- # If none of the sequences have finished, then the min will be 0 and
738- # we have to replace it by -ve INF if it is. The score of any seq in alive
739- # will be much higher than -ve INF and the termination condition will not
740- # be met.
741- lowest_score_of_finished_in_finished += (
742- (1. - tf .to_float (tf .reduce_any (finished_in_finished , 1 ))) * - INF )
727+ if not stop_early :
728+ # by considering the min score (in the top N beams) we ensure that
729+ # the decoder will keep decoding until there is at least one beam
730+ # (in the top N) that can be improved (w.r.t. the alive beams).
731+ # any unfinished beam will have score -INF - thus the min
732+ # will always be -INF if there is at least one unfinished beam -
733+ # which means the bound_is_met condition cannot be true in this case.
734+ lowest_score_of_finished_in_finished = tf .reduce_min (finished_scores )
735+ else :
736+ # by taking the max score we only care about the the first beam;
737+ # as soon as this first beam cannot be beaten from the alive beams
738+ # the beam decoder can stop.
739+ # similarly to the above, if the top beam is not completed, its
740+ # finished_score is -INF, thus it will not activate the
741+ # bound_is_met condition. (i.e., decoder will keep going on).
742+ # note we need to find the max for every sequence eparately - so, we need
743+ # to keep the batch dimension (see axis=1)
744+ lowest_score_of_finished_in_finished = tf .reduce_max (finished_scores ,
745+ axis = 1 )
743746
744747 bound_is_met = tf .reduce_all (
745748 tf .greater (lowest_score_of_finished_in_finished ,
0 commit comments