From 8f33d650b169c1c064235bbb20b649c83eed4500 Mon Sep 17 00:00:00 2001 From: Alex Chao Date: Fri, 8 Nov 2024 12:25:18 -0800 Subject: [PATCH] fix: MarkdownChunker, retain subsection headers Update MarkdownChunker to retain level 2 and level 3 Markdown headers in the chunk content for better retrieval. Closes #322 --- .../components/chunking/MarkdownChunker.py | 44 ++++++++++++++----- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/goldenverba/components/chunking/MarkdownChunker.py b/goldenverba/components/chunking/MarkdownChunker.py index 17c932964..113d27411 100644 --- a/goldenverba/components/chunking/MarkdownChunker.py +++ b/goldenverba/components/chunking/MarkdownChunker.py @@ -2,6 +2,7 @@ with contextlib.suppress(Exception): from langchain_text_splitters import MarkdownHeaderTextSplitter + from langchain_core.documents import Document as LangChainDocument from goldenverba.components.chunk import Chunk from goldenverba.components.interfaces import Chunker @@ -9,6 +10,30 @@ from goldenverba.components.interfaces import Embedding +HEADERS_TO_SPLIT_ON = [ + ("#", "Header 1"), + ("##", "Header 2"), + ("###", "Header 3"), +] + + +def get_header_values( + split_doc: LangChainDocument, +) -> list[str]: + """ + Get the text values of the headers in the LangChain Document resulting from a split. + """ + # This function uses an explicit list of header keys because the LangChain Document + # metadata is a dictionary with arbitrary entries, some of which may not be headers. + header_keys = [header_key for _, header_key in HEADERS_TO_SPLIT_ON] + + return [ + header_value + for header_key in header_keys + if (header_value := split_doc.metadata.get(header_key)) is not None + ] + + class MarkdownChunker(Chunker): """ MarkdownChunker for Verba using LangChain. @@ -31,11 +56,7 @@ async def chunk( ) -> list[Document]: text_splitter = MarkdownHeaderTextSplitter( - headers_to_split_on=[ - ("#", "Header 1"), - ("##", "Header 2"), - ("###", "Header 3"), - ] + headers_to_split_on=HEADERS_TO_SPLIT_ON ) char_end_i = -1 @@ -45,16 +66,17 @@ async def chunk( if len(document.chunks) > 0: continue - for i, chunk in enumerate(text_splitter.split_text(document.content)): - + for i, split_doc in enumerate(text_splitter.split_text(document.content)): + chunk_text = "" - # append title and page content (should only be one header as we are splitting at header so index at 0), if a header is found - if len(chunk.metadata) > 0: - chunk_text += list(chunk.metadata.values())[0] + "\n" + # Add header content to retain context and improve retrieval + header_values = get_header_values(split_doc) + for header_value in header_values: + chunk_text += header_value + "\n" # append page content (always there) - chunk_text += chunk.page_content + chunk_text += split_doc.page_content char_start_i = char_end_i + 1 char_end_i = char_start_i + len(chunk_text)