Skip to content

Commit

Permalink
Bugfix and optimization in end_of_generation_condition()
Browse files Browse the repository at this point in the history
1. Bugfix

The previous implementation did not verify that `end_string` was encoded
into a single token, which could trigger the end of generation earlier
than intended (see discussion in NVIDIA#7187)

2. Optimization

The previous implementation was scaling linearly with the batch size and
quadratically with the length of the generated sequence, which could
lead to a significant overhead in some situations.

The new implementation is much more efficient in "normal" situations
(where the end of generation is identified by a set of unique tokens),
and raises a warning when it needs to fallback to the inefficient string
matching case.

Note that it does not behave exactly the same as before, because we skip
the string comparison when the end strings all have unique tokens
associated to them. For instance, in the previous implementation, if the
model had generated the string
        "Some string.<|endoftext|>"
(where "<|endoftext|>" would really be generated as a string, and not as
a single token), then the previous implementation would have considered
it to be the end of generation (assuming `end_strings` has length > 1),
while the new one would not. The previous behavior was likely a bug
though, since we expect models to generate the special tokens associated
to end strings when they exist (for instance, the standard case
`end_strings=["<|endoftext|>"]` has always been handled by just
comparing the last token to `eod_id`).

Fixes NVIDIA#7187
  • Loading branch information
odelalleau committed Aug 18, 2023
1 parent ef730aa commit 2dd5a09
Showing 1 changed file with 77 additions and 20 deletions.
97 changes: 77 additions & 20 deletions nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import abc
from typing import List, Tuple
import warnings
from typing import List, Set, Tuple

import torch

Expand Down Expand Up @@ -50,6 +51,7 @@ class TextGenerationStrategy:
def __init__(self, model):
self.model = model
self.model.eval()
self._end_of_generation_cache = None

def forward_step(self, batch, tensor_shape):
fwd_bwd_function = get_forward_backward_func()
Expand Down Expand Up @@ -147,26 +149,22 @@ def end_of_generation_condition(
returns:
a boolean tensor indicating whether the generation should stop
"""
if len(end_strings) == 1 and end_strings[0] == END_OF_SEQ:
if (len(end_strings) == 1 and end_strings[0] == END_OF_SEQ) or not end_strings:
# Simple scenario: only finish on end of document token.
return prev == eod_id
else:
tokenizer = self.model.tokenizer
conditions = []
end_tokens = set()
end_tokens.add(eod_id)
for end_string in end_strings:
ids_1 = tokenizer.text_to_ids(f'<extra_id_1>{end_string}')
ids_2 = tokenizer.text_to_ids('<extra_id_1>')
if len(ids_1) <= len(ids_2):
continue
token_id = ids_1[len(ids_2) :][0]
end_tokens.add(token_id)
for p, token_item in zip(prev, tokens):
text = tokenizer.ids_to_text(token_item.tolist())
conditions.append(
any([text.endswith(end_string) for end_string in end_strings] + [p.item() in end_tokens])
)
return torch.tensor(conditions, dtype=torch.bool, device=tokens.device)

end_tokens, end_strings_to_check = self._get_end_of_generation_tokens_and_strings(eod_id, end_strings)
assert end_tokens

is_end = torch.isin(prev, torch.tensor(list(end_tokens), dtype=prev.dtype, device=prev.device))

if end_strings_to_check:
# The loop below is inefficient (see warning in `_get_end_of_generation_tokens_and_strings()`)
for idx, token_seq in enumerate(tokens):
text = self.model.tokenizer.ids_to_text(token_seq.tolist())
is_end[idx] |= any(text.endswith(end_string) for end_string in end_strings_to_check)

return is_end

def post_generation_process(self, output):
"""
Expand All @@ -176,6 +174,65 @@ def post_generation_process(self, output):
"""
return output

def _get_end_of_generation_tokens_and_strings(
self, eod_id: int, end_strings: List[str]
) -> Tuple[Set[int], List[str]]:
"""
return the tokens and strings indicating the end of generation
Args:
eod_id (int): the end of document token id
end_strings (List[str]): the list of end of generation strings
Returns:
a pair `(tokens, strings)` where `tokens` is a set of tokens (int) and `strings` is a list of strings,
which must all be used to identify the end of generation (`tokens` always contains `eod_id`, while
`strings` may be empty if all end strings are associated to unique tokens)
"""
tokenizer = self.model.tokenizer
# A cache is used to remember which end strings are associated to unique tokens vs. which ones
# require an actual string comparison.
if self._end_of_generation_cache is None or self._end_of_generation_cache["tokenizer"] is not tokenizer:
# Invalidate the cache.
self._end_of_generation_cache = {
"tokenizer": tokenizer,
"end_string_to_token": {END_OF_SEQ: eod_id},
"end_strings_to_check": set(),
}
end_string_to_token = self._end_of_generation_cache["end_string_to_token"]

end_tokens = {eod_id} # always include `eod_id`, even if `END_OF_SEQ` is not within `end_strings`
end_strings_to_check = [] # will contain end strings that have no associated special token

for end_string in end_strings:
try:
end_tokens.add(end_string_to_token[end_string])
continue
except KeyError:
if end_string in self._end_of_generation_cache["end_strings_to_check"]:
end_strings_to_check.append(end_string)
continue

# `end_string` does not exist in the cache yet: check if `end_string` is a special token for
# the tokenizer. Ideally, we would simply use `tokenizer.text_to_ids(end_string)`, but some
# tokenizers (e.g., SentencePiece) may prefix the special token with another token associated
# to an empty string. The code below is thus meant to extract the special token associated to
# `end_string` (if it exists). Note that using "This is a sequence." as reference is arbitrary.
ids_ref = tokenizer.text_to_ids("This is a sequence.")
ids_with_end_string = tokenizer.text_to_ids(f"This is a sequence.{end_string}")
if len(ids_with_end_string) == len(ids_ref) + 1 and ids_with_end_string[:-1] == ids_ref:
# We can assume that the extra token is the one corresponding to `end_string`.
end_string_to_token[end_string] = ids_with_end_string[-1]
end_tokens.add(ids_with_end_string[-1])
else:
# No special token.
warnings.warn(
f"The end string '{end_string}' has no associated special token: this may slow down "
"generation (consider using a different tokenizer or modifying `end_strings`)"
)
self._end_of_generation_cache["end_strings_to_check"].add(end_string)
end_strings_to_check.append(end_string)

return end_tokens, end_strings_to_check


class GPTModelTextGenerationStrategy(TextGenerationStrategy):
def __init__(self, model):
Expand Down

0 comments on commit 2dd5a09

Please sign in to comment.