diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Chunkers/DocumentTokenChunker.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Chunkers/DocumentTokenChunker.cs new file mode 100644 index 00000000000..57025773049 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Chunkers/DocumentTokenChunker.cs @@ -0,0 +1,125 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using Microsoft.ML.Tokenizers; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.DataIngestion.Chunkers +{ + /// + /// Processes a document by tokenizing its content and dividing it into overlapping chunks of tokens. + /// + /// + /// This class uses a tokenizer to convert the document's content into tokens and then splits the + /// tokens into chunks of a specified size, with a configurable overlap between consecutive chunks. + /// Note that tables may be split mid-row. + /// + public sealed class DocumentTokenChunker : IngestionChunker + { + private readonly Tokenizer _tokenizer; + private readonly int _maxTokensPerChunk; + private readonly int _chunkOverlap; + + /// + /// Initializes a new instance of the class with the specified options. + /// + /// The options used to configure the chunker, including tokenizer and chunk sizes. + public DocumentTokenChunker(IngestionChunkerOptions options) + { + _ = Throw.IfNull(options); + + _tokenizer = options.Tokenizer; + _maxTokensPerChunk = options.MaxTokensPerChunk; + _chunkOverlap = options.OverlapTokens; + } + + /// + public override async IAsyncEnumerable> ProcessAsync(IngestionDocument document, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(document); + + int stringBuilderTokenCount = 0; + StringBuilder stringBuilder = new(); + foreach (IngestionDocumentElement element in document.EnumerateContent()) + { + cancellationToken.ThrowIfCancellationRequested(); + string? elementContent = element.GetSemanticContent(); + if (string.IsNullOrEmpty(elementContent)) + { + continue; + } + + int contentToProcessTokenCount = _tokenizer.CountTokens(elementContent!, considerNormalization: false); + ReadOnlyMemory contentToProcess = elementContent.AsMemory(); + while (stringBuilderTokenCount + contentToProcessTokenCount >= _maxTokensPerChunk) + { + int index = _tokenizer.GetIndexByTokenCount( + text: contentToProcess.Span, + maxTokenCount: _maxTokensPerChunk - stringBuilderTokenCount, + out string? _, + out int _, + considerNormalization: false); + + unsafe + { + fixed (char* ptr = &MemoryMarshal.GetReference(contentToProcess.Span)) + { + _ = stringBuilder.Append(ptr, index); + } + } + yield return FinalizeChunk(); + + contentToProcess = contentToProcess.Slice(index); + contentToProcessTokenCount = _tokenizer.CountTokens(contentToProcess.Span, considerNormalization: false); + } + + _ = stringBuilder.Append(contentToProcess); + stringBuilderTokenCount += contentToProcessTokenCount; + } + + if (stringBuilder.Length > 0) + { + yield return FinalizeChunk(); + } + yield break; + + IngestionChunk FinalizeChunk() + { + IngestionChunk chunk = new IngestionChunk( + content: stringBuilder.ToString(), + document: document, + context: string.Empty); + _ = stringBuilder.Clear(); + stringBuilderTokenCount = 0; + + if (_chunkOverlap > 0) + { + int index = _tokenizer.GetIndexByTokenCountFromEnd( + text: chunk.Content, + maxTokenCount: _chunkOverlap, + out string? _, + out stringBuilderTokenCount, + considerNormalization: false); + + ReadOnlySpan overlapContent = chunk.Content.AsSpan().Slice(index); + unsafe + { + fixed (char* ptr = &MemoryMarshal.GetReference(overlapContent)) + { + _ = stringBuilder.Append(ptr, overlapContent.Length); + } + } + } + + return chunk; + } + } + + } +} diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionDocumentElementExtensions.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionDocumentElementExtensions.cs new file mode 100644 index 00000000000..a8c98dd2c02 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionDocumentElementExtensions.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.DataIngestion; + +/// +/// Extension methods for . +/// +internal static class IngestionDocumentElementExtensions +{ + /// + /// Gets the semantic content of the element if available. + /// + /// The element to get semantic content from. + /// The semantic content suitable for embedding generation. + internal static string? GetSemanticContent(this IngestionDocumentElement element) + { + return element switch + { + IngestionDocumentImage image => image.AlternativeText ?? image.Text, + _ => element.GetMarkdown() + }; + } +} diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj b/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj index 8df401df7fd..e2ccd34747a 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj @@ -13,6 +13,7 @@ false 75 75 + true diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Chunkers/DocumentTokenChunkerTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Chunkers/DocumentTokenChunkerTests.cs new file mode 100644 index 00000000000..aa9b0ced0bd --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Chunkers/DocumentTokenChunkerTests.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.DataIngestion.Chunkers.Tests +{ + public abstract class DocumentTokenChunkerTests : DocumentChunkerTests + { + [Fact] + public async Task SingleChunkText() + { + string text = "This is a short document that fits within a single chunk."; + IngestionDocument doc = new IngestionDocument("singleChunkDoc"); + doc.Sections.Add(new IngestionDocumentSection + { + Elements = + { + new IngestionDocumentParagraph(text) + } + }); + + IngestionChunker chunker = CreateDocumentChunker(); + IReadOnlyList> chunks = await chunker.ProcessAsync(doc).ToListAsync(); + + IngestionChunk chunk = Assert.Single(chunks); + Assert.Equal(text, chunk.Content, ignoreLineEndingDifferences: true); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Chunkers/NoOverlapTokenChunkerTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Chunkers/NoOverlapTokenChunkerTests.cs new file mode 100644 index 00000000000..04553aade44 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Chunkers/NoOverlapTokenChunkerTests.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.ML.Tokenizers; +using Xunit; + +namespace Microsoft.Extensions.DataIngestion.Chunkers.Tests +{ + public class NoOverlapTokenChunkerTests : DocumentTokenChunkerTests + { + protected override IngestionChunker CreateDocumentChunker(int maxTokensPerChunk = 2_000, int overlapTokens = 500) + { + var tokenizer = TiktokenTokenizer.CreateForModel("gpt-4o"); + return new DocumentTokenChunker(new(tokenizer) { MaxTokensPerChunk = maxTokensPerChunk, OverlapTokens = 0 }); + } + + [Fact] + public async Task TwoChunks() + { + string text = string.Join(" ", Enumerable.Repeat("word", 600)); // each word is 1 token + IngestionDocument doc = new IngestionDocument("twoChunksNoOverlapDoc"); + doc.Sections.Add(new IngestionDocumentSection + { + Elements = + { + new IngestionDocumentParagraph(text) + } + }); + IngestionChunker chunker = CreateDocumentChunker(maxTokensPerChunk: 512); + IReadOnlyList> chunks = await chunker.ProcessAsync(doc).ToListAsync(); + Assert.Equal(2, chunks.Count); + Assert.True(chunks[0].Content.Split(' ').Length <= 512); + Assert.True(chunks[1].Content.Split(' ').Length <= 512); + Assert.Equal(text, string.Join("", chunks.Select(c => c.Content))); + } + + [Fact] + public async Task ManyChunks() + { + string text = string.Join(" ", Enumerable.Repeat("word", 1500)); // each word is 1 token + IngestionDocument doc = new IngestionDocument("smallChunksNoOverlapDoc"); + doc.Sections.Add(new IngestionDocumentSection + { + Elements = + { + new IngestionDocumentParagraph(text) + } + }); + + IngestionChunker chunker = CreateDocumentChunker(maxTokensPerChunk: 200, overlapTokens: 0); + IReadOnlyList> chunks = await chunker.ProcessAsync(doc).ToListAsync(); + Assert.Equal(8, chunks.Count); + foreach (var chunk in chunks) + { + Assert.True(chunk.Content.Split(' ').Count(str => str.Contains("word")) <= 200); + } + + Assert.Equal(text, string.Join("", chunks.Select(c => c.Content))); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Chunkers/OverlapTokenChunkerTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Chunkers/OverlapTokenChunkerTests.cs new file mode 100644 index 00000000000..2d7e247dce2 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Chunkers/OverlapTokenChunkerTests.cs @@ -0,0 +1,62 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.ML.Tokenizers; +using Xunit; + +namespace Microsoft.Extensions.DataIngestion.Chunkers.Tests +{ + public class OverlapTokenChunkerTests : DocumentTokenChunkerTests + { + protected override IngestionChunker CreateDocumentChunker(int maxTokensPerChunk = 2_000, int overlapTokens = 500) + { + var tokenizer = TiktokenTokenizer.CreateForModel("gpt-4o"); + return new DocumentTokenChunker(new(tokenizer) { MaxTokensPerChunk = maxTokensPerChunk, OverlapTokens = overlapTokens }); + } + + [Fact] + public async Task TokenChunking_WithOverlap() + { + string text = "The quick brown fox jumps over the lazy dog"; + var tokenizer = TiktokenTokenizer.CreateForModel("gpt-4o"); + int chunkSize = 4; // Small chunk size to demonstrate overlap + int chunkOverlap = 1; + + var chunker = new DocumentTokenChunker(new(tokenizer) { MaxTokensPerChunk = chunkSize, OverlapTokens = chunkOverlap }); + IngestionDocument doc = new IngestionDocument("overlapExample"); + doc.Sections.Add(new IngestionDocumentSection + { + Elements = + { + new IngestionDocumentParagraph(text) + } + }); + + IReadOnlyList> chunks = await chunker.ProcessAsync(doc).ToListAsync(); + Assert.Equal(3, chunks.Count); + Assert.Equal("The quick brown fox", chunks[0].Content, ignoreLineEndingDifferences: true); + Assert.Equal(" fox jumps over the", chunks[1].Content, ignoreLineEndingDifferences: true); + Assert.Equal(" the lazy dog", chunks[2].Content, ignoreLineEndingDifferences: true); + + Assert.True(tokenizer.CountTokens(chunks.Last().Content) <= chunkSize); + + for (int i = 0; i < chunks.Count - 1; i++) + { + var currentChunk = chunks[i]; + var nextChunk = chunks[i + 1]; + + var currentWords = currentChunk.Content.Split(new[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); + var nextWords = nextChunk.Content.Split(new[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); + + bool hasOverlap = currentWords.Intersect(nextWords).Any(); + Assert.True(hasOverlap, $"Chunks {i} and {i + 1} should have overlapping content"); + } + + Assert.NotEmpty(string.Concat(chunks.Select(c => c.Content))); + } + } +}