Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add token_count_util #421

Merged
merged 6 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, num_tokens_from_text
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
from autogen.token_count_utils import count_token
from autogen.code_utils import extract_code

from typing import Callable, Dict, Optional, Union, List, Tuple, Any
Expand Down Expand Up @@ -124,8 +125,8 @@ def __init__(
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function.
Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
Expand Down Expand Up @@ -180,7 +181,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._get_or_create = (
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = True if self._docs_path is None else False # whether the collection is created
Expand Down Expand Up @@ -244,7 +245,7 @@ def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
continue
if results["ids"][0][idx] in self._doc_ids:
continue
_doc_tokens = num_tokens_from_text(doc, custom_token_count_function=self.custom_token_count_function)
_doc_tokens = self.custom_token_count_function(doc, self._model)
if _doc_tokens > self._context_max_tokens:
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
print(colored(func_print, "green"), flush=True)
Expand Down
80 changes: 3 additions & 77 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import chromadb.utils.embedding_functions as ef
import logging
import pypdf

from autogen.token_count_utils import count_token

logger = logging.getLogger(__name__)
TEXT_FORMATS = [
Expand All @@ -37,80 +37,6 @@
VALID_CHUNK_MODES = frozenset({"one_line", "multi_lines"})


def num_tokens_from_text(
text: str,
model: str = "gpt-3.5-turbo-0613",
return_tokens_per_name_and_message: bool = False,
custom_token_count_function: Callable = None,
) -> Union[int, Tuple[int, int, int]]:
"""Return the number of tokens used by a text.
Args:
text (str): The text to count tokens for.
model (Optional, str): The model to use for tokenization. Default is "gpt-3.5-turbo-0613".
return_tokens_per_name_and_message (Optional, bool): Whether to return the number of tokens per name and per
message. Default is False.
custom_token_count_function (Optional, Callable): A custom function to count tokens. Default is None.
Returns:
int: The number of tokens used by the text.
int: The number of tokens per message. Only returned if return_tokens_per_name_and_message is True.
int: The number of tokens per name. Only returned if return_tokens_per_name_and_message is True.
"""
if isinstance(custom_token_count_function, Callable):
token_count, tokens_per_message, tokens_per_name = custom_token_count_function(text)
else:
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
known_models = {
"gpt-3.5-turbo": (3, 1),
"gpt-35-turbo": (3, 1),
"gpt-3.5-turbo-0613": (3, 1),
"gpt-3.5-turbo-16k-0613": (3, 1),
"gpt-3.5-turbo-0301": (4, -1),
"gpt-4": (3, 1),
"gpt-4-0314": (3, 1),
"gpt-4-32k-0314": (3, 1),
"gpt-4-0613": (3, 1),
"gpt-4-32k-0613": (3, 1),
}
tokens_per_message, tokens_per_name = known_models.get(model, (3, 1))
token_count = len(encoding.encode(text))

if return_tokens_per_name_and_message:
return token_count, tokens_per_message, tokens_per_name
else:
return token_count


def num_tokens_from_messages(
messages: dict,
model: str = "gpt-3.5-turbo-0613",
custom_token_count_function: Callable = None,
custom_prime_count: int = 3,
):
"""Return the number of tokens used by a list of messages."""
num_tokens = 0
for message in messages:
for key, value in message.items():
_num_tokens, tokens_per_message, tokens_per_name = num_tokens_from_text(
value,
model=model,
return_tokens_per_name_and_message=True,
custom_token_count_function=custom_token_count_function,
)
num_tokens += _num_tokens
if key == "name":
num_tokens += tokens_per_name
num_tokens += tokens_per_message
num_tokens += custom_prime_count # With ChatGPT, every reply is primed with <|start|>assistant<|message|>
return num_tokens


def split_text_to_chunks(
text: str,
max_tokens: int = 4000,
Expand All @@ -125,7 +51,7 @@ def split_text_to_chunks(
must_break_at_empty_line = False
chunks = []
lines = text.split("\n")
lines_tokens = [num_tokens_from_text(line) for line in lines]
lines_tokens = [count_token(line) for line in lines]
sum_tokens = sum(lines_tokens)
while sum_tokens > max_tokens:
if chunk_mode == "one_line":
Expand All @@ -148,7 +74,7 @@ def split_text_to_chunks(
split_len = int(max_tokens / lines_tokens[0] * 0.9 * len(lines[0]))
prev = lines[0][:split_len]
lines[0] = lines[0][split_len:]
lines_tokens[0] = num_tokens_from_text(lines[0])
lines_tokens[0] = count_token(lines[0])
else:
logger.warning("Failed to split docs with must_break_at_empty_line being True, set to False.")
must_break_at_empty_line = False
Expand Down
182 changes: 182 additions & 0 deletions autogen/token_count_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import tiktoken
from typing import List, Union, Dict, Tuple
import logging
import json


logger = logging.getLogger(__name__)


def get_max_token_limit(model="gpt-3.5-turbo-0613"):
max_token_limit = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-35-turbo": 4096,
"gpt-35-turbo-16k": 16384,
"gpt-35-turbo-instruct": 4096,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768, # deprecate in Sep
"gpt-4-0314": 8192, # deprecate in Sep
"gpt-4-0613": 8192,
"gpt-4-32k-0613": 32768,
}
return max_token_limit[model]


def percentile_used(input, model="gpt-3.5-turbo-0613"):
return count_token(input) / get_max_token_limit(model)


def token_left(input: Union[str, List, Dict], model="gpt-3.5-turbo-0613") -> int:
"""Count number of tokens left for an OpenAI model.
Args:
input: (str, list, dict): Input to the model.
model: (str): Model name.
Returns:
int: Number of tokens left that the model can use for completion.
"""
return get_max_token_limit(model) - count_token(input, model=model)


def count_token(input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613") -> int:
"""Count number of tokens used by an OpenAI model.
Args:
input: (str, list, dict): Input to the model.
model: (str): Model name.
Returns:
int: Number of tokens from the input.
"""
if isinstance(input, str):
return _num_token_from_text(input, model=model)
elif isinstance(input, list) or isinstance(input, dict):
return _num_token_from_messages(input, model=model)
else:
raise ValueError("input must be str, list or dict")


def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a string."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))


def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages.
retrieved from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb/
"""
if isinstance(messages, dict):
messages = [messages]

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
logger.info("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return _num_token_from_messages(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logger.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
if value is None:
continue

# function calls
if not isinstance(value, str):
try:
value = json.dumps(value)
except TypeError:
logger.warning(
f"Value {value} is not a string and cannot be converted to json. It is a type: {type(value)} Skipping."
)
continue

num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens


def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
"""Return the number of tokens used by a list of functions.
Args:
functions: (list): List of function descriptions that will be passed in model.
model: (str): Model name.
Returns:
int: Number of tokens from the function descriptions.
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")

num_tokens = 0
for function in functions:
function_tokens = len(encoding.encode(function["name"]))
function_tokens += len(encoding.encode(function["description"]))
function_tokens -= 2
if "parameters" in function:
parameters = function["parameters"]
if "properties" in parameters:
for propertiesKey in parameters["properties"]:
function_tokens += len(encoding.encode(propertiesKey))
v = parameters["properties"][propertiesKey]
for field in v:
if field == "type":
function_tokens += 2
function_tokens += len(encoding.encode(v["type"]))
elif field == "description":
function_tokens += 2
function_tokens += len(encoding.encode(v["description"]))
elif field == "enum":
function_tokens -= 3
for o in v["enum"]:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
print(f"Warning: not supported field {field}")
function_tokens += 11
if len(parameters["properties"]) == 0:
function_tokens -= 2

num_tokens += function_tokens

num_tokens += 12
return num_tokens
26 changes: 2 additions & 24 deletions test/test_retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
is_url,
create_vector_db_from_dir,
query_vector_db,
num_tokens_from_text,
num_tokens_from_messages,
TEXT_FORMATS,
)
from autogen.token_count_utils import count_token

import os
import sys
Expand All @@ -31,31 +30,10 @@


class TestRetrieveUtils:
def test_num_tokens_from_text_custom_token_count_function(self):
def custom_token_count_function(text):
return len(text), 1, 2

text = "This is a sample text."
assert num_tokens_from_text(
text, return_tokens_per_name_and_message=True, custom_token_count_function=custom_token_count_function
) == (22, 1, 2)

def test_num_tokens_from_text(self):
text = "This is a sample text."
assert num_tokens_from_text(text) == len(tiktoken.get_encoding("cl100k_base").encode(text))

def test_num_tokens_from_messages(self):
messages = [{"content": "This is a sample text."}, {"content": "Another sample text."}]
# Review the implementation of num_tokens_from_messages
# and adjust the expected_tokens accordingly.
actual_tokens = num_tokens_from_messages(messages)
expected_tokens = actual_tokens # Adjusted to make the test pass temporarily.
assert actual_tokens == expected_tokens

def test_split_text_to_chunks(self):
long_text = "A" * 10000
chunks = split_text_to_chunks(long_text, max_tokens=1000)
assert all(num_tokens_from_text(chunk) <= 1000 for chunk in chunks)
assert all(count_token(chunk) <= 1000 for chunk in chunks)

def test_split_text_to_chunks_raises_on_invalid_chunk_mode(self):
with pytest.raises(AssertionError):
Expand Down
Loading
Loading