From fa6e2a52c00a28756b9cf30e85bc336d4c5da055 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 10 Oct 2023 20:53:18 +0800 Subject: [PATCH] Add support to customized vectordb and embedding functions (#161) * Add custom embedding function * Add support to custom vector db * Improve docstring * Improve docstring * Improve docstring * Add support to customized is_termination_msg fucntion * Add a test for customize vector db with lancedb * Fix tests * Add test for embedding_function * Update docstring --- .github/workflows/build.yml | 2 +- .../contrib/retrieve_user_proxy_agent.py | 75 ++++++++++++++++--- autogen/retrieve_utils.py | 62 +++++++++++++-- setup.py | 1 + test/agentchat/test_retrievechat.py | 3 + test/test_retrieve_utils.py | 64 ++++++++++++++++ 6 files changed, 192 insertions(+), 15 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 63ca0a254609..5e5fd186beac 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -40,7 +40,7 @@ jobs: python -m pip install --upgrade pip wheel pip install -e . python -c "import autogen" - pip install -e.[mathchat,retrievechat] datasets pytest + pip install -e.[mathchat,retrievechat,test] datasets pytest pip uninstall -y openai - name: Test with pytest if: matrix.python-version != '3.10' diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 0fc83bdb7593..0f29aa62d14f 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -67,6 +67,7 @@ def __init__( self, name="RetrieveChatAgent", # default set to RetrieveChatAgent human_input_mode: Optional[str] = "ALWAYS", + is_termination_msg: Optional[Callable[[Dict], bool]] = None, retrieve_config: Optional[Dict] = None, # config for the retrieve agent **kwargs, ): @@ -82,14 +83,17 @@ def __init__( the number of auto reply reaches the max_consecutive_auto_reply. (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. + is_termination_msg (function): a function that takes a message in the form of a dictionary + and returns a boolean value indicating if this received message is a termination message. + The dict can contain the following keys: "content", "role", "name", "function_call". retrieve_config (dict or None): config for the retrieve agent. To use default config, set to None. Otherwise, set to a dictionary with the following keys: - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System prompt will be different for different tasks. The default value is `default`, which supports both code and qa. - - client (Optional, chromadb.Client): the chromadb client. - If key not provided, a default client `chromadb.Client()` will be used. + - client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()` + will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function. - docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file, - or the url to a single file. If key not provided, a default path `./docs` will be used. + or the url to a single file. Default is None, which works only if the collection is already created. - collection_name (Optional, str): the name of the collection. If key not provided, a default name `autogen-docs` will be used. - model (Optional, str): the model to use for the retrieve chat. @@ -106,16 +110,45 @@ def __init__( If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended. + - embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None, + SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or + other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None. - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. - get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat. - This is the same as that used in chromadb. Default is False. + This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None. - custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string. The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name). Default is None, tiktoken will be used and may not be accurate for non-OpenAI models. **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). + + Example of overriding retrieve_docs: + If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code. + ```python + class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): + def query_vector_db( + self, + query_texts: List[str], + n_results: int = 10, + search_string: str = "", + **kwargs, + ) -> Dict[str, Union[List[str], List[List[str]]]]: + # define your own query function here + pass + + def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = "", **kwargs): + results = self.query_vector_db( + query_texts=[problem], + n_results=n_results, + search_string=search_string, + **kwargs, + ) + + self._results = results + print("doc_ids: ", results["ids"]) + ``` """ super().__init__( name=name, @@ -126,7 +159,7 @@ def __init__( self._retrieve_config = {} if retrieve_config is None else retrieve_config self._task = self._retrieve_config.get("task", "default") self._client = self._retrieve_config.get("client", chromadb.Client()) - self._docs_path = self._retrieve_config.get("docs_path", "./docs") + self._docs_path = self._retrieve_config.get("docs_path", None) self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs") self._model = self._retrieve_config.get("model", "gpt-4") self._max_tokens = self.get_max_tokens(self._model) @@ -134,20 +167,26 @@ def __init__( self._chunk_mode = self._retrieve_config.get("chunk_mode", "multi_lines") self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True) self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2") + self._embedding_function = self._retrieve_config.get("embedding_function", None) self.customized_prompt = self._retrieve_config.get("customized_prompt", None) self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper() self.update_context = self._retrieve_config.get("update_context", True) - self._get_or_create = self._retrieve_config.get("get_or_create", False) + self._get_or_create = ( + self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False + ) self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None) self._context_max_tokens = self._max_tokens * 0.8 - self._collection = False # the collection is not created + self._collection = True if self._docs_path is None else False # whether the collection is created self._ipython = get_ipython() self._doc_idx = -1 # the index of the current used doc self._results = {} # the results of the current query self._intermediate_answers = set() # the intermediate answers self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc - self._is_termination_msg = self._is_termination_msg_retrievechat # update the termination message function + # update the termination message function + self._is_termination_msg = ( + self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg + ) self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=1) def _is_termination_msg_retrievechat(self, message): @@ -188,7 +227,7 @@ def _reset(self, intermediate=False): self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc - def _get_context(self, results): + def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]): doc_contents = "" current_tokens = 0 _doc_idx = self._doc_idx @@ -297,6 +336,22 @@ def _generate_retrieve_user_reply( return False, None def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): + """Retrieve docs based on the given problem and assign the results to the class property `_results`. + In case you want to customize the retrieval process, such as using a different vector db whose APIs are not + compatible with chromadb or filter results with metadata, you can override this function. Just keep the current + parameters and add your own parameters with default values, and keep the results in below type. + + Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of + the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer + to `chromadb.api.types.QueryResult` as an example. + ids: List[string] + documents: List[List[string]] + + Args: + problem (str): the problem to be solved. + n_results (int): the number of results to be retrieved. + search_string (str): only docs containing this string will be retrieved. + """ if not self._collection or self._get_or_create: print("Trying to create collection.") create_vector_db_from_dir( @@ -308,6 +363,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = must_break_at_empty_line=self._must_break_at_empty_line, embedding_model=self._embedding_model, get_or_create=self._get_or_create, + embedding_function=self._embedding_function, ) self._collection = True self._get_or_create = False @@ -319,6 +375,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = client=self._client, collection_name=self._collection_name, embedding_model=self._embedding_model, + embedding_function=self._embedding_function, ) self._results = results print("doc_ids: ", results["ids"]) diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index fbe7c28784ae..721b1ec29a5e 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -6,6 +6,7 @@ import tiktoken import chromadb from chromadb.api import API +from chromadb.api.types import QueryResult import chromadb.utils.embedding_functions as ef import logging import pypdf @@ -263,12 +264,36 @@ def create_vector_db_from_dir( chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True, embedding_model: str = "all-MiniLM-L6-v2", + embedding_function: Callable = None, ): - """Create a vector db from all the files in a given directory.""" + """Create a vector db from all the files in a given directory, the directory can also be a single file or a url to + a single file. We support chromadb compatible APIs to create the vector db, this function is not required if + you prepared your own vector db. + + Args: + dir_path (str): the path to the directory, file or url. + max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000. + client (Optional, API): the chromadb client. Default is None. + db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db". + collection_name (Optional, str): the name of the collection. Default is "all-my-documents". + get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection + will be recreated if it already exists. + chunk_mode (Optional, str): the chunk mode. Default is "multi_lines". + must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True. + embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if + embedding_function is not None. + embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with + the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding + functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. + """ if client is None: client = chromadb.PersistentClient(path=db_path) try: - embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model) + embedding_function = ( + ef.SentenceTransformerEmbeddingFunction(embedding_model) + if embedding_function is None + else embedding_function + ) collection = client.create_collection( collection_name, get_or_create=get_or_create, @@ -300,14 +325,41 @@ def query_vector_db( collection_name: str = "all-my-documents", search_string: str = "", embedding_model: str = "all-MiniLM-L6-v2", -) -> Dict[str, List[str]]: - """Query a vector db.""" + embedding_function: Callable = None, +) -> QueryResult: + """Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db + and query function. + + Args: + query_texts (List[str]): the query texts. + n_results (Optional, int): the number of results to return. Default is 10. + client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used. + db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db". + collection_name (Optional, str): the name of the collection. Default is "all-my-documents". + search_string (Optional, str): the search string. Default is "". + embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if + embedding_function is not None. + embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with + the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding + functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. + + Returns: + QueryResult: the query result. The format is: + class QueryResult(TypedDict): + ids: List[IDs] + embeddings: Optional[List[List[Embedding]]] + documents: Optional[List[List[Document]]] + metadatas: Optional[List[List[Metadata]]] + distances: Optional[List[List[float]]] + """ if client is None: client = chromadb.PersistentClient(path=db_path) # the collection's embedding function is always the default one, but we want to use the one we used to create the # collection. So we compute the embeddings ourselves and pass it to the query function. collection = client.get_collection(collection_name) - embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model) + embedding_function = ( + ef.SentenceTransformerEmbeddingFunction(embedding_model) if embedding_function is None else embedding_function + ) query_embeddings = embedding_function(query_texts) # Query/search n most similar results. You can also .get by id results = collection.query( diff --git a/setup.py b/setup.py index 37c9d2d883fd..a42432eb0333 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ extras_require={ "test": [ "chromadb", + "lancedb", "coverage>=5.3", "datasets", "ipykernel", diff --git a/test/agentchat/test_retrievechat.py b/test/agentchat/test_retrievechat.py index bde5730cbbb2..99e395de5056 100644 --- a/test/agentchat/test_retrievechat.py +++ b/test/agentchat/test_retrievechat.py @@ -12,6 +12,7 @@ ) from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db import chromadb + from chromadb.utils import embedding_functions as ef skip_test = False except ImportError: @@ -49,6 +50,7 @@ def test_retrievechat(): }, ) + sentence_transformer_ef = ef.SentenceTransformerEmbeddingFunction() ragproxyagent = RetrieveUserProxyAgent( name="ragproxyagent", human_input_mode="NEVER", @@ -58,6 +60,7 @@ def test_retrievechat(): "chunk_token_size": 2000, "model": config_list[0]["model"], "client": chromadb.PersistentClient(path="/tmp/chromadb"), + "embedding_function": sentence_transformer_ef, }, ) diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index fdb93d26ca8d..be215facb846 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -100,6 +100,70 @@ def test_query_vector_db(self): results = query_vector_db(["autogen"], client=client) assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) + def test_custom_vector_db(self): + try: + import lancedb + except ImportError: + return + from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent + + db_path = "/tmp/lancedb" + + def create_lancedb(): + db = lancedb.connect(db_path) + data = [ + {"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"}, + {"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"}, + {"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"}, + {"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"}, + {"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"}, + {"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"}, + ] + try: + db.create_table("my_table", data) + except OSError: + pass + + class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): + def query_vector_db( + self, + query_texts, + n_results=10, + search_string="", + ): + if query_texts: + vector = [0.1, 0.3] + db = lancedb.connect(db_path) + table = db.open_table("my_table") + query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df() + return {"ids": query["id"].tolist(), "documents": query["documents"].tolist()} + + def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): + results = self.query_vector_db( + query_texts=[problem], + n_results=n_results, + search_string=search_string, + ) + + self._results = results + print("doc_ids: ", results["ids"]) + + ragragproxyagent = MyRetrieveUserProxyAgent( + name="ragproxyagent", + human_input_mode="NEVER", + max_consecutive_auto_reply=2, + retrieve_config={ + "task": "qa", + "chunk_token_size": 2000, + "client": "__", + "embedding_model": "all-mpnet-base-v2", + }, + ) + + create_lancedb() + ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") + assert ragragproxyagent._results["ids"] == [3, 1, 5] + if __name__ == "__main__": pytest.main()