Skip to content

Commit 3c51ad5

Browse files
authored
Add support to customized vectordb and embedding functions (microsoft#161)
* Add custom embedding function * Add support to custom vector db * Improve docstring * Improve docstring * Improve docstring * Add support to customized is_termination_msg fucntion * Add a test for customize vector db with lancedb * Fix tests * Add test for embedding_function * Update docstring
1 parent be2b02b commit 3c51ad5

File tree

6 files changed

+192
-15
lines changed

6 files changed

+192
-15
lines changed

.github/workflows/build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
python -m pip install --upgrade pip wheel
4141
pip install -e .
4242
python -c "import autogen"
43-
pip install -e.[mathchat,retrievechat] datasets pytest
43+
pip install -e.[mathchat,retrievechat,test] datasets pytest
4444
pip uninstall -y openai
4545
- name: Test with pytest
4646
if: matrix.python-version != '3.10'

autogen/agentchat/contrib/retrieve_user_proxy_agent.py

+66-9
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
self,
6868
name="RetrieveChatAgent", # default set to RetrieveChatAgent
6969
human_input_mode: Optional[str] = "ALWAYS",
70+
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
7071
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
7172
**kwargs,
7273
):
@@ -82,14 +83,17 @@ def __init__(
8283
the number of auto reply reaches the max_consecutive_auto_reply.
8384
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
8485
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
86+
is_termination_msg (function): a function that takes a message in the form of a dictionary
87+
and returns a boolean value indicating if this received message is a termination message.
88+
The dict can contain the following keys: "content", "role", "name", "function_call".
8589
retrieve_config (dict or None): config for the retrieve agent.
8690
To use default config, set to None. Otherwise, set to a dictionary with the following keys:
8791
- task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System
8892
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
89-
- client (Optional, chromadb.Client): the chromadb client.
90-
If key not provided, a default client `chromadb.Client()` will be used.
93+
- client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
94+
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
9195
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
92-
or the url to a single file. If key not provided, a default path `./docs` will be used.
96+
or the url to a single file. Default is None, which works only if the collection is already created.
9397
- collection_name (Optional, str): the name of the collection.
9498
If key not provided, a default name `autogen-docs` will be used.
9599
- model (Optional, str): the model to use for the retrieve chat.
@@ -106,16 +110,45 @@ def __init__(
106110
If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models
107111
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
108112
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
113+
- embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None,
114+
SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or
115+
other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
109116
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
110117
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
111118
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
112119
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
113120
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
114-
This is the same as that used in chromadb. Default is False.
121+
This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
115122
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
116123
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
117124
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
118125
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
126+
127+
Example of overriding retrieve_docs:
128+
If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code.
129+
```python
130+
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
131+
def query_vector_db(
132+
self,
133+
query_texts: List[str],
134+
n_results: int = 10,
135+
search_string: str = "",
136+
**kwargs,
137+
) -> Dict[str, Union[List[str], List[List[str]]]]:
138+
# define your own query function here
139+
pass
140+
141+
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = "", **kwargs):
142+
results = self.query_vector_db(
143+
query_texts=[problem],
144+
n_results=n_results,
145+
search_string=search_string,
146+
**kwargs,
147+
)
148+
149+
self._results = results
150+
print("doc_ids: ", results["ids"])
151+
```
119152
"""
120153
super().__init__(
121154
name=name,
@@ -126,28 +159,34 @@ def __init__(
126159
self._retrieve_config = {} if retrieve_config is None else retrieve_config
127160
self._task = self._retrieve_config.get("task", "default")
128161
self._client = self._retrieve_config.get("client", chromadb.Client())
129-
self._docs_path = self._retrieve_config.get("docs_path", "./docs")
162+
self._docs_path = self._retrieve_config.get("docs_path", None)
130163
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
131164
self._model = self._retrieve_config.get("model", "gpt-4")
132165
self._max_tokens = self.get_max_tokens(self._model)
133166
self._chunk_token_size = int(self._retrieve_config.get("chunk_token_size", self._max_tokens * 0.4))
134167
self._chunk_mode = self._retrieve_config.get("chunk_mode", "multi_lines")
135168
self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True)
136169
self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2")
170+
self._embedding_function = self._retrieve_config.get("embedding_function", None)
137171
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
138172
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
139173
self.update_context = self._retrieve_config.get("update_context", True)
140-
self._get_or_create = self._retrieve_config.get("get_or_create", False)
174+
self._get_or_create = (
175+
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
176+
)
141177
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
142178
self._context_max_tokens = self._max_tokens * 0.8
143-
self._collection = False # the collection is not created
179+
self._collection = True if self._docs_path is None else False # whether the collection is created
144180
self._ipython = get_ipython()
145181
self._doc_idx = -1 # the index of the current used doc
146182
self._results = {} # the results of the current query
147183
self._intermediate_answers = set() # the intermediate answers
148184
self._doc_contents = [] # the contents of the current used doc
149185
self._doc_ids = [] # the ids of the current used doc
150-
self._is_termination_msg = self._is_termination_msg_retrievechat # update the termination message function
186+
# update the termination message function
187+
self._is_termination_msg = (
188+
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
189+
)
151190
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=1)
152191

153192
def _is_termination_msg_retrievechat(self, message):
@@ -188,7 +227,7 @@ def _reset(self, intermediate=False):
188227
self._doc_contents = [] # the contents of the current used doc
189228
self._doc_ids = [] # the ids of the current used doc
190229

191-
def _get_context(self, results):
230+
def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
192231
doc_contents = ""
193232
current_tokens = 0
194233
_doc_idx = self._doc_idx
@@ -297,6 +336,22 @@ def _generate_retrieve_user_reply(
297336
return False, None
298337

299338
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
339+
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
340+
In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
341+
compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
342+
parameters and add your own parameters with default values, and keep the results in below type.
343+
344+
Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
345+
the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
346+
to `chromadb.api.types.QueryResult` as an example.
347+
ids: List[string]
348+
documents: List[List[string]]
349+
350+
Args:
351+
problem (str): the problem to be solved.
352+
n_results (int): the number of results to be retrieved.
353+
search_string (str): only docs containing this string will be retrieved.
354+
"""
300355
if not self._collection or self._get_or_create:
301356
print("Trying to create collection.")
302357
create_vector_db_from_dir(
@@ -308,6 +363,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
308363
must_break_at_empty_line=self._must_break_at_empty_line,
309364
embedding_model=self._embedding_model,
310365
get_or_create=self._get_or_create,
366+
embedding_function=self._embedding_function,
311367
)
312368
self._collection = True
313369
self._get_or_create = False
@@ -319,6 +375,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
319375
client=self._client,
320376
collection_name=self._collection_name,
321377
embedding_model=self._embedding_model,
378+
embedding_function=self._embedding_function,
322379
)
323380
self._results = results
324381
print("doc_ids: ", results["ids"])

autogen/retrieve_utils.py

+57-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tiktoken
77
import chromadb
88
from chromadb.api import API
9+
from chromadb.api.types import QueryResult
910
import chromadb.utils.embedding_functions as ef
1011
import logging
1112
import pypdf
@@ -263,12 +264,36 @@ def create_vector_db_from_dir(
263264
chunk_mode: str = "multi_lines",
264265
must_break_at_empty_line: bool = True,
265266
embedding_model: str = "all-MiniLM-L6-v2",
267+
embedding_function: Callable = None,
266268
):
267-
"""Create a vector db from all the files in a given directory."""
269+
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
270+
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
271+
you prepared your own vector db.
272+
273+
Args:
274+
dir_path (str): the path to the directory, file or url.
275+
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
276+
client (Optional, API): the chromadb client. Default is None.
277+
db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
278+
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
279+
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
280+
will be recreated if it already exists.
281+
chunk_mode (Optional, str): the chunk mode. Default is "multi_lines".
282+
must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True.
283+
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
284+
embedding_function is not None.
285+
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
286+
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
287+
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
288+
"""
268289
if client is None:
269290
client = chromadb.PersistentClient(path=db_path)
270291
try:
271-
embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model)
292+
embedding_function = (
293+
ef.SentenceTransformerEmbeddingFunction(embedding_model)
294+
if embedding_function is None
295+
else embedding_function
296+
)
272297
collection = client.create_collection(
273298
collection_name,
274299
get_or_create=get_or_create,
@@ -300,14 +325,41 @@ def query_vector_db(
300325
collection_name: str = "all-my-documents",
301326
search_string: str = "",
302327
embedding_model: str = "all-MiniLM-L6-v2",
303-
) -> Dict[str, List[str]]:
304-
"""Query a vector db."""
328+
embedding_function: Callable = None,
329+
) -> QueryResult:
330+
"""Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db
331+
and query function.
332+
333+
Args:
334+
query_texts (List[str]): the query texts.
335+
n_results (Optional, int): the number of results to return. Default is 10.
336+
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
337+
db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
338+
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
339+
search_string (Optional, str): the search string. Default is "".
340+
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
341+
embedding_function is not None.
342+
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
343+
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
344+
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
345+
346+
Returns:
347+
QueryResult: the query result. The format is:
348+
class QueryResult(TypedDict):
349+
ids: List[IDs]
350+
embeddings: Optional[List[List[Embedding]]]
351+
documents: Optional[List[List[Document]]]
352+
metadatas: Optional[List[List[Metadata]]]
353+
distances: Optional[List[List[float]]]
354+
"""
305355
if client is None:
306356
client = chromadb.PersistentClient(path=db_path)
307357
# the collection's embedding function is always the default one, but we want to use the one we used to create the
308358
# collection. So we compute the embeddings ourselves and pass it to the query function.
309359
collection = client.get_collection(collection_name)
310-
embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model)
360+
embedding_function = (
361+
ef.SentenceTransformerEmbeddingFunction(embedding_model) if embedding_function is None else embedding_function
362+
)
311363
query_embeddings = embedding_function(query_texts)
312364
# Query/search n most similar results. You can also .get by id
313365
results = collection.query(

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
extras_require={
4141
"test": [
4242
"chromadb",
43+
"lancedb",
4344
"coverage>=5.3",
4445
"datasets",
4546
"ipykernel",

test/agentchat/test_retrievechat.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
1414
import chromadb
15+
from chromadb.utils import embedding_functions as ef
1516

1617
skip_test = False
1718
except ImportError:
@@ -49,6 +50,7 @@ def test_retrievechat():
4950
},
5051
)
5152

53+
sentence_transformer_ef = ef.SentenceTransformerEmbeddingFunction()
5254
ragproxyagent = RetrieveUserProxyAgent(
5355
name="ragproxyagent",
5456
human_input_mode="NEVER",
@@ -58,6 +60,7 @@ def test_retrievechat():
5860
"chunk_token_size": 2000,
5961
"model": config_list[0]["model"],
6062
"client": chromadb.PersistentClient(path="/tmp/chromadb"),
63+
"embedding_function": sentence_transformer_ef,
6164
},
6265
)
6366

0 commit comments

Comments
 (0)