Skip to content

Commit

Permalink
Add support to customized vectordb and embedding functions (#161)
Browse files Browse the repository at this point in the history
* Add custom embedding function

* Add support to custom vector db

* Improve docstring

* Improve docstring

* Improve docstring

* Add support to customized is_termination_msg fucntion

* Add a test for customize vector db with lancedb

* Fix tests

* Add test for embedding_function

* Update docstring
  • Loading branch information
thinkall authored Oct 10, 2023
1 parent 37a07a8 commit fa6e2a5
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
pip install -e.[mathchat,retrievechat] datasets pytest
pip install -e.[mathchat,retrievechat,test] datasets pytest
pip uninstall -y openai
- name: Test with pytest
if: matrix.python-version != '3.10'
Expand Down
75 changes: 66 additions & 9 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self,
name="RetrieveChatAgent", # default set to RetrieveChatAgent
human_input_mode: Optional[str] = "ALWAYS",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
**kwargs,
):
Expand All @@ -82,14 +83,17 @@ def __init__(
the number of auto reply reaches the max_consecutive_auto_reply.
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
is_termination_msg (function): a function that takes a message in the form of a dictionary
and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call".
retrieve_config (dict or None): config for the retrieve agent.
To use default config, set to None. Otherwise, set to a dictionary with the following keys:
- task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
- client (Optional, chromadb.Client): the chromadb client.
If key not provided, a default client `chromadb.Client()` will be used.
- client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
or the url to a single file. If key not provided, a default path `./docs` will be used.
or the url to a single file. Default is None, which works only if the collection is already created.
- collection_name (Optional, str): the name of the collection.
If key not provided, a default name `autogen-docs` will be used.
- model (Optional, str): the model to use for the retrieve chat.
Expand All @@ -106,16 +110,45 @@ def __init__(
If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
- embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None,
SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or
other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
This is the same as that used in chromadb. Default is False.
This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
Example of overriding retrieve_docs:
If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code.
```python
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def query_vector_db(
self,
query_texts: List[str],
n_results: int = 10,
search_string: str = "",
**kwargs,
) -> Dict[str, Union[List[str], List[List[str]]]]:
# define your own query function here
pass
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = "", **kwargs):
results = self.query_vector_db(
query_texts=[problem],
n_results=n_results,
search_string=search_string,
**kwargs,
)
self._results = results
print("doc_ids: ", results["ids"])
```
"""
super().__init__(
name=name,
Expand All @@ -126,28 +159,34 @@ def __init__(
self._retrieve_config = {} if retrieve_config is None else retrieve_config
self._task = self._retrieve_config.get("task", "default")
self._client = self._retrieve_config.get("client", chromadb.Client())
self._docs_path = self._retrieve_config.get("docs_path", "./docs")
self._docs_path = self._retrieve_config.get("docs_path", None)
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
self._model = self._retrieve_config.get("model", "gpt-4")
self._max_tokens = self.get_max_tokens(self._model)
self._chunk_token_size = int(self._retrieve_config.get("chunk_token_size", self._max_tokens * 0.4))
self._chunk_mode = self._retrieve_config.get("chunk_mode", "multi_lines")
self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True)
self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2")
self._embedding_function = self._retrieve_config.get("embedding_function", None)
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False)
self._get_or_create = (
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = False # the collection is not created
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
self._doc_idx = -1 # the index of the current used doc
self._results = {} # the results of the current query
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self._is_termination_msg = self._is_termination_msg_retrievechat # update the termination message function
# update the termination message function
self._is_termination_msg = (
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
)
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=1)

def _is_termination_msg_retrievechat(self, message):
Expand Down Expand Up @@ -188,7 +227,7 @@ def _reset(self, intermediate=False):
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc

def _get_context(self, results):
def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
doc_contents = ""
current_tokens = 0
_doc_idx = self._doc_idx
Expand Down Expand Up @@ -297,6 +336,22 @@ def _generate_retrieve_user_reply(
return False, None

def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
parameters and add your own parameters with default values, and keep the results in below type.
Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
to `chromadb.api.types.QueryResult` as an example.
ids: List[string]
documents: List[List[string]]
Args:
problem (str): the problem to be solved.
n_results (int): the number of results to be retrieved.
search_string (str): only docs containing this string will be retrieved.
"""
if not self._collection or self._get_or_create:
print("Trying to create collection.")
create_vector_db_from_dir(
Expand All @@ -308,6 +363,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
must_break_at_empty_line=self._must_break_at_empty_line,
embedding_model=self._embedding_model,
get_or_create=self._get_or_create,
embedding_function=self._embedding_function,
)
self._collection = True
self._get_or_create = False
Expand All @@ -319,6 +375,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
client=self._client,
collection_name=self._collection_name,
embedding_model=self._embedding_model,
embedding_function=self._embedding_function,
)
self._results = results
print("doc_ids: ", results["ids"])
Expand Down
62 changes: 57 additions & 5 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tiktoken
import chromadb
from chromadb.api import API
from chromadb.api.types import QueryResult
import chromadb.utils.embedding_functions as ef
import logging
import pypdf
Expand Down Expand Up @@ -263,12 +264,36 @@ def create_vector_db_from_dir(
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
embedding_model: str = "all-MiniLM-L6-v2",
embedding_function: Callable = None,
):
"""Create a vector db from all the files in a given directory."""
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
you prepared your own vector db.
Args:
dir_path (str): the path to the directory, file or url.
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
client (Optional, API): the chromadb client. Default is None.
db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
will be recreated if it already exists.
chunk_mode (Optional, str): the chunk mode. Default is "multi_lines".
must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True.
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
embedding_function is not None.
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
"""
if client is None:
client = chromadb.PersistentClient(path=db_path)
try:
embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model)
embedding_function = (
ef.SentenceTransformerEmbeddingFunction(embedding_model)
if embedding_function is None
else embedding_function
)
collection = client.create_collection(
collection_name,
get_or_create=get_or_create,
Expand Down Expand Up @@ -300,14 +325,41 @@ def query_vector_db(
collection_name: str = "all-my-documents",
search_string: str = "",
embedding_model: str = "all-MiniLM-L6-v2",
) -> Dict[str, List[str]]:
"""Query a vector db."""
embedding_function: Callable = None,
) -> QueryResult:
"""Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db
and query function.
Args:
query_texts (List[str]): the query texts.
n_results (Optional, int): the number of results to return. Default is 10.
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
search_string (Optional, str): the search string. Default is "".
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
embedding_function is not None.
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
Returns:
QueryResult: the query result. The format is:
class QueryResult(TypedDict):
ids: List[IDs]
embeddings: Optional[List[List[Embedding]]]
documents: Optional[List[List[Document]]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
"""
if client is None:
client = chromadb.PersistentClient(path=db_path)
# the collection's embedding function is always the default one, but we want to use the one we used to create the
# collection. So we compute the embeddings ourselves and pass it to the query function.
collection = client.get_collection(collection_name)
embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model)
embedding_function = (
ef.SentenceTransformerEmbeddingFunction(embedding_model) if embedding_function is None else embedding_function
)
query_embeddings = embedding_function(query_texts)
# Query/search n most similar results. You can also .get by id
results = collection.query(
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
extras_require={
"test": [
"chromadb",
"lancedb",
"coverage>=5.3",
"datasets",
"ipykernel",
Expand Down
3 changes: 3 additions & 0 deletions test/agentchat/test_retrievechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
import chromadb
from chromadb.utils import embedding_functions as ef

skip_test = False
except ImportError:
Expand Down Expand Up @@ -49,6 +50,7 @@ def test_retrievechat():
},
)

sentence_transformer_ef = ef.SentenceTransformerEmbeddingFunction()
ragproxyagent = RetrieveUserProxyAgent(
name="ragproxyagent",
human_input_mode="NEVER",
Expand All @@ -58,6 +60,7 @@ def test_retrievechat():
"chunk_token_size": 2000,
"model": config_list[0]["model"],
"client": chromadb.PersistentClient(path="/tmp/chromadb"),
"embedding_function": sentence_transformer_ef,
},
)

Expand Down
Loading

0 comments on commit fa6e2a5

Please sign in to comment.