Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cost calculation to client #769

Merged
merged 7 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import os
import sys
from typing import List, Optional, Dict, Callable
from typing import List, Optional, Dict, Callable, Union
import logging
import inspect
from flaml.automl.logger import logger_formatter

from autogen.oai.openai_utils import get_key
from autogen.oai.openai_utils import get_key, oai_price1k
from autogen.token_count_utils import count_token

try:
Expand Down Expand Up @@ -240,7 +240,7 @@ def yes_or_no_filter(context, response):
# Return the response if it passes the filter or it is the last client
response.config_id = i
response.pass_filter = pass_filter
# TODO: add response.cost
response.cost = self.cost(response)
return response
continue # filter is not passed; try the next config
try:
Expand All @@ -261,10 +261,25 @@ def yes_or_no_filter(context, response):
# Return the response if it passes the filter or it is the last client
response.config_id = i
response.pass_filter = pass_filter
# TODO: add response.cost
response.cost = self.cost(response)
return response
continue # filter is not passed; try the next config

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in oai_price1k:
# TODO: add logging to warn that the model is not found
return 0
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved

n_input_tokens = response.usage.prompt_tokens
n_output_tokens = response.usage.completion_tokens
tmp_price1K = oai_price1k[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000

def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
# If streaming is enabled, has messages, and does not have functions, then
Expand Down
30 changes: 30 additions & 0 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,36 @@

NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]

oai_price1k = {
"text-ada-001": 0.0004,
"text-babbage-001": 0.0005,
"text-curie-001": 0.002,
"code-cushman-001": 0.024,
"code-davinci-002": 0.1,
"text-davinci-002": 0.02,
"text-davinci-003": 0.02,
"gpt-3.5-turbo-instruct": (0.0015, 0.002),
"gpt-3.5-turbo-0301": (0.0015, 0.002), # deprecate in Sep
"gpt-3.5-turbo-0613": (0.0015, 0.002),
"gpt-3.5-turbo-16k": (0.003, 0.004),
"gpt-3.5-turbo-16k-0613": (0.003, 0.004),
"gpt-35-turbo": (0.0015, 0.002),
"gpt-35-turbo-16k": (0.003, 0.004),
"gpt-35-turbo-instruct": (0.0015, 0.002),
"gpt-4": (0.03, 0.06),
"gpt-4-32k": (0.06, 0.12),
"gpt-4-0314": (0.03, 0.06), # deprecate in Sep
"gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep
"gpt-4-0613": (0.03, 0.06),
"gpt-4-32k-0613": (0.06, 0.12),
# 11-06
"gpt-3.5-turbo": (0.001, 0.002),
"gpt-3.5-turbo-1106": (0.001, 0.002),
"gpt-35-turbo-1106": (0.001, 0.002),
"gpt-4-1106-preview": (0.01, 0.03),
"gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images
}


def get_key(config):
"""Get a unique identifier of a configuration.
Expand Down
17 changes: 17 additions & 0 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,24 @@ def test_completion():
print(client.extract_text_or_function_call(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.parametrize(
"cache_seed, model",
[
(None, "gpt-3.5-turbo-instruct"),
(42, "gpt-3.5-turbo-instruct"),
(None, "text-ada-001"),
],
)
def test_cost(cache_seed, model):
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed)
response = client.create(prompt="1+3=", model=model)
print(response.cost)


if __name__ == "__main__":
test_aoai_chat_completion()
test_chat_completion()
test_completion()
test_cost()
2 changes: 2 additions & 0 deletions website/docs/Use-Cases/enhanced_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ client = OpenAIWrapper()
response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo")
# extract the response text
print(client.extract_text_or_function_call(response))
# get cost of this completion
print(response.cost)
# Azure OpenAI endpoint
client = OpenAIWrapper(api_key=..., base_url=..., api_version=..., api_type="azure")
# Completion
Expand Down
Loading