Skip to content

Commit

Permalink
Fix custom client registration (#1653)
Browse files Browse the repository at this point in the history
* fix custom client registration

* fix

* add test with extra args
  • Loading branch information
olgavrou authored Feb 13, 2024
1 parent b05b148 commit 71829f3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
7 changes: 3 additions & 4 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,9 @@ def register_model_client(self, model_client_cls: ModelClient, **kwargs):
if isinstance(client, PlaceHolderClient):
placeholder_config = client.config

if placeholder_config in self._config_list:
if placeholder_config.get("model_client_cls") == model_client_cls.__name__:
self._clients[i] = model_client_cls(placeholder_config, **kwargs)
return
if placeholder_config.get("model_client_cls") == model_client_cls.__name__:
self._clients[i] = model_client_cls(placeholder_config, **kwargs)
return
elif isinstance(client, model_client_cls):
existing_client_class = True

Expand Down
43 changes: 40 additions & 3 deletions test/oai/test_custom_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
skip = False


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_custom_model_client():
TEST_COST = 20000000
TEST_CUSTOM_RESPONSE = "This is a custom response."
Expand Down Expand Up @@ -87,7 +86,6 @@ def get_usage(response) -> Dict:
assert test_hook["max_length"] == TEST_MAX_LENGTH


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_registering_with_wrong_class_name_raises_error():
class CustomModel:
def __init__(self, config: Dict):
Expand Down Expand Up @@ -118,7 +116,6 @@ def get_usage(response) -> Dict:
client.register_model_client(model_client_cls=CustomModel)


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_not_all_clients_registered_raises_error():
class CustomModel:
def __init__(self, config: Dict):
Expand Down Expand Up @@ -164,3 +161,43 @@ def get_usage(response) -> Dict:

with pytest.raises(RuntimeError):
client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)


def test_registering_with_extra_config_args():
class CustomModel:
def __init__(self, config: Dict, test_hook):
self.test_hook = test_hook
self.test_hook["called"] = True

def create(self, params):
from types import SimpleNamespace

response = SimpleNamespace()
response.choices = []
return response

def message_retrieval(self, response):
return []

def cost(self, response) -> float:
"""Calculate the cost of the response."""
return 0

@staticmethod
def get_usage(response) -> Dict:
return {}

config_list = [
{
"model": "local_model_name",
"model_client_cls": "CustomModel",
"device": "test_device",
},
]

test_hook = {"called": False}

client = OpenAIWrapper(config_list=config_list, cache_seed=None)
client.register_model_client(model_client_cls=CustomModel, test_hook=test_hook)
client.create(messages=[{"role": "user", "content": "2+2="}])
assert test_hook["called"]

0 comments on commit 71829f3

Please sign in to comment.