forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bugfix and optimization in
end_of_generation_condition()
1. Bugfix The previous implementation did not verify that `end_string` was encoded into a single token, which could trigger the end of generation earlier than intended (see discussion in NVIDIA#7187) 2. Optimization The previous implementation was scaling linearly with the batch size and quadratically with the length of the generated sequence, which could lead to a significant overhead in some situations. The new implementation is much more efficient in "normal" situations (where the end of generation is identified by a set of unique tokens), and raises a warning when it needs to fallback to the inefficient string matching case. Note that it does not behave exactly the same as before, because we skip the string comparison when the end strings all have unique tokens associated to them. For instance, in the previous implementation, if the model had generated the string "Some string.<|endoftext|>" (where "<|endoftext|>" would really be generated as a string, and not as a single token), then the previous implementation would have considered it to be the end of generation (assuming `end_strings` has length > 1), while the new one would not. The previous behavior was likely a bug though, since we expect models to generate the special tokens associated to end strings when they exist (for instance, the standard case `end_strings=["<|endoftext|>"]` has always been handled by just comparing the last token to `eod_id`). Fixes NVIDIA#7187
- Loading branch information
1 parent
ef730aa
commit 2dd5a09
Showing
1 changed file
with
77 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters