Skip to content

Commit

Permalink
Add token_count_util (#421)
Browse files Browse the repository at this point in the history
* add token_count_util

* remove token_count from retrieval util

* format

* update dependency

* update test
  • Loading branch information
yiranwu0 authored Oct 27, 2023
1 parent cde99e0 commit 8c3401d
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 106 deletions.
11 changes: 6 additions & 5 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, num_tokens_from_text
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
from autogen.token_count_utils import count_token
from autogen.code_utils import extract_code

from typing import Callable, Dict, Optional, Union, List, Tuple, Any
Expand Down Expand Up @@ -124,8 +125,8 @@ def __init__(
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function.
Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
Expand Down Expand Up @@ -180,7 +181,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._get_or_create = (
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = True if self._docs_path is None else False # whether the collection is created
Expand Down Expand Up @@ -244,7 +245,7 @@ def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
continue
if results["ids"][0][idx] in self._doc_ids:
continue
_doc_tokens = num_tokens_from_text(doc, custom_token_count_function=self.custom_token_count_function)
_doc_tokens = self.custom_token_count_function(doc, self._model)
if _doc_tokens > self._context_max_tokens:
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
print(colored(func_print, "green"), flush=True)
Expand Down
80 changes: 3 additions & 77 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import chromadb.utils.embedding_functions as ef
import logging
import pypdf

from autogen.token_count_utils import count_token

logger = logging.getLogger(__name__)
TEXT_FORMATS = [
Expand All @@ -37,80 +37,6 @@
VALID_CHUNK_MODES = frozenset({"one_line", "multi_lines"})


def num_tokens_from_text(
text: str,
model: str = "gpt-3.5-turbo-0613",
return_tokens_per_name_and_message: bool = False,
custom_token_count_function: Callable = None,
) -> Union[int, Tuple[int, int, int]]:
"""Return the number of tokens used by a text.
Args:
text (str): The text to count tokens for.
model (Optional, str): The model to use for tokenization. Default is "gpt-3.5-turbo-0613".
return_tokens_per_name_and_message (Optional, bool): Whether to return the number of tokens per name and per
message. Default is False.
custom_token_count_function (Optional, Callable): A custom function to count tokens. Default is None.
Returns:
int: The number of tokens used by the text.
int: The number of tokens per message. Only returned if return_tokens_per_name_and_message is True.
int: The number of tokens per name. Only returned if return_tokens_per_name_and_message is True.
"""
if isinstance(custom_token_count_function, Callable):
token_count, tokens_per_message, tokens_per_name = custom_token_count_function(text)
else:
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
known_models = {
"gpt-3.5-turbo": (3, 1),
"gpt-35-turbo": (3, 1),
"gpt-3.5-turbo-0613": (3, 1),
"gpt-3.5-turbo-16k-0613": (3, 1),
"gpt-3.5-turbo-0301": (4, -1),
"gpt-4": (3, 1),
"gpt-4-0314": (3, 1),
"gpt-4-32k-0314": (3, 1),
"gpt-4-0613": (3, 1),
"gpt-4-32k-0613": (3, 1),
}
tokens_per_message, tokens_per_name = known_models.get(model, (3, 1))
token_count = len(encoding.encode(text))

if return_tokens_per_name_and_message:
return token_count, tokens_per_message, tokens_per_name
else:
return token_count


def num_tokens_from_messages(
messages: dict,
model: str = "gpt-3.5-turbo-0613",
custom_token_count_function: Callable = None,
custom_prime_count: int = 3,
):
"""Return the number of tokens used by a list of messages."""
num_tokens = 0
for message in messages:
for key, value in message.items():
_num_tokens, tokens_per_message, tokens_per_name = num_tokens_from_text(
value,
model=model,
return_tokens_per_name_and_message=True,
custom_token_count_function=custom_token_count_function,
)
num_tokens += _num_tokens
if key == "name":
num_tokens += tokens_per_name
num_tokens += tokens_per_message
num_tokens += custom_prime_count # With ChatGPT, every reply is primed with <|start|>assistant<|message|>
return num_tokens


def split_text_to_chunks(
text: str,
max_tokens: int = 4000,
Expand All @@ -125,7 +51,7 @@ def split_text_to_chunks(
must_break_at_empty_line = False
chunks = []
lines = text.split("\n")
lines_tokens = [num_tokens_from_text(line) for line in lines]
lines_tokens = [count_token(line) for line in lines]
sum_tokens = sum(lines_tokens)
while sum_tokens > max_tokens:
if chunk_mode == "one_line":
Expand All @@ -148,7 +74,7 @@ def split_text_to_chunks(
split_len = int(max_tokens / lines_tokens[0] * 0.9 * len(lines[0]))
prev = lines[0][:split_len]
lines[0] = lines[0][split_len:]
lines_tokens[0] = num_tokens_from_text(lines[0])
lines_tokens[0] = count_token(lines[0])
else:
logger.warning("Failed to split docs with must_break_at_empty_line being True, set to False.")
must_break_at_empty_line = False
Expand Down
182 changes: 182 additions & 0 deletions autogen/token_count_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import tiktoken
from typing import List, Union, Dict, Tuple
import logging
import json


logger = logging.getLogger(__name__)


def get_max_token_limit(model="gpt-3.5-turbo-0613"):
max_token_limit = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-35-turbo": 4096,
"gpt-35-turbo-16k": 16384,
"gpt-35-turbo-instruct": 4096,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768, # deprecate in Sep
"gpt-4-0314": 8192, # deprecate in Sep
"gpt-4-0613": 8192,
"gpt-4-32k-0613": 32768,
}
return max_token_limit[model]


def percentile_used(input, model="gpt-3.5-turbo-0613"):
return count_token(input) / get_max_token_limit(model)


def token_left(input: Union[str, List, Dict], model="gpt-3.5-turbo-0613") -> int:
"""Count number of tokens left for an OpenAI model.
Args:
input: (str, list, dict): Input to the model.
model: (str): Model name.
Returns:
int: Number of tokens left that the model can use for completion.
"""
return get_max_token_limit(model) - count_token(input, model=model)


def count_token(input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613") -> int:
"""Count number of tokens used by an OpenAI model.
Args:
input: (str, list, dict): Input to the model.
model: (str): Model name.
Returns:
int: Number of tokens from the input.
"""
if isinstance(input, str):
return _num_token_from_text(input, model=model)
elif isinstance(input, list) or isinstance(input, dict):
return _num_token_from_messages(input, model=model)
else:
raise ValueError("input must be str, list or dict")


def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a string."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))


def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages.
retrieved from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb/
"""
if isinstance(messages, dict):
messages = [messages]

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
logger.info("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return _num_token_from_messages(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logger.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
if value is None:
continue

# function calls
if not isinstance(value, str):
try:
value = json.dumps(value)
except TypeError:
logger.warning(
f"Value {value} is not a string and cannot be converted to json. It is a type: {type(value)} Skipping."
)
continue

num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens


def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
"""Return the number of tokens used by a list of functions.
Args:
functions: (list): List of function descriptions that will be passed in model.
model: (str): Model name.
Returns:
int: Number of tokens from the function descriptions.
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")

num_tokens = 0
for function in functions:
function_tokens = len(encoding.encode(function["name"]))
function_tokens += len(encoding.encode(function["description"]))
function_tokens -= 2
if "parameters" in function:
parameters = function["parameters"]
if "properties" in parameters:
for propertiesKey in parameters["properties"]:
function_tokens += len(encoding.encode(propertiesKey))
v = parameters["properties"][propertiesKey]
for field in v:
if field == "type":
function_tokens += 2
function_tokens += len(encoding.encode(v["type"]))
elif field == "description":
function_tokens += 2
function_tokens += len(encoding.encode(v["description"]))
elif field == "enum":
function_tokens -= 3
for o in v["enum"]:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
print(f"Warning: not supported field {field}")
function_tokens += 11
if len(parameters["properties"]) == 0:
function_tokens -= 2

num_tokens += function_tokens

num_tokens += 12
return num_tokens
26 changes: 2 additions & 24 deletions test/test_retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
is_url,
create_vector_db_from_dir,
query_vector_db,
num_tokens_from_text,
num_tokens_from_messages,
TEXT_FORMATS,
)
from autogen.token_count_utils import count_token

import os
import sys
Expand All @@ -31,31 +30,10 @@


class TestRetrieveUtils:
def test_num_tokens_from_text_custom_token_count_function(self):
def custom_token_count_function(text):
return len(text), 1, 2

text = "This is a sample text."
assert num_tokens_from_text(
text, return_tokens_per_name_and_message=True, custom_token_count_function=custom_token_count_function
) == (22, 1, 2)

def test_num_tokens_from_text(self):
text = "This is a sample text."
assert num_tokens_from_text(text) == len(tiktoken.get_encoding("cl100k_base").encode(text))

def test_num_tokens_from_messages(self):
messages = [{"content": "This is a sample text."}, {"content": "Another sample text."}]
# Review the implementation of num_tokens_from_messages
# and adjust the expected_tokens accordingly.
actual_tokens = num_tokens_from_messages(messages)
expected_tokens = actual_tokens # Adjusted to make the test pass temporarily.
assert actual_tokens == expected_tokens

def test_split_text_to_chunks(self):
long_text = "A" * 10000
chunks = split_text_to_chunks(long_text, max_tokens=1000)
assert all(num_tokens_from_text(chunk) <= 1000 for chunk in chunks)
assert all(count_token(chunk) <= 1000 for chunk in chunks)

def test_split_text_to_chunks_raises_on_invalid_chunk_mode(self):
with pytest.raises(AssertionError):
Expand Down
Loading

0 comments on commit 8c3401d

Please sign in to comment.