Skip to content

Commit 5636a88

Browse files
authored
Add cost summary to client.py (microsoft#812)
* init commit * add doc, notebook and test * fix test * update * update * update * update
1 parent c2386bf commit 5636a88

File tree

6 files changed

+485
-17
lines changed

6 files changed

+485
-17
lines changed

autogen/oai/client.py

+106-17
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class OpenAIWrapper:
3636
cache_path_root: str = ".cache"
3737
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
3838
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
39+
total_usage_summary: Dict = None
40+
actual_usage_summary: Dict = None
3941

4042
def __init__(self, *, config_list: List[Dict] = None, **base_config):
4143
"""
@@ -233,14 +235,15 @@ def yes_or_no_filter(context, response):
233235
# Try to get the response from cache
234236
key = get_key(params)
235237
response = cache.get(key, None)
238+
if response is not None:
239+
self._update_usage_summary(response, use_cache=True)
236240
if response is not None:
237241
# check the filter
238242
pass_filter = filter_func is None or filter_func(context=context, response=response)
239243
if pass_filter or i == last:
240244
# Return the response if it passes the filter or it is the last client
241245
response.config_id = i
242246
response.pass_filter = pass_filter
243-
response.cost = self.cost(response)
244247
return response
245248
continue # filter is not passed; try the next config
246249
try:
@@ -250,6 +253,9 @@ def yes_or_no_filter(context, response):
250253
if i == last:
251254
raise
252255
else:
256+
# add cost calculation before caching not matter filter is passed or not
257+
response.cost = self.cost(response)
258+
self._update_usage_summary(response, use_cache=False)
253259
if cache_seed is not None:
254260
# Cache the response
255261
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
@@ -261,25 +267,9 @@ def yes_or_no_filter(context, response):
261267
# Return the response if it passes the filter or it is the last client
262268
response.config_id = i
263269
response.pass_filter = pass_filter
264-
response.cost = self.cost(response)
265270
return response
266271
continue # filter is not passed; try the next config
267272

268-
def cost(self, response: Union[ChatCompletion, Completion]) -> float:
269-
"""Calculate the cost of the response."""
270-
model = response.model
271-
if model not in oai_price1k:
272-
# TODO: add logging to warn that the model is not found
273-
return 0
274-
275-
n_input_tokens = response.usage.prompt_tokens
276-
n_output_tokens = response.usage.completion_tokens
277-
tmp_price1K = oai_price1k[model]
278-
# First value is input token rate, second value is output token rate
279-
if isinstance(tmp_price1K, tuple):
280-
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
281-
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
282-
283273
def _completions_create(self, client, params):
284274
completions = client.chat.completions if "messages" in params else client.completions
285275
# If streaming is enabled, has messages, and does not have functions, then
@@ -342,6 +332,105 @@ def _completions_create(self, client, params):
342332
response = completions.create(**params)
343333
return response
344334

335+
def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None:
336+
"""Update the usage summary.
337+
338+
Usage is calculated no mattter filter is passed or not.
339+
"""
340+
341+
def update_usage(usage_summary):
342+
if usage_summary is None:
343+
usage_summary = {"total_cost": response.cost}
344+
else:
345+
usage_summary["total_cost"] += response.cost
346+
347+
usage_summary[response.model] = {
348+
"cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost,
349+
"prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0)
350+
+ response.usage.prompt_tokens,
351+
"completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0)
352+
+ response.usage.completion_tokens,
353+
"total_tokens": usage_summary.get(response.model, {}).get("total_tokens", 0)
354+
+ response.usage.total_tokens,
355+
}
356+
return usage_summary
357+
358+
self.total_usage_summary = update_usage(self.total_usage_summary)
359+
if not use_cache:
360+
self.actual_usage_summary = update_usage(self.actual_usage_summary)
361+
362+
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
363+
"""Print the usage summary."""
364+
365+
def print_usage(usage_summary, usage_type="total"):
366+
word_from_type = "including" if usage_type == "total" else "excluding"
367+
if usage_summary is None:
368+
print("No actual cost incurred (all completions are using cache).", flush=True)
369+
return
370+
371+
print(f"Usage summary {word_from_type} cached usage: ", flush=True)
372+
print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True)
373+
for model, counts in usage_summary.items():
374+
if model == "total_cost":
375+
continue #
376+
print(
377+
f"* Model '{model}': cost: {round(counts['cost'], 5)}, prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}",
378+
flush=True,
379+
)
380+
381+
if self.total_usage_summary is None:
382+
print('No usage summary. Please call "create" first.', flush=True)
383+
return
384+
385+
if isinstance(mode, list):
386+
if len(mode) == 0 or len(mode) > 2:
387+
raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
388+
if "actual" in mode and "total" in mode:
389+
mode = "both"
390+
elif "actual" in mode:
391+
mode = "actual"
392+
elif "total" in mode:
393+
mode = "total"
394+
395+
print("-" * 100, flush=True)
396+
if mode == "both":
397+
print_usage(self.actual_usage_summary, "actual")
398+
print()
399+
if self.total_usage_summary != self.actual_usage_summary:
400+
print_usage(self.total_usage_summary, "total")
401+
else:
402+
print(
403+
"All completions are non-cached: the total cost with cached completions is the same as actual cost.",
404+
flush=True,
405+
)
406+
elif mode == "total":
407+
print_usage(self.total_usage_summary, "total")
408+
elif mode == "actual":
409+
print_usage(self.actual_usage_summary, "actual")
410+
else:
411+
raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
412+
print("-" * 100, flush=True)
413+
414+
def clear_usage_summary(self) -> None:
415+
"""Clear the usage summary."""
416+
self.total_usage_summary = None
417+
self.actual_usage_summary = None
418+
419+
def cost(self, response: Union[ChatCompletion, Completion]) -> float:
420+
"""Calculate the cost of the response."""
421+
model = response.model
422+
if model not in oai_price1k:
423+
# TODO: add logging to warn that the model is not found
424+
return 0
425+
426+
n_input_tokens = response.usage.prompt_tokens
427+
n_output_tokens = response.usage.completion_tokens
428+
tmp_price1K = oai_price1k[model]
429+
# First value is input token rate, second value is output token rate
430+
if isinstance(tmp_price1K, tuple):
431+
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
432+
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
433+
345434
@classmethod
346435
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
347436
"""Extract the text or function calls from a completion or chat response.

0 commit comments

Comments
 (0)