-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bugfix and optimization in end_of_generation_condition()
#7267
Conversation
ab56968
to
3cd7a89
Compare
# tokenizers (e.g., SentencePiece) may prefix the special token with another token associated | ||
# to an empty string. The code below is thus meant to extract the special token associated to | ||
# `end_string` (if it exists). Note that using "This is a sequence." as reference is arbitrary. | ||
ids_ref = tokenizer.text_to_ids("This is a sequence.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"This is a sequence."
is not a good prefix string since the .
character might be merged with other characters in the end_string
. Ideally we should use some special token so it will always stand alone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was exactly commenting this, but the code checks ids_with_end_string[:-1] == ids_ref
below and also .
is usually an independent token in most tokenizers
ideally a seaprate special token might be better, but do we have any token that we can guarantee will always be present?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
odds wise, I feel having .
as the special token is more likely than any other thing. worst case we do a string check but it still doesnt give incorrect answer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started to write an explanation for my reasoning here, but this brought up a doubt in my mind that is very important to clear up, as it may invalidate a strong assumption made in this PR:
Is it possible that we may want to use an end string associated to a unique token (e.g., <extra_id_1>
) and yet expect the model to end a response with the string "<extra_id_1>"
but without generating this token?
I assumed no (similar to how we don't stop generation if the model generates the string "<|endoftext|>"
instead of eos_id
), but can you confirm this is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"This is a sequence."
is not a good prefix string since the.
character might be merged with other characters in theend_string
. Ideally we should use some special token so it will always stand alone.
So, following offline discussion, there are three cases:
-
end_string
is actually a special token for the tokenizer (e.g.<extra_id_1>
typically is, when we use it). Then we are guaranteed that it will be tokenized as a single token, and we will identify it properly. -
end_string
is not a special token for the tokenizer, and is tokenized into more than one token: in that case we don't care whether or not it is merged with the preceding.
, because we will need to use string comparisons anyway. -
end_string
is not a special token for the tokenizer, and is tokenized into a single token: in that case either we identify the single token and rely on token comparison, or we don't (because of the tokenizer merging the.
) and we fall back to string matching. I would argue that the latter is safer, because (a) it will ensure we always end generation correctly, and (b) it shows a warning that may alert the user that something may be off (there's a good chance they didn't expect the tokenizer to mergeend_string
with other characters). And if the tokenizer merges the.
withend_string
it seems particularly important to be aware of it, since this sequence of characters is likely to be quite common to finish responses. So it seems to me that it's actually a better option than<extra_id_1>
(though I admit I hadn't thought it through before ;))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the assmption is either end_string is a special token or not, the generation will end with it inclusively. It is up to tokenizer to decide whether the ids_to_text
method wants to show it. I think <extra_id_1>
is a better prefix because if it is a special token, it will work as expected. If it is not, the >
is less likely to be merged with the end_string, which is a cleaner case as you mentioned in the 3. If '.' is merged with the other partial end-string during the generation, the string match might not capture it correctly. E.g. end_string is "hello", '.' is merged to 'hel', the generated tokens are '.hel', 'lo,' it won't trigger the string match check as we use 'endswith` method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, let's go with <extra_id_1>
, I don't think it's a big deal anyway, it shouldn't matter much in practice. This is done in c9a6d71 (I also rebased on top of main, but there are no changes in previous commits).
I still want to address the point below since a similar situation may still happen with <extra_id_1>
:
If '.' is merged with the other partial end-string during the generation, the string match might not capture it correctly. E.g. end_string is "hello", '.' is merged to 'hel', the generated tokens are '.hel', 'lo,' it won't trigger the string match check as we use 'endswith` method.
If '.' is merged with 'hel', then this will trigger the string match because the second condition of this check won't be satisfied (we check both that there are N+1 tokens and also that the first N tokens are the same):
if len(ids_with_end_string) == len(ids_ref) + 1 and ids_with_end_string[:-1] == ids_ref:
As a result, there will be a warning displayed, and the comparison will be made with text.endswith("hello")
which will match any generation ending with "hello", regardless of what tokens this corresponds to.
Note that there could still be situations where the model generates "hello" without generation ending, e.g. if it generates tokens ".hel" followed by "lo world". But this case was not handled previously either, and it is unclear that we should stop there (since we can't truncate the model output mid-token, so the response would actually end with "hello world" rather than "hello".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah. I think it is the fundamental problem within the string match method. Maybe the work around is not to use endswith
method and do post truncation of the extra characters after the generation stops. I think the string match method is already a hack anyway, maybe add a TODO comment and we might come back to this in the future if we see this causes any problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a TODO comment and we might come back to this in the future if we see this causes any problems.
Good idea, added in 0753075
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Everything looks good to me, there are some pending comments by Yi but once we feel those are resolved I think we should be good to merge
1. Bugfix The previous implementation did not verify that `end_string` was encoded into a single token, which could trigger the end of generation earlier than intended (see discussion in NVIDIA#7187) 2. Optimization The previous implementation was scaling linearly with the batch size and quadratically with the length of the generated sequence, which could lead to a significant overhead in some situations. The new implementation is much more efficient in "normal" situations (where the end of generation is identified by a set of unique tokens), and raises a warning when it needs to fallback to the inefficient string matching case. Note that it does not behave exactly the same as before, because we skip the string comparison when the end strings all have unique tokens associated to them. For instance, in the previous implementation, if the model had generated the string "Some string.<|endoftext|>" (where "<|endoftext|>" would really be generated as a string, and not as a single token), then the previous implementation would have considered it to be the end of generation (assuming `end_strings` has length > 1), while the new one would not. The previous behavior was likely a bug though, since we expect models to generate the special tokens associated to end strings when they exist (for instance, the standard case `end_strings=["<|endoftext|>"]` has always been handled by just comparing the last token to `eod_id`). Fixes NVIDIA#7187 Signed-off-by: Olivier Delalleau <[email protected]>
Systematically calling `mode.eval()` does not seem like a good idea, as it might have side effects leading to unexpected behavior. It would be better to raise an exception if one attempts to generate while in training mode, but this may break existing code => sticking to a warning for now. Signed-off-by: Olivier Delalleau <[email protected]>
Signed-off-by: Olivier Delalleau <[email protected]>
3cd7a89
to
c9a6d71
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. thanks for refining the end of generation logics.
Signed-off-by: Olivier Delalleau <[email protected]>
9c3523f
to
0753075
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
* Bugfix and optimization in `end_of_generation_condition()` 1. Bugfix The previous implementation did not verify that `end_string` was encoded into a single token, which could trigger the end of generation earlier than intended (see discussion in NVIDIA#7187) 2. Optimization The previous implementation was scaling linearly with the batch size and quadratically with the length of the generated sequence, which could lead to a significant overhead in some situations. The new implementation is much more efficient in "normal" situations (where the end of generation is identified by a set of unique tokens), and raises a warning when it needs to fallback to the inefficient string matching case. Note that it does not behave exactly the same as before, because we skip the string comparison when the end strings all have unique tokens associated to them. For instance, in the previous implementation, if the model had generated the string "Some string.<|endoftext|>" (where "<|endoftext|>" would really be generated as a string, and not as a single token), then the previous implementation would have considered it to be the end of generation (assuming `end_strings` has length > 1), while the new one would not. The previous behavior was likely a bug though, since we expect models to generate the special tokens associated to end strings when they exist (for instance, the standard case `end_strings=["<|endoftext|>"]` has always been handled by just comparing the last token to `eod_id`). Fixes NVIDIA#7187 Signed-off-by: Olivier Delalleau <[email protected]> * Add warning when model is not in eval mode during generation Systematically calling `mode.eval()` does not seem like a good idea, as it might have side effects leading to unexpected behavior. It would be better to raise an exception if one attempts to generate while in training mode, but this may break existing code => sticking to a warning for now. Signed-off-by: Olivier Delalleau <[email protected]> * Use "<extra_id_1>" as prefix string Signed-off-by: Olivier Delalleau <[email protected]> * Add TODO for potential failure mode of the string match mechanism Signed-off-by: Olivier Delalleau <[email protected]> --------- Signed-off-by: Olivier Delalleau <[email protected]>
What does this PR do ?
It fixes a bug in
end_of_generation_condition()
(#7187) and makes it significantly faster in some cases.Collection: nlp
Changelog
Detailed explanation
The previous implementation did not verify that
end_string
was encodedinto a single token, which could trigger the end of generation earlier
than intended (see discussion in #7187)
The previous implementation was scaling linearly with the batch size and
quadratically with the length of the generated sequence, which could
lead to a significant overhead in some situations.
The new implementation is much more efficient in "normal" situations
(where the end of generation is identified by a set of unique tokens),
and raises a warning when it needs to fallback to the inefficient string
matching case.
Note that it does not behave exactly the same as before, because we skip
the string comparison when the end strings all have unique tokens
associated to them. For instance, in the previous implementation, if the
model had generated the string
"Some string.<|endoftext|>"
(where "<|endoftext|>" would really be generated as a string, and not as
a single token), then the previous implementation would have considered
it to be the end of generation (assuming
end_strings
has length > 1),while the new one would not. The previous behavior was likely a bug
though, since we expect models to generate the special tokens associated
to end strings when they exist (for instance, the standard case
end_strings=["<|endoftext|>"]
has always been handled by justcomparing the last token to
eod_id
).See commit message of ab56968 for the explanation regarding the warning that was added in
__init__()
Tests
Hopefully there are existing tests on CI -- I have tested this myself on my own jobs.
Before your PR is "Ready for review"
Pre checks:
PR Type:
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information
Fixes ##7187