Skip to content

Commit

Permalink
Correct token count for GPT-4v images
Browse files Browse the repository at this point in the history
  • Loading branch information
BeibinLi committed Apr 17, 2024
1 parent 297904f commit 4f34f5b
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 5 deletions.
51 changes: 51 additions & 0 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re
from io import BytesIO
from math import ceil
from typing import Dict, List, Tuple, Union

import requests
Expand Down Expand Up @@ -298,3 +299,53 @@ def message_formatter_pil_to_b64(messages: List[Dict]) -> List[Dict]:
new_messages.append(message)

return new_messages


def num_tokens_from_gpt_image(image_data: Union[str, Image.Image]) -> int:
"""
Calculate the number of tokens required to process an image based on its dimensions after scaling.
This function scales the image so that its longest edge is at most 2048 pixels and its shortest edge
is at most 768 pixels. It then calculates the number of 512x512 tiles needed to cover the scaled
image and computes the total tokens based on the number of these tiles.
See more official details at:
- https://openai.com/pricing
- https://platform.openai.com/docs/guides/vision
See community discussion of OpenAI at:
- https://community.openai.com/t/how-do-i-calculate-image-tokens-in-gpt4-vision/
Args:
image_data : Union[str, Image.Image]: The image data which can either be a base64
encoded string, a URL, a file path, or a PIL Image object.
Returns:
int: The total number of tokens required for processing the image.
Examples
--------
>>> from PIL import Image
>>> img = Image.new('RGB', (2500, 2500), color = 'red')
>>> num_tokens_from_gpt_image(img)
765
"""
image = get_pil_image(image_data) # PIL Image
width, height = image.size

# 1. Constrain the longest edge to 2048 pixels
if max(width, height) > 2048:
scale_factor = 2048.0 / max(width, height)
width, height = int(width * scale_factor), int(height * scale_factor)

# 2. Further constrain the shortest edge to 768 pixels
if min(width, height) > 768:
scale_factor = 768.0 / min(width, height)
width, height = int(width * scale_factor), int(height * scale_factor)

# 3. Count how many tiles are needed to cover the image
tiles_width = ceil(width / 512)
tiles_height = ceil(height / 512)
total_tokens = 85 + 170 * (tiles_width * tiles_height)

return total_tokens
35 changes: 31 additions & 4 deletions autogen/token_count_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import json
import logging
import re
import warnings
from typing import Dict, List, Union

import tiktoken

try:
from autogen.agentchat.contrib.img_utils import num_tokens_from_gpt_image

img_util_imported = True
except ImportError:

def num_tokens_from_gpt_image(_):
return 0

img_util_imported = False

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -111,9 +123,6 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
elif "gpt-4" in model:
logger.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
elif "gemini" in model:
logger.info("Gemini is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
Expand All @@ -125,7 +134,25 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
if value is None:
continue

# function calls
# handle content if images are in GPT-4-vision
if key == "content" and isinstance(value, list):
for part in value:
if not isinstance(part, dict) or "type" not in part:
continue
if part["type"] == "text":
num_tokens += len(encoding.encode(part["text"]))
if "image_url" in part:
assert "url" in part["image_url"]
if not img_util_imported:
warnings.warn(
"img_utils or PIL not imported. Skipping image token count."
"Please install autogen with [lmm] option.",
ImportWarning,
)
num_tokens += num_tokens_from_gpt_image(part["image_url"]["url"])
continue

# function calls and other objects
if not isinstance(value, str):
try:
value = json.dumps(value)
Expand Down
20 changes: 20 additions & 0 deletions test/agentchat/contrib/test_img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
gpt4v_formatter,
llava_formatter,
message_formatter_pil_to_b64,
num_tokens_from_gpt_image,
)
except ImportError:
skip = True
Expand Down Expand Up @@ -290,5 +291,24 @@ def test_formatting(self):
self.assertEqual(result, expected_output)


@pytest.mark.skipif(skip, reason="dependency is not installed")
class ImageTokenCountTest(unittest.TestCase):
def test_tokens(self):
small_image = Image.new("RGB", (10, 10), color="red")
self.assertEqual(num_tokens_from_gpt_image(small_image), 85 + 170)

med_image = Image.new("RGB", (512, 1025), color="red")
self.assertEqual(num_tokens_from_gpt_image(med_image), 85 + 170 * 1 * 3)

tall_image = Image.new("RGB", (10, 1025), color="red")
self.assertEqual(num_tokens_from_gpt_image(tall_image), 85 + 170 * 1 * 3)

huge_image = Image.new("RGB", (10000, 10000), color="red")
self.assertEqual(num_tokens_from_gpt_image(huge_image), 85 + 170 * 2 * 2)

huge_wide_image = Image.new("RGB", (10000, 5000), color="red")
self.assertEqual(num_tokens_from_gpt_image(huge_wide_image), 85 + 170 * 3 * 2)


if __name__ == "__main__":
unittest.main()
40 changes: 39 additions & 1 deletion test/test_token_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
token_left,
)

try:
from autogen.agentchat.contrib.img_utils import num_tokens_from_gpt_image

img_util_imported = True
except ImportError:
img_util_imported = False

func1 = {
"name": "sh",
"description": "run a shell script and return the execution result.",
Expand Down Expand Up @@ -83,7 +90,38 @@ def test_model_aliases():
assert get_max_token_limit("gpt4-32k") == get_max_token_limit("gpt-4-32k")


@pytest.mark.skipif(not img_util_imported, reason="img_utils not imported")
def test_num_tokens_from_gpt_image():
# mock num_tokens_from_gpt_image function

base64_encoded_image = (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4"
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
)

messages = [
{
"role": "system",
"content": "you are a helpful assistant. af3758 *3 33(3)",
},
{
"role": "user",
"content": [
{"type": "text", "text": "hello asdfjj qeweee"},
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
],
},
]
tokens = count_token(messages, model="gpt-4-vision-preview")

# The total number of tokens is text + image
# where text = 34, as shown in the previous test case
# the image token is: 85 + 170 = 255
assert tokens == 34 + 255


if __name__ == "__main__":
# test_num_tokens_from_functions()
# test_count_token()
test_count_token()
test_model_aliases()
test_num_tokens_from_gpt_image()

0 comments on commit 4f34f5b

Please sign in to comment.