Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, List, Optional, Tuple
import math
from typing import Callable, Iterable, List, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -72,6 +73,7 @@ def postprocess_next_token_scores(
repetition_penalty,
batch_size,
num_beams,
prefix_allowed_tokens_fn,
):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
Expand Down Expand Up @@ -105,6 +107,22 @@ def postprocess_next_token_scores(
# Modify the scores in place by setting the banned tokens logits to `-inf`
set_scores_to_inf_for_banned_tokens(scores, banned_tokens)

if prefix_allowed_tokens_fn is not None:
# calculate a list of banned tokens according `prefix_allowed_tokens_fn`
banned_tokens = prefix_allowed_tokens_fn(input_ids.view(batch_size, num_beams, -1))
# Modify the scores by setting the banned tokens logits to `-inf`
mask = torch.full_like(scores, -math.inf)

for i, beam_banned_tokens_i in enumerate(
[
beam_banned_tokens
for batch_banned_tokens in banned_tokens
for beam_banned_tokens in batch_banned_tokens
]
):
mask[i, beam_banned_tokens_i] = 0
scores += mask

return scores

@torch.no_grad()
Expand All @@ -131,6 +149,7 @@ def generate(
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
prefix_allowed_tokens_fn: Optional[Callable] = None,
**model_kwargs
) -> torch.LongTensor:
r"""
Expand Down Expand Up @@ -205,6 +224,14 @@ def generate(
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
prefix_allowed_tokens_fn: (:obj:`Callable`, `optional`, defaults to :obj:`None`):
If provided, it has to be a function that has as arguments :obj:`inputs_id`. At each step of Beam
Search, this function is called with the :obj:`inputs_id` containing the previously generated tokens as
a tensor of shape :obj:`(batch_size * num_beams)`:. This function has to return a list of lists with
the allowed BPE tokens at the next step (list of batches and list of beams).

This argument is useful for constrained generation conditioned on the prefix. If not provided no
constrain is applied.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.

Expand Down Expand Up @@ -486,6 +513,7 @@ def generate(
vocab_size=vocab_size,
attention_mask=attention_mask,
use_cache=use_cache,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
model_kwargs=model_kwargs,
)
else:
Expand All @@ -506,6 +534,7 @@ def generate(
batch_size=effective_batch_size,
attention_mask=attention_mask,
use_cache=use_cache,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
model_kwargs=model_kwargs,
)

Expand All @@ -529,6 +558,7 @@ def _generate_no_beam_search(
batch_size,
attention_mask,
use_cache,
prefix_allowed_tokens_fn,
model_kwargs,
):
"""
Expand Down Expand Up @@ -560,6 +590,7 @@ def _generate_no_beam_search(
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=1,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

# if model has past, then set the past variable to speed up decoding
Expand Down Expand Up @@ -635,6 +666,7 @@ def _generate_beam_search(
vocab_size,
attention_mask,
use_cache,
prefix_allowed_tokens_fn,
model_kwargs,
):
"""Generate sequences for each example with beam search."""
Expand Down Expand Up @@ -692,6 +724,7 @@ def _generate_beam_search(
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=num_beams,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)

assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""RAG model implementation."""

from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -1225,6 +1225,7 @@ def generate(
bad_words_ids=None,
num_return_sequences=None,
decoder_start_token_id=None,
prefix_allowed_tokens_fn: Optional[Callable] = None,
n_docs=None,
**kwargs
):
Expand Down Expand Up @@ -1297,6 +1298,12 @@ def generate(
function, where we set ``num_return_sequences`` to :obj:`num_beams`.
decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
prefix_allowed_tokens_fn: (:obj:`Callable`, `optional`, defaults to :obj:`None`):
If provided, it has to be a function that has as arguments :obj:`inputs_id`. At each step of Beam
Search, this function is called with the :obj:`inputs_id` containing the previously generated tokens as
a tensor of shape :obj:`(batch_size * num_beams)`:. This function has to return a list of lists with
the allowed BPE tokens at the next step (list of batches and list of beams). This argument is useful
for constrained generation conditioned on the prefix. If not provided no constrain is applied.
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.

Expand Down Expand Up @@ -1424,6 +1431,7 @@ def extend_enc_output(tensor, num_beams=None):
vocab_size=vocab_size,
attention_mask=context_attention_mask,
use_cache=use_cache,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
model_kwargs=kwargs,
)
else:
Expand All @@ -1444,6 +1452,7 @@ def extend_enc_output(tensor, num_beams=None):
batch_size=batch_size,
attention_mask=context_attention_mask,
use_cache=use_cache,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
model_kwargs=kwargs,
)

Expand Down