Skip to content

Commit 58d77bc

Browse files
authored
check response usage is not None (#1599)
1 parent 9eab2b5 commit 58d77bc

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

autogen/oai/client.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,11 @@ def message_retrieval(
126126

127127
if TOOL_ENABLED:
128128
return [ # type: ignore [return-value]
129-
choice.message # type: ignore [union-attr]
130-
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
131-
else choice.message.content # type: ignore [union-attr]
129+
(
130+
choice.message # type: ignore [union-attr]
131+
if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
132+
else choice.message.content
133+
) # type: ignore [union-attr]
132134
for choice in choices
133135
]
134136
else:
@@ -276,8 +278,8 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
276278
logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
277279
return 0
278280

279-
n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr]
280-
n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr]
281+
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
282+
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
281283
tmp_price1K = OAI_PRICE1K[model]
282284
# First value is input token rate, second value is output token rate
283285
if isinstance(tmp_price1K, tuple):
@@ -287,10 +289,10 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
287289
@staticmethod
288290
def get_usage(response: Union[ChatCompletion, Completion]) -> Dict:
289291
return {
290-
"prompt_tokens": response.usage.prompt_tokens,
291-
"completion_tokens": response.usage.completion_tokens,
292-
"total_tokens": response.usage.total_tokens,
293-
"cost": response.cost,
292+
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
293+
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
294+
"total_tokens": response.usage.total_tokens if response.usage is not None else 0,
295+
"cost": response.cost if hasattr(response, "cost") else 0,
294296
"model": response.model,
295297
}
296298

@@ -471,12 +473,14 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs:
471473
elif context:
472474
# Instantiate the messages
473475
params["messages"] = [
474-
{
475-
**m,
476-
"content": self.instantiate(m["content"], context, allow_format_str_template),
477-
}
478-
if m.get("content")
479-
else m
476+
(
477+
{
478+
**m,
479+
"content": self.instantiate(m["content"], context, allow_format_str_template),
480+
}
481+
if m.get("content")
482+
else m
483+
)
480484
for m in messages # type: ignore [union-attr]
481485
]
482486
return params

0 commit comments

Comments
 (0)