@@ -67,6 +67,7 @@ def __init__(
67
67
self ,
68
68
name = "RetrieveChatAgent" , # default set to RetrieveChatAgent
69
69
human_input_mode : Optional [str ] = "ALWAYS" ,
70
+ is_termination_msg : Optional [Callable [[Dict ], bool ]] = None ,
70
71
retrieve_config : Optional [Dict ] = None , # config for the retrieve agent
71
72
** kwargs ,
72
73
):
@@ -82,14 +83,17 @@ def __init__(
82
83
the number of auto reply reaches the max_consecutive_auto_reply.
83
84
(3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
84
85
when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
86
+ is_termination_msg (function): a function that takes a message in the form of a dictionary
87
+ and returns a boolean value indicating if this received message is a termination message.
88
+ The dict can contain the following keys: "content", "role", "name", "function_call".
85
89
retrieve_config (dict or None): config for the retrieve agent.
86
90
To use default config, set to None. Otherwise, set to a dictionary with the following keys:
87
91
- task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System
88
92
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
89
- - client (Optional, chromadb.Client): the chromadb client.
90
- If key not provided, a default client `chromadb.Client()` will be used .
93
+ - client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
94
+ will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function .
91
95
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
92
- or the url to a single file. If key not provided, a default path `./docs` will be used .
96
+ or the url to a single file. Default is None, which works only if the collection is already created .
93
97
- collection_name (Optional, str): the name of the collection.
94
98
If key not provided, a default name `autogen-docs` will be used.
95
99
- model (Optional, str): the model to use for the retrieve chat.
@@ -106,16 +110,45 @@ def __init__(
106
110
If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models
107
111
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
108
112
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
113
+ - embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None,
114
+ SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or
115
+ other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
109
116
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
110
117
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
111
118
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
112
119
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
113
120
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
114
- This is the same as that used in chromadb. Default is False.
121
+ This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
115
122
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
116
123
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
117
124
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
118
125
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
126
+
127
+ Example of overriding retrieve_docs:
128
+ If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code.
129
+ ```python
130
+ class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
131
+ def query_vector_db(
132
+ self,
133
+ query_texts: List[str],
134
+ n_results: int = 10,
135
+ search_string: str = "",
136
+ **kwargs,
137
+ ) -> Dict[str, Union[List[str], List[List[str]]]]:
138
+ # define your own query function here
139
+ pass
140
+
141
+ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = "", **kwargs):
142
+ results = self.query_vector_db(
143
+ query_texts=[problem],
144
+ n_results=n_results,
145
+ search_string=search_string,
146
+ **kwargs,
147
+ )
148
+
149
+ self._results = results
150
+ print("doc_ids: ", results["ids"])
151
+ ```
119
152
"""
120
153
super ().__init__ (
121
154
name = name ,
@@ -126,28 +159,34 @@ def __init__(
126
159
self ._retrieve_config = {} if retrieve_config is None else retrieve_config
127
160
self ._task = self ._retrieve_config .get ("task" , "default" )
128
161
self ._client = self ._retrieve_config .get ("client" , chromadb .Client ())
129
- self ._docs_path = self ._retrieve_config .get ("docs_path" , "./docs" )
162
+ self ._docs_path = self ._retrieve_config .get ("docs_path" , None )
130
163
self ._collection_name = self ._retrieve_config .get ("collection_name" , "autogen-docs" )
131
164
self ._model = self ._retrieve_config .get ("model" , "gpt-4" )
132
165
self ._max_tokens = self .get_max_tokens (self ._model )
133
166
self ._chunk_token_size = int (self ._retrieve_config .get ("chunk_token_size" , self ._max_tokens * 0.4 ))
134
167
self ._chunk_mode = self ._retrieve_config .get ("chunk_mode" , "multi_lines" )
135
168
self ._must_break_at_empty_line = self ._retrieve_config .get ("must_break_at_empty_line" , True )
136
169
self ._embedding_model = self ._retrieve_config .get ("embedding_model" , "all-MiniLM-L6-v2" )
170
+ self ._embedding_function = self ._retrieve_config .get ("embedding_function" , None )
137
171
self .customized_prompt = self ._retrieve_config .get ("customized_prompt" , None )
138
172
self .customized_answer_prefix = self ._retrieve_config .get ("customized_answer_prefix" , "" ).upper ()
139
173
self .update_context = self ._retrieve_config .get ("update_context" , True )
140
- self ._get_or_create = self ._retrieve_config .get ("get_or_create" , False )
174
+ self ._get_or_create = (
175
+ self ._retrieve_config .get ("get_or_create" , False ) if self ._docs_path is not None else False
176
+ )
141
177
self .custom_token_count_function = self ._retrieve_config .get ("custom_token_count_function" , None )
142
178
self ._context_max_tokens = self ._max_tokens * 0.8
143
- self ._collection = False # the collection is not created
179
+ self ._collection = True if self . _docs_path is None else False # whether the collection is created
144
180
self ._ipython = get_ipython ()
145
181
self ._doc_idx = - 1 # the index of the current used doc
146
182
self ._results = {} # the results of the current query
147
183
self ._intermediate_answers = set () # the intermediate answers
148
184
self ._doc_contents = [] # the contents of the current used doc
149
185
self ._doc_ids = [] # the ids of the current used doc
150
- self ._is_termination_msg = self ._is_termination_msg_retrievechat # update the termination message function
186
+ # update the termination message function
187
+ self ._is_termination_msg = (
188
+ self ._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
189
+ )
151
190
self .register_reply (Agent , RetrieveUserProxyAgent ._generate_retrieve_user_reply , position = 1 )
152
191
153
192
def _is_termination_msg_retrievechat (self , message ):
@@ -188,7 +227,7 @@ def _reset(self, intermediate=False):
188
227
self ._doc_contents = [] # the contents of the current used doc
189
228
self ._doc_ids = [] # the ids of the current used doc
190
229
191
- def _get_context (self , results ):
230
+ def _get_context (self , results : Dict [ str , Union [ List [ str ], List [ List [ str ]]]] ):
192
231
doc_contents = ""
193
232
current_tokens = 0
194
233
_doc_idx = self ._doc_idx
@@ -297,6 +336,22 @@ def _generate_retrieve_user_reply(
297
336
return False , None
298
337
299
338
def retrieve_docs (self , problem : str , n_results : int = 20 , search_string : str = "" ):
339
+ """Retrieve docs based on the given problem and assign the results to the class property `_results`.
340
+ In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
341
+ compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
342
+ parameters and add your own parameters with default values, and keep the results in below type.
343
+
344
+ Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
345
+ the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
346
+ to `chromadb.api.types.QueryResult` as an example.
347
+ ids: List[string]
348
+ documents: List[List[string]]
349
+
350
+ Args:
351
+ problem (str): the problem to be solved.
352
+ n_results (int): the number of results to be retrieved.
353
+ search_string (str): only docs containing this string will be retrieved.
354
+ """
300
355
if not self ._collection or self ._get_or_create :
301
356
print ("Trying to create collection." )
302
357
create_vector_db_from_dir (
@@ -308,6 +363,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
308
363
must_break_at_empty_line = self ._must_break_at_empty_line ,
309
364
embedding_model = self ._embedding_model ,
310
365
get_or_create = self ._get_or_create ,
366
+ embedding_function = self ._embedding_function ,
311
367
)
312
368
self ._collection = True
313
369
self ._get_or_create = False
@@ -319,6 +375,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
319
375
client = self ._client ,
320
376
collection_name = self ._collection_name ,
321
377
embedding_model = self ._embedding_model ,
378
+ embedding_function = self ._embedding_function ,
322
379
)
323
380
self ._results = results
324
381
print ("doc_ids: " , results ["ids" ])
0 commit comments