diff --git a/autogen/oai/client.py b/autogen/oai/client.py index fd6742027d34..12dc9e8c3878 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -126,9 +126,11 @@ def message_retrieval( if TOOL_ENABLED: return [ # type: ignore [return-value] - choice.message # type: ignore [union-attr] - if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] - else choice.message.content # type: ignore [union-attr] + ( + choice.message # type: ignore [union-attr] + if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] + else choice.message.content + ) # type: ignore [union-attr] for choice in choices ] else: @@ -276,8 +278,8 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float: logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True) return 0 - n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr] - n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr] + n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr] + n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr] tmp_price1K = OAI_PRICE1K[model] # First value is input token rate, second value is output token rate if isinstance(tmp_price1K, tuple): @@ -287,10 +289,10 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float: @staticmethod def get_usage(response: Union[ChatCompletion, Completion]) -> Dict: return { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - "cost": response.cost, + "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0, + "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0, + "total_tokens": response.usage.total_tokens if response.usage is not None else 0, + "cost": response.cost if hasattr(response, "cost") else 0, "model": response.model, } @@ -471,12 +473,14 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: elif context: # Instantiate the messages params["messages"] = [ - { - **m, - "content": self.instantiate(m["content"], context, allow_format_str_template), - } - if m.get("content") - else m + ( + { + **m, + "content": self.instantiate(m["content"], context, allow_format_str_template), + } + if m.get("content") + else m + ) for m in messages # type: ignore [union-attr] ] return params