Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 132 additions & 98 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs

Large diffs are not rendered by default.

97 changes: 19 additions & 78 deletions src/Microsoft.ML.Tokenizers/Model/Cache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,112 +4,53 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;

namespace Microsoft.ML.Tokenizers
{
internal sealed class Cache<TKey, TValue> where TKey : notnull where TValue : notnull
{
private readonly int _capacity;
private readonly Dictionary<TKey, TValue> _map;
private object SyncObj => _map;

internal Cache() : this(Bpe.DefaultCacheCapacity) { }

internal Cache(int capacity)
{
Capacity = capacity;
Map = new Dictionary<TKey, TValue>(Capacity);
_capacity = capacity;
_map = new Dictionary<TKey, TValue>(capacity);
}

private readonly object _lock = new();

internal Dictionary<TKey, TValue> Map { get; set; }

internal int Capacity { get; set; }

internal void Fresh() => Map = new Dictionary<TKey, TValue>(Capacity);

internal void Clear()
internal bool TryGetValue(TKey key, out TValue value)
{
lock (_lock)
lock (SyncObj)
{
Map.Clear();
return _map.TryGetValue(key, out value!);
}
}

internal List<TValue> GetValues(IEnumerable<TKey> keys)
{
List<TValue> values = new();
lock (_lock)
{
foreach (TKey key in keys)
{
if (Map.TryGetValue(key, out TValue? value))
{
values.Add(value);
}
}
}

return values;
}

internal bool TryGet(TKey key, out TValue value)
{
lock (_lock)
{
return Map.TryGetValue(key, out value!);
}
}

internal void SetValues(IEnumerable<(TKey, TValue)> entries)
{
lock (_lock)
{
foreach ((TKey, TValue) entry in entries)
{
if (Capacity <= Map.Count)
{
break;
}
Map[entry.Item1] = entry.Item2;
}
}
}

internal void Set(TKey k, TValue v)
internal TValue GetOrAdd(TKey key, TValue value)
{
lock (_lock)
lock (SyncObj)
{
if (Capacity > Map.Count)
if (_map.TryGetValue(key, out TValue? v))
{
Map[k] = v;
return v!;
}
}
}

internal KeyValuePair<TKey, TValue>[] ToArray()
{
lock (_lock)
{
return Map.ToArray();
_map[key] = value;
return value;
}
}

internal TValue GetOrAdd(TKey key, TValue value)
internal void Set(TKey key, TValue value)
{
lock (_lock)
lock (SyncObj)
{
if (Map.TryGetValue(key, out TValue? v))
if (_map.Count < _capacity)
{
return v;
_map[key] = value;
}

if (Capacity > Map.Count)
{
Map[key] = value;
}

return value;
}
}
}
Expand Down
89 changes: 58 additions & 31 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ namespace Microsoft.ML.Tokenizers
public sealed class EnglishRoberta : Model
{
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
private readonly IReadOnlyDictionary<string, int> _vocab;
private readonly SortedDictionary<int, string> _vocabReverse;
private readonly IReadOnlyDictionary<StringSpanOrdinalKey, int> _vocab;
private Dictionary<string, int>? _vocabOriginal;
private readonly SortedDictionary<int, StringSpanOrdinalKey> _vocabReverse;
private readonly Cache<(string, string), int> _mergeRanks;
private readonly IReadOnlyDictionary<char, char> _byteToUnicode;
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
private readonly string[] _charToString;
private readonly Cache<string, List<Token>> _cache;
private readonly StringSpanOrdinalKeyCache<List<Token>> _cache;

/// <summary>
/// Indicate if want to filter the unsupported characters during the decoding.
Expand Down Expand Up @@ -77,7 +78,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc
}

_unicodeToByte = _byteToUnicode.Reverse();
_cache = new Cache<string, List<Token>>();
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
}

/// <summary>
Expand Down Expand Up @@ -118,13 +119,13 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
}

_unicodeToByte = _byteToUnicode.Reverse();
_cache = new Cache<string, List<Token>>();
_cache = new StringSpanOrdinalKeyCache<List<Token>>();
}

/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int> Vocab => _vocab;
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= (_vocabOriginal = _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value));

//
// Public Model interfaces implementation
Expand All @@ -145,14 +146,15 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes

if (_vocabReverse.TryGetValue(id, out var value))
{
string v = value.Data!;
if (FilterUnsupportedChars)
{
char[] buffer = ArrayPool<char>.Shared.Rent(value.Length);
char[] buffer = ArrayPool<char>.Shared.Rent(v.Length);
int i = 0;

for (int j = 0; j < value.Length; j++)
for (int j = 0; j < v.Length; j++)
{
if (_unicodeToByte.TryGetValue(value[j], out var c))
if (_unicodeToByte.TryGetValue(v[j], out var c))
{
buffer[i++] = c;
}
Expand All @@ -164,7 +166,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
}
else
{
return value;
return v;
}
}

Expand Down Expand Up @@ -205,7 +207,7 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
return Array.Empty<Token>();
}

if (_cache.TryGet(text, out List<Token>? hit))
if (_cache.TryGetValue(text, out List<Token>? hit))
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
Expand All @@ -225,24 +227,24 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
/// <param name="text">The text to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, accumulatedIds);
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, accumulatedIds);

/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIds(text, null);
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIds(text, null);

private int EncodeToIds(string text, IList<int>? accumulatedIds)
private int EncodeToIds(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
{
if (string.IsNullOrEmpty(text))
if (text.IsEmpty)
{
return 0;
}

if (_cache.TryGet(text, out List<Token>? hit))
if (_cache.TryGetValue(text, out List<Token>? hit))
{
if (accumulatedIds is not null)
{
Expand All @@ -255,17 +257,41 @@ private int EncodeToIds(string text, IList<int>? accumulatedIds)
return hit.Count;
}

// If the cache doesn't have the text, then encode it and add it to the cache
IReadOnlyList<Token> tokens = Encode(text);
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);

int newTokenIndex = 0;
for (int i = 0; i < text.Length; i++)
{
if (_byteToUnicode.TryGetValue(text[i], out var value))
{
token[newTokenIndex] = value;
indexMapping[newTokenIndex] = i;
newTokenIndex++;
}
}

if (newTokenIndex == 0)
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
return 0;
}

List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
_cache.Set(text.ToString(), result);
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);

if (accumulatedIds is not null)
{
foreach (var t in tokens)
foreach (var t in result)
{
accumulatedIds.Add(t.Id);
}
}

return tokens.Count;
return result.Count;
}

/// <summary>
Expand All @@ -274,7 +300,7 @@ private int EncodeToIds(string text, IList<int>? accumulatedIds)
/// <param name="token">The token to map to the Id.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(string token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out var value) ? value : null;
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValueUnsafe(token, out int value) ? value : null;

/// <summary>
/// Convert a list of tokens Ids to highest occurrence rankings.
Expand Down Expand Up @@ -397,12 +423,13 @@ private IReadOnlyList<Token> ModifyTokenListOffsets(IReadOnlyList<Token> tokens,
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);

private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
private Dictionary<StringSpanOrdinalKey, int> GetVocabulary(Stream vocabularyStream)
{
Dictionary<string, int>? vocab;
Dictionary<StringSpanOrdinalKey, int>? vocab;
try
{
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(vocabularyStream) as Dictionary<string, int>;
JsonSerializerOptions options = new() { Converters = { new StringSpanOrdinalKeyConverter() } };
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, int>;
}
catch (Exception e)
{
Expand All @@ -416,22 +443,22 @@ private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)

if (_vocabIdToHighestOccurrence.BosWord is not null)
{
vocab[_vocabIdToHighestOccurrence.BosWord] = -_vocabIdToHighestOccurrence.BosIndex;
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.BosWord)] = -_vocabIdToHighestOccurrence.BosIndex;
}

if (_vocabIdToHighestOccurrence.EosWord is not null)
{
vocab[_vocabIdToHighestOccurrence.EosWord] = -_vocabIdToHighestOccurrence.EosIndex;
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.EosWord)] = -_vocabIdToHighestOccurrence.EosIndex;
}

if (_vocabIdToHighestOccurrence.UnkWord is not null)
{
vocab[_vocabIdToHighestOccurrence.UnkWord] = -_vocabIdToHighestOccurrence.UnkIndex;
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.UnkWord)] = -_vocabIdToHighestOccurrence.UnkIndex;
}

if (_vocabIdToHighestOccurrence.PadWord is not null)
{
vocab[_vocabIdToHighestOccurrence.PadWord] = -_vocabIdToHighestOccurrence.PadIndex;
vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.PadWord)] = -_vocabIdToHighestOccurrence.PadIndex;
}

return vocab;
Expand Down Expand Up @@ -510,7 +537,7 @@ private List<Token> EncodeToTokens(Span<char> token, Span<int> indexMapping)
if (token.Length == 1)
{
string tokenValue = _charToString[token[0]];
return new List<Token> { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], 1)) };
return new List<Token> { new Token(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) };
}

List<string> word = new(token.Length);
Expand Down Expand Up @@ -539,7 +566,7 @@ private List<Token> EncodeToTokens(Span<char> token, Span<int> indexMapping)

// get the most frequent bi-gram pair
var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue));
if (!_mergeRanks.TryGet((first, second), out int _))
if (!_mergeRanks.TryGetValue((first, second), out int _))
{
break;
}
Expand Down Expand Up @@ -599,7 +626,7 @@ private List<Token> EncodeToTokens(Span<char> token, Span<int> indexMapping)

foreach (string w in word)
{
tokens.Add(new Token(_vocab[w], w, (indexMapping[index], w.Length)));
tokens.Add(new Token(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length)));
index += w.Length;
}

Expand Down
Loading