diff --git a/docs/source/en/internal/generation_utils.mdx b/docs/source/en/internal/generation_utils.mdx index c8b42d91848e..5a717edb98fb 100644 --- a/docs/source/en/internal/generation_utils.mdx +++ b/docs/source/en/internal/generation_utils.mdx @@ -178,6 +178,12 @@ generation. [[autodoc]] TFRepetitionPenaltyLogitsProcessor - __call__ +[[autodoc]] TFForcedBOSTokenLogitsProcessor + - __call__ + +[[autodoc]] TFForcedEOSTokenLogitsProcessor + - __call__ + [[autodoc]] FlaxLogitsProcessor - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 841b830c3b1d..a47a1ed2540c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1699,6 +1699,8 @@ _import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"] _import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"] _import_structure["generation_tf_logits_process"] = [ + "TFForcedBOSTokenLogitsProcessor", + "TFForcedEOSTokenLogitsProcessor", "TFLogitsProcessor", "TFLogitsProcessorList", "TFLogitsWarper", @@ -3827,6 +3829,8 @@ # Benchmarks from .benchmark.benchmark_tf import TensorFlowBenchmark from .generation_tf_logits_process import ( + TFForcedBOSTokenLogitsProcessor, + TFForcedEOSTokenLogitsProcessor, TFLogitsProcessor, TFLogitsProcessorList, TFLogitsWarper, diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index 22ece088ae42..92bae58eb8fd 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -216,14 +216,10 @@ def __init__(self, min_length: int, eos_token_id: int): self.min_length = min_length self.eos_token_id = eos_token_id - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: - # create boolean flag to decide if min length penalty should be applied - cur_len = input_ids.shape[-1] - apply_penalty = 1 - tf.clip_by_value(cur_len - self.min_length, 0, 1) - + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: # TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since # generate is not XLA - compileable anyways - if apply_penalty: + if cur_len < self.min_length: eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape) scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf")) @@ -259,8 +255,8 @@ def _create_score_penalties(self, input_ids, logits): np.put(token_penalties[i], prev_input_id, logit_penalties) return tf.convert_to_tensor(token_penalties, dtype=tf.float32) - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: - score_penalties = self._create_score_penalties(input_ids, scores) + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: + score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores) scores = tf.math.multiply(scores, score_penalties) @@ -330,12 +326,12 @@ def _tokens_match(prev_tokens, tokens): return banned_tokens - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: vocab_size = scores.shape[-1] # calculate a list of banned tokens according to bad words - banned_tokens = self.calc_banned_bad_words_ids(input_ids) + banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len]) banned_tokens_indices_mask = [] for banned_tokens_slice in banned_tokens: @@ -365,12 +361,13 @@ def __init__(self, ngram_size: int): raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") self.ngram_size = ngram_size - def calc_banned_ngram_tokens(self, prev_input_ids, num_hypos, cur_len): + def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len): # Copied from fairseq for no_repeat_ngram in beam_search if cur_len + 1 < self.ngram_size: # return no banned tokens if we haven't generated ngram_size tokens yet return [[] for _ in range(num_hypos)] generated_ngrams = [{} for _ in range(num_hypos)] + prev_input_ids = input_ids[:, :cur_len] for idx in range(num_hypos): gen_tokens = prev_input_ids[idx].numpy().tolist() generated_ngram = generated_ngrams[idx] @@ -388,10 +385,9 @@ def _get_generated_ngrams(hypo_idx): return banned_tokens - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: batch_size, vocab_size = scores.shape - cur_len = input_ids.shape[-1] banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len) # create banned_tokens boolean mask @@ -406,3 +402,66 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: ) return scores + + +class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor): + r""" + [`TFLogitsProcessor`] that enforces the specified token as the first generated token. + + Args: + bos_token_id (`int`): + The id of the token to force as the first generated token. + """ + + def __init__(self, bos_token_id: int): + if bos_token_id < 0: + raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}") + self.bos_token_id = bos_token_id + + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: + if cur_len == 1: + batch_size, num_tokens = scores.shape + # sets the score to 0 in the bos_token_id column + scores = tf.zeros((batch_size, 1)) + # sets the score to -inf everywhere else + if self.bos_token_id > 0: + scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1) + if self.bos_token_id < (num_tokens - 1): + scores = tf.concat( + (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))), + axis=-1, + ) + return scores + + +class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor): + r""" + [`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached. + + Args: + max_length (`int`): + The maximum length of the sequence to be generated. + eos_token_id (`int`): + The id of the token to force as the last generated token when `max_length` is reached. + """ + + def __init__(self, max_length: int, eos_token_id: int): + self.max_length = max_length + if eos_token_id < 0: + raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}") + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: + if cur_len == self.max_length - 1: + batch_size, num_tokens = scores.shape + # sets the score to 0 in the eos_token_id column + scores = tf.zeros((batch_size, 1)) + # sets the score to -inf everywhere else + if self.eos_token_id > 0: + scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1) + if self.eos_token_id < (num_tokens - 1): + scores = tf.concat( + (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))), + axis=-1, + ) + return scores diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index e96eb191e6a2..8668e6f8dcc2 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -16,12 +16,15 @@ import inspect from dataclasses import dataclass +from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import tensorflow as tf from .generation_tf_logits_process import ( + TFForcedBOSTokenLogitsProcessor, + TFForcedEOSTokenLogitsProcessor, TFLogitsProcessorList, TFMinLengthLogitsProcessor, TFNoBadWordsLogitsProcessor, @@ -560,7 +563,7 @@ def generate( num_beams = num_beams if num_beams is not None else self.config.num_beams do_sample = do_sample if do_sample is not None else self.config.do_sample - if num_beams == 1: + if do_sample is False or num_beams == 1: return self._generate( input_ids=input_ids, max_length=max_length, @@ -586,6 +589,8 @@ def generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, ) # We cannot generate if the model does not have a LM head @@ -1377,8 +1382,8 @@ def _generate( the target language token. forced_eos_token_id (`int`, *optional*): The id of the token to force as the last generated token when `max_length` is reached. - model_specific_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. + model_kwargs: + Additional model specific kwargs will be forwarded to the `call` function of the model. Return: [`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when @@ -1452,12 +1457,20 @@ def _generate( # 1. Set generation parameters if not already defined max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + forced_bos_token_id = ( + forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id + ) + forced_eos_token_id = ( + forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id + ) + output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1489,10 +1502,13 @@ def _generate( model_kwargs["output_hidden_states"] = output_hidden_states if use_cache is not None: model_kwargs["use_cache"] = use_cache + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id) # 4. Prepare model inputs which will be used for auto-regressive generation @@ -1519,6 +1535,7 @@ def _generate( # TODO(Matt, Joao, Patrick) - add more use cases here is_greedy_gen_mode = (num_beams == 1) and do_sample is False is_sample_gen_mode = (num_beams == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and do_sample is False # 6. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( @@ -1526,7 +1543,10 @@ def _generate( no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, min_length=min_length, + max_length=max_length, eos_token_id=eos_token_id, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, ) # 7. go into different generation modes @@ -1571,18 +1591,62 @@ def _generate( **model_kwargs, ) + elif is_beam_gen_mode: + if num_beams < num_return_sequences: + raise ValueError( + "Greedy beam search decoding cannot return more sequences than it has beams. Please set " + f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)" + ) + + # 8. broadcast inputs to the desired number of beams + input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) + + if "encoder_outputs" in model_kwargs: + model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( + model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams + ) + + if "attention_mask" in model_kwargs: + model_kwargs["attention_mask"] = self._expand_to_num_beams( + model_kwargs["attention_mask"], num_beams=num_beams + ) + + # 9. run beam search + return self.beam_search( + input_ids, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + length_penalty=length_penalty, + early_stopping=early_stopping, + logits_processor=logits_processor, + return_dict_in_generate=return_dict_in_generate, + num_return_sequences=num_return_sequences, + **model_kwargs, + ) + + else: + # TODO(Matt, Joao, Patrick) - add more sub-generation methods here + raise NotImplementedError("Beam sampling is currently not implemented.") + + @staticmethod + def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor: + return tf.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) + def _prepare_attention_mask_for_generation( self, - input_ids: tf.Tensor, + inputs: tf.Tensor, pad_token_id: int, ) -> tf.Tensor: - # prepare `attention_mask` if not passed - if (pad_token_id is not None) and tf.math.reduce_any(input_ids == pad_token_id): - return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32) + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64) + is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id) + # Check if input is input_ids and padded -> only then is attention_mask defined + if is_input_ids and is_pad_token_in_inputs: + return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32) else: - return tf.ones(input_ids.shape[:2], dtype=tf.int32) + return tf.ones(inputs.shape[:2], dtype=tf.int32) - def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, model_kwargs) -> Dict[str, Any]: + def _prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor: tf.Tensor, model_kwargs) -> Dict[str, Any]: # get encoder and store encoder outputs encoder = self.get_encoder() @@ -1595,11 +1659,9 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, m } # vision models don't use `attention_mask`. - signature = dict(inspect.signature(encoder.call).parameters) - if "attention_mask" not in signature: - encoder_kwargs.pop("attention_mask") - - encoder_outputs = encoder(input_ids, **encoder_kwargs) + encoder_kwargs["return_dict"] = True + encoder_kwargs[self.main_input_name] = inputs_tensor + encoder_outputs = encoder(**encoder_kwargs) model_kwargs["encoder_outputs"] = encoder_outputs return model_kwargs @@ -1757,7 +1819,10 @@ def _get_logits_processor( no_repeat_ngram_size: int, bad_words_ids: List[List[int]], min_length: int, + max_length: int, eos_token_id: int, + forced_bos_token_id: int, + forced_eos_token_id: int, ) -> TFLogitsProcessorList: """ This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`] @@ -1781,6 +1846,10 @@ def _get_logits_processor( processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) if min_length is not None and eos_token_id is not None and min_length > 0: processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id)) + if forced_bos_token_id is not None: + processors.append(TFForcedBOSTokenLogitsProcessor(forced_bos_token_id)) + if forced_eos_token_id is not None: + processors.append(TFForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) return processors @@ -1949,7 +2018,7 @@ def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_po if not use_xla: input_ids = tf.reshape(generated.concat(), (-1, batch_size)) input_ids = tf.transpose(input_ids[: current_pos[0]]) - next_tokens_scores = logits_processor(input_ids, next_token_logits) + next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=current_pos[0]) # argmax next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) @@ -2064,7 +2133,7 @@ def sample( input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`TFLogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] + An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. logits_warper (`TFLogitsProcessorList`, *optional*): An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`] @@ -2184,7 +2253,7 @@ def sample( next_token_logits = outputs.logits[:, -1, :] # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required @@ -2259,6 +2328,527 @@ def sample( else: return input_ids + def beam_search( + self, + input_ids: tf.Tensor, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + early_stopping: Optional[bool] = None, + logits_processor: Optional[TFLogitsProcessorList] = None, + num_return_sequences: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[TFBeamSearchOutput, tf.Tensor]: + r""" + Generates sequences for models with a language modeling head using beam search with multinomial sampling. + + Parameters: + + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + max_length (`int`, *optional*, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + length_penalty (`float`, *optional*, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. + early_stopping (`bool`, *optional*, defaults to `False`): + Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + logits_processor (`[TFLogitsProcessorList]`, *optional*): + An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + num_return_sequences(`int`, *optional*, defaults to 1): + The number of independently computed returned sequences for each element in the batch. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + model_kwargs: + Additional model specific kwargs will be forwarded to the `call` function of the model. If model is an + encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`], + [`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`] or `tf.Tensor`: A `tf.Tensor` containing the + generated tokens (default behaviour) or a [`~generation_tf_utils.TFBeamSearchDecoderOnlyOutput`] if + `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a + [`~generation_tf_utils.TFBeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... TFAutoModelForSeq2SeqLM, + ... TFLogitsProcessorList, + ... TFMinLengthLogitsProcessor, + ... ) + >>> import tensorflow as tf + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = tf.ones((num_beams, 1), dtype=tf.int64) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... tf.repeat(encoder_input_ids, num_beams, axis=0), return_dict=True + ... ) + ... } + + >>> # instantiate logits processors + >>> logits_processor = TFLogitsProcessorList( + ... [TFMinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] + ... ) + + >>> outputs = model.beam_search(input_ids, logits_processor=logits_processor, **model_kwargs) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + ```""" + + def flatten_beam_dim(tensor, batch_axis=0): + """Flattens the first two dimensions of a non-scalar array.""" + # ignore scalars (e.g. cache index) + if tf.rank(tensor) == 0: + return tensor + return tf.reshape( + tensor, + tensor.shape[:batch_axis] + + [tensor.shape[batch_axis] * tensor.shape[batch_axis + 1]] + + tensor.shape[batch_axis + 2 :], + ) + + def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0): + """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" + # ignore scalars (e.g. cache index) + if tf.rank(tensor) == 0: + return tensor + return tf.reshape( + tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :] + ) + + def gather_beams(nested, beam_indices, batch_axis=0): + """Gathers the beam slices indexed by beam_indices into new beam array.""" + + def gather_fn(tensor): + # ignore scalars (e.g. cache index) + if tf.rank(tensor) == 0: + return tensor + else: + if batch_axis > 0: + # pushes all dimentions before the batch to the end, so we get (batch, beam_id, ...) + perm = [axis for axis in range(tf.rank(tensor)) if axis >= batch_axis] + list( + range(batch_axis) + ) + tensor = tf.transpose(tensor, perm=perm) + + gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1) + if batch_axis > 0: + # transposes back to the original dimensions + perm = [axis for axis in range(tf.rank(tensor)) if axis >= batch_axis] + list( + range(batch_axis) + ) + perm = tf.math.invert_permutation(perm) + gathered_tensor = tf.transpose(gathered_tensor, perm=perm) + + return gathered_tensor + + return tf.nest.map_structure(gather_fn, nested) + + # 1. init beam_search values + logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() + + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_scores = output_scores if output_scores is not None else self.config.output_scores + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + + use_xla = not tf.executing_eagerly() + # TODO (Joao): fix cache format or find programatic way to detect cache index + # GPT2 and other models has a slightly different cache structure, with a different batch axis + model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) + cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0 + + # 2. init `attentions`, `hidden_states`, and `scores` tuples + scores = [] if (return_dict_in_generate and output_scores) else None + decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None + cross_attentions = [] if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None + + # 3. init tensors to use for "xla-compileable" generate function + batch_size, num_beams, cur_len = input_ids.shape + + # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` + sequences = tf.TensorArray( + element_shape=(batch_size, num_beams), + dtype=tf.int32, + dynamic_size=False, + size=max_length, + clear_after_read=False, + ) + running_sequences = tf.TensorArray( + element_shape=(batch_size, num_beams), + dtype=tf.int32, + dynamic_size=False, + size=max_length, + clear_after_read=False, + ) + intermediary_running_sequences = tf.TensorArray( + element_shape=(batch_size, num_beams * 2), + dtype=tf.int32, + dynamic_size=False, + size=max_length, + clear_after_read=False, + ) + for i in range(max_length): + sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) + running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) + intermediary_running_sequences = intermediary_running_sequences.write( + i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2)) + ) + + # write prompt to running_sequences + for i in range(cur_len): + running_sequences = running_sequences.write(i, input_ids[:, :, i]) + + # per batch,beam-item state bit indicating if sentence has finished. + is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool) + + # per batch, beam-item score, logprobs + running_scores = tf.tile( + tf.expand_dims(tf.convert_to_tensor([0.0] + [-1.0e9] * (num_beams - 1)), axis=0), [batch_size, 1] + ) + scores = tf.ones((batch_size, num_beams)) * -1.0e9 + + # flatten beam dim + if "encoder_outputs" in model_kwargs: + model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( + model_kwargs["encoder_outputs"]["last_hidden_state"] + ) + if "attention_mask" in model_kwargs: + model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"]) + + # 4. define "xla-compile-able" stop-condition and auto-regressive function + # define stop-condition and auto-regressive function + def beam_search_cond_fn( + cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs + ): + """ + Beam Search termination condition function -- halts the generation loop if any of these conditions becomes + False + """ + # 1. is less than max length? + not_max_length_yet = cur_len < max_length + + # 2. can the new beams still improve? + best_running_score = running_scores[:, :1] / (max_length**length_penalty) + worst_finished_score = tf.where( + is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 + ) + improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score) + + # 3. is there still a beam that has not finished? + still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) + + return not_max_length_yet & (still_open_beam | improvement_still_possible) + + def beam_search_body_fn( + cur_len, + running_sequences, + running_scores, + sequences, + scores, + is_sent_finished, + model_kwargs, + input_ids_length=1, + intermediary_running_sequences=None, + ): + """ + Beam Search iterative update function -- each iteration adds a new token and updates the best sequences + seen so far + """ + # TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`. + # Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA. + + # 1. Forward current tokens + + # TF places the dynamic dimension (seq_len) in the first axis, we want it in the last + running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0]) + input_token = tf.slice( + running_sequences_seq_last, + (0, 0, cur_len - input_ids_length), + (batch_size, num_beams, input_ids_length), + ) + model_inputs = self.prepare_inputs_for_generation( + flatten_beam_dim(input_token), use_xla=use_xla, **model_kwargs + ) + model_outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) + + # Store scores, attentions and hidden_states when required + if not use_xla and return_dict_in_generate: + if output_scores: + scores.append(model_outputs.logits[:, -1]) + if output_attentions and self.config.is_encoder_decoder: + decoder_attentions.append(model_outputs.decoder_attentions) + elif output_attentions and not self.config.is_encoder_decoder: + decoder_attentions.append(model_outputs.attentions) + if self.config.is_encoder_decoder: + cross_attentions.append(model_outputs.cross_attentions) + + if output_hidden_states and self.config.is_encoder_decoder: + decoder_hidden_states.append(model_outputs.decoder_hidden_states) + elif output_hidden_states and self.config.is_encoder_decoder: + decoder_hidden_states.append(model_outputs.hidden_states) + + # 2. Compute log probs + # get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and + # add new logprobs to existing running logprobs scores. + log_probs = tf.nn.log_softmax(logits) + log_probs = logits_processor( + flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len=cur_len + ) + log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) + log_probs = log_probs + tf.expand_dims(running_scores, axis=2) + vocab_size = log_probs.shape[2] + log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size)) + + # 3. Retrieve top-K + # Each item in batch has num_beams * vocab_size candidate sequences. For each item, get the top 2*k + # candidates with the highest log-probabilities. We gather the top 2*K beams here so that even if the + # best K sequences reach EOS simultaneously, we have another K sequences remaining to continue the live + # beam search. + # Gather the top 2*K scores from _all_ beams. + # Gather 2*k top beams. + # Recover the beam index by floor division. + # Recover token id by modulo division and expand Id array for broadcasting. + # Update sequences for the 2*K top-k new sequences. + beams_to_keep = 2 * num_beams + topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep) + topk_beam_indices = topk_indices // vocab_size + topk_running_sequences_seq_last = gather_beams(running_sequences_seq_last, topk_beam_indices) + topk_ids = topk_indices % vocab_size + + # writes the new token + intermediary_running_sequences = intermediary_running_sequences.unstack( + tf.transpose(topk_running_sequences_seq_last, perm=[2, 0, 1]) + ) + topk_sequences = intermediary_running_sequences.write(cur_len, topk_ids) + topk_sequences_seq_last = tf.transpose(topk_sequences.stack(), perm=[1, 2, 0]) + + # 4. Check which sequences have ended + # Update current sequences: Did the top `num_beams` sequences reach an end marker? + # To prevent these just finished sequences from being added to the current sequences + # set of active beam search sequences, set their log probs to a very large negative value. + eos_in_next_token = topk_sequences_seq_last[:, :, cur_len] == eos_token_id + if eos_token_id is None: + eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences_seq_last[:, :, cur_len].shape) + did_topk_just_finished = eos_in_next_token & tf.broadcast_to( + tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0), + eos_in_next_token.shape, + ) + + # non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next + # running sentences either + running_topk_log_probs = topk_log_probs + tf.cast(eos_in_next_token, tf.float32) * -1.0e9 + + # 5. Get running sequences scores for next + # Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams + # (from top 2*k beams). + next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1] + next_running_sequences_seq_last, next_running_scores = gather_beams( + [topk_sequences_seq_last, running_topk_log_probs], next_topk_indices + ) + + # 6. Process topk logits + # Further process log probs: + # - add length penalty + # - make sure no scores can be added anymore if beam is full + # - make sure still running sequences cannot be chosen as finalized beam + topk_log_probs = topk_log_probs / (cur_len**length_penalty) + beams_in_batch_are_full = ( + tf.broadcast_to( + tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape + ) + & early_stopping + ) + add_penalty = ~did_topk_just_finished | beams_in_batch_are_full + topk_log_probs += tf.cast(add_penalty, tf.float32) * -1.0e9 + + # 7. Get scores, sequences, is sentence finished for next. + # Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores + # to existing finished scores and select the best from the new set of beams + sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0]) + merged_sequences = tf.concat([sequences_seq_last, topk_sequences_seq_last], axis=1) + merged_scores = tf.concat([scores, topk_log_probs], axis=1) + merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1) + topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1] + next_sequences_seq_last, next_scores, next_is_sent_finished = gather_beams( + [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices + ) + + # 8. Prepare data for the next iteration + # Determine the top k beam indices from the original set of all beams. With these, gather the top k + # beam-associated caches. + if "past_key_values" in model_outputs: + cache = tf.nest.map_structure( + lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=cache_batch_axis), + model_outputs.past_key_values, + ) + next_running_indices = gather_beams(topk_beam_indices, next_topk_indices) + next_cache = gather_beams(cache, next_running_indices, batch_axis=cache_batch_axis) + model_outputs["past_key_values"] = tf.nest.map_structure( + lambda tensor: flatten_beam_dim(tensor, batch_axis=cache_batch_axis), next_cache + ) + + if use_xla: + next_model_kwargs = self._update_model_kwargs_for_xla_generation( + model_outputs, model_kwargs, cur_len, max_length + ) + else: + next_model_kwargs = self._update_model_kwargs_for_generation( + model_outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if we don't cache past key values we need the whole input + if model_kwargs.get("past", None) is None: + input_ids_length = cur_len + 1 + # let's throw out `past` since we don't want `None` tensors + model_kwargs.pop("past", None) + + # 9. Prepare the `tf.TensorArray` for the next iteration + next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1])) + next_running_sequences = running_sequences.unstack( + tf.transpose(next_running_sequences_seq_last, perm=[2, 0, 1]) + ) + + return ( + cur_len + 1, + next_running_sequences, + next_running_scores, + next_sequences, + next_scores, + next_is_sent_finished, + next_model_kwargs, + ) + + # 5. run generation + # Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad + beam_search_body_fn = partial( + beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences + ) + + # 1st generation step has to be run before to initialize `past` + beam_search_body_fn_first_iter = partial(beam_search_body_fn, input_ids_length=cur_len) + ( + cur_len, + running_sequences, + running_scores, + sequences, + scores, + is_sent_finished, + model_kwargs, + ) = beam_search_body_fn_first_iter( + cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs + ) + + # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does + # NOT yield EOS token though) + if beam_search_cond_fn( + cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs + ): + maximum_iterations = max_length - cur_len + cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop( + beam_search_cond_fn, + beam_search_body_fn, + (cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs), + maximum_iterations=maximum_iterations, + ) + + # 6. prepare outputs + # convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len) + sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0]) + running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0]) + + # Account for the edge-case where there are no finished sequences for a particular batch item. If so, return + # running sequences for that batch item. + none_finished = tf.math.reduce_any(is_sent_finished, axis=1) + sequences_seq_last = tf.where(none_finished[:, None, None], sequences_seq_last, running_sequences_seq_last) + scores = tf.where(none_finished[:, None], scores, running_scores) + + # Take best beams for each batch (the score is sorted in ascending order) + sequences_seq_last = flatten_beam_dim(sequences_seq_last[:, :num_return_sequences, :]) + scores = flatten_beam_dim(scores[:, :num_return_sequences]) + + if not use_xla: + # Cut for backward compatibility + sequences_seq_last = sequences_seq_last[:, :cur_len] + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + return TFBeamSearchEncoderDecoderOutput( + sequences=sequences_seq_last, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return TFBeamSearchDecoderOnlyOutput( + sequences=sequences_seq_last, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequences_seq_last + def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): # create logit penalties for already seen input_ids @@ -2445,7 +3035,6 @@ def is_done(self, best_sum_logprobs, cur_len): If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we are done with this sentence. """ - if len(self) < self.num_beams: return False elif self.early_stopping: diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py index c5245108e8e9..30f50a29ff40 100644 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -1245,7 +1245,10 @@ def extend_enc_output(tensor, num_beams=None): no_repeat_ngram_size=no_repeat_ngram_size, bad_words_ids=bad_words_ids, min_length=min_length, + max_length=max_length, eos_token_id=eos_token_id, + forced_bos_token_id=None, + forced_eos_token_id=None, ) model_kwargs["attention_mask"] = context_attention_mask diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index db6d8d467aa2..cbd3358ac8ff 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -17,6 +17,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFForcedBOSTokenLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFLogitsProcessor(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/generation/test_generation_logits_process.py b/tests/generation/test_generation_logits_process.py index b95110d0e06b..8489efd44d16 100644 --- a/tests/generation/test_generation_logits_process.py +++ b/tests/generation/test_generation_logits_process.py @@ -472,14 +472,14 @@ def test_forced_eos_token_logits_processor(self): logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) - # check that all scores are -inf except the eos_token_id when max_length is reached + # check that all scores are -inf except the eos_token_id when max_length-1 is reached input_ids = ids_tensor((batch_size, 4), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores) self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all()) self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero - # check that eos_token_id is not forced if max_length is not reached + # check that eos_token_id is not forced if max_length-1 is not reached input_ids = ids_tensor((batch_size, 3), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) scores = logits_processor(input_ids, scores) diff --git a/tests/generation/test_generation_tf_logits_process.py b/tests/generation/test_generation_tf_logits_process.py index 2dc366f05c76..8a5ee053680e 100644 --- a/tests/generation/test_generation_tf_logits_process.py +++ b/tests/generation/test_generation_tf_logits_process.py @@ -26,6 +26,8 @@ import tensorflow as tf from transformers.generation_tf_logits_process import ( + TFForcedBOSTokenLogitsProcessor, + TFForcedEOSTokenLogitsProcessor, TFLogitsProcessorList, TFMinLengthLogitsProcessor, TFNoBadWordsLogitsProcessor, @@ -43,7 +45,7 @@ @require_tf class TFLogitsProcessorTest(unittest.TestCase): def _get_uniform_logits(self, batch_size: int, length: int): - scores = np.ones((batch_size, length), dtype=np.float32) / length + scores = tf.ones((batch_size, length), dtype=tf.float32) / length return scores def test_min_length_dist_processor(self): @@ -54,15 +56,17 @@ def test_min_length_dist_processor(self): min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) # check that min length is applied at length 5 - input_ids = ids_tensor((batch_size, 5), vocab_size=20) + cur_len = 5 + input_ids = ids_tensor((batch_size, cur_len), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = min_dist_processor(input_ids, scores) + scores_before_min_length = min_dist_processor(input_ids, scores, cur_len) self.assertListEqual(scores_before_min_length[:, eos_token_id].numpy().tolist(), 4 * [-float("inf")]) # check that min length is not applied anymore at length 15 - input_ids = ids_tensor((batch_size, 15), vocab_size=20) + cur_len = 15 + input_ids = ids_tensor((batch_size, cur_len), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = min_dist_processor(input_ids, scores) + scores_before_min_length = min_dist_processor(input_ids, scores, cur_len) self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy()) def test_temperature_dist_warper(self): @@ -72,8 +76,10 @@ def test_temperature_dist_warper(self): scores = self._get_uniform_logits(batch_size=2, length=length) # tweak scores to not be uniform anymore + scores = scores.numpy() scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch + scores = tf.convert_to_tensor(scores) # compute softmax probs = tf.nn.softmax(scores, axis=-1) @@ -97,8 +103,11 @@ def test_temperature_dist_warper(self): self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :])) def test_repetition_penalty_dist_process(self): - input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32) vocab_size = 10 + cur_len = 2 + + input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32) + self.assertEqual(cur_len, input_ids.shape[1]) scores = self._get_uniform_logits(batch_size=2, length=vocab_size) @@ -109,7 +118,7 @@ def test_repetition_penalty_dist_process(self): rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) - scores = rep_penalty_proc(input_ids, tf.identity(scores)) + scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len) # check that values were correctly changed self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2) @@ -188,15 +197,18 @@ def test_top_p_dist_warper(self): def test_no_repeat_ngram_dist_processor(self): vocab_size = 3 batch_size = 2 + cur_len = 4 input_ids = tf.constant([[1, 1, 2, 1], [0, 1, 0, 1]], dtype=tf.int32) + self.assertEqual(cur_len, input_ids.shape[1]) + scores = self._get_uniform_logits(batch_size, vocab_size) no_repeat_proc_2_gram = TFNoRepeatNGramLogitsProcessor(2) no_repeat_proc_3_gram = TFNoRepeatNGramLogitsProcessor(3) - filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, tf.identity(scores)) - filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, tf.identity(scores)) + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, tf.identity(scores), cur_len) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, tf.identity(scores), cur_len) # 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch self.assertListEqual( @@ -212,14 +224,17 @@ def test_no_bad_words_dist_processor(self): vocab_size = 5 batch_size = 2 eos_token_id = 4 + cur_len = 4 input_ids = tf.constant([[0, 1, 3, 1], [0, 1, 0, 1]], dtype=tf.int32) + self.assertEqual(cur_len, input_ids.shape[1]) + bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]] scores = self._get_uniform_logits(batch_size, vocab_size) no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id) - filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores)) + filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len) # batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden # batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden @@ -228,14 +243,65 @@ def test_no_bad_words_dist_processor(self): [[True, True, False, True, True], [True, True, True, False, True]], ) + def test_forced_bos_token_logits_processor(self): + vocab_size = 20 + batch_size = 4 + bos_token_id = 0 + + logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id) + + # check that all scores are -inf except the bos_token_id score + cur_len = 1 + input_ids = ids_tensor((batch_size, cur_len), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores, cur_len) + self.assertTrue( + tf.math.reduce_all(tf.math.is_inf(scores[:, bos_token_id + 1 :]) & (scores[:, bos_token_id + 1 :] < 0)) + ) + self.assertListEqual(scores[:, bos_token_id].numpy().tolist(), 4 * [0]) # score for bos_token_id shold be zero + + # check that bos_token_id is not forced if current length is greater than 1 + cur_len = 4 + input_ids = ids_tensor((batch_size, cur_len), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores, cur_len) + self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores)))) + + def test_forced_eos_token_logits_processor(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + max_length = 5 + + logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) + + # check that all scores are -inf except the eos_token_id when max_length-1 is reached + cur_len = 4 + input_ids = ids_tensor((batch_size, cur_len), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores, cur_len) + self.assertTrue( + tf.math.reduce_all(tf.math.is_inf(scores[:, eos_token_id + 1 :]) & (scores[:, eos_token_id + 1 :] < 0)) + ) + self.assertListEqual( + scores[:, eos_token_id].numpy().tolist(), 4 * [0] + ) # score for eos_token_id should be zero + + # check that eos_token_id is not forced if max_length-1 is not reached + cur_len = 3 + input_ids = ids_tensor((batch_size, cur_len), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores, cur_len) + self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores)))) + def test_processor_list(self): batch_size = 4 - sequence_length = 10 + cur_len = 10 vocab_size = 15 eos_token_id = 0 # dummy input_ids and scores - input_ids = ids_tensor((batch_size, sequence_length), vocab_size) + input_ids = ids_tensor((batch_size, cur_len), vocab_size) input_ids_comp = tf.identity(input_ids) scores = self._get_uniform_logits(batch_size, vocab_size) @@ -251,13 +317,13 @@ def test_processor_list(self): no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) # no processor list - scores = min_dist_proc(input_ids, scores) + scores = min_dist_proc(input_ids, scores, cur_len) scores = temp_dist_warp(input_ids, scores) - scores = rep_penalty_proc(input_ids, scores) + scores = rep_penalty_proc(input_ids, scores, cur_len) scores = top_k_warp(input_ids, scores) scores = top_p_warp(input_ids, scores) - scores = no_repeat_proc(input_ids, scores) - scores = no_bad_words_dist_proc(input_ids, scores) + scores = no_repeat_proc(input_ids, scores, cur_len) + scores = no_bad_words_dist_proc(input_ids, scores, cur_len) # with processor list processor = TFLogitsProcessorList( @@ -271,7 +337,7 @@ def test_processor_list(self): no_bad_words_dist_proc, ] ) - scores_comp = processor(input_ids, scores_comp) + scores_comp = processor(input_ids, scores_comp, cur_len=cur_len) # remove inf scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9) diff --git a/tests/gpt2/test_modeling_tf_gpt2.py b/tests/gpt2/test_modeling_tf_gpt2.py index d6470c0d1526..5b78bdf1bd2e 100644 --- a/tests/gpt2/test_modeling_tf_gpt2.py +++ b/tests/gpt2/test_modeling_tf_gpt2.py @@ -536,7 +536,6 @@ def test_lm_generate_greedy_distilgpt2_beam_search_special(self): "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids], "no_repeat_ngram_size": 2, "do_sample": False, - "repetition_penalty": 1.3, "num_beams": 2, } @@ -544,8 +543,8 @@ def test_lm_generate_greedy_distilgpt2_beam_search_special(self): output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) expected_output_string = [ - "Today is a beautiful day and I hope you enjoy it.\nI am very happy to announce that", - "Yesterday was the first time I've ever seen a game where you can play with", + "Today is a beautiful day and a great day for all of us.\n\nI’m", + "Yesterday was the first day of the year for the second time in a row,", ] self.assertListEqual(output_strings, expected_output_string) diff --git a/tests/speech_to_text/test_modeling_tf_speech_to_text.py b/tests/speech_to_text/test_modeling_tf_speech_to_text.py index d8d538fe194a..897c54722a2f 100644 --- a/tests/speech_to_text/test_modeling_tf_speech_to_text.py +++ b/tests/speech_to_text/test_modeling_tf_speech_to_text.py @@ -508,7 +508,7 @@ def test_lm_head_model_random_beam_search_generate(self): # if bos token id is not defined model needs input_ids, num_return_sequences = 1 self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2)) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): # generating more sequences than having beams leads is not possible model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index b7b4b68414a7..bbb0befeb631 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1179,7 +1179,7 @@ def test_lm_head_model_random_beam_search_generate(self): # num_return_sequences = 1 self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2)) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): # generating more sequences than having beams leads is not possible model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)