-
Notifications
You must be signed in to change notification settings - Fork 32k
TF generate refactor - Beam Search #16374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3a27aec
721495c
dfe220e
0a7c5a4
f6b3fbc
f4b5254
1617fcd
6b22171
cec6334
54688db
a337bb5
9910d5e
533b7e6
c91f256
66c5212
22975f0
c78777d
563297e
5a7ec34
e7f6e34
53fe788
a11075a
f4c7fef
d4e6dcd
2b66990
b6f17d9
d1347da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -216,14 +216,10 @@ def __init__(self, min_length: int, eos_token_id: int): | |
| self.min_length = min_length | ||
| self.eos_token_id = eos_token_id | ||
|
|
||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: | ||
| # create boolean flag to decide if min length penalty should be applied | ||
| cur_len = input_ids.shape[-1] | ||
| apply_penalty = 1 - tf.clip_by_value(cur_len - self.min_length, 0, 1) | ||
|
|
||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: | ||
| # TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since | ||
| # generate is not XLA - compileable anyways | ||
| if apply_penalty: | ||
| if cur_len < self.min_length: | ||
| eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape) | ||
| scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf")) | ||
|
|
||
|
|
@@ -259,8 +255,8 @@ def _create_score_penalties(self, input_ids, logits): | |
| np.put(token_penalties[i], prev_input_id, logit_penalties) | ||
| return tf.convert_to_tensor(token_penalties, dtype=tf.float32) | ||
|
|
||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: | ||
| score_penalties = self._create_score_penalties(input_ids, scores) | ||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: | ||
| score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores) | ||
|
|
||
| scores = tf.math.multiply(scores, score_penalties) | ||
|
|
||
|
|
@@ -330,12 +326,12 @@ def _tokens_match(prev_tokens, tokens): | |
|
|
||
| return banned_tokens | ||
|
|
||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: | ||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: | ||
|
|
||
| vocab_size = scores.shape[-1] | ||
|
|
||
| # calculate a list of banned tokens according to bad words | ||
| banned_tokens = self.calc_banned_bad_words_ids(input_ids) | ||
| banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len]) | ||
|
|
||
| banned_tokens_indices_mask = [] | ||
| for banned_tokens_slice in banned_tokens: | ||
|
|
@@ -365,12 +361,13 @@ def __init__(self, ngram_size: int): | |
| raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") | ||
| self.ngram_size = ngram_size | ||
|
|
||
| def calc_banned_ngram_tokens(self, prev_input_ids, num_hypos, cur_len): | ||
| def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len): | ||
| # Copied from fairseq for no_repeat_ngram in beam_search | ||
| if cur_len + 1 < self.ngram_size: | ||
| # return no banned tokens if we haven't generated ngram_size tokens yet | ||
| return [[] for _ in range(num_hypos)] | ||
| generated_ngrams = [{} for _ in range(num_hypos)] | ||
| prev_input_ids = input_ids[:, :cur_len] | ||
| for idx in range(num_hypos): | ||
gante marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| gen_tokens = prev_input_ids[idx].numpy().tolist() | ||
| generated_ngram = generated_ngrams[idx] | ||
|
|
@@ -388,10 +385,9 @@ def _get_generated_ngrams(hypo_idx): | |
|
|
||
| return banned_tokens | ||
|
|
||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: | ||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: | ||
|
|
||
| batch_size, vocab_size = scores.shape | ||
| cur_len = input_ids.shape[-1] | ||
| banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len) | ||
|
|
||
| # create banned_tokens boolean mask | ||
|
|
@@ -406,3 +402,66 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: | |
| ) | ||
|
|
||
| return scores | ||
|
|
||
|
|
||
| class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor): | ||
| r""" | ||
| [`TFLogitsProcessor`] that enforces the specified token as the first generated token. | ||
|
|
||
| Args: | ||
| bos_token_id (`int`): | ||
| The id of the token to force as the first generated token. | ||
| """ | ||
|
|
||
| def __init__(self, bos_token_id: int): | ||
| if bos_token_id < 0: | ||
| raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}") | ||
| self.bos_token_id = bos_token_id | ||
|
|
||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: | ||
| if cur_len == 1: | ||
| batch_size, num_tokens = scores.shape | ||
| # sets the score to 0 in the bos_token_id column | ||
| scores = tf.zeros((batch_size, 1)) | ||
| # sets the score to -inf everywhere else | ||
| if self.bos_token_id > 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as for EOS |
||
| scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1) | ||
| if self.bos_token_id < (num_tokens - 1): | ||
| scores = tf.concat( | ||
| (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))), | ||
| axis=-1, | ||
| ) | ||
| return scores | ||
|
|
||
|
|
||
| class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor): | ||
| r""" | ||
| [`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached. | ||
|
|
||
| Args: | ||
| max_length (`int`): | ||
| The maximum length of the sequence to be generated. | ||
| eos_token_id (`int`): | ||
| The id of the token to force as the last generated token when `max_length` is reached. | ||
| """ | ||
|
|
||
| def __init__(self, max_length: int, eos_token_id: int): | ||
| self.max_length = max_length | ||
| if eos_token_id < 0: | ||
| raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}") | ||
| self.eos_token_id = eos_token_id | ||
|
|
||
| def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: | ||
| if cur_len == self.max_length - 1: | ||
| batch_size, num_tokens = scores.shape | ||
| # sets the score to 0 in the eos_token_id column | ||
| scores = tf.zeros((batch_size, 1)) | ||
| # sets the score to -inf everywhere else | ||
| if self.eos_token_id > 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) think it'd be cleaner to raise a ValueError if eos_token_id <= 0 in |
||
| scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1) | ||
| if self.eos_token_id < (num_tokens - 1): | ||
| scores = tf.concat( | ||
| (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))), | ||
| axis=-1, | ||
| ) | ||
| return scores | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XLA greedy search was probably missing this as well in the logits processors, since it has the same padded
input_idsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes thanks for adding it! Just note that I don't really think we can make this processor XLA compilable anyways as it's very complex and numpy can't be used in XLA.
cur_lenis mostly added in Flax/JAX to make the rprocessors XLA-compilable. But doesn't hurt to added it here!Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tf.uniqueis not compatible with XLA because the output shape is dependent on the specific input data, and so cannot be inferred at compile time. However, there should be a way to make this logit processor XLA-compilable - there's probably some solution where you store counts in a sparse matrix and then usetriu()ortril()followed by a matmul to see if a token has been preceded by the same token. Let me know if you want me to try that (here or in a separate PR)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧠
I'd leave it to a subsequent PR, XLA-readiness is not the main priority here and this PR is already very long