-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for different models in num_tokens_from_text function #90
Changes from all commits
3034a9b
098e020
4499dc2
0130a98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -35,42 +35,42 @@ | |||||||||||||||||||||||
def num_tokens_from_text( | ||||||||||||||||||||||||
text: str, model: str = "gpt-3.5-turbo-0613", return_tokens_per_name_and_message: bool = False | ||||||||||||||||||||||||
) -> Union[int, Tuple[int, int, int]]: | ||||||||||||||||||||||||
"""Return the number of tokens used by a text.""" | ||||||||||||||||||||||||
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
encoding = tiktoken.encoding_for_model(model) | ||||||||||||||||||||||||
except KeyError: | ||||||||||||||||||||||||
logger.debug("Warning: model not found. Using cl100k_base encoding.") | ||||||||||||||||||||||||
encoding = tiktoken.get_encoding("cl100k_base") | ||||||||||||||||||||||||
if model in { | ||||||||||||||||||||||||
"gpt-3.5-turbo-0613", | ||||||||||||||||||||||||
"gpt-3.5-turbo-16k-0613", | ||||||||||||||||||||||||
"gpt-4-0314", | ||||||||||||||||||||||||
"gpt-4-32k-0314", | ||||||||||||||||||||||||
"gpt-4-0613", | ||||||||||||||||||||||||
"gpt-4-32k-0613", | ||||||||||||||||||||||||
}: | ||||||||||||||||||||||||
tokens_per_message = 3 | ||||||||||||||||||||||||
tokens_per_name = 1 | ||||||||||||||||||||||||
elif model == "gpt-3.5-turbo-0301": | ||||||||||||||||||||||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | ||||||||||||||||||||||||
tokens_per_name = -1 # if there's a name, the role is omitted | ||||||||||||||||||||||||
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: | ||||||||||||||||||||||||
logger.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") | ||||||||||||||||||||||||
return num_tokens_from_text(text, model="gpt-3.5-turbo-0613") | ||||||||||||||||||||||||
elif "gpt-4" in model: | ||||||||||||||||||||||||
logger.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") | ||||||||||||||||||||||||
return num_tokens_from_text(text, model="gpt-4-0613") | ||||||||||||||||||||||||
"""Return the number of tokens used by a text for different models.""" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Define token counts for known models | ||||||||||||||||||||||||
known_models = { | ||||||||||||||||||||||||
"gpt-3.5-turbo-0613": (3, 1), | ||||||||||||||||||||||||
"gpt-3.5-turbo-16k-0613": (3, 1), | ||||||||||||||||||||||||
"gpt-4-0314": (3, 1), | ||||||||||||||||||||||||
"gpt-4-32k-0314": (3, 1), | ||||||||||||||||||||||||
"gpt-4-0613": (3, 1), | ||||||||||||||||||||||||
"gpt-4-32k-0613": (3, 1), | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can add a parameter to the function, say if isinstance(model_token, dict):
known_models.update(model_token) The parameter can be passed in |
||||||||||||||||||||||||
# Check if the model is known and retrieve token counts | ||||||||||||||||||||||||
if model in known_models: | ||||||||||||||||||||||||
tokens_per_message, tokens_per_name = known_models[model] | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
raise NotImplementedError( | ||||||||||||||||||||||||
f"""num_tokens_from_text() is not implemented for model {model}. See """ | ||||||||||||||||||||||||
f"""https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are """ | ||||||||||||||||||||||||
f"""converted to tokens.""" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
logger.warning(f"Warning: Model '{model}' is not in known models. Using default token counts.") | ||||||||||||||||||||||||
# You can add support for additional models and their token counts here. | ||||||||||||||||||||||||
if model == "your-new-model-name": | ||||||||||||||||||||||||
tokens_per_message = 3 | ||||||||||||||||||||||||
tokens_per_name = 1 | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
raise NotImplementedError( | ||||||||||||||||||||||||
f"num_tokens_from_text() is not implemented for model {model}. See " | ||||||||||||||||||||||||
f"https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are " | ||||||||||||||||||||||||
f"converted to tokens." | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
Comment on lines
+56
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Use tiktoken to calculate the number of tokens in the text | ||||||||||||||||||||||||
encoding = tiktoken.encoding_for_model(model) | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
try...catch is needed here. |
||||||||||||||||||||||||
token_count = len(encoding.encode(text)) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if return_tokens_per_name_and_message: | ||||||||||||||||||||||||
return len(encoding.encode(text)), tokens_per_message, tokens_per_name | ||||||||||||||||||||||||
return token_count, tokens_per_message, tokens_per_name | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
return len(encoding.encode(text)) | ||||||||||||||||||||||||
return token_count | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def num_tokens_from_messages(messages: dict, model: str = "gpt-3.5-turbo-0613"): | ||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import unittest | ||
from autogen.retrieve_utils import num_tokens_from_text | ||
|
||
class TestNumTokensFromText(unittest.TestCase): | ||
|
||
def test_known_model(self): | ||
# Test with a known model and known token counts | ||
text = "This is a test message." | ||
model = "gpt-3.5-turbo-0613" | ||
result = num_tokens_from_text(text, model) | ||
self.assertEqual(result, 6) # Adjust the expected token count | ||
|
||
def test_unknown_model(self): | ||
# Test with an unknown model | ||
text = "This is a test message." | ||
model = "unknown-model" | ||
with self.assertRaises(NotImplementedError): | ||
num_tokens_from_text(text, model) | ||
Comment on lines
+17
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to update. |
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why
gpt-3.5-turbo-0301
is not in the known model?