diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index e437e55f48a3..c9f3adda9e01 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -3066,16 +3066,20 @@ def beam_search_cond_fn( not_max_length_yet = cur_len < max_length # 2. can the new beams still improve? - best_running_score = running_scores[:, :1] / (max_length**length_penalty) + best_running_score = running_scores[:, :1] / tf.cast(cur_len, dtype=running_scores.dtype) ** length_penalty worst_finished_score = tf.where( is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 ) improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score) # 3. is there still a beam that has not finished? - still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) + # still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) + still_open_beam = ~(tf.math.reduce_all(is_sent_finished)) - return not_max_length_yet & (still_open_beam | improvement_still_possible) + _early_stopping = tf.constant(early_stopping > 0, dtype=tf.bool) + + # return not_max_length_yet & (still_open_beam | improvement_still_possible) + return not_max_length_yet & (still_open_beam | (~_early_stopping & improvement_still_possible)) def beam_search_body_fn( cur_len,