Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cost summary to client.py #812

Merged
merged 11 commits into from
Dec 3, 2023
Merged
123 changes: 106 additions & 17 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class OpenAIWrapper:
cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
total_usage_summary: Dict = None
actual_usage_summary: Dict = None

def __init__(self, *, config_list: List[Dict] = None, **base_config):
"""
Expand Down Expand Up @@ -233,14 +235,15 @@ def yes_or_no_filter(context, response):
# Try to get the response from cache
key = get_key(params)
response = cache.get(key, None)
if response is not None:
self._update_usage_summary(response, use_cache=True)
if response is not None:
# check the filter
pass_filter = filter_func is None or filter_func(context=context, response=response)
if pass_filter or i == last:
# Return the response if it passes the filter or it is the last client
response.config_id = i
response.pass_filter = pass_filter
response.cost = self.cost(response)
return response
continue # filter is not passed; try the next config
try:
Expand All @@ -250,6 +253,9 @@ def yes_or_no_filter(context, response):
if i == last:
raise
else:
# add cost calculation before caching not matter filter is passed or not
response.cost = self.cost(response)
self._update_usage_summary(response, use_cache=False)
if cache_seed is not None:
# Cache the response
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
Expand All @@ -261,25 +267,9 @@ def yes_or_no_filter(context, response):
# Return the response if it passes the filter or it is the last client
response.config_id = i
response.pass_filter = pass_filter
response.cost = self.cost(response)
return response
continue # filter is not passed; try the next config

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in oai_price1k:
# TODO: add logging to warn that the model is not found
return 0

n_input_tokens = response.usage.prompt_tokens
n_output_tokens = response.usage.completion_tokens
tmp_price1K = oai_price1k[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000

def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
# If streaming is enabled, has messages, and does not have functions, then
Expand Down Expand Up @@ -342,6 +332,105 @@ def _completions_create(self, client, params):
response = completions.create(**params)
return response

def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None:
"""Update the usage summary.

Usage is calculated no mattter filter is passed or not.
"""

def update_usage(usage_summary):
if usage_summary is None:
usage_summary = {"total_cost": response.cost}
else:
usage_summary["total_cost"] += response.cost

usage_summary[response.model] = {
"cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost,
"prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0)
+ response.usage.prompt_tokens,
"completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0)
+ response.usage.completion_tokens,
"total_tokens": usage_summary.get(response.model, {}).get("total_tokens", 0)
+ response.usage.total_tokens,
}
return usage_summary

self.total_usage_summary = update_usage(self.total_usage_summary)
if not use_cache:
self.actual_usage_summary = update_usage(self.actual_usage_summary)

def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
"""Print the usage summary."""

def print_usage(usage_summary, usage_type="total"):
word_from_type = "including" if usage_type == "total" else "excluding"
if usage_summary is None:
print("No actual cost incurred (all completions are using cache).", flush=True)
return

print(f"Usage summary {word_from_type} cached usage: ", flush=True)
print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True)
for model, counts in usage_summary.items():
if model == "total_cost":
continue #
print(
f"* Model '{model}': cost: {round(counts['cost'], 5)}, prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}",
flush=True,
)

if self.total_usage_summary is None:
print('No usage summary. Please call "create" first.', flush=True)
return

if isinstance(mode, list):
if len(mode) == 0 or len(mode) > 2:
raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
if "actual" in mode and "total" in mode:
mode = "both"
elif "actual" in mode:
mode = "actual"
elif "total" in mode:
mode = "total"

print("-" * 100, flush=True)
if mode == "both":
print_usage(self.actual_usage_summary, "actual")
print()
if self.total_usage_summary != self.actual_usage_summary:
print_usage(self.total_usage_summary, "total")
else:
print(
"All completions are non-cached: the total cost with cached completions is the same as actual cost.",
flush=True,
)
elif mode == "total":
print_usage(self.total_usage_summary, "total")
elif mode == "actual":
print_usage(self.actual_usage_summary, "actual")
else:
raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
print("-" * 100, flush=True)

def clear_usage_summary(self) -> None:
"""Clear the usage summary."""
self.total_usage_summary = None
self.actual_usage_summary = None

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in oai_price1k:
# TODO: add logging to warn that the model is not found
return 0

n_input_tokens = response.usage.prompt_tokens
n_output_tokens = response.usage.completion_tokens
tmp_price1K = oai_price1k[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000

@classmethod
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
"""Extract the text or function calls from a completion or chat response.
Expand Down
Loading
Loading