Skip to content

Commit 5a96dc2

Browse files
authored
Add source to the answer for default prompt (#2289)
* Add source to the answer for default prompt * Fix qdrant * Fix tests * Update docstring * Fix check files * Fix qdrant test error
1 parent 5292024 commit 5a96dc2

File tree

4 files changed

+39
-17
lines changed

4 files changed

+39
-17
lines changed

autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,12 @@ def create_qdrant_from_dir(
190190
client.set_model(embedding_model)
191191

192192
if custom_text_split_function is not None:
193-
chunks = split_files_to_chunks(
193+
chunks, sources = split_files_to_chunks(
194194
get_files_from_dir(dir_path, custom_text_types, recursive),
195195
custom_text_split_function=custom_text_split_function,
196196
)
197197
else:
198-
chunks = split_files_to_chunks(
198+
chunks, sources = split_files_to_chunks(
199199
get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line
200200
)
201201
logger.info(f"Found {len(chunks)} chunks.")
@@ -298,5 +298,6 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore
298298
data = {
299299
"ids": [[result.id for result in sublist] for sublist in results],
300300
"documents": [[result.document for result in sublist] for sublist in results],
301+
"metadatas": [[result.metadata for result in sublist] for sublist in results],
301302
}
302303
return data

autogen/agentchat/contrib/retrieve_user_proxy_agent.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
User's question is: {input_question}
3535
3636
Context is: {input_context}
37+
38+
The source of the context is: {input_sources}
39+
40+
If you can answer the question, in the end of your answer, add the source of the context in the format of `Sources: source1, source2, ...`.
3741
"""
3842

3943
PROMPT_CODE = """You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the
@@ -101,7 +105,8 @@ def __init__(
101105
following keys:
102106
- `task` (Optional, str) - the task of the retrieve chat. Possible values are
103107
"code", "qa" and "default". System prompt will be different for different tasks.
104-
The default value is `default`, which supports both code and qa.
108+
The default value is `default`, which supports both code and qa, and provides
109+
source information in the end of the response.
105110
- `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
106111
default client `chromadb.Client()` will be used. If you want to use other
107112
vector db, extend this class and override the `retrieve_docs` function.
@@ -243,6 +248,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
243248
self._intermediate_answers = set() # the intermediate answers
244249
self._doc_contents = [] # the contents of the current used doc
245250
self._doc_ids = [] # the ids of the current used doc
251+
self._current_docs_in_context = [] # the ids of the current context sources
246252
self._search_string = "" # the search string used in the current query
247253
# update the termination message function
248254
self._is_termination_msg = (
@@ -290,6 +296,7 @@ def _reset(self, intermediate=False):
290296

291297
def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
292298
doc_contents = ""
299+
self._current_docs_in_context = []
293300
current_tokens = 0
294301
_doc_idx = self._doc_idx
295302
_tmp_retrieve_count = 0
@@ -310,6 +317,9 @@ def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
310317
print(colored(func_print, "green"), flush=True)
311318
current_tokens += _doc_tokens
312319
doc_contents += doc + "\n"
320+
_metadatas = results.get("metadatas")
321+
if isinstance(_metadatas, list) and isinstance(_metadatas[0][idx], dict):
322+
self._current_docs_in_context.append(results["metadatas"][0][idx].get("source", ""))
313323
self._doc_idx = idx
314324
self._doc_ids.append(results["ids"][0][idx])
315325
self._doc_contents.append(doc)
@@ -329,7 +339,9 @@ def _generate_message(self, doc_contents, task="default"):
329339
elif task.upper() == "QA":
330340
message = PROMPT_QA.format(input_question=self.problem, input_context=doc_contents)
331341
elif task.upper() == "DEFAULT":
332-
message = PROMPT_DEFAULT.format(input_question=self.problem, input_context=doc_contents)
342+
message = PROMPT_DEFAULT.format(
343+
input_question=self.problem, input_context=doc_contents, input_sources=self._current_docs_in_context
344+
)
333345
else:
334346
raise NotImplementedError(f"task {task} is not implemented.")
335347
return message

autogen/retrieve_utils.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import glob
22
import os
33
import re
4-
from typing import Callable, List, Union
4+
from typing import Callable, List, Tuple, Union
55
from urllib.parse import urlparse
66

77
import chromadb
@@ -160,8 +160,14 @@ def split_files_to_chunks(
160160
"""Split a list of files into chunks of max_tokens."""
161161

162162
chunks = []
163+
sources = []
163164

164165
for file in files:
166+
if isinstance(file, tuple):
167+
url = file[1]
168+
file = file[0]
169+
else:
170+
url = None
165171
_, file_extension = os.path.splitext(file)
166172
file_extension = file_extension.lower()
167173

@@ -179,11 +185,13 @@ def split_files_to_chunks(
179185
continue # Skip to the next file if no text is available
180186

181187
if custom_text_split_function is not None:
182-
chunks += custom_text_split_function(text)
188+
tmp_chunks = custom_text_split_function(text)
183189
else:
184-
chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
190+
tmp_chunks = split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
191+
chunks += tmp_chunks
192+
sources += [{"source": url if url else file}] * len(tmp_chunks)
185193

186-
return chunks
194+
return chunks, sources
187195

188196

189197
def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
@@ -267,7 +275,7 @@ def parse_html_to_markdown(html: str, url: str = None) -> str:
267275
return webpage_text
268276

269277

270-
def get_file_from_url(url: str, save_path: str = None):
278+
def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
271279
"""Download a file from a URL."""
272280
if save_path is None:
273281
save_path = "tmp/chromadb"
@@ -303,7 +311,7 @@ def get_file_from_url(url: str, save_path: str = None):
303311
with open(save_path, "wb") as f:
304312
for chunk in response.iter_content(chunk_size=8192):
305313
f.write(chunk)
306-
return save_path
314+
return save_path, url
307315

308316

309317
def is_url(string: str):
@@ -383,12 +391,12 @@ def create_vector_db_from_dir(
383391
length = len(collection.get()["ids"])
384392

385393
if custom_text_split_function is not None:
386-
chunks = split_files_to_chunks(
394+
chunks, sources = split_files_to_chunks(
387395
get_files_from_dir(dir_path, custom_text_types, recursive),
388396
custom_text_split_function=custom_text_split_function,
389397
)
390398
else:
391-
chunks = split_files_to_chunks(
399+
chunks, sources = split_files_to_chunks(
392400
get_files_from_dir(dir_path, custom_text_types, recursive),
393401
max_tokens,
394402
chunk_mode,
@@ -401,6 +409,7 @@ def create_vector_db_from_dir(
401409
collection.upsert(
402410
documents=chunks[i:end_idx],
403411
ids=[f"doc_{j+length}" for j in range(i, end_idx)], # unique for each doc
412+
metadatas=sources[i:end_idx],
404413
)
405414
except ValueError as e:
406415
logger.warning(f"{e}")

test/test_retrieve_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_extract_text_from_pdf(self):
6969
def test_split_files_to_chunks(self):
7070
pdf_file_path = os.path.join(test_dir, "example.pdf")
7171
txt_file_path = os.path.join(test_dir, "example.txt")
72-
chunks = split_files_to_chunks([pdf_file_path, txt_file_path])
72+
chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path])
7373
assert all(
7474
isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip()
7575
for chunk in chunks
@@ -81,7 +81,7 @@ def test_get_files_from_dir(self):
8181
pdf_file_path = os.path.join(test_dir, "example.pdf")
8282
txt_file_path = os.path.join(test_dir, "example.txt")
8383
files = get_files_from_dir([pdf_file_path, txt_file_path])
84-
assert all(os.path.isfile(file) for file in files)
84+
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
8585
files = get_files_from_dir(
8686
[
8787
pdf_file_path,
@@ -91,7 +91,7 @@ def test_get_files_from_dir(self):
9191
],
9292
recursive=True,
9393
)
94-
assert all(os.path.isfile(file) for file in files)
94+
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
9595
files = get_files_from_dir(
9696
[
9797
pdf_file_path,
@@ -102,7 +102,7 @@ def test_get_files_from_dir(self):
102102
recursive=True,
103103
types=["pdf", "txt"],
104104
)
105-
assert all(os.path.isfile(file) for file in files)
105+
assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files)
106106
assert len(files) == 3
107107

108108
def test_is_url(self):
@@ -243,7 +243,7 @@ def test_unstructured(self):
243243
pdf_file_path = os.path.join(test_dir, "example.pdf")
244244
txt_file_path = os.path.join(test_dir, "example.txt")
245245
word_file_path = os.path.join(test_dir, "example.docx")
246-
chunks = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path])
246+
chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path])
247247
assert all(
248248
isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip()
249249
for chunk in chunks

0 commit comments

Comments
 (0)