Skip to content

Support custom text formats and recursive #496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ec3f6f2
Add custom text types and recursive
thinkall Oct 31, 2023
7f6eb70
Add custom text types and recursive
thinkall Oct 31, 2023
041dfb3
Merge branch 'main' into text_formats
thinkall Nov 1, 2023
14dc63d
Merge branch 'main' into text_formats
thinkall Nov 1, 2023
a5936f7
Fix format
thinkall Nov 1, 2023
4a879b6
Merge branch 'main' into text_formats
thinkall Nov 3, 2023
8a09586
Merge main
thinkall Nov 5, 2023
b7ca956
Merge main
thinkall Nov 6, 2023
8ca010e
Update qdrant, Add pdf to unstructured
thinkall Nov 6, 2023
ae6e080
Use unstructed as the default text extractor if installed
thinkall Nov 6, 2023
a52602a
Add tests for unstructured
thinkall Nov 6, 2023
5749783
Update tests env for unstructured
thinkall Nov 6, 2023
51f4251
Fix error if last message is a function call, issue #569
thinkall Nov 6, 2023
1b7a6b6
Merge branch 'main' into text_formats
thinkall Nov 7, 2023
852a295
Merge branch 'main' into text_formats
thinkall Nov 8, 2023
6d10f63
Merge branch 'main' into text_formats
thinkall Nov 9, 2023
499ac1d
Merge branch 'main' into text_formats
thinkall Nov 10, 2023
e13b934
Merge branch 'main' into text_formats
thinkall Nov 12, 2023
25cd7a6
Merge main
thinkall Nov 13, 2023
253e086
Merge branch 'main' into text_formats
thinkall Nov 13, 2023
db28f5e
Merge branch 'main' into text_formats
thinkall Nov 14, 2023
5501996
Merge branch 'main' into text_formats
thinkall Nov 15, 2023
6059dec
Merge branch 'main' into text_formats
thinkall Nov 16, 2023
b612825
Merge branch 'main' into text_formats
thinkall Nov 17, 2023
879237a
Merge branch 'main' into text_formats
thinkall Nov 17, 2023
056fe9a
Remove csv, md and tsv from UNSTRUCTURED_FORMATS
thinkall Nov 17, 2023
fd6c62b
Update docstring of docs_path
thinkall Nov 17, 2023
5346800
Update test for get_files_from_dir
thinkall Nov 17, 2023
c9566ad
Update docstring of custom_text_types
thinkall Nov 17, 2023
c200d09
Fix missing search_string in update_context
thinkall Nov 17, 2023
f998f8c
Add custom_text_types to notebook example
thinkall Nov 17, 2023
ff6b92b
Resolve conflicts in notebook
thinkall Nov 18, 2023
8906a4f
Merge branch 'main' into text_formats
thinkall Nov 18, 2023
48bab13
Merge branch 'main' into text_formats
thinkall Nov 19, 2023
e27ebf1
Merge branch 'main' into text_formats
thinkall Nov 20, 2023
39e7718
Merge branch 'main' into text_formats
thinkall Nov 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS
from autogen.token_count_utils import count_token
from autogen.code_utils import extract_code

Expand Down Expand Up @@ -129,6 +129,8 @@ def __init__(
Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
- custom_text_types(Optional, List[str]): a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.
- recursive(Optional, bool): whether to search documents recursively in the docs_path. Default is True.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).

Example of overriding retrieve_docs:
Expand Down Expand Up @@ -183,6 +185,8 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS)
self._recursive = self._retrieve_config.get("recursive", True)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
Expand Down Expand Up @@ -373,6 +377,8 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
get_or_create=self._get_or_create,
embedding_function=self._embedding_function,
custom_text_split_function=self.custom_text_split_function,
custom_text_types=self._custom_text_types,
recursive=self._recursive,
)
self._collection = True
self._get_or_create = False
Expand Down
14 changes: 12 additions & 2 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def create_vector_db_from_dir(
embedding_model: str = "all-MiniLM-L6-v2",
embedding_function: Callable = None,
custom_text_split_function: Callable = None,
custom_text_types: List[str] = TEXT_FORMATS,
recursive: bool = True,
) -> API:
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
Expand All @@ -236,6 +238,10 @@ def create_vector_db_from_dir(
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
custom_text_types(Optional, List[str]): a list of file types to be processed. Default is TEXT_FORMATS.
recursive(Optional, bool): whether to search documents recursively in the dir_path. Default is True.

Returns:
API: the chromadb client.
Expand All @@ -260,11 +266,15 @@ def create_vector_db_from_dir(

if custom_text_split_function is not None:
chunks = split_files_to_chunks(
get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function
get_files_from_dir(dir_path, custom_text_types, recursive),
custom_text_split_function=custom_text_split_function,
)
else:
chunks = split_files_to_chunks(
get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line
get_files_from_dir(dir_path, custom_text_types, recursive),
max_tokens,
chunk_mode,
must_break_at_empty_line,
)
logger.info(f"Found {len(chunks)} chunks.")
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000
Expand Down
8 changes: 7 additions & 1 deletion test/test_retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def custom_text_split_function(text):
collection_name="mytestcollection",
custom_text_split_function=custom_text_split_function,
get_or_create=True,
recursive=False,
)
results = query_vector_db(["autogen"], client=client, collection_name="mytestcollection", n_results=1)
assert (
Expand All @@ -163,7 +164,12 @@ def custom_text_split_function(text):

def test_retrieve_utils(self):
client = chromadb.PersistentClient(path="/tmp/chromadb")
create_vector_db_from_dir(dir_path="./website/docs", client=client, collection_name="autogen-docs")
create_vector_db_from_dir(
dir_path="./website/docs",
client=client,
collection_name="autogen-docs",
custom_text_types=["txt", "md", "rtf", "rst"],
)
results = query_vector_db(
query_texts=[
"How can I use AutoGen UserProxyAgent and AssistantAgent to do code generation?",
Expand Down