diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index d22fbaf280de..11c6f6ac99a6 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from abc import ABC, abstractmethod from collections import UserDict from typing import List, Optional, Tuple @@ -156,6 +155,7 @@ def __init__( self, batch_size: int, num_beams: int, + max_length: int, device: torch.device, length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, @@ -167,6 +167,7 @@ def __init__( self.device = device self.length_penalty = length_penalty self.do_early_stopping = do_early_stopping + self.max_length = max_length self.num_beam_hyps_to_keep = num_beam_hyps_to_keep self.num_beam_groups = num_beam_groups self.group_size = self.num_beams // self.num_beam_groups @@ -177,6 +178,7 @@ def __init__( num_beams=self.num_beams, length_penalty=self.length_penalty, early_stopping=self.do_early_stopping, + max_length=self.max_length, ) for _ in range(batch_size) ] @@ -194,13 +196,6 @@ def __init__( f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." ) - if "max_length" in kwargs: - warnings.warn( - "Passing `max_length` to BeamSearchScorer is deprecated and has no effect. " - "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" - ", or `group_beam_search(...)`." - ) - @property def is_done(self) -> bool: return self._done.all() @@ -424,6 +419,7 @@ def __init__( self, batch_size: int, num_beams: int, + max_length: int, constraints: List[Constraint], device: torch.device, length_penalty: Optional[float] = 1.0, @@ -436,6 +432,7 @@ def __init__( self.device = device self.length_penalty = length_penalty self.do_early_stopping = do_early_stopping + self.max_length = max_length self.num_beam_hyps_to_keep = num_beam_hyps_to_keep self.num_beam_groups = num_beam_groups self.group_size = self.num_beams // self.num_beam_groups @@ -447,6 +444,7 @@ def __init__( num_beams=self.num_beams, length_penalty=self.length_penalty, early_stopping=self.do_early_stopping, + max_length=self.max_length, ) for _ in range(batch_size) ] @@ -464,13 +462,6 @@ def __init__( f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." ) - if "max_length" in kwargs: - warnings.warn( - "Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. " - "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" - ", or `group_beam_search(...)`." - ) - @property def is_done(self) -> bool: return self._done.all() @@ -851,12 +842,13 @@ def finalize( class BeamHypotheses: - def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool): + def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: int): """ Initialize n-best list of hypotheses. """ self.length_penalty = length_penalty self.early_stopping = early_stopping + self.max_length = max_length self.num_beams = num_beams self.beams = [] self.worst_score = 1e9 @@ -892,6 +884,9 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: elif self.early_stopping: return True else: - cur_score = best_sum_logprobs / cur_len**self.length_penalty + if self.length_penalty > 0.0: + cur_score = best_sum_logprobs / self.max_length**self.length_penalty + else: + cur_score = best_sum_logprobs / cur_len**self.length_penalty ret = self.worst_score >= cur_score return ret diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 5d936ce5b1dc..38bf786495ef 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -797,11 +797,14 @@ def beam_search_cond_fn(state): not_max_length_yet = state.cur_len < max_length # 2. can the new beams still improve? - best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) + if length_penalty > 0.0: + best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) + else: + best_running_score = state.running_scores[:, -1:] / (state.cur_len**length_penalty) worst_finished_score = jnp.where( state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) ) - improvement_still_possible = jnp.all(worst_finished_score < best_running_score) + improvement_still_possible = jnp.any(best_running_score > worst_finished_score) # 3. is there still a beam that has not finished? still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index e437e55f48a3..1418f4c11afe 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -3066,16 +3066,19 @@ def beam_search_cond_fn( not_max_length_yet = cur_len < max_length # 2. can the new beams still improve? - best_running_score = running_scores[:, :1] / (max_length**length_penalty) + if length_penalty > 0.0: + best_running_score = running_scores[:, :1] / (max_length**length_penalty) + else: + best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) worst_finished_score = tf.where( is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 ) - improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score) + improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score) # 3. is there still a beam that has not finished? still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) - return not_max_length_yet & (still_open_beam | improvement_still_possible) + return not_max_length_yet & still_open_beam & improvement_still_possible def beam_search_body_fn( cur_len, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 03ad4a25a1d9..e98f5705dedc 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1406,6 +1406,7 @@ def generate( beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, @@ -1442,6 +1443,7 @@ def generate( beam_scorer = BeamSearchScorer( batch_size=batch_size * generation_config.num_return_sequences, num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, @@ -1577,6 +1579,7 @@ def typeerror(): constraints=final_constraints, batch_size=batch_size, num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, @@ -2534,6 +2537,7 @@ def beam_search( >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, + ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, ... ) @@ -3543,7 +3547,11 @@ def constrained_beam_search( >>> # instantiate beam scorer >>> beam_scorer = ConstrainedBeamSearchScorer( - ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints + ... batch_size=1, + ... num_beams=num_beams, + ... max_length=model.config.max_length, + ... device=model.device, + ... constraints=constraints, ... ) >>> # instantiate logits processors diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 461e06ec4f75..5272e2720385 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1562,6 +1562,7 @@ def extend_enc_output(tensor, num_beams=None): beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, + max_length=generation_config.max_length, device=self.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, diff --git a/tests/generation/test_beam_search.py b/tests/generation/test_beam_search.py index 72202ae2dad9..b1dc31dc2652 100644 --- a/tests/generation/test_beam_search.py +++ b/tests/generation/test_beam_search.py @@ -66,6 +66,7 @@ def prepare_beam_scorer(self, **kwargs): return BeamSearchScorer( batch_size=kwargs.get("batch_size", self.batch_size), num_beams=kwargs.get("num_beams", self.num_beams), + max_length=kwargs.get("max_length", self.max_length), device=torch_device, length_penalty=kwargs.get("length_penalty", self.length_penalty), do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping), @@ -81,7 +82,7 @@ def prepare_inputs(self): def check_beam_hypotheses(self, input_ids, *args): # check that correct number of beam hypotheses is set in beam scorer - beam_scorer = self.prepare_beam_scorer(do_early_stopping=True) + beam_scorer = self.prepare_beam_scorer(do_early_stopping=True, max_length=input_ids.shape[-1]) beam_hyp = beam_scorer._beam_hyps[0] self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size) @@ -100,7 +101,7 @@ def check_beam_hypotheses(self, input_ids, *args): self.parent.assertTrue(beam_hyp.is_done(-10.0, 5)) # re-init - beam_scorer = self.prepare_beam_scorer(do_early_stopping=False) + beam_scorer = self.prepare_beam_scorer(do_early_stopping=False, max_length=input_ids.shape[-1]) beam_hyp = beam_scorer._beam_hyps[0] # add `num_beams + 1` beams to change `worst_score` @@ -291,6 +292,7 @@ def prepare_constrained_beam_scorer(self, **kwargs): constraints=kwargs.get("constraints", self.constraints), batch_size=kwargs.get("batch_size", self.batch_size), num_beams=kwargs.get("num_beams", self.num_beams), + max_length=kwargs.get("max_length", self.max_length), device=torch_device, length_penalty=kwargs.get("length_penalty", self.length_penalty), do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping), @@ -309,7 +311,9 @@ def prepare_inputs(self): def check_beam_hypotheses(self, input_ids, *args): # check that correct number of beam hypotheses is set in beam scorer - constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True) + constrained_beam_scorer = self.prepare_constrained_beam_scorer( + do_early_stopping=True, max_length=input_ids.shape[-1] + ) beam_hyp = constrained_beam_scorer._beam_hyps[0] self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size) @@ -328,7 +332,9 @@ def check_beam_hypotheses(self, input_ids, *args): self.parent.assertTrue(beam_hyp.is_done(-10.0, 5)) # re-init - constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False) + constrained_beam_scorer = self.prepare_constrained_beam_scorer( + do_early_stopping=False, max_length=input_ids.shape[-1] + ) beam_hyp = constrained_beam_scorer._beam_hyps[0] # add `num_beams + 1` beams to change `worst_score` diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a03f0d12b9d1..29c9f89190a3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -174,6 +174,7 @@ def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=beam_kwargs["num_beams"], + max_length=max_length, device=torch_device, length_penalty=beam_kwargs["length_penalty"], do_early_stopping=beam_kwargs["early_stopping"], @@ -194,6 +195,7 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=beam_kwargs["num_beams"], + max_length=max_length, device=torch_device, length_penalty=beam_kwargs["length_penalty"], do_early_stopping=beam_kwargs["early_stopping"], @@ -214,6 +216,7 @@ def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, batch_size=batch_size, constraints=constraints, num_beams=beam_kwargs["num_beams"], + max_length=max_length, device=torch_device, length_penalty=beam_kwargs["length_penalty"], do_early_stopping=beam_kwargs["early_stopping"], @@ -1898,6 +1901,7 @@ def test_max_length_backward_compat_beam_search(self): beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, + max_length=max_length, device=torch_device, ) with self.assertWarns(UserWarning): @@ -1930,6 +1934,7 @@ def test_max_length_backward_compat_group_beam_search(self): diverse_beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, + max_length=max_length, device=torch_device, num_beam_hyps_to_keep=num_return_sequences, num_beam_groups=num_beam_groups, @@ -1991,6 +1996,7 @@ def test_max_length_warning_if_different(self): beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, + max_length=max_length, device=torch_device, ) with self.assertWarns(UserWarning): @@ -2008,6 +2014,7 @@ def test_max_length_warning_if_different(self): diverse_beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, + max_length=max_length, device=torch_device, num_beam_hyps_to_keep=num_return_sequences, num_beam_groups=num_beam_groups, @@ -2022,59 +2029,6 @@ def test_max_length_warning_if_different(self): **model_kwargs, ) - def test_beam_search_warning_if_max_length_is_passed(self): - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - - batch_size = 1 - num_beams = 3 - - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - input_ids = input_ids.expand(num_beams, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - - # pretend decoder_input_ids correspond to first encoder input id - decoder_input_ids = input_ids[:, :1] - - stopping_criteria_max_length = 18 - stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) - - with self.assertWarns(UserWarning): - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - max_length=10, - ) - - generated_ids = bart_model.beam_search( - decoder_input_ids, - num_beams=num_beams, - stopping_criteria=stopping_criteria, - beam_scorer=beam_scorer, - **model_kwargs, - ) - - beam_scorer_no_max_len = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - ) - - generated_ids_no_max_len = bart_model.beam_search( - decoder_input_ids, - num_beams=num_beams, - stopping_criteria=stopping_criteria, - beam_scorer=beam_scorer_no_max_len, - **model_kwargs, - ) - - # BeamSearchScorer max_length should not influence "real" max_length - self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist()) - def test_custom_stopping_criteria_overload_error(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") @@ -2744,6 +2698,7 @@ def test_beam_search_example_integration(self): beam_scorer = BeamSearchScorer( batch_size=1, num_beams=num_beams, + max_length=model.config.max_length, device=model.device, ) @@ -2924,7 +2879,11 @@ def test_constrained_beam_search_example_integration(self): # instantiate beam scorer beam_scorer = ConstrainedBeamSearchScorer( - batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints + batch_size=1, + num_beams=num_beams, + max_length=model.config.max_length, + device=model.device, + constraints=constraints, ) # instantiate logits processors diff --git a/tests/models/t5/test_modeling_flax_t5.py b/tests/models/t5/test_modeling_flax_t5.py index f4bd54e97af1..10e6622bb7df 100644 --- a/tests/models/t5/test_modeling_flax_t5.py +++ b/tests/models/t5/test_modeling_flax_t5.py @@ -1076,7 +1076,7 @@ def test_summarization(self): expected_summaries = [ 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" - " magazine says . all 150 on board were killed when germanwings flight 9525 crashed .", + " magazine says . all 150 on board were killed in the crash .", "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" " preliminary examination into the situation in the occupied Palestinian territory . as members of the" " court, Palestinians may be subject to counter-charges as well .", diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 4dcc14d80703..201fbd009465 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1851,6 +1851,17 @@ def _generate_and_check_results(model, config, inputs_dict): config.eos_token_id = None # Generate until max length config.do_sample = False + # fix config for models with additional sequence-length limiting settings + for var_name in ["max_position_embeddings", "max_target_positions"]: + attr = getattr(config, var_name, None) + if attr is not None and attr < generate_kwargs["max_new_tokens"]: + try: + setattr(config, var_name, generate_kwargs["max_new_tokens"]) + except NotImplementedError: + # xlnet will raise an exception when trying to set + # max_position_embeddings. + pass + model = model_class(config) if model.supports_xla_generation: