diff --git a/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs b/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs index 523db677ee..d6254fadf7 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; namespace Microsoft.ML.Tokenizers @@ -19,14 +20,21 @@ public static int[] BytePairEncode(ReadOnlyMemory mergingBytes, Dictionary return [ranks[mergingBytes]]; } - var byteIndicesAndRanks = new List<(int Index, int Rank)>(); - for (int i = 0; i < mergingBytes.Length + 1; i++) + (int Index, int Rank)[]? arrayPoolArray = null; + int requiredLength = mergingBytes.Length + 1; + Span<(int Index, int Rank)> byteIndicesAndRanks = requiredLength <= 64 ? + stackalloc (int, int)[64] : + (arrayPoolArray = ArrayPool<(int, int)>.Shared.Rent(requiredLength)); + byteIndicesAndRanks = byteIndicesAndRanks.Slice(0, requiredLength); + + for (int i = 0; i < byteIndicesAndRanks.Length; i++) { - byteIndicesAndRanks.Add((i, int.MaxValue)); + byteIndicesAndRanks[i] = (i, int.MaxValue); } - int GetRank(int startIndex, int skip = 0) + + int GetRank(Span<(int Index, int Rank)> byteIndicesAndRanks, int startIndex, int skip = 0) { - if (startIndex + skip + 2 < byteIndicesAndRanks.Count) + if (startIndex + skip + 2 < byteIndicesAndRanks.Length) { var slice = mergingBytes.SliceStartEnd(byteIndicesAndRanks[startIndex].Index, byteIndicesAndRanks[startIndex + skip + 2].Index); if (ranks.TryGetValue(slice, out var rank)) @@ -34,35 +42,41 @@ int GetRank(int startIndex, int skip = 0) return rank; } } + return int.MaxValue; } - for (int i = 0; i < byteIndicesAndRanks.Count - 2; i++) + + for (int i = 0; i < byteIndicesAndRanks.Length - 2; i++) { - var rank = GetRank(i); + int rank = GetRank(byteIndicesAndRanks, i); if (rank != int.MaxValue) { - byteIndicesAndRanks[i] = (byteIndicesAndRanks[i].Index, rank); + byteIndicesAndRanks[i].Rank = rank; } } - while (byteIndicesAndRanks.Count > 1) + + while (byteIndicesAndRanks.Length > 1) { var minRank = (Index: 0, Rank: int.MaxValue); - for (int i = 0; i < byteIndicesAndRanks.Count - 1; i++) + for (int i = 0; i < byteIndicesAndRanks.Length - 1; i++) { if (byteIndicesAndRanks[i].Rank < minRank.Rank) { minRank = (i, byteIndicesAndRanks[i].Rank); } } + if (minRank.Rank != int.MaxValue) { int j = minRank.Index; - byteIndicesAndRanks[j] = (byteIndicesAndRanks[j].Index, GetRank(j, 1)); + byteIndicesAndRanks[j].Rank = GetRank(byteIndicesAndRanks, j, 1); if (j > 0) { - byteIndicesAndRanks[j - 1] = (byteIndicesAndRanks[j - 1].Index, GetRank(j - 1, 1)); + byteIndicesAndRanks[j - 1].Rank = GetRank(byteIndicesAndRanks, j - 1, 1); } - byteIndicesAndRanks.RemoveAt(j + 1); + + byteIndicesAndRanks.Slice(j + 2).CopyTo(byteIndicesAndRanks.Slice(j + 1)); + byteIndicesAndRanks = byteIndicesAndRanks.Slice(0, byteIndicesAndRanks.Length - 1); } else { @@ -70,12 +84,18 @@ int GetRank(int startIndex, int skip = 0) } } - var outList = new int[byteIndicesAndRanks.Count - 1]; - for (int i = 0; i < byteIndicesAndRanks.Count - 1; i++) + var result = new int[byteIndicesAndRanks.Length - 1]; + for (int i = 0; i < result.Length; i++) { - outList[i] = ranks[mergingBytes.SliceStartEnd(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)]; + result[i] = ranks[mergingBytes.SliceStartEnd(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)]; } - return outList; + + if (arrayPoolArray is not null) + { + ArrayPool<(int, int)>.Shared.Return(arrayPoolArray); + } + + return result; } private static ReadOnlyMemory SliceStartEnd(this ReadOnlyMemory memory, int start, int end) => memory.Slice(start, end - start);