From 71829f3fdf2389f59d63c7376aa4b055427294e8 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 13 Feb 2024 23:41:20 +0200 Subject: [PATCH] Fix custom client registration (#1653) * fix custom client registration * fix * add test with extra args --- autogen/oai/client.py | 7 +++--- test/oai/test_custom_client.py | 43 +++++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 1c46a01e87fd..5970c87126d4 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -418,10 +418,9 @@ def register_model_client(self, model_client_cls: ModelClient, **kwargs): if isinstance(client, PlaceHolderClient): placeholder_config = client.config - if placeholder_config in self._config_list: - if placeholder_config.get("model_client_cls") == model_client_cls.__name__: - self._clients[i] = model_client_cls(placeholder_config, **kwargs) - return + if placeholder_config.get("model_client_cls") == model_client_cls.__name__: + self._clients[i] = model_client_cls(placeholder_config, **kwargs) + return elif isinstance(client, model_client_cls): existing_client_class = True diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 8e536921795f..04669a3e02ff 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -11,7 +11,6 @@ skip = False -@pytest.mark.skipif(skip, reason="openai>=1 not installed") def test_custom_model_client(): TEST_COST = 20000000 TEST_CUSTOM_RESPONSE = "This is a custom response." @@ -87,7 +86,6 @@ def get_usage(response) -> Dict: assert test_hook["max_length"] == TEST_MAX_LENGTH -@pytest.mark.skipif(skip, reason="openai>=1 not installed") def test_registering_with_wrong_class_name_raises_error(): class CustomModel: def __init__(self, config: Dict): @@ -118,7 +116,6 @@ def get_usage(response) -> Dict: client.register_model_client(model_client_cls=CustomModel) -@pytest.mark.skipif(skip, reason="openai>=1 not installed") def test_not_all_clients_registered_raises_error(): class CustomModel: def __init__(self, config: Dict): @@ -164,3 +161,43 @@ def get_usage(response) -> Dict: with pytest.raises(RuntimeError): client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) + + +def test_registering_with_extra_config_args(): + class CustomModel: + def __init__(self, config: Dict, test_hook): + self.test_hook = test_hook + self.test_hook["called"] = True + + def create(self, params): + from types import SimpleNamespace + + response = SimpleNamespace() + response.choices = [] + return response + + def message_retrieval(self, response): + return [] + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + return 0 + + @staticmethod + def get_usage(response) -> Dict: + return {} + + config_list = [ + { + "model": "local_model_name", + "model_client_cls": "CustomModel", + "device": "test_device", + }, + ] + + test_hook = {"called": False} + + client = OpenAIWrapper(config_list=config_list, cache_seed=None) + client.register_model_client(model_client_cls=CustomModel, test_hook=test_hook) + client.create(messages=[{"role": "user", "content": "2+2="}]) + assert test_hook["called"]