Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for different models in num_tokens_from_text function #90

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,42 +35,42 @@
def num_tokens_from_text(
text: str, model: str = "gpt-3.5-turbo-0613", return_tokens_per_name_and_message: bool = False
) -> Union[int, Tuple[int, int, int]]:
"""Return the number of tokens used by a text."""
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
logger.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return num_tokens_from_text(text, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logger.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return num_tokens_from_text(text, model="gpt-4-0613")
"""Return the number of tokens used by a text for different models."""

# Define token counts for known models
known_models = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why gpt-3.5-turbo-0301 is not in the known model?

"gpt-3.5-turbo-0613": (3, 1),
"gpt-3.5-turbo-16k-0613": (3, 1),
"gpt-4-0314": (3, 1),
"gpt-4-32k-0314": (3, 1),
"gpt-4-0613": (3, 1),
"gpt-4-32k-0613": (3, 1),
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add a parameter to the function, say model_token: dict = None. And add below code to support customizing model token_per_message without modifying code here.

if isinstance(model_token, dict):
    known_models.update(model_token)

The parameter can be passed in retrieve_config in autogen/autogen/agentchat/contrib/retrieve_user_proxy_agent.py

# Check if the model is known and retrieve token counts
if model in known_models:
tokens_per_message, tokens_per_name = known_models[model]
else:
raise NotImplementedError(
f"""num_tokens_from_text() is not implemented for model {model}. See """
f"""https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are """
f"""converted to tokens."""
)
logger.warning(f"Warning: Model '{model}' is not in known models. Using default token counts.")
# You can add support for additional models and their token counts here.
if model == "your-new-model-name":
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"num_tokens_from_text() is not implemented for model {model}. See "
f"https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are "
f"converted to tokens."
)
Comment on lines +56 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if model == "your-new-model-name":
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"num_tokens_from_text() is not implemented for model {model}. See "
f"https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are "
f"converted to tokens."
)
tokens_per_message = 3
tokens_per_name = 1


# Use tiktoken to calculate the number of tokens in the text
encoding = tiktoken.encoding_for_model(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoding = tiktoken.encoding_for_model(model)
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")

try...catch is needed here.

token_count = len(encoding.encode(text))

if return_tokens_per_name_and_message:
return len(encoding.encode(text)), tokens_per_message, tokens_per_name
return token_count, tokens_per_message, tokens_per_name
else:
return len(encoding.encode(text))
return token_count


def num_tokens_from_messages(messages: dict, model: str = "gpt-3.5-turbo-0613"):
Expand Down
21 changes: 21 additions & 0 deletions test/test_num_tokens_from_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest
from autogen.retrieve_utils import num_tokens_from_text

class TestNumTokensFromText(unittest.TestCase):

def test_known_model(self):
# Test with a known model and known token counts
text = "This is a test message."
model = "gpt-3.5-turbo-0613"
result = num_tokens_from_text(text, model)
self.assertEqual(result, 6) # Adjust the expected token count

def test_unknown_model(self):
# Test with an unknown model
text = "This is a test message."
model = "unknown-model"
with self.assertRaises(NotImplementedError):
num_tokens_from_text(text, model)
Comment on lines +17 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update.


if __name__ == '__main__':
unittest.main()
Loading