Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: MarkdownChunker, retain subsection headers #323

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions goldenverba/components/chunking/MarkdownChunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,38 @@

with contextlib.suppress(Exception):
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_core.documents import Document as LangChainDocument

from goldenverba.components.chunk import Chunk
from goldenverba.components.interfaces import Chunker
from goldenverba.components.document import Document
from goldenverba.components.interfaces import Embedding


HEADERS_TO_SPLIT_ON = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]


def get_header_values(
split_doc: LangChainDocument,
) -> list[str]:
"""
Get the text values of the headers in the LangChain Document resulting from a split.
"""
# This function uses an explicit list of header keys because the LangChain Document
# metadata is a dictionary with arbitrary entries, some of which may not be headers.
header_keys = [header_key for _, header_key in HEADERS_TO_SPLIT_ON]

return [
header_value
for header_key in header_keys
if (header_value := split_doc.metadata.get(header_key)) is not None
]


class MarkdownChunker(Chunker):
"""
MarkdownChunker for Verba using LangChain.
Expand All @@ -31,11 +56,7 @@ async def chunk(
) -> list[Document]:

text_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=[
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
headers_to_split_on=HEADERS_TO_SPLIT_ON
)

char_end_i = -1
Expand All @@ -45,16 +66,17 @@ async def chunk(
if len(document.chunks) > 0:
continue

for i, chunk in enumerate(text_splitter.split_text(document.content)):
for i, split_doc in enumerate(text_splitter.split_text(document.content)):

chunk_text = ""

# append title and page content (should only be one header as we are splitting at header so index at 0), if a header is found
if len(chunk.metadata) > 0:
chunk_text += list(chunk.metadata.values())[0] + "\n"
# Add header content to retain context and improve retrieval
header_values = get_header_values(split_doc)
for header_value in header_values:
chunk_text += header_value + "\n"

# append page content (always there)
chunk_text += chunk.page_content
chunk_text += split_doc.page_content

char_start_i = char_end_i + 1
char_end_i = char_start_i + len(chunk_text)
Expand Down