-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chat tokenization fixes in generate.py & API #1035
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1035
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d90e33b with merge base 147c292 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
lgtm, remember to verify that |
) | ||
|
||
encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device) | ||
print(self.tokenizer.decode(tokens)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checking that this is an intentional print
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - this prints out the prompt on the server side so that it's easy to track the full prompt solely from the server side.
However, this raises a larger issue in the generate/API stack - we need to replace print statements with a logger so that users can choose not to print these debug messages.
and x.item() == self.tokenizer.special_tokens["<|eot_id|>"] | ||
): | ||
buffer = buffer[:-1] # drop the eot_id from the output buffer | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this is a pass again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The callback function is only used in generate() for the CLI interactive chat to print results to stdout. I initially copied this code naively when refactoring the original generate.py and copied it over to openaiapi where it isn't used.
I believe this PR introduced a regression because Fix proposal in #1061 |
This fixes the following assert that is easy to repro in any chat session: ``` Traceback (most recent call last): File "/home/ubuntu/cali/torchchat/torchchat.py", line 69, in <module> generate_main(args) File "/home/ubuntu/cali/torchchat/generate.py", line 896, in main for _ in gen.chat(generator_args): File "/home/ubuntu/cali/torchchat/generate.py", line 748, in chat self.chat_formatter.encode_header( File "/home/ubuntu/cali/torchchat/generate.py", line 53, in encode_header tokens.extend(self.tokenizer.encode(role, bos=False, eos=False)) File "/home/ubuntu/cali/torchchat/tokenizer/tiktoken.py", line 133, in encode assert type(s) is str ``` I believe this regressed with #1035.
Currently, only the chat() function would encode chat-style messages into the correct format during an interactive session.
I adapted the ChatFormat class to create a generic _ChatFormatter base class - each one includes the function
encode_dialog_prompt
which can take a series of message objects and encode the system prompt & user/assistant messages correctly.Allows us to load in a whole conversation (i.e. message objects in a completion request) & keep the API/server stateless.
Test:
Initially, I prompted with the following cURL request:
Then, I added the model's response and prompted it again with the following:
And got the following response: