Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: check response usage is not None #1599

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ def message_retrieval(

if TOOL_ENABLED:
return [ # type: ignore [return-value]
choice.message # type: ignore [union-attr]
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
else choice.message.content # type: ignore [union-attr]
(
choice.message # type: ignore [union-attr]
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
else choice.message.content
) # type: ignore [union-attr]
for choice in choices
]
else:
Expand Down Expand Up @@ -276,8 +278,8 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
return 0

n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr]
n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr]
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
tmp_price1K = OAI_PRICE1K[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
Expand All @@ -287,10 +289,10 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
@staticmethod
def get_usage(response: Union[ChatCompletion, Completion]) -> Dict:
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
"total_tokens": response.usage.total_tokens if response.usage is not None else 0,
"cost": response.cost if hasattr(response, "cost") else 0,
"model": response.model,
}

Expand Down Expand Up @@ -471,12 +473,14 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs:
elif context:
# Instantiate the messages
params["messages"] = [
{
**m,
"content": self.instantiate(m["content"], context, allow_format_str_template),
}
if m.get("content")
else m
(
jackgerrits marked this conversation as resolved.
Show resolved Hide resolved
{
**m,
"content": self.instantiate(m["content"], context, allow_format_str_template),
}
if m.get("content")
else m
)
for m in messages # type: ignore [union-attr]
]
return params
Expand Down
Loading