diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index ecca2f2b0bf..744de31a6af 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -35,42 +35,42 @@ def num_tokens_from_text( text: str, model: str = "gpt-3.5-turbo-0613", return_tokens_per_name_and_message: bool = False ) -> Union[int, Tuple[int, int, int]]: - """Return the number of tokens used by a text.""" - # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.debug("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - if model in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", - }: - tokens_per_message = 3 - tokens_per_name = 1 - elif model == "gpt-3.5-turbo-0301": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - tokens_per_name = -1 # if there's a name, the role is omitted - elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: - logger.warning("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") - return num_tokens_from_text(text, model="gpt-3.5-turbo-0613") - elif "gpt-4" in model: - logger.warning("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") - return num_tokens_from_text(text, model="gpt-4-0613") + """Return the number of tokens used by a text for different models.""" + + # Define token counts for known models + known_models = { + "gpt-3.5-turbo-0613": (3, 1), + "gpt-3.5-turbo-16k-0613": (3, 1), + "gpt-4-0314": (3, 1), + "gpt-4-32k-0314": (3, 1), + "gpt-4-0613": (3, 1), + "gpt-4-32k-0613": (3, 1), + } + + # Check if the model is known and retrieve token counts + if model in known_models: + tokens_per_message, tokens_per_name = known_models[model] else: - raise NotImplementedError( - f"""num_tokens_from_text() is not implemented for model {model}. See """ - f"""https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are """ - f"""converted to tokens.""" - ) + logger.warning(f"Warning: Model '{model}' is not in known models. Using default token counts.") + # You can add support for additional models and their token counts here. + if model == "your-new-model-name": + tokens_per_message = 3 + tokens_per_name = 1 + else: + raise NotImplementedError( + f"num_tokens_from_text() is not implemented for model {model}. See " + f"https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are " + f"converted to tokens." + ) + + # Use tiktoken to calculate the number of tokens in the text + encoding = tiktoken.encoding_for_model(model) + token_count = len(encoding.encode(text)) + if return_tokens_per_name_and_message: - return len(encoding.encode(text)), tokens_per_message, tokens_per_name + return token_count, tokens_per_message, tokens_per_name else: - return len(encoding.encode(text)) + return token_count def num_tokens_from_messages(messages: dict, model: str = "gpt-3.5-turbo-0613"): diff --git a/test/test_num_tokens_from_text.py b/test/test_num_tokens_from_text.py new file mode 100644 index 00000000000..8ce1c3f86ca --- /dev/null +++ b/test/test_num_tokens_from_text.py @@ -0,0 +1,21 @@ +import unittest +from autogen.retrieve_utils import num_tokens_from_text + +class TestNumTokensFromText(unittest.TestCase): + + def test_known_model(self): + # Test with a known model and known token counts + text = "This is a test message." + model = "gpt-3.5-turbo-0613" + result = num_tokens_from_text(text, model) + self.assertEqual(result, 6) # Adjust the expected token count + + def test_unknown_model(self): + # Test with an unknown model + text = "This is a test message." + model = "unknown-model" + with self.assertRaises(NotImplementedError): + num_tokens_from_text(text, model) + +if __name__ == '__main__': + unittest.main()