From 5a7c7b527d530809cb226ef6edd8d8a1fa3cb0c8 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 8 Feb 2024 10:45:59 -0500 Subject: [PATCH] check response usage is not None --- autogen/oai/client.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 70251353f654..cf07d0696fc0 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -126,9 +126,11 @@ def message_retrieval( if TOOL_ENABLED: return [ # type: ignore [return-value] - choice.message # type: ignore [union-attr] - if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] - else choice.message.content # type: ignore [union-attr] + ( + choice.message # type: ignore [union-attr] + if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] + else choice.message.content + ) # type: ignore [union-attr] for choice in choices ] else: @@ -276,8 +278,8 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float: logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True) return 0 - n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr] - n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr] + n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr] + n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr] tmp_price1K = OAI_PRICE1K[model] # First value is input token rate, second value is output token rate if isinstance(tmp_price1K, tuple): @@ -287,10 +289,10 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float: @staticmethod def get_usage(response: Union[ChatCompletion, Completion]) -> Dict: return { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - "cost": response.cost, + "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0, + "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0, + "total_tokens": response.usage.total_tokens if response.usage is not None else 0, + "cost": response.cost if hasattr(response, "cost") else 0, "model": response.model, } @@ -471,12 +473,14 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: elif context: # Instantiate the messages params["messages"] = [ - { - **m, - "content": self.instantiate(m["content"], context, allow_format_str_template), - } - if m.get("content") - else m + ( + { + **m, + "content": self.instantiate(m["content"], context, allow_format_str_template), + } + if m.get("content") + else m + ) for m in messages # type: ignore [union-attr] ] return params