Skip to content

Commit

Permalink
Support custom models with OpenAI client (microsoft#4808)
Browse files Browse the repository at this point in the history
  • Loading branch information
srjoglekar246 authored Dec 24, 2024
1 parent d2537ab commit 3b4dd6e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,10 @@ def __init__(
):
self._client = client
if model_capabilities is None:
self._model_capabilities = _model_info.get_capabilities(create_args["model"])
try:
self._model_capabilities = _model_info.get_capabilities(create_args["model"])
except KeyError as err:
raise ValueError("model_capabilities is required when model name is not a valid OpenAI model") from err
else:
self._model_capabilities = model_capabilities

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from autogen_core import ComponentModel
from autogen_core.models import ModelCapabilities
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict

from .._azure_token_provider import AzureTokenProvider
Expand Down Expand Up @@ -79,6 +79,9 @@ class CreateArgumentsConfigModel(BaseModel):


class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel):
# To allow `model_capabilities` field without triggering pydantic warnings.
model_config = ConfigDict(protected_namespaces=())

model: str
api_key: str | None = None
timeout: float | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ async def test_openai_chat_completion_client() -> None:
assert client


@pytest.mark.asyncio
async def test_custom_model_with_capabilities() -> None:
with pytest.raises(ValueError, match="model_capabilities is required"):
client = OpenAIChatCompletionClient(model="dummy_model", base_url="https://api.dummy.com/v0", api_key="api_key")

client = OpenAIChatCompletionClient(
model="dummy_model",
base_url="https://api.dummy.com/v0",
api_key="api_key",
model_capabilities={
"vision": False,
"function_calling": False,
"json_output": False,
},
)
assert client


@pytest.mark.asyncio
async def test_azure_openai_chat_completion_client() -> None:
client = AzureOpenAIChatCompletionClient(
Expand Down

0 comments on commit 3b4dd6e

Please sign in to comment.