Skip to content

Commit 625e2b4

Browse files
yiranwu0davorrunjesonichi
authored andcommitted
Allow passing in custom pricing in config_list (#2902)
* update * update * TODO comment removed * update --------- Co-authored-by: Yiran Wu <[email protected]> Co-authored-by: Davor Runje <[email protected]> Co-authored-by: Chi Wang <[email protected]>
1 parent 1de3357 commit 625e2b4

File tree

3 files changed

+93
-6
lines changed

3 files changed

+93
-6
lines changed

autogen/oai/client.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,10 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
290290
"""Calculate the cost of the response."""
291291
model = response.model
292292
if model not in OAI_PRICE1K:
293-
# TODO: add logging to warn that the model is not found
294-
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
293+
# log warning that the model is not found
294+
logger.warning(
295+
f'Model {model} is not found. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.'
296+
)
295297
return 0
296298

297299
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
@@ -328,6 +330,7 @@ class OpenAIWrapper:
328330
"api_version",
329331
"api_type",
330332
"tags",
333+
"price",
331334
}
332335

333336
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
@@ -592,6 +595,14 @@ def yes_or_no_filter(context, response):
592595
filter_func = extra_kwargs.get("filter_func")
593596
context = extra_kwargs.get("context")
594597
agent = extra_kwargs.get("agent")
598+
price = extra_kwargs.get("price", None)
599+
if isinstance(price, list):
600+
price = tuple(price)
601+
elif isinstance(price, float) or isinstance(price, int):
602+
logger.warning(
603+
"Input price is a float/int. Using the same price for prompt and completion tokens. Use a list/tuple if prompt and completion token prices are different."
604+
)
605+
price = (price, price)
595606

596607
total_usage = None
597608
actual_usage = None
@@ -678,7 +689,10 @@ def yes_or_no_filter(context, response):
678689
raise
679690
else:
680691
# add cost calculation before caching no matter filter is passed or not
681-
response.cost = client.cost(response)
692+
if price is not None:
693+
response.cost = self._cost_with_customized_price(response, price)
694+
else:
695+
response.cost = client.cost(response)
682696
actual_usage = client.get_usage(response)
683697
total_usage = actual_usage.copy() if actual_usage is not None else total_usage
684698
self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
@@ -712,6 +726,17 @@ def yes_or_no_filter(context, response):
712726
continue # filter is not passed; try the next config
713727
raise RuntimeError("Should not reach here.")
714728

729+
@staticmethod
730+
def _cost_with_customized_price(
731+
response: ModelClient.ModelClientResponseProtocol, price_1k: Tuple[float, float]
732+
) -> None:
733+
"""If a customized cost is passed, overwrite the cost in the response."""
734+
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
735+
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
736+
if n_output_tokens is None:
737+
n_output_tokens = 0
738+
return n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]
739+
715740
@staticmethod
716741
def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
717742
"""Update the dict from the chunk.

notebook/agentchat_cost_token_tracking.ipynb

+53-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@
3131
"\n",
3232
"To gather usage data for a list of agents, we provide an utility function `autogen.gather_usage_summary(agents)` where you pass in a list of agents and gather the usage summary.\n",
3333
"\n",
34+
"## 3. Custom token price for up-to-date cost estimation\n",
35+
"AutoGen tries to keep the token prices up-to-date. However, you can pass in a `price` field in `config_list` if the token price is not listed or up-to-date. Please creating an issue or pull request to help us keep the token prices up-to-date!\n",
36+
"\n",
37+
"Note: in json files, the price should be a list of two floats.\n",
38+
"\n",
39+
"Example Usage:\n",
40+
"```python\n",
41+
"{\n",
42+
" \"model\": \"gpt-3.5-turbo-xxxx\",\n",
43+
" \"api_key\": \"YOUR_API_KEY\",\n",
44+
" \"price\": [0.0005, 0.0015]\n",
45+
"}\n",
46+
"```\n",
47+
"\n",
3448
"## Caution when using Azure OpenAI!\n",
3549
"If you are using azure OpenAI, the model returned from completion doesn't have the version information. The returned model is either 'gpt-35-turbo' or 'gpt-4'. From there, we are calculating the cost based on gpt-3.5-turbo-0125: (0.0005, 0.0015) per 1k prompt and completion tokens and gpt-4-0613: (0.03, 0.06). This means the cost can be wrong if you are using a different version from azure OpenAI.\n",
3650
"\n",
@@ -55,7 +69,7 @@
5569
},
5670
{
5771
"cell_type": "code",
58-
"execution_count": 1,
72+
"execution_count": null,
5973
"metadata": {},
6074
"outputs": [],
6175
"source": [
@@ -65,7 +79,7 @@
6579
"config_list = autogen.config_list_from_json(\n",
6680
" \"OAI_CONFIG_LIST\",\n",
6781
" filter_dict={\n",
68-
" \"tags\": [\"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"], # comment out to get all\n",
82+
" \"model\": [\"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"], # comment out to get all\n",
6983
" },\n",
7084
")"
7185
]
@@ -127,6 +141,42 @@
127141
"print(response.cost)"
128142
]
129143
},
144+
{
145+
"cell_type": "markdown",
146+
"metadata": {},
147+
"source": [
148+
"## OpenAIWrapper with custom token price"
149+
]
150+
},
151+
{
152+
"cell_type": "code",
153+
"execution_count": 7,
154+
"metadata": {},
155+
"outputs": [
156+
{
157+
"name": "stdout",
158+
"output_type": "stream",
159+
"text": [
160+
"Price: 109\n"
161+
]
162+
}
163+
],
164+
"source": [
165+
"# Adding price to the config_list\n",
166+
"for i in range(len(config_list)):\n",
167+
" config_list[i][\"price\"] = [\n",
168+
" 1,\n",
169+
" 1,\n",
170+
" ] # Note: This price is just for demonstration purposes. Please replace it with the actual price of the model.\n",
171+
"\n",
172+
"client = OpenAIWrapper(config_list=config_list)\n",
173+
"messages = [\n",
174+
" {\"role\": \"user\", \"content\": \"Can you give me 3 useful tips on learning Python? Keep it simple and short.\"},\n",
175+
"]\n",
176+
"response = client.create(messages=messages, cache_seed=None)\n",
177+
"print(\"Price:\", response.cost)"
178+
]
179+
},
130180
{
131181
"cell_type": "markdown",
132182
"metadata": {},
@@ -504,7 +554,7 @@
504554
"name": "python",
505555
"nbconvert_exporter": "python",
506556
"pygments_lexer": "ipython3",
507-
"version": "3.10.13"
557+
"version": "3.9.19"
508558
}
509559
},
510560
"nbformat": 4,

test/oai/test_client.py

+12
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,18 @@ def test_cost(cache_seed):
130130
print(response.cost)
131131

132132

133+
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
134+
def test_customized_cost():
135+
config_list = config_list_from_json(
136+
env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]}
137+
)
138+
for config in config_list:
139+
config.update({"price": [1, 1]})
140+
client = OpenAIWrapper(config_list=config_list, cache_seed=None)
141+
response = client.create(prompt="1+3=")
142+
assert response.cost >= 4, "Due to customized pricing, cost should be greater than 4"
143+
144+
133145
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
134146
def test_usage_summary():
135147
config_list = config_list_from_json(

0 commit comments

Comments
 (0)