Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat tokenization fixes in generate.py & API #1035

Merged
merged 2 commits into from
Aug 19, 2024
Merged

Conversation

vmpuri
Copy link
Contributor

@vmpuri vmpuri commented Aug 15, 2024

Currently, only the chat() function would encode chat-style messages into the correct format during an interactive session.

I adapted the ChatFormat class to create a generic _ChatFormatter base class - each one includes the function encode_dialog_prompt which can take a series of message objects and encode the system prompt & user/assistant messages correctly.

Allows us to load in a whole conversation (i.e. message objects in a completion request) & keep the API/server stateless.

Test:

Initially, I prompted with the following cURL request:

curl http://127.0.0.1:5000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "llama2",
    "stream": "true",
    "seed": "123",
    "max_tokens": "200",
    "messages": [
      {
        "role": "system",
        "content": "You are a helpful assistant. Be as brief as possible, no yapping. Do not include additional details unless asked."
      },
      {
        "role": "user",
        "content": "List 3 early jet fighters."
      }
    ]}'

Then, I added the model's response and prompted it again with the following:

curl http://127.0.0.1:5000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "llama2",
    "stream": "true",
    "seed": "123",
    "max_tokens": "200",
    "messages": [
      {
        "role": "system",
        "content": "You are a helpful assistant. Be as brief as possible, no yapping. Do not include additional details unless asked."
      },
      {
        "role": "user",
        "content": "List 3 early jet fighters."
      },
      {
        "role": "system",
        "content": "Certainly! Here are three early jet fighters:\n\n1. Gloster Meteor (UK, 1943)\n2. Messerschmitt Me 262 (Germany, 1944)\n3. Lockheed P-80 Shooting Star (US, 1948)"
      },
      {
        "role": "user",
        "content": "Were there any notable ones from the Soviet Union or Japan?"
      }
    ]
  }'

And got the following response:

Yes, here are a few notable early jet fighters from the Soviet Union and Japan:

Soviet Union:

1. MiG-15 (1947)
2. Lavochkin La-150 (1948)
3. Yakovlev Yak-15 (1946)

Japan:

1. Nakajima J7W (1944)
2. Mitsubishi J8M (1945)
3. Kawasaki Ki-200 (1945)

These early jet fighters represented significant advancements in aeronautical technology and played important roles in their respective countries' militaries during the Cold War era.

Copy link

pytorch-bot bot commented Aug 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1035

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d90e33b with merge base 147c292 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 15, 2024
@vmpuri vmpuri requested a review from Jack-Khuu August 16, 2024 22:52
@vmpuri vmpuri marked this pull request as ready for review August 16, 2024 22:52
@Jack-Khuu
Copy link
Contributor

lgtm, remember to verify that python3 torchchat.py generate and chat still works properly for llama2/3 with these changes as well?

)

encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device)
print(self.tokenizer.decode(tokens))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking that this is an intentional print

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - this prints out the prompt on the server side so that it's easy to track the full prompt solely from the server side.

However, this raises a larger issue in the generate/API stack - we need to replace print statements with a logger so that users can choose not to print these debug messages.

and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
):
buffer = buffer[:-1] # drop the eot_id from the output buffer
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this is a pass again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback function is only used in generate() for the CLI interactive chat to print results to stdout. I initially copied this code naively when refactoring the original generate.py and copied it over to openaiapi where it isn't used.

@vmpuri vmpuri merged commit c7f56f2 into main Aug 19, 2024
51 checks passed
@vmpuri vmpuri mentioned this pull request Aug 19, 2024
@prideout
Copy link
Contributor

prideout commented Aug 25, 2024

I believe this PR introduced a regression because encode_header now takes a string, but there is a spot in the Generator::chat method that is still passing a dictionary.

Fix proposal in #1061

Jack-Khuu pushed a commit that referenced this pull request Aug 25, 2024
This fixes the following assert that is easy to repro in any chat session:

```
Traceback (most recent call last):
    File "/home/ubuntu/cali/torchchat/torchchat.py", line 69, in <module>
        generate_main(args)
    File "/home/ubuntu/cali/torchchat/generate.py", line 896, in main
        for _ in gen.chat(generator_args):
    File "/home/ubuntu/cali/torchchat/generate.py", line 748, in chat
        self.chat_formatter.encode_header(
    File "/home/ubuntu/cali/torchchat/generate.py", line 53, in encode_header
        tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
    File "/home/ubuntu/cali/torchchat/tokenizer/tiktoken.py", line 133, in encode
        assert type(s) is str
```

I believe this regressed with #1035.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants