Skip to content

Commit 1108818

Browse files
thinkallqingyun-wu
andauthored
Improve RetrieveChat (#6)
* Upsert in batch * Improve update context, support customized answer prefix * Update tests * Update intermediate answer * Fix duplicate intermediate answer, add example 6 to notebook * Add notebook results * Works better without intermediate answers in the context * Bump version to 0.1.2 * Remove commented code and add descriptions to _generate_retrieve_user_reply --------- Co-authored-by: Qingyun Wu <[email protected]>
1 parent f619ecc commit 1108818

File tree

5 files changed

+538
-31
lines changed

5 files changed

+538
-31
lines changed

.github/workflows/build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
python -m pip install --upgrade pip wheel
4141
pip install -e .
4242
python -c "import autogen"
43-
pip install -e.[mathchat] datasets pytest
43+
pip install -e.[mathchat,retrievechat] datasets pytest
4444
pip uninstall -y openai
4545
- name: Test with pytest
4646
if: matrix.python-version != '3.10'

autogen/agentchat/contrib/retrieve_user_proxy_agent.py

+68-21
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import chromadb
23
from autogen.agentchat.agent import Agent
34
from autogen.agentchat import UserProxyAgent
@@ -122,6 +123,9 @@ def __init__(
122123
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
123124
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
124125
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
126+
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
127+
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
128+
- no_update_context (Optional, bool): if True, will not apply `Update Context` for interactive retrieval. Default is False.
125129
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
126130
"""
127131
super().__init__(
@@ -143,11 +147,16 @@ def __init__(
143147
self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True)
144148
self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2")
145149
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
150+
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
151+
self.no_update_context = self._retrieve_config.get("no_update_context", False)
146152
self._context_max_tokens = self._max_tokens * 0.8
147153
self._collection = False # the collection is not created
148154
self._ipython = get_ipython()
149155
self._doc_idx = -1 # the index of the current used doc
150156
self._results = {} # the results of the current query
157+
self._intermediate_answers = set() # the intermediate answers
158+
self._doc_contents = [] # the contents of the current used doc
159+
self._doc_ids = [] # the ids of the current used doc
151160
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply)
152161

153162
@staticmethod
@@ -161,17 +170,24 @@ def get_max_tokens(model="gpt-3.5-turbo"):
161170
else:
162171
return 4000
163172

164-
def _reset(self):
173+
def _reset(self, intermediate=False):
165174
self._doc_idx = -1 # the index of the current used doc
166175
self._results = {} # the results of the current query
176+
if not intermediate:
177+
self._intermediate_answers = set() # the intermediate answers
178+
self._doc_contents = [] # the contents of the current used doc
179+
self._doc_ids = [] # the ids of the current used doc
167180

168181
def _get_context(self, results):
169182
doc_contents = ""
170183
current_tokens = 0
171184
_doc_idx = self._doc_idx
185+
_tmp_retrieve_count = 0
172186
for idx, doc in enumerate(results["documents"][0]):
173187
if idx <= _doc_idx:
174188
continue
189+
if results["ids"][0][idx] in self._doc_ids:
190+
continue
175191
_doc_tokens = num_tokens_from_text(doc)
176192
if _doc_tokens > self._context_max_tokens:
177193
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
@@ -185,14 +201,19 @@ def _get_context(self, results):
185201
current_tokens += _doc_tokens
186202
doc_contents += doc + "\n"
187203
self._doc_idx = idx
204+
self._doc_ids.append(results["ids"][0][idx])
205+
self._doc_contents.append(doc)
206+
_tmp_retrieve_count += 1
207+
if _tmp_retrieve_count >= self.n_results:
208+
break
188209
return doc_contents
189210

190211
def _generate_message(self, doc_contents, task="default"):
191212
if not doc_contents:
192213
print(colored("No more context, will terminate.", "green"), flush=True)
193214
return "TERMINATE"
194215
if self.customized_prompt:
195-
message = self.customized_prompt + "\nUser's question is: " + self.problem + "\nContext is: " + doc_contents
216+
message = self.customized_prompt.format(input_question=self.problem, input_context=doc_contents)
196217
elif task.upper() == "CODE":
197218
message = PROMPT_CODE.format(input_question=self.problem, input_context=doc_contents)
198219
elif task.upper() == "QA":
@@ -209,24 +230,64 @@ def _generate_retrieve_user_reply(
209230
sender: Optional[Agent] = None,
210231
config: Optional[Any] = None,
211232
) -> Tuple[bool, Union[str, Dict, None]]:
233+
"""In this function, we will update the context and reset the conversation based on different conditions.
234+
We'll update the context and reset the conversation if no_update_context is False and either of the following:
235+
(1) the last message contains "UPDATE CONTEXT",
236+
(2) the last message doesn't contain "UPDATE CONTEXT" and the customized_answer_prefix is not in the message.
237+
"""
212238
if config is None:
213239
config = self
214240
if messages is None:
215241
messages = self._oai_messages[sender]
216242
message = messages[-1]
217-
if (
243+
update_context_case1 = (
218244
"UPDATE CONTEXT" in message.get("content", "")[-20:].upper()
219245
or "UPDATE CONTEXT" in message.get("content", "")[:20].upper()
220-
):
246+
)
247+
update_context_case2 = (
248+
self.customized_answer_prefix and self.customized_answer_prefix not in message.get("content", "").upper()
249+
)
250+
if (update_context_case1 or update_context_case2) and not self.no_update_context:
221251
print(colored("Updating context and resetting conversation.", "green"), flush=True)
252+
# extract the first sentence in the response as the intermediate answer
253+
_message = message.get("content", "").split("\n")[0].strip()
254+
_intermediate_info = re.split(r"(?<=[.!?])\s+", _message)
255+
self._intermediate_answers.add(_intermediate_info[0])
256+
257+
if update_context_case1:
258+
# try to get more context from the current retrieved doc results because the results may be too long to fit
259+
# in the LLM context.
260+
doc_contents = self._get_context(self._results)
261+
262+
# Always use self.problem as the query text to retrieve docs, but each time we replace the context with the
263+
# next similar docs in the retrieved doc results.
264+
if not doc_contents:
265+
for _tmp_retrieve_count in range(1, 5):
266+
self._reset(intermediate=True)
267+
self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1))
268+
doc_contents = self._get_context(self._results)
269+
if doc_contents:
270+
break
271+
elif update_context_case2:
272+
# Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar
273+
# docs in the retrieved doc results to the context.
274+
for _tmp_retrieve_count in range(5):
275+
self._reset(intermediate=True)
276+
self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1))
277+
self._get_context(self._results)
278+
doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers)
279+
if doc_contents:
280+
break
281+
222282
self.clear_history()
223283
sender.clear_history()
224-
doc_contents = self._get_context(self._results)
225284
return True, self._generate_message(doc_contents, task=self._task)
226-
return False, None
285+
else:
286+
return False, None
227287

228288
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
229289
if not self._collection:
290+
print("Trying to create collection.")
230291
create_vector_db_from_dir(
231292
dir_path=self._docs_path,
232293
max_tokens=self._chunk_token_size,
@@ -263,6 +324,7 @@ def generate_init_message(self, problem: str, n_results: int = 20, search_string
263324
self._reset()
264325
self.retrieve_docs(problem, n_results, search_string)
265326
self.problem = problem
327+
self.n_results = n_results
266328
doc_contents = self._get_context(self._results)
267329
message = self._generate_message(doc_contents, self._task)
268330
return message
@@ -278,21 +340,6 @@ def run_code(self, code, **kwargs):
278340
if self._ipython is None or lang != "python":
279341
return super().run_code(code, **kwargs)
280342
else:
281-
# # capture may not work as expected
282-
# result = self._ipython.run_cell("%%capture --no-display cap\n" + code)
283-
# log = self._ipython.ev("cap.stdout")
284-
# log += self._ipython.ev("cap.stderr")
285-
# if result.result is not None:
286-
# log += str(result.result)
287-
# exitcode = 0 if result.success else 1
288-
# if result.error_before_exec is not None:
289-
# log += f"\n{result.error_before_exec}"
290-
# exitcode = 1
291-
# if result.error_in_exec is not None:
292-
# log += f"\n{result.error_in_exec}"
293-
# exitcode = 1
294-
# return exitcode, log, None
295-
296343
result = self._ipython.run_cell(code)
297344
log = str(result.result)
298345
exitcode = 0 if result.success else 1

autogen/retrieve_utils.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,18 @@ def create_vector_db_from_dir(
207207
)
208208

209209
chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
210-
# updates existing items, or adds them if they don't yet exist.
210+
print(f"Found {len(chunks)} chunks.")
211+
# upsert in batch of 40000
212+
for i in range(0, len(chunks), 40000):
213+
collection.upsert(
214+
documents=chunks[
215+
i : i + 40000
216+
], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
217+
ids=[f"doc_{i}" for i in range(i, i + 40000)], # unique for each doc
218+
)
211219
collection.upsert(
212-
documents=chunks, # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
213-
ids=[f"doc_{i}" for i in range(len(chunks))], # unique for each doc
220+
documents=chunks[i : len(chunks)],
221+
ids=[f"doc_{i}" for i in range(i, len(chunks))], # unique for each doc
214222
)
215223
except ValueError as e:
216224
logger.warning(f"{e}")

0 commit comments

Comments
 (0)