From 0ae2f6130655a47a190482b2c394015243d944e2 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Fri, 27 Oct 2023 09:24:14 +0800 Subject: [PATCH] Fix client creation error --- autogen/agentchat/contrib/retrieve_user_proxy_agent.py | 2 +- autogen/retrieve_utils.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 40a146e93edb..3ca7a3baba72 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -361,7 +361,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = """ if not self._collection or self._get_or_create: print("Trying to create collection.") - create_vector_db_from_dir( + self._client = create_vector_db_from_dir( dir_path=self._docs_path, max_tokens=self._chunk_token_size, client=self._client, diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index 4fc8e6edd2fe..0b47b6fe0d80 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -291,7 +291,7 @@ def create_vector_db_from_dir( embedding_model: str = "all-MiniLM-L6-v2", embedding_function: Callable = None, custom_text_split_function: Callable = None, -): +) -> API: """Create a vector db from all the files in a given directory, the directory can also be a single file or a url to a single file. We support chromadb compatible APIs to create the vector db, this function is not required if you prepared your own vector db. @@ -311,6 +311,9 @@ def create_vector_db_from_dir( embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. + + Returns: + API: the chromadb client. """ if client is None: client = chromadb.PersistentClient(path=db_path) @@ -348,6 +351,7 @@ def create_vector_db_from_dir( ) except ValueError as e: logger.warning(f"{e}") + return client def query_vector_db(