diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 6e06a6fdf816..9521e2d21636 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Tuple +import math +from typing import Callable, Iterable, List, Optional, Tuple import torch from torch import Tensor @@ -72,6 +73,7 @@ def postprocess_next_token_scores( repetition_penalty, batch_size, num_beams, + prefix_allowed_tokens_fn, ): # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: @@ -105,6 +107,22 @@ def postprocess_next_token_scores( # Modify the scores in place by setting the banned tokens logits to `-inf` set_scores_to_inf_for_banned_tokens(scores, banned_tokens) + if prefix_allowed_tokens_fn is not None: + # calculate a list of banned tokens according `prefix_allowed_tokens_fn` + banned_tokens = prefix_allowed_tokens_fn(input_ids.view(batch_size, num_beams, -1)) + # Modify the scores by setting the banned tokens logits to `-inf` + mask = torch.full_like(scores, -math.inf) + + for i, beam_banned_tokens_i in enumerate( + [ + beam_banned_tokens + for batch_banned_tokens in banned_tokens + for beam_banned_tokens in batch_banned_tokens + ] + ): + mask[i, beam_banned_tokens_i] = 0 + scores += mask + return scores @torch.no_grad() @@ -131,6 +149,7 @@ def generate( attention_mask: Optional[torch.LongTensor] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, + prefix_allowed_tokens_fn: Optional[Callable] = None, **model_kwargs ) -> torch.LongTensor: r""" @@ -205,6 +224,14 @@ def generate( use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + prefix_allowed_tokens_fn: (:obj:`Callable`, `optional`, defaults to :obj:`None`): + If provided, it has to be a function that has as arguments :obj:`inputs_id`. At each step of Beam + Search, this function is called with the :obj:`inputs_id` containing the previously generated tokens as + a tensor of shape :obj:`(batch_size * num_beams)`:. This function has to return a list of lists with + the allowed BPE tokens at the next step (list of batches and list of beams). + + This argument is useful for constrained generation conditioned on the prefix. If not provided no + constrain is applied. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. @@ -486,6 +513,7 @@ def generate( vocab_size=vocab_size, attention_mask=attention_mask, use_cache=use_cache, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, model_kwargs=model_kwargs, ) else: @@ -506,6 +534,7 @@ def generate( batch_size=effective_batch_size, attention_mask=attention_mask, use_cache=use_cache, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, model_kwargs=model_kwargs, ) @@ -529,6 +558,7 @@ def _generate_no_beam_search( batch_size, attention_mask, use_cache, + prefix_allowed_tokens_fn, model_kwargs, ): """ @@ -560,6 +590,7 @@ def _generate_no_beam_search( repetition_penalty=repetition_penalty, batch_size=batch_size, num_beams=1, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, ) # if model has past, then set the past variable to speed up decoding @@ -635,6 +666,7 @@ def _generate_beam_search( vocab_size, attention_mask, use_cache, + prefix_allowed_tokens_fn, model_kwargs, ): """Generate sequences for each example with beam search.""" @@ -692,6 +724,7 @@ def _generate_beam_search( repetition_penalty=repetition_penalty, batch_size=batch_size, num_beams=num_beams, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, ) assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( diff --git a/src/transformers/modeling_rag.py b/src/transformers/modeling_rag.py index c5809e1436a0..7f135394f68b 100644 --- a/src/transformers/modeling_rag.py +++ b/src/transformers/modeling_rag.py @@ -15,7 +15,7 @@ """RAG model implementation.""" from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch @@ -1225,6 +1225,7 @@ def generate( bad_words_ids=None, num_return_sequences=None, decoder_start_token_id=None, + prefix_allowed_tokens_fn: Optional[Callable] = None, n_docs=None, **kwargs ): @@ -1297,6 +1298,12 @@ def generate( function, where we set ``num_return_sequences`` to :obj:`num_beams`. decoder_start_token_id (:obj:`int`, `optional`): If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. + prefix_allowed_tokens_fn: (:obj:`Callable`, `optional`, defaults to :obj:`None`): + If provided, it has to be a function that has as arguments :obj:`inputs_id`. At each step of Beam + Search, this function is called with the :obj:`inputs_id` containing the previously generated tokens as + a tensor of shape :obj:`(batch_size * num_beams)`:. This function has to return a list of lists with + the allowed BPE tokens at the next step (list of batches and list of beams). This argument is useful + for constrained generation conditioned on the prefix. If not provided no constrain is applied. n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`) Number of documents to retrieve and/or number of documents for which to generate an answer. @@ -1424,6 +1431,7 @@ def extend_enc_output(tensor, num_beams=None): vocab_size=vocab_size, attention_mask=context_attention_mask, use_cache=use_cache, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, model_kwargs=kwargs, ) else: @@ -1444,6 +1452,7 @@ def extend_enc_output(tensor, num_beams=None): batch_size=batch_size, attention_mask=context_attention_mask, use_cache=use_cache, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, model_kwargs=kwargs, )