Skip to content

Commit

Permalink
check response usage is not None (microsoft#1599)
Browse files Browse the repository at this point in the history
  • Loading branch information
olgavrou authored Feb 8, 2024
1 parent 3483171 commit 67c19c0
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ def message_retrieval(

if TOOL_ENABLED:
return [ # type: ignore [return-value]
choice.message # type: ignore [union-attr]
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
else choice.message.content # type: ignore [union-attr]
(
choice.message # type: ignore [union-attr]
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
else choice.message.content
) # type: ignore [union-attr]
for choice in choices
]
else:
Expand Down Expand Up @@ -276,8 +278,8 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
return 0

n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr]
n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr]
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
tmp_price1K = OAI_PRICE1K[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
Expand All @@ -287,10 +289,10 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
@staticmethod
def get_usage(response: Union[ChatCompletion, Completion]) -> Dict:
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
"total_tokens": response.usage.total_tokens if response.usage is not None else 0,
"cost": response.cost if hasattr(response, "cost") else 0,
"model": response.model,
}

Expand Down Expand Up @@ -471,12 +473,14 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs:
elif context:
# Instantiate the messages
params["messages"] = [
{
**m,
"content": self.instantiate(m["content"], context, allow_format_str_template),
}
if m.get("content")
else m
(
{
**m,
"content": self.instantiate(m["content"], context, allow_format_str_template),
}
if m.get("content")
else m
)
for m in messages # type: ignore [union-attr]
]
return params
Expand Down

0 comments on commit 67c19c0

Please sign in to comment.