Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ generation.
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__

[[autodoc]] TFForcedBOSTokenLogitsProcessor
- __call__

[[autodoc]] TFForcedEOSTokenLogitsProcessor
- __call__

[[autodoc]] FlaxLogitsProcessor
- __call__

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,8 @@
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_logits_process"] = [
"TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor",
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFLogitsWarper",
Expand Down Expand Up @@ -3827,6 +3829,8 @@
# Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFLogitsProcessor,
TFLogitsProcessorList,
TFLogitsWarper,
Expand Down
85 changes: 72 additions & 13 deletions src/transformers/generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,10 @@ def __init__(self, min_length: int, eos_token_id: int):
self.min_length = min_length
self.eos_token_id = eos_token_id

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
# create boolean flag to decide if min length penalty should be applied
cur_len = input_ids.shape[-1]
apply_penalty = 1 - tf.clip_by_value(cur_len - self.min_length, 0, 1)

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
# generate is not XLA - compileable anyways
if apply_penalty:
if cur_len < self.min_length:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf"))

Expand Down Expand Up @@ -259,8 +255,8 @@ def _create_score_penalties(self, input_ids, logits):
np.put(token_penalties[i], prev_input_id, logit_penalties)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids, scores)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA greedy search was probably missing this as well in the logits processors, since it has the same padded input_ids

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes thanks for adding it! Just note that I don't really think we can make this processor XLA compilable anyways as it's very complex and numpy can't be used in XLA. cur_len is mostly added in Flax/JAX to make the rprocessors XLA-compilable. But doesn't hurt to added it here!

Copy link
Member

@Rocketknight1 Rocketknight1 Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tf.unique is not compatible with XLA because the output shape is dependent on the specific input data, and so cannot be inferred at compile time. However, there should be a way to make this logit processor XLA-compilable - there's probably some solution where you store counts in a sparse matrix and then use triu() or tril() followed by a matmul to see if a token has been preceded by the same token. Let me know if you want me to try that (here or in a separate PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧠
I'd leave it to a subsequent PR, XLA-readiness is not the main priority here and this PR is already very long

score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)

scores = tf.math.multiply(scores, score_penalties)

Expand Down Expand Up @@ -330,12 +326,12 @@ def _tokens_match(prev_tokens, tokens):

return banned_tokens

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:

vocab_size = scores.shape[-1]

# calculate a list of banned tokens according to bad words
banned_tokens = self.calc_banned_bad_words_ids(input_ids)
banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len])

banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
Expand Down Expand Up @@ -365,12 +361,13 @@ def __init__(self, ngram_size: int):
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size

def calc_banned_ngram_tokens(self, prev_input_ids, num_hypos, cur_len):
def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search
if cur_len + 1 < self.ngram_size:
# return no banned tokens if we haven't generated ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
prev_input_ids = input_ids[:, :cur_len]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].numpy().tolist()
generated_ngram = generated_ngrams[idx]
Expand All @@ -388,10 +385,9 @@ def _get_generated_ngrams(hypo_idx):

return banned_tokens

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:

batch_size, vocab_size = scores.shape
cur_len = input_ids.shape[-1]
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)

# create banned_tokens boolean mask
Expand All @@ -406,3 +402,66 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
)

return scores


class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] that enforces the specified token as the first generated token.

Args:
bos_token_id (`int`):
The id of the token to force as the first generated token.
"""

def __init__(self, bos_token_id: int):
if bos_token_id < 0:
raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}")
self.bos_token_id = bos_token_id

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
if cur_len == 1:
batch_size, num_tokens = scores.shape
# sets the score to 0 in the bos_token_id column
scores = tf.zeros((batch_size, 1))
# sets the score to -inf everywhere else
if self.bos_token_id > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as for EOS

scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1)
if self.bos_token_id < (num_tokens - 1):
scores = tf.concat(
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))),
axis=-1,
)
return scores


class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.

Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`int`):
The id of the token to force as the last generated token when `max_length` is reached.
"""

def __init__(self, max_length: int, eos_token_id: int):
self.max_length = max_length
if eos_token_id < 0:
raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}")
self.eos_token_id = eos_token_id

def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
if cur_len == self.max_length - 1:
batch_size, num_tokens = scores.shape
# sets the score to 0 in the eos_token_id column
scores = tf.zeros((batch_size, 1))
# sets the score to -inf everywhere else
if self.eos_token_id > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) think it'd be cleaner to raise a ValueError if eos_token_id <= 0 in __init__() . This should never be the case really. But maybe let's leave it for a follow-up PR

scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1)
if self.eos_token_id < (num_tokens - 1):
scores = tf.concat(
(scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))),
axis=-1,
)
return scores
Loading