Skip to content

Commit 07646d4

Browse files
authored
Support custom text formats and recursive (#496)
* Add custom text types and recursive * Add custom text types and recursive * Fix format * Update qdrant, Add pdf to unstructured * Use unstructed as the default text extractor if installed * Add tests for unstructured * Update tests env for unstructured * Fix error if last message is a function call, issue #569 * Remove csv, md and tsv from UNSTRUCTURED_FORMATS * Update docstring of docs_path * Update test for get_files_from_dir * Update docstring of custom_text_types * Fix missing search_string in update_context * Add custom_text_types to notebook example
1 parent ef1c3d3 commit 07646d4

File tree

7 files changed

+516
-269
lines changed

7 files changed

+516
-269
lines changed

.github/workflows/build.yml

-4
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ jobs:
4242
python -c "import autogen"
4343
pip install -e. pytest mock
4444
pip uninstall -y openai
45-
- name: Install unstructured if not windows
46-
if: matrix.os != 'windows-2019'
47-
run: |
48-
pip install "unstructured[all-docs]"
4945
- name: Test with pytest
5046
if: matrix.python-version != '3.10'
5147
run: |

.github/workflows/contrib-tests.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ jobs:
3434
run: |
3535
python -m pip install --upgrade pip wheel
3636
pip install pytest
37-
- name: Install qdrant_client when python-version is 3.10
38-
if: matrix.python-version == '3.10' || matrix.python-version == '3.8'
37+
- name: Install qdrant_client when python-version is 3.8 and 3.10
38+
if: matrix.python-version == '3.8' || matrix.python-version == '3.10'
3939
run: |
4040
pip install qdrant_client[fastembed]
41+
- name: Install unstructured when python-version is 3.9 and 3.11 and not windows
42+
if: (matrix.python-version == '3.9' || matrix.python-version == '3.11') && matrix.os != 'windows-2019'
43+
run: |
44+
pip install unstructured[all-docs]
4145
- name: Install packages and dependencies for RetrieveChat
4246
run: |
4347
pip install -e .[retrievechat]

autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py

+31-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Callable, Dict, List, Optional
22

33
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
4-
from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks
4+
from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS
55
import logging
66

77
logger = logging.getLogger(__name__)
@@ -45,8 +45,8 @@ def __init__(
4545
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
4646
- client (Optional, qdrant_client.QdrantClient(":memory:")): A QdrantClient instance. If not provided, an in-memory instance will be assigned. Not recommended for production.
4747
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
48-
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
49-
or the url to a single file. Default is None, which works only if the collection is already created.
48+
- docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file,
49+
the url to a single file or a list of directories, files and urls. Default is None, which works only if the collection is already created.
5050
- collection_name (Optional, str): the name of the collection.
5151
If key not provided, a default name `autogen-docs` will be used.
5252
- model (Optional, str): the model to use for the retrieve chat.
@@ -66,11 +66,14 @@ def __init__(
6666
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
6767
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
6868
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
69-
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
69+
- custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string.
7070
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
7171
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
72-
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
72+
- custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings.
7373
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
74+
- custom_text_types (Optional, List[str]): a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.
75+
This only applies to files under the directories in `docs_path`. Explictly included files and urls will be chunked regardless of their types.
76+
- recursive (Optional, bool): whether to search documents recursively in the docs_path. Default is True.
7477
- parallel (Optional, int): How many parallel workers to use for embedding. Defaults to the number of CPU cores.
7578
- on_disk (Optional, bool): Whether to store the collection on disk. Default is False.
7679
- quantization_config: Quantization configuration. If None, quantization will be disabled.
@@ -111,6 +114,8 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
111114
must_break_at_empty_line=self._must_break_at_empty_line,
112115
embedding_model=self._embedding_model,
113116
custom_text_split_function=self.custom_text_split_function,
117+
custom_text_types=self._custom_text_types,
118+
recursive=self._recursive,
114119
parallel=self._parallel,
115120
on_disk=self._on_disk,
116121
quantization_config=self._quantization_config,
@@ -139,15 +144,17 @@ def create_qdrant_from_dir(
139144
must_break_at_empty_line: bool = True,
140145
embedding_model: str = "BAAI/bge-small-en-v1.5",
141146
custom_text_split_function: Callable = None,
147+
custom_text_types: List[str] = TEXT_FORMATS,
148+
recursive: bool = True,
142149
parallel: int = 0,
143150
on_disk: bool = False,
144151
quantization_config: Optional[models.QuantizationConfig] = None,
145152
hnsw_config: Optional[models.HnswConfigDiff] = None,
146153
payload_indexing: bool = False,
147154
qdrant_client_options: Optional[Dict] = {},
148155
):
149-
"""Create a Qdrant collection from all the files in a given directory, the directory can also be a single file or a url to
150-
a single file.
156+
"""Create a Qdrant collection from all the files in a given directory, the directory can also be a single file or a
157+
url to a single file.
151158
152159
Args:
153160
dir_path (str): the path to the directory, file or url.
@@ -156,24 +163,35 @@ def create_qdrant_from_dir(
156163
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
157164
chunk_mode (Optional, str): the chunk mode. Default is "multi_lines".
158165
must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True.
159-
embedding_model (Optional, str): the embedding model to use. Default is "BAAI/bge-small-en-v1.5". The list of all the available models can be at https://qdrant.github.io/fastembed/examples/Supported_Models/.
166+
embedding_model (Optional, str): the embedding model to use. Default is "BAAI/bge-small-en-v1.5".
167+
The list of all the available models can be at https://qdrant.github.io/fastembed/examples/Supported_Models/.
168+
custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings.
169+
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
170+
custom_text_types (Optional, List[str]): a list of file types to be processed. Default is TEXT_FORMATS.
171+
recursive (Optional, bool): whether to search documents recursively in the dir_path. Default is True.
160172
parallel (Optional, int): How many parallel workers to use for embedding. Defaults to the number of CPU cores
161173
on_disk (Optional, bool): Whether to store the collection on disk. Default is False.
162-
quantization_config: Quantization configuration. If None, quantization will be disabled. Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection
163-
hnsw_config: HNSW configuration. If None, default configuration will be used. Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection
174+
quantization_config: Quantization configuration. If None, quantization will be disabled.
175+
Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection
176+
hnsw_config: HNSW configuration. If None, default configuration will be used.
177+
Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection
164178
payload_indexing: Whether to create a payload index for the document field. Default is False.
165-
qdrant_client_options: (Optional, dict): the options for instantiating the qdrant client. Reference: https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L36-L58.
179+
qdrant_client_options: (Optional, dict): the options for instantiating the qdrant client.
180+
Ref: https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L36-L58.
166181
"""
167182
if client is None:
168183
client = QdrantClient(**qdrant_client_options)
169184
client.set_model(embedding_model)
170185

171186
if custom_text_split_function is not None:
172187
chunks = split_files_to_chunks(
173-
get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function
188+
get_files_from_dir(dir_path, custom_text_types, recursive),
189+
custom_text_split_function=custom_text_split_function,
174190
)
175191
else:
176-
chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
192+
chunks = split_files_to_chunks(
193+
get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line
194+
)
177195
logger.info(f"Found {len(chunks)} chunks.")
178196

179197
# Check if collection by same name exists, if not, create it with custom options

autogen/agentchat/contrib/retrieve_user_proxy_agent.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
77
from autogen.agentchat.agent import Agent
88
from autogen.agentchat import UserProxyAgent
9-
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
9+
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS
1010
from autogen.token_count_utils import count_token
1111
from autogen.code_utils import extract_code
1212

@@ -97,8 +97,8 @@ def __init__(
9797
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
9898
- client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
9999
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
100-
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
101-
or the url to a single file. Default is None, which works only if the collection is already created.
100+
- docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file,
101+
the url to a single file or a list of directories, files and urls. Default is None, which works only if the collection is already created.
102102
- collection_name (Optional, str): the name of the collection.
103103
If key not provided, a default name `autogen-docs` will be used.
104104
- model (Optional, str): the model to use for the retrieve chat.
@@ -124,11 +124,14 @@ def __init__(
124124
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
125125
- get_or_create (Optional, bool): if True, will create/return a collection for the retrieve chat. This is the same as that used in chromadb.
126126
Default is False. Will raise ValueError if the collection already exists and get_or_create is False. Will be set to True if docs_path is None.
127-
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
127+
- custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string.
128128
The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function.
129129
Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
130-
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
130+
- custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings.
131131
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
132+
- custom_text_types (Optional, List[str]): a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.
133+
This only applies to files under the directories in `docs_path`. Explictly included files and urls will be chunked regardless of their types.
134+
- recursive (Optional, bool): whether to search documents recursively in the docs_path. Default is True.
132135
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
133136
134137
Example of overriding retrieve_docs:
@@ -181,6 +184,8 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
181184
self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True
182185
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
183186
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
187+
self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS)
188+
self._recursive = self._retrieve_config.get("recursive", True)
184189
self._context_max_tokens = self._max_tokens * 0.8
185190
self._collection = True if self._docs_path is None else False # whether the collection is created
186191
self._ipython = get_ipython()
@@ -189,6 +194,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
189194
self._intermediate_answers = set() # the intermediate answers
190195
self._doc_contents = [] # the contents of the current used doc
191196
self._doc_ids = [] # the ids of the current used doc
197+
self._search_string = "" # the search string used in the current query
192198
# update the termination message function
193199
self._is_termination_msg = (
194200
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
@@ -282,6 +288,8 @@ def _generate_message(self, doc_contents, task="default"):
282288
def _check_update_context(self, message):
283289
if isinstance(message, dict):
284290
message = message.get("content", "")
291+
elif not isinstance(message, str):
292+
message = ""
285293
update_context_case1 = "UPDATE CONTEXT" in message[-20:].upper() or "UPDATE CONTEXT" in message[:20].upper()
286294
update_context_case2 = self.customized_answer_prefix and self.customized_answer_prefix not in message.upper()
287295
return update_context_case1, update_context_case2
@@ -320,7 +328,9 @@ def _generate_retrieve_user_reply(
320328
if not doc_contents:
321329
for _tmp_retrieve_count in range(1, 5):
322330
self._reset(intermediate=True)
323-
self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1))
331+
self.retrieve_docs(
332+
self.problem, self.n_results * (2 * _tmp_retrieve_count + 1), self._search_string
333+
)
324334
doc_contents = self._get_context(self._results)
325335
if doc_contents:
326336
break
@@ -329,7 +339,9 @@ def _generate_retrieve_user_reply(
329339
# docs in the retrieved doc results to the context.
330340
for _tmp_retrieve_count in range(5):
331341
self._reset(intermediate=True)
332-
self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1))
342+
self.retrieve_docs(
343+
_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1), self._search_string
344+
)
333345
self._get_context(self._results)
334346
doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers)
335347
if doc_contents:
@@ -371,6 +383,8 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
371383
get_or_create=self._get_or_create,
372384
embedding_function=self._embedding_function,
373385
custom_text_split_function=self.custom_text_split_function,
386+
custom_text_types=self._custom_text_types,
387+
recursive=self._recursive,
374388
)
375389
self._collection = True
376390
self._get_or_create = True
@@ -384,6 +398,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
384398
embedding_model=self._embedding_model,
385399
embedding_function=self._embedding_function,
386400
)
401+
self._search_string = search_string
387402
self._results = results
388403
print("doc_ids: ", results["ids"])
389404

0 commit comments

Comments
 (0)