Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding tool_choice to ModelSettings #825

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from json import JSONDecodeError, loads as json_loads
from typing import Any, Literal, Union, cast, overload


from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

Expand Down
28 changes: 27 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterable
from dataclasses import dataclass, field
from itertools import chain
from typing import Literal, Union, cast
from typing import Literal, Union, Any, cast

from cohere import TextAssistantMessageContentItem
from httpx import AsyncClient as AsyncHTTPClient
Expand Down Expand Up @@ -71,10 +71,12 @@

CohereModelName = Union[NamedCohereModels, str]

V2ChatRequestToolChoice = Union[Literal["REQUIRED", "NONE"], Any]

class CohereModelSettings(ModelSettings):
"""Settings used for a Cohere model request."""


# This class is a placeholder for any future cohere-specific settings


Expand Down Expand Up @@ -166,6 +168,29 @@ async def request(
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
return self._process_response(response), _map_usage(response)

def _get_tool_choice(self, model_settings: CohereModelSettings) -> V2ChatRequestToolChoice | None:
"""Determine the tool_choice setting for the model.

Allowed values in model_settings:
- 'REQUIRED': The model must use at least one tool.
- 'NONE': The model is forced not to use a tool.
If not provided, the model is free to choose:
- If no tools are available, leave unspecified.
- If text responses are disallowed, force tool usage ('REQUIRED').
- If text responses are allowed, leave unspecified (free to choose).
"""
tool_choice: V2ChatRequestToolChoice | None = getattr(model_settings, 'tool_choice', None)

if tool_choice is None:
if not self.tools:
tool_choice = None
elif not self.allow_text_result:
tool_choice = 'REQUIRED'
else:
tool_choice = None

return tool_choice

async def _chat(
self,
messages: list[ModelMessage],
Expand All @@ -176,6 +201,7 @@ async def _chat(
model=self.model_name,
messages=cohere_messages,
tools=self.tools or OMIT,
tool_choice=self._get_tool_choice(model_settings) or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
"""

FunctionCallConfigMode = Literal["ANY", "NONE", "AUTO"]

class GeminiModelSettings(ModelSettings):
"""Settings used for a Gemini model request."""

# This class is a placeholder for any future gemini-specific settings


Expand Down
36 changes: 27 additions & 9 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from itertools import chain
from typing import Literal, cast, overload
from typing import Literal, Dict, Any, cast, overload

from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never
from typing_extensions import TypedDict, assert_never

from .. import UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
Expand Down Expand Up @@ -63,6 +63,10 @@
See [the Groq docs](https://console.groq.com/docs/models) for a full list.
"""

class ChatCompletionNamedToolChoiceParam(TypedDict):
type: Literal["named"]
name: str
parameters: Dict[str, Any]

class GroqModelSettings(ModelSettings):
"""Settings used for a Groq model request."""
Expand Down Expand Up @@ -180,16 +184,30 @@ async def _completions_create(
) -> chat.ChatCompletion:
pass

def _get_tool_choice(self, model_settings: GroqModelSettings) -> Literal['none', 'required', 'auto'] | None:
"""Get tool choice for the model.

- "auto": Default mode. Model decides if it uses the tool or not.
- "none": Prevents tool use.
- "required": Forces tool use.
"""
tool_choice: Literal['none', 'required', 'auto'] | None = getattr(model_settings, 'tool_choice', None)

if tool_choice is None:
if not self.tools:
tool_choice = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
tool_choice = 'auto'

return tool_choice

async def _completions_create(
self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
# standalone function to make it easier to override
if not self.tools:
tool_choice: Literal['none', 'required', 'auto'] | None = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
tool_choice = 'auto'


groq_messages = list(chain(*(self._map_message(m) for m in messages)))

Expand All @@ -199,7 +217,7 @@ async def _completions_create(
n=1,
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
tools=self.tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
tool_choice=self._get_tool_choice(model_settings) or NOT_GIVEN,
stream=stream,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
temperature=model_settings.get('temperature', NOT_GIVEN),
Expand Down
41 changes: 29 additions & 12 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from itertools import chain
from typing import Literal, Union, cast, overload

from typing import Literal, Union, cast, overload, Any, Dict
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never
from typing_extensions import TypedDict, assert_never

from .. import UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
Expand Down Expand Up @@ -53,6 +52,11 @@

OpenAISystemPromptRole = Literal['system', 'developer', 'user']

class ChatCompletionNamedToolChoiceParam(TypedDict):
type: Literal["named"]
name: str
parameters: Dict[str, Any]


class OpenAIModelSettings(ModelSettings):
"""Settings used for an OpenAI model request."""
Expand Down Expand Up @@ -182,17 +186,30 @@ async def _completions_create(
) -> chat.ChatCompletion:
pass

def _get_tool_choice(self, model_settings: OpenAIModelSettings) -> Literal['none', 'required', 'auto'] | None:
"""Get tool choice for the model.

- "auto": Default mode. Model decides if it uses the tool or not.
- "none": Prevents tool use.
- "required": Forces tool use.
"""
tool_choice: Literal['none', 'required', 'auto'] | None = getattr(model_settings, 'tool_choice', None)

if tool_choice is None:
if not self.tools:
tool_choice = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
tool_choice = 'auto'

return tool_choice


async def _completions_create(
self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
# standalone function to make it easier to override
if not self.tools:
tool_choice: Literal['none', 'required', 'auto'] | None = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
tool_choice = 'auto'


openai_messages = list(chain(*(self._map_message(m) for m in messages)))

return await self.client.chat.completions.create(
Expand All @@ -201,7 +218,7 @@ async def _completions_create(
n=1,
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
tools=self.tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
tool_choice=self._get_tool_choice(model_settings) or NOT_GIVEN,
stream=stream,
stream_options={'include_usage': True} if stream else NOT_GIVEN,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
Expand Down
26 changes: 25 additions & 1 deletion pydantic_ai_slim/pydantic_ai/settings.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from typing import Literal, Dict, Union, Any
from httpx import Timeout
from typing_extensions import TypedDict

if TYPE_CHECKING:
pass

class ChatCompletionNamedToolChoiceParam(TypedDict):
type: Literal["named"]
name: str
parameters: Dict[str, Any]

class ModelSettings(TypedDict, total=False):
"""Settings to configure an LLM.
Expand Down Expand Up @@ -131,6 +135,26 @@ class ModelSettings(TypedDict, total=False):
"""


tool_choice: Union[
Literal["none", "auto", "required"],
ChatCompletionNamedToolChoiceParam
]
"""Whether to require a specific tool to be used.

Supported by:

* Gemini
* Anthropic
* OpenAI
* Groq
* Cohere
* Mistral
Comment on lines +146 to +151
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indication that it's supported by all of these models needs to be backed up with code changes - in all of the corresponding model files, we need to check model_settings.tool_choice like you've done for groq and openai

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also use more documentation - we'll want to document what each of these specifically means.

Additionally, I'm concerned about having this on the top ModelSettings level - anthropic supports "auto", "any", or a specific tool name, which is different than the above. Thus, I think we should implement tool_choice on individual model settings (like AnthropicModelSettings, OpenAIModelSettings) with the appropriate options.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I agree, will rebase onto each individual one since they all seem to to something slightly different

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like mistral already supports the tool_choice where when is creating the chat completion, it has a method self._get_tool_choice() at L251 that does the conditional check i added inline in groq and openai already, should i refactor those to match this ? seems to be a cleaner pattern

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also mistral has tool_choice set already in MistralToolChoiceEnum which was generated by speakeasy , not sure i like this pattern , i feel your rationale makes the most sense to have it inside each specific providers model settings, this seems out of place in its current state

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also seems like anthropic already has it supported via ToolChoiceParam which was also generated by speakeasy, not sure how to proceed on that as its already on that and mistral insofar


"""




def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
"""Merge two sets of model settings, preferring the overrides.

Expand Down