Skip to content

Commit

Permalink
Allow passing in custom pricing in config_list (#2902)
Browse files Browse the repository at this point in the history
* update

* update

* TODO comment removed

* update

---------

Co-authored-by: Yiran Wu <[email protected]>
Co-authored-by: Davor Runje <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
4 people authored Jun 13, 2024
1 parent cd8f437 commit 39f6887
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 6 deletions.
31 changes: 28 additions & 3 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,10 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in OAI_PRICE1K:
# TODO: add logging to warn that the model is not found
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
# log warning that the model is not found
logger.warning(
f'Model {model} is not found. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.'
)
return 0

n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
Expand Down Expand Up @@ -328,6 +330,7 @@ class OpenAIWrapper:
"api_version",
"api_type",
"tags",
"price",
}

openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
Expand Down Expand Up @@ -592,6 +595,14 @@ def yes_or_no_filter(context, response):
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")
agent = extra_kwargs.get("agent")
price = extra_kwargs.get("price", None)
if isinstance(price, list):
price = tuple(price)
elif isinstance(price, float) or isinstance(price, int):
logger.warning(
"Input price is a float/int. Using the same price for prompt and completion tokens. Use a list/tuple if prompt and completion token prices are different."
)
price = (price, price)

total_usage = None
actual_usage = None
Expand Down Expand Up @@ -678,7 +689,10 @@ def yes_or_no_filter(context, response):
raise
else:
# add cost calculation before caching no matter filter is passed or not
response.cost = client.cost(response)
if price is not None:
response.cost = self._cost_with_customized_price(response, price)
else:
response.cost = client.cost(response)
actual_usage = client.get_usage(response)
total_usage = actual_usage.copy() if actual_usage is not None else total_usage
self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
Expand Down Expand Up @@ -712,6 +726,17 @@ def yes_or_no_filter(context, response):
continue # filter is not passed; try the next config
raise RuntimeError("Should not reach here.")

@staticmethod
def _cost_with_customized_price(
response: ModelClient.ModelClientResponseProtocol, price_1k: Tuple[float, float]
) -> None:
"""If a customized cost is passed, overwrite the cost in the response."""
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
if n_output_tokens is None:
n_output_tokens = 0
return n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]

@staticmethod
def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
"""Update the dict from the chunk.
Expand Down
56 changes: 53 additions & 3 deletions notebook/agentchat_cost_token_tracking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@
"\n",
"To gather usage data for a list of agents, we provide an utility function `autogen.gather_usage_summary(agents)` where you pass in a list of agents and gather the usage summary.\n",
"\n",
"## 3. Custom token price for up-to-date cost estimation\n",
"AutoGen tries to keep the token prices up-to-date. However, you can pass in a `price` field in `config_list` if the token price is not listed or up-to-date. Please creating an issue or pull request to help us keep the token prices up-to-date!\n",
"\n",
"Note: in json files, the price should be a list of two floats.\n",
"\n",
"Example Usage:\n",
"```python\n",
"{\n",
" \"model\": \"gpt-3.5-turbo-xxxx\",\n",
" \"api_key\": \"YOUR_API_KEY\",\n",
" \"price\": [0.0005, 0.0015]\n",
"}\n",
"```\n",
"\n",
"## Caution when using Azure OpenAI!\n",
"If you are using azure OpenAI, the model returned from completion doesn't have the version information. The returned model is either 'gpt-35-turbo' or 'gpt-4'. From there, we are calculating the cost based on gpt-3.5-turbo-0125: (0.0005, 0.0015) per 1k prompt and completion tokens and gpt-4-0613: (0.03, 0.06). This means the cost can be wrong if you are using a different version from azure OpenAI.\n",
"\n",
Expand All @@ -55,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -65,7 +79,7 @@
"config_list = autogen.config_list_from_json(\n",
" \"OAI_CONFIG_LIST\",\n",
" filter_dict={\n",
" \"tags\": [\"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"], # comment out to get all\n",
" \"model\": [\"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"], # comment out to get all\n",
" },\n",
")"
]
Expand Down Expand Up @@ -127,6 +141,42 @@
"print(response.cost)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## OpenAIWrapper with custom token price"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Price: 109\n"
]
}
],
"source": [
"# Adding price to the config_list\n",
"for i in range(len(config_list)):\n",
" config_list[i][\"price\"] = [\n",
" 1,\n",
" 1,\n",
" ] # Note: This price is just for demonstration purposes. Please replace it with the actual price of the model.\n",
"\n",
"client = OpenAIWrapper(config_list=config_list)\n",
"messages = [\n",
" {\"role\": \"user\", \"content\": \"Can you give me 3 useful tips on learning Python? Keep it simple and short.\"},\n",
"]\n",
"response = client.create(messages=messages, cache_seed=None)\n",
"print(\"Price:\", response.cost)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -504,7 +554,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.9.19"
}
},
"nbformat": 4,
Expand Down
12 changes: 12 additions & 0 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ def test_cost(cache_seed):
print(response.cost)


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_customized_cost():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]}
)
for config in config_list:
config.update({"price": [1, 1]})
client = OpenAIWrapper(config_list=config_list, cache_seed=None)
response = client.create(prompt="1+3=")
assert response.cost >= 4, "Due to customized pricing, cost should be greater than 4"


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_usage_summary():
config_list = config_list_from_json(
Expand Down

0 comments on commit 39f6887

Please sign in to comment.