1717from ..types .content import ContentBlock , Messages
1818from ..types .streaming import StreamEvent
1919from ..types .tools import ToolChoice , ToolResult , ToolSpec , ToolUse
20- from ._config_validation import validate_config_keys
20+ from ._validation import validate_config_keys
2121from .model import Model
2222
2323logger = logging .getLogger (__name__ )
@@ -175,7 +175,7 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
175175 }
176176
177177 @classmethod
178- def format_request_tool_choice (cls , tool_choice : ToolChoice ) -> Union [ str , dict [str , Any ] ]:
178+ def _format_request_tool_choice (cls , tool_choice : ToolChoice | None ) -> dict [str , Any ]:
179179 """Format a tool choice for OpenAI compatibility.
180180
181181 Args:
@@ -184,16 +184,19 @@ def format_request_tool_choice(cls, tool_choice: ToolChoice) -> Union[str, dict[
184184 Returns:
185185 OpenAI compatible tool choice format.
186186 """
187+ if not tool_choice :
188+ return {}
189+
187190 match tool_choice :
188191 case {"auto" : _}:
189- return " auto" # OpenAI SDK doesn't define constants for these values
192+ return { "tool_choice" : " auto"} # OpenAI SDK doesn't define constants for these values
190193 case {"any" : _}:
191- return " required"
194+ return { "tool_choice" : " required"}
192195 case {"tool" : {"name" : tool_name }}:
193- return {"type" : "function" , "function" : {"name" : tool_name }}
196+ return {"tool_choice" : { " type" : "function" , "function" : {"name" : tool_name } }}
194197 case _:
195198 # This should not happen with proper typing, but handle gracefully
196- return " auto"
199+ return { "tool_choice" : " auto"}
197200
198201 @classmethod
199202 def format_request_messages (cls , messages : Messages , system_prompt : Optional [str ] = None ) -> list [dict [str , Any ]]:
@@ -241,7 +244,7 @@ def format_request(
241244 messages : Messages ,
242245 tool_specs : Optional [list [ToolSpec ]] = None ,
243246 system_prompt : Optional [str ] = None ,
244- tool_choice : Optional [ ToolChoice ] = None ,
247+ tool_choice : ToolChoice | None = None ,
245248 ) -> dict [str , Any ]:
246249 """Format an OpenAI compatible chat streaming request.
247250
@@ -274,7 +277,7 @@ def format_request(
274277 }
275278 for tool_spec in tool_specs or []
276279 ],
277- ** ({ "tool_choice" : self .format_request_tool_choice (tool_choice )} if tool_choice else {} ),
280+ ** (self ._format_request_tool_choice (tool_choice )),
278281 ** cast (dict [str , Any ], self .config .get ("params" , {})),
279282 }
280283
@@ -356,7 +359,7 @@ async def stream(
356359 messages : Messages ,
357360 tool_specs : Optional [list [ToolSpec ]] = None ,
358361 system_prompt : Optional [str ] = None ,
359- tool_choice : Optional [ ToolChoice ] = None ,
362+ tool_choice : ToolChoice | None = None ,
360363 ** kwargs : Any ,
361364 ) -> AsyncGenerator [StreamEvent , None ]:
362365 """Stream conversation with the OpenAI model.
0 commit comments