From 27e619e46e3abac5ba6a025698137d9f635c4dd8 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Fri, 27 Oct 2023 10:52:38 +0800 Subject: [PATCH] [Blocking Issue] Add tests dependencies for qdrant and fix chromadb errors (#435) * Add tests dependencies for qdrant * Update chromadb API * Update chromadb API version * Fix typehint * Add py 3.9 condition * Fix client creation error --- .github/workflows/openai.yml | 4 ++++ .../contrib/qdrant_retrieve_user_proxy_agent.py | 8 ++++---- .../agentchat/contrib/retrieve_user_proxy_agent.py | 2 +- autogen/retrieve_utils.py | 12 ++++++++++-- setup.py | 1 + 5 files changed, 20 insertions(+), 7 deletions(-) diff --git a/.github/workflows/openai.yml b/.github/workflows/openai.yml index 0a678fd9ed10..339c36527ba6 100644 --- a/.github/workflows/openai.yml +++ b/.github/workflows/openai.yml @@ -56,6 +56,10 @@ jobs: - name: Install packages for Teachable when needed run: | pip install -e .[teachable] + - name: Install packages for RetrieveChat with QDrant when needed + if: matrix.python-version == '3.9' + run: | + pip install qdrant_client[fastembed] - name: Coverage if: matrix.python-version == '3.9' env: diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py index b348b07e0b8b..e0bb8d8216f0 100644 --- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -18,10 +18,10 @@ class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent): def __init__( self, - name="RetrieveChatAgent", - human_input_mode: str | None = "ALWAYS", - is_termination_msg: Callable[[Dict], bool] | None = None, - retrieve_config: Dict | None = None, + name="RetrieveChatAgent", # default set to RetrieveChatAgent + human_input_mode: Optional[str] = "ALWAYS", + is_termination_msg: Optional[Callable[[Dict], bool]] = None, + retrieve_config: Optional[Dict] = None, # config for the retrieve agent **kwargs, ): """ diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 40a146e93edb..3ca7a3baba72 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -361,7 +361,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = """ if not self._collection or self._get_or_create: print("Trying to create collection.") - create_vector_db_from_dir( + self._client = create_vector_db_from_dir( dir_path=self._docs_path, max_tokens=self._chunk_token_size, client=self._client, diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index c29ced376a1e..0b47b6fe0d80 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -5,7 +5,11 @@ import glob import tiktoken import chromadb -from chromadb.api import API + +if chromadb.__version__ < "0.4.15": + from chromadb.api import API +else: + from chromadb.api import ClientAPI as API from chromadb.api.types import QueryResult import chromadb.utils.embedding_functions as ef import logging @@ -287,7 +291,7 @@ def create_vector_db_from_dir( embedding_model: str = "all-MiniLM-L6-v2", embedding_function: Callable = None, custom_text_split_function: Callable = None, -): +) -> API: """Create a vector db from all the files in a given directory, the directory can also be a single file or a url to a single file. We support chromadb compatible APIs to create the vector db, this function is not required if you prepared your own vector db. @@ -307,6 +311,9 @@ def create_vector_db_from_dir( embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. + + Returns: + API: the chromadb client. """ if client is None: client = chromadb.PersistentClient(path=db_path) @@ -344,6 +351,7 @@ def create_vector_db_from_dir( ) except ValueError as e: logger.warning(f"{e}") + return client def query_vector_db( diff --git a/setup.py b/setup.py index b5f846984aea..891eaba17884 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ "sympy", "tiktoken", "wolframalpha", + "qdrant_client[fastembed]", ], "blendsearch": ["flaml[blendsearch]"], "mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],