diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index e5f38ea5b8a..95f8e8dd8dd 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -262,7 +262,7 @@ def __init__( self._system_messages = [SystemMessage(content=system_message)] self._tools: List[Tool] = [] if tools is not None: - if model_client.capabilities["function_calling"] is False: + if model_client.model_info["function_calling"] is False: raise ValueError("The model does not support function calling.") for tool in tools: if isinstance(tool, Tool): @@ -283,7 +283,7 @@ def __init__( self._handoff_tools: List[Tool] = [] self._handoffs: Dict[str, HandoffBase] = {} if handoffs is not None: - if model_client.capabilities["function_calling"] is False: + if model_client.model_info["function_calling"] is False: raise ValueError("The model does not support function calling, which is needed for handoffs.") for handoff in handoffs: if isinstance(handoff, str): @@ -331,7 +331,7 @@ async def on_messages_stream( ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]: # Add messages to the model context. for msg in messages: - if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False: + if isinstance(msg, MultiModalMessage) and self._model_client.model_info["vision"] is False: raise ValueError("The model does not support vision.") await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source)) diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 48f51c4712e..ca079ce407b 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -19,6 +19,7 @@ from autogen_core import Image from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.models import LLMMessage +from autogen_core.models._model_client import ModelFamily from autogen_core.tools import FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient from openai.resources.chat.completions import AsyncCompletions @@ -387,11 +388,7 @@ async def test_invalid_model_capabilities() -> None: model_client = OpenAIChatCompletionClient( model=model, api_key="", - model_capabilities={ - "vision": False, - "function_calling": False, - "json_output": False, - }, + model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN}, ) with pytest.raises(ValueError): diff --git a/python/packages/autogen-core/src/autogen_core/models/__init__.py b/python/packages/autogen-core/src/autogen_core/models/__init__.py index 9b12aa702ed..c9fa23ffa11 100644 --- a/python/packages/autogen-core/src/autogen_core/models/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/models/__init__.py @@ -1,4 +1,4 @@ -from ._model_client import ChatCompletionClient, ModelCapabilities +from ._model_client import ChatCompletionClient, ModelCapabilities, ModelFamily, ModelInfo # type: ignore from ._types import ( AssistantMessage, ChatCompletionTokenLogprob, @@ -27,4 +27,6 @@ "CreateResult", "TopLogprob", "ChatCompletionTokenLogprob", + "ModelFamily", + "ModelInfo", ] diff --git a/python/packages/autogen-core/src/autogen_core/models/_model_client.py b/python/packages/autogen-core/src/autogen_core/models/_model_client.py index 7a0762a46e9..a952ad43458 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_model_client.py +++ b/python/packages/autogen-core/src/autogen_core/models/_model_client.py @@ -1,15 +1,10 @@ from __future__ import annotations +import warnings from abc import ABC, abstractmethod -from typing import Mapping, Optional, Sequence +from typing import Literal, Mapping, Optional, Sequence, TypeAlias -from typing_extensions import ( - Any, - AsyncGenerator, - Required, - TypedDict, - Union, -) +from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated from .. import CancellationToken from .._component_config import ComponentLoader @@ -17,12 +12,41 @@ from ._types import CreateResult, LLMMessage, RequestUsage +class ModelFamily: + """A model family is a group of models that share similar characteristics from a capabilities perspective. This is different to discrete supported features such as vision, function calling, and JSON output. + + This namespace class holds constants for the model families that AutoGen understands. Other families definitely exist and can be represented by a string, however, AutoGen will treat them as unknown.""" + + GPT_4O = "gpt-4o" + O1 = "o1" + GPT_4 = "gpt-4" + GPT_35 = "gpt-35" + UNKNOWN = "unknown" + + ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"] + + def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily: + raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.") + + +@deprecated("Use the ModelInfo class instead ModelCapabilities.") class ModelCapabilities(TypedDict, total=False): vision: Required[bool] function_calling: Required[bool] json_output: Required[bool] +class ModelInfo(TypedDict, total=False): + vision: Required[bool] + """True if the model supports vision, aka image input, otherwise False.""" + function_calling: Required[bool] + """True if the model supports function calling, otherwise False.""" + json_output: Required[bool] + """True if the model supports json output, otherwise False. Note: this is different to structured json.""" + family: Required[ModelFamily.ANY | str] + """Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family.""" + + class ChatCompletionClient(ABC, ComponentLoader): # Caching has to be handled internally as they can depend on the create args that were stored in the constructor @abstractmethod @@ -63,6 +87,18 @@ def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | @abstractmethod def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ... + # Deprecated + @property + @abstractmethod + def capabilities(self) -> ModelCapabilities: ... # type: ignore + @property @abstractmethod - def capabilities(self) -> ModelCapabilities: ... + def model_info(self) -> ModelInfo: + warnings.warn( + "Model client in use does not implement model_info property. Falling back to capabilities property. The capabilities property is deprecated and will be removed soon, please implement model_info instead in the model client class.", + stacklevel=2, + ) + base_info: ModelInfo = self.capabilities # type: ignore + base_info["family"] = ModelFamily.UNKNOWN + return base_info diff --git a/python/packages/autogen-core/tests/test_tool_agent.py b/python/packages/autogen-core/tests/test_tool_agent.py index c93815d5dfd..85fcd3892c9 100644 --- a/python/packages/autogen-core/tests/test_tool_agent.py +++ b/python/packages/autogen-core/tests/test_tool_agent.py @@ -11,10 +11,11 @@ FunctionExecutionResult, FunctionExecutionResultMessage, LLMMessage, - ModelCapabilities, + ModelCapabilities, # type: ignore RequestUsage, UserMessage, ) +from autogen_core.models._model_client import ModelFamily, ModelInfo from autogen_core.tool_agent import ( InvalidToolArgumentsException, ToolAgent, @@ -138,8 +139,12 @@ def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[To return 0 @property - def capabilities(self) -> ModelCapabilities: - return ModelCapabilities(vision=False, function_calling=True, json_output=False) + def capabilities(self) -> ModelCapabilities: # type: ignore + return ModelCapabilities(vision=False, function_calling=True, json_output=False) # type: ignore + + @property + def model_info(self) -> ModelInfo: + return ModelInfo(vision=False, function_calling=True, json_output=False, family=ModelFamily.UNKNOWN) client = MockChatCompletionClient() tools: List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")] diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py index 8f59964731d..cbf59e6b0eb 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py @@ -183,11 +183,11 @@ def __init__( raise ValueError( "Cannot save screenshots without a debug directory. Set it using the 'debug_dir' parameter. The debug directory is created if it does not exist." ) - if model_client.capabilities["function_calling"] is False: + if model_client.model_info["function_calling"] is False: raise ValueError( "The model does not support function calling. MultimodalWebSurfer requires a model that supports function calling." ) - if model_client.capabilities["vision"] is False: + if model_client.model_info["vision"] is False: raise ValueError("The model is not multimodal. MultimodalWebSurfer requires a multimodal model.") self._model_client = model_client self.headless = headless diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_model_info.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_model_info.py index 43ca65ac72d..d67beb44e1d 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_model_info.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_model_info.py @@ -1,6 +1,6 @@ from typing import Dict -from autogen_core.models import ModelCapabilities +from autogen_core.models import ModelFamily, ModelInfo # Based on: https://platform.openai.com/docs/models/continuous-model-upgrades # This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime`` @@ -17,86 +17,102 @@ "gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613", } -_MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = { +_MODEL_INFO: Dict[str, ModelInfo] = { "o1-preview-2024-09-12": { "vision": False, "function_calling": False, "json_output": False, + "family": ModelFamily.O1, }, "o1-mini-2024-09-12": { "vision": False, "function_calling": False, "json_output": False, + "family": ModelFamily.O1, }, "gpt-4o-2024-08-06": { "vision": True, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_4O, }, "gpt-4o-2024-05-13": { "vision": True, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_4O, }, "gpt-4o-mini-2024-07-18": { "vision": True, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_4O, }, "gpt-4-turbo-2024-04-09": { "vision": True, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_4, }, "gpt-4-0125-preview": { "vision": False, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_4, }, "gpt-4-1106-preview": { "vision": False, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_4, }, "gpt-4-1106-vision-preview": { "vision": True, "function_calling": False, "json_output": False, + "family": ModelFamily.GPT_4, }, "gpt-4-0613": { "vision": False, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_4, }, "gpt-4-32k-0613": { "vision": False, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_4, }, "gpt-3.5-turbo-0125": { "vision": False, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_35, }, "gpt-3.5-turbo-1106": { "vision": False, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_35, }, "gpt-3.5-turbo-instruct": { "vision": False, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_35, }, "gpt-3.5-turbo-0613": { "vision": False, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_35, }, "gpt-3.5-turbo-16k-0613": { "vision": False, "function_calling": True, "json_output": True, + "family": ModelFamily.GPT_35, }, } @@ -126,9 +142,9 @@ def resolve_model(model: str) -> str: return model -def get_capabilities(model: str) -> ModelCapabilities: +def get_info(model: str) -> ModelInfo: resolved_model = resolve_model(model) - return _MODEL_CAPABILITIES[resolved_model] + return _MODEL_INFO[resolved_model] def get_token_limit(model: str) -> int: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 2a0bfcf247e..db5fd911302 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -37,7 +37,9 @@ CreateResult, FunctionExecutionResultMessage, LLMMessage, - ModelCapabilities, + ModelCapabilities, # type: ignore + ModelFamily, + ModelInfo, RequestUsage, SystemMessage, TopLogprob, @@ -331,16 +333,24 @@ def __init__( client: Union[AsyncOpenAI, AsyncAzureOpenAI], *, create_args: Dict[str, Any], - model_capabilities: Optional[ModelCapabilities] = None, + model_capabilities: Optional[ModelCapabilities] = None, # type: ignore + model_info: Optional[ModelInfo] = None, ): self._client = client - if model_capabilities is None: + if model_capabilities is None and model_info is None: try: - self._model_capabilities = _model_info.get_capabilities(create_args["model"]) + self._model_info = _model_info.get_info(create_args["model"]) except KeyError as err: - raise ValueError("model_capabilities is required when model name is not a valid OpenAI model") from err - else: - self._model_capabilities = model_capabilities + raise ValueError("model_info is required when model name is not a valid OpenAI model") from err + elif model_capabilities is not None and model_info is not None: + raise ValueError("model_capabilities and model_info are mutually exclusive") + elif model_capabilities is not None and model_info is None: + warnings.warn("model_capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) + info = cast(ModelInfo, model_capabilities) + info["family"] = ModelFamily.UNKNOWN + self._model_info = info + elif model_capabilities is None and model_info is not None: + self._model_info = model_info self._resolved_model: Optional[str] = None if "model" in create_args: @@ -349,7 +359,7 @@ def __init__( if ( "response_format" in create_args and create_args["response_format"]["type"] == "json_object" - and not self._model_capabilities["json_output"] + and not self._model_info["json_output"] ): raise ValueError("Model does not support JSON output") @@ -870,8 +880,13 @@ def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[To return token_limit - self.count_tokens(messages, tools=tools) @property - def capabilities(self) -> ModelCapabilities: - return self._model_capabilities + def capabilities(self) -> ModelCapabilities: # type: ignore + warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) + return self._model_info + + @property + def model_info(self) -> ModelInfo: + return self._model_info class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenAIClientConfigurationConfigModel]): @@ -941,16 +956,23 @@ def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]): if "model" not in kwargs: raise ValueError("model is required for OpenAIChatCompletionClient") - model_capabilities: Optional[ModelCapabilities] = None + model_capabilities: Optional[ModelCapabilities] = None # type: ignore copied_args = dict(kwargs).copy() if "model_capabilities" in kwargs: model_capabilities = kwargs["model_capabilities"] del copied_args["model_capabilities"] + model_info: Optional[ModelInfo] = None + if "model_info" in kwargs: + model_info = kwargs["model_info"] + del copied_args["model_info"] + client = _openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) self._raw_config: Dict[str, Any] = copied_args - super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities) + super().__init__( + client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info + ) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() @@ -1026,16 +1048,23 @@ class AzureOpenAIChatCompletionClient( component_provider_override = "autogen_ext.models.openai.AzureOpenAIChatCompletionClient" def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]): - model_capabilities: Optional[ModelCapabilities] = None + model_capabilities: Optional[ModelCapabilities] = None # type: ignore copied_args = dict(kwargs).copy() if "model_capabilities" in kwargs: model_capabilities = kwargs["model_capabilities"] del copied_args["model_capabilities"] + model_info: Optional[ModelInfo] = None + if "model_info" in kwargs: + model_info = kwargs["model_info"] + del copied_args["model_info"] + client = _azure_openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) self._raw_config: Dict[str, Any] = copied_args - super().__init__(client=client, create_args=create_args, model_capabilities=model_capabilities) + super().__init__( + client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info + ) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py index c2938cb7c85..b98158504a8 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py @@ -1,7 +1,7 @@ from typing import Awaitable, Callable, Dict, List, Literal, Optional, Union from autogen_core import ComponentModel -from autogen_core.models import ModelCapabilities +from autogen_core.models import ModelCapabilities, ModelInfo # type: ignore from pydantic import BaseModel from typing_extensions import Required, TypedDict @@ -34,7 +34,8 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False): api_key: str timeout: Union[float, None] max_retries: int - model_capabilities: ModelCapabilities + model_capabilities: ModelCapabilities # type: ignore + model_info: ModelInfo """What functionality the model supports, determined by default from model name but is overriden if value passed.""" @@ -83,7 +84,8 @@ class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel): api_key: str | None = None timeout: float | None = None max_retries: int | None = None - model_capabilities: ModelCapabilities | None = None + model_capabilities: ModelCapabilities | None = None # type: ignore + model_info: ModelInfo | None = None # See OpenAI docs for explanation of these parameters diff --git a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py index 1167be59b8a..b62084b646b 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import warnings from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union from autogen_core import EVENT_LOGGER_NAME, CancellationToken @@ -8,7 +9,9 @@ ChatCompletionClient, CreateResult, LLMMessage, - ModelCapabilities, + ModelCapabilities, # type: ignore + ModelFamily, + ModelInfo, RequestUsage, ) from autogen_core.tools import Tool, ToolSchema @@ -119,7 +122,9 @@ def __init__( ): self.chat_completions = list(chat_completions) self.provided_message_count = len(self.chat_completions) - self._model_capabilities = ModelCapabilities(vision=False, function_calling=False, json_output=False) + self._model_info = ModelInfo( + vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN + ) self._total_available_tokens = 10000 self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) @@ -231,9 +236,14 @@ def _update_total_usage(self) -> None: self._total_usage.prompt_tokens += self._cur_usage.prompt_tokens @property - def capabilities(self) -> ModelCapabilities: + def capabilities(self) -> ModelCapabilities: # type: ignore """Return mock capabilities.""" - return self._model_capabilities + warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) + return self._model_info + + @property + def model_info(self) -> ModelInfo: + return self._model_info def reset(self) -> None: """Reset the client state and usage to its initial state.""" diff --git a/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py b/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py index 94ecc708ace..65c6ef0ba2e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py +++ b/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py @@ -124,7 +124,7 @@ def __init__(self, client: ChatCompletionClient, hil_mode: bool = False): super().__init__(agents, model_client=client) def _validate_client_capabilities(self, client: ChatCompletionClient) -> None: - capabilities = client.capabilities + capabilities = client.model_info required_capabilities = ["vision", "function_calling", "json_output"] if not all(capabilities.get(cap) for cap in required_capabilities): diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index 54ae8a6ad6e..73654546e21 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -14,6 +14,7 @@ SystemMessage, UserMessage, ) +from autogen_core.models._model_client import ModelFamily from autogen_core.tools import BaseTool, FunctionTool from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient from autogen_ext.models.openai._model_info import resolve_model @@ -139,18 +140,14 @@ async def test_openai_chat_completion_client() -> None: @pytest.mark.asyncio async def test_custom_model_with_capabilities() -> None: - with pytest.raises(ValueError, match="model_capabilities is required"): + with pytest.raises(ValueError, match="model_info is required"): client = OpenAIChatCompletionClient(model="dummy_model", base_url="https://api.dummy.com/v0", api_key="api_key") client = OpenAIChatCompletionClient( model="dummy_model", base_url="https://api.dummy.com/v0", api_key="api_key", - model_capabilities={ - "vision": False, - "function_calling": False, - "json_output": False, - }, + model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN}, ) assert client @@ -163,7 +160,7 @@ async def test_azure_openai_chat_completion_client() -> None: api_key="api_key", api_version="2020-08-04", azure_endpoint="https://dummy.com", - model_capabilities={"vision": True, "function_calling": True, "json_output": True}, + model_info={"vision": True, "function_calling": True, "json_output": True, "family": ModelFamily.GPT_4O}, ) assert client diff --git a/python/packages/autogen-magentic-one/src/autogen_magentic_one/utils.py b/python/packages/autogen-magentic-one/src/autogen_magentic_one/utils.py index e8df40bae59..0537142f0ff 100644 --- a/python/packages/autogen-magentic-one/src/autogen_magentic_one/utils.py +++ b/python/packages/autogen-magentic-one/src/autogen_magentic_one/utils.py @@ -9,7 +9,7 @@ from autogen_core.logging import LLMCallEvent from autogen_core.models import ( ChatCompletionClient, - ModelCapabilities, + ModelCapabilities, # type: ignore ) from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient @@ -54,7 +54,7 @@ def create_completion_client_from_env(env: Dict[str, str] | None = None, **kwarg # If model capabilities were provided, deserialize them as well if "model_capabilities" in _kwargs: - _kwargs["model_capabilities"] = ModelCapabilities( + _kwargs["model_capabilities"] = ModelCapabilities( # type: ignore vision=_kwargs["model_capabilities"].get("vision"), function_calling=_kwargs["model_capabilities"].get("function_calling"), json_output=_kwargs["model_capabilities"].get("json_output"),