1
+ import re
1
2
import chromadb
2
3
from autogen .agentchat .agent import Agent
3
4
from autogen .agentchat import UserProxyAgent
@@ -122,6 +123,9 @@ def __init__(
122
123
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
123
124
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
124
125
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
126
+ - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
127
+ If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
128
+ - no_update_context (Optional, bool): if True, will not apply `Update Context` for interactive retrieval. Default is False.
125
129
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
126
130
"""
127
131
super ().__init__ (
@@ -143,11 +147,16 @@ def __init__(
143
147
self ._must_break_at_empty_line = self ._retrieve_config .get ("must_break_at_empty_line" , True )
144
148
self ._embedding_model = self ._retrieve_config .get ("embedding_model" , "all-MiniLM-L6-v2" )
145
149
self .customized_prompt = self ._retrieve_config .get ("customized_prompt" , None )
150
+ self .customized_answer_prefix = self ._retrieve_config .get ("customized_answer_prefix" , "" ).upper ()
151
+ self .no_update_context = self ._retrieve_config .get ("no_update_context" , False )
146
152
self ._context_max_tokens = self ._max_tokens * 0.8
147
153
self ._collection = False # the collection is not created
148
154
self ._ipython = get_ipython ()
149
155
self ._doc_idx = - 1 # the index of the current used doc
150
156
self ._results = {} # the results of the current query
157
+ self ._intermediate_answers = set () # the intermediate answers
158
+ self ._doc_contents = [] # the contents of the current used doc
159
+ self ._doc_ids = [] # the ids of the current used doc
151
160
self .register_reply (Agent , RetrieveUserProxyAgent ._generate_retrieve_user_reply )
152
161
153
162
@staticmethod
@@ -161,17 +170,24 @@ def get_max_tokens(model="gpt-3.5-turbo"):
161
170
else :
162
171
return 4000
163
172
164
- def _reset (self ):
173
+ def _reset (self , intermediate = False ):
165
174
self ._doc_idx = - 1 # the index of the current used doc
166
175
self ._results = {} # the results of the current query
176
+ if not intermediate :
177
+ self ._intermediate_answers = set () # the intermediate answers
178
+ self ._doc_contents = [] # the contents of the current used doc
179
+ self ._doc_ids = [] # the ids of the current used doc
167
180
168
181
def _get_context (self , results ):
169
182
doc_contents = ""
170
183
current_tokens = 0
171
184
_doc_idx = self ._doc_idx
185
+ _tmp_retrieve_count = 0
172
186
for idx , doc in enumerate (results ["documents" ][0 ]):
173
187
if idx <= _doc_idx :
174
188
continue
189
+ if results ["ids" ][0 ][idx ] in self ._doc_ids :
190
+ continue
175
191
_doc_tokens = num_tokens_from_text (doc )
176
192
if _doc_tokens > self ._context_max_tokens :
177
193
func_print = f"Skip doc_id { results ['ids' ][0 ][idx ]} as it is too long to fit in the context."
@@ -185,14 +201,19 @@ def _get_context(self, results):
185
201
current_tokens += _doc_tokens
186
202
doc_contents += doc + "\n "
187
203
self ._doc_idx = idx
204
+ self ._doc_ids .append (results ["ids" ][0 ][idx ])
205
+ self ._doc_contents .append (doc )
206
+ _tmp_retrieve_count += 1
207
+ if _tmp_retrieve_count >= self .n_results :
208
+ break
188
209
return doc_contents
189
210
190
211
def _generate_message (self , doc_contents , task = "default" ):
191
212
if not doc_contents :
192
213
print (colored ("No more context, will terminate." , "green" ), flush = True )
193
214
return "TERMINATE"
194
215
if self .customized_prompt :
195
- message = self .customized_prompt + " \n User's question is: " + self .problem + " \n Context is: " + doc_contents
216
+ message = self .customized_prompt . format ( input_question = self .problem , input_context = doc_contents )
196
217
elif task .upper () == "CODE" :
197
218
message = PROMPT_CODE .format (input_question = self .problem , input_context = doc_contents )
198
219
elif task .upper () == "QA" :
@@ -209,24 +230,64 @@ def _generate_retrieve_user_reply(
209
230
sender : Optional [Agent ] = None ,
210
231
config : Optional [Any ] = None ,
211
232
) -> Tuple [bool , Union [str , Dict , None ]]:
233
+ """In this function, we will update the context and reset the conversation based on different conditions.
234
+ We'll update the context and reset the conversation if no_update_context is False and either of the following:
235
+ (1) the last message contains "UPDATE CONTEXT",
236
+ (2) the last message doesn't contain "UPDATE CONTEXT" and the customized_answer_prefix is not in the message.
237
+ """
212
238
if config is None :
213
239
config = self
214
240
if messages is None :
215
241
messages = self ._oai_messages [sender ]
216
242
message = messages [- 1 ]
217
- if (
243
+ update_context_case1 = (
218
244
"UPDATE CONTEXT" in message .get ("content" , "" )[- 20 :].upper ()
219
245
or "UPDATE CONTEXT" in message .get ("content" , "" )[:20 ].upper ()
220
- ):
246
+ )
247
+ update_context_case2 = (
248
+ self .customized_answer_prefix and self .customized_answer_prefix not in message .get ("content" , "" ).upper ()
249
+ )
250
+ if (update_context_case1 or update_context_case2 ) and not self .no_update_context :
221
251
print (colored ("Updating context and resetting conversation." , "green" ), flush = True )
252
+ # extract the first sentence in the response as the intermediate answer
253
+ _message = message .get ("content" , "" ).split ("\n " )[0 ].strip ()
254
+ _intermediate_info = re .split (r"(?<=[.!?])\s+" , _message )
255
+ self ._intermediate_answers .add (_intermediate_info [0 ])
256
+
257
+ if update_context_case1 :
258
+ # try to get more context from the current retrieved doc results because the results may be too long to fit
259
+ # in the LLM context.
260
+ doc_contents = self ._get_context (self ._results )
261
+
262
+ # Always use self.problem as the query text to retrieve docs, but each time we replace the context with the
263
+ # next similar docs in the retrieved doc results.
264
+ if not doc_contents :
265
+ for _tmp_retrieve_count in range (1 , 5 ):
266
+ self ._reset (intermediate = True )
267
+ self .retrieve_docs (self .problem , self .n_results * (2 * _tmp_retrieve_count + 1 ))
268
+ doc_contents = self ._get_context (self ._results )
269
+ if doc_contents :
270
+ break
271
+ elif update_context_case2 :
272
+ # Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar
273
+ # docs in the retrieved doc results to the context.
274
+ for _tmp_retrieve_count in range (5 ):
275
+ self ._reset (intermediate = True )
276
+ self .retrieve_docs (_intermediate_info [0 ], self .n_results * (2 * _tmp_retrieve_count + 1 ))
277
+ self ._get_context (self ._results )
278
+ doc_contents = "\n " .join (self ._doc_contents ) # + "\n" + "\n".join(self._intermediate_answers)
279
+ if doc_contents :
280
+ break
281
+
222
282
self .clear_history ()
223
283
sender .clear_history ()
224
- doc_contents = self ._get_context (self ._results )
225
284
return True , self ._generate_message (doc_contents , task = self ._task )
226
- return False , None
285
+ else :
286
+ return False , None
227
287
228
288
def retrieve_docs (self , problem : str , n_results : int = 20 , search_string : str = "" ):
229
289
if not self ._collection :
290
+ print ("Trying to create collection." )
230
291
create_vector_db_from_dir (
231
292
dir_path = self ._docs_path ,
232
293
max_tokens = self ._chunk_token_size ,
@@ -263,6 +324,7 @@ def generate_init_message(self, problem: str, n_results: int = 20, search_string
263
324
self ._reset ()
264
325
self .retrieve_docs (problem , n_results , search_string )
265
326
self .problem = problem
327
+ self .n_results = n_results
266
328
doc_contents = self ._get_context (self ._results )
267
329
message = self ._generate_message (doc_contents , self ._task )
268
330
return message
@@ -278,21 +340,6 @@ def run_code(self, code, **kwargs):
278
340
if self ._ipython is None or lang != "python" :
279
341
return super ().run_code (code , ** kwargs )
280
342
else :
281
- # # capture may not work as expected
282
- # result = self._ipython.run_cell("%%capture --no-display cap\n" + code)
283
- # log = self._ipython.ev("cap.stdout")
284
- # log += self._ipython.ev("cap.stderr")
285
- # if result.result is not None:
286
- # log += str(result.result)
287
- # exitcode = 0 if result.success else 1
288
- # if result.error_before_exec is not None:
289
- # log += f"\n{result.error_before_exec}"
290
- # exitcode = 1
291
- # if result.error_in_exec is not None:
292
- # log += f"\n{result.error_in_exec}"
293
- # exitcode = 1
294
- # return exitcode, log, None
295
-
296
343
result = self ._ipython .run_cell (code )
297
344
log = str (result .result )
298
345
exitcode = 0 if result .success else 1
0 commit comments