Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6e32f5
Fix cache when calling EncodeToIds
tarekgh Feb 17, 2024
0553922
Make EnglishRoberta _mergeRanks thread safe
tarekgh Feb 17, 2024
a4cb1f5
Delete Trainer
tarekgh Feb 19, 2024
6a13025
Remove the setters on the Bpe properties
tarekgh Feb 19, 2024
3278aff
Remove Roberta and Tiktoken special casing in the Tokenizer and suppo…
tarekgh Feb 19, 2024
b5f7fa2
Support text-embedding-3-small/large embedding
tarekgh Feb 19, 2024
a11f4e0
Remove redundant TokenToId abstraction and keep the one with the extr…
tarekgh Feb 19, 2024
865068a
Enable creating Tiktoken asynchronously or directly using the tokeniz…
tarekgh Feb 20, 2024
4077de0
Add cancellationToken support in CreateAsync APIs
tarekgh Feb 21, 2024
5aaf849
Rename sequence to text and Tokenize to Encode
tarekgh Feb 21, 2024
b5e0927
Rename skipSpecialTokens to considerSpecialTokens
tarekgh Feb 21, 2024
5e26b3e
Rename TokenizerResult to EncodingResult
tarekgh Feb 21, 2024
985de8a
Make Token publicly immutable
tarekgh Feb 21, 2024
b551e7d
Change offset tuples from (Index, End) to (Index, Length)
tarekgh Feb 21, 2024
7ea7f70
Rename NormalizedString method's parameters
tarekgh Feb 21, 2024
b0c8244
Rename Model's methods to start with verb
tarekgh Feb 21, 2024
450418a
Convert Model.GetVocab() method to a Vocab property
tarekgh Feb 21, 2024
6f53de8
Some method's parameters and variable renaming
tarekgh Feb 22, 2024
62334c6
Remove Vocab and VocabSize from the abstraction
tarekgh Feb 22, 2024
d48b32d
Cleanup normalization support
tarekgh Feb 22, 2024
191ab03
Minor Bpe cleanup
tarekgh Feb 22, 2024
b9b0f58
Resolve rebase change
tarekgh Feb 23, 2024
1ad157f
Address the feedback
tarekgh Feb 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public string? UnknownToken


/// <summary>
/// Construct a new Bpe model object to use for sentence tokenization.
/// Construct a new Bpe model object to use for text encoding.
/// </summary>
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
Expand All @@ -85,7 +85,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
}

/// <summary>
/// Construct a new Bpe model object to use for sentence tokenization.
/// Construct a new Bpe model object to use for text encoding.
/// </summary>
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
Expand Down Expand Up @@ -171,39 +171,39 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
public static TokenizerDecoder Decoder { get; } = new BpeDecoder();

/// <summary>
/// Tokenize a sequence string to a list of tokens.
/// Encode a text string to a list of tokens.
/// </summary>
/// <param name="sequence">The sequence to tokenize.</param>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken = false)
/// <returns>The list of tokens generated from the text tokenization.</returns>
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false)
{
if (sequence.Length == 0)
if (text.Length == 0)
{
return EmptyTokensList;
}

return TokenizeWithCache(sequence);
return EncodeWithCache(text);
}

/// <summary>
/// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="sequence">The sequence to split.</param>
/// <param name="text">The text to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds) => TokenizeToIdsWithCache(sequence, accumulatedIds);
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);

/// <summary>
/// Get the number of tokens that the input sequence will be encoded to.
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="sequence">The text to tokenize.</param>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input sequence will be encoded to.</returns>
public override int CountTokens(string sequence, bool isSpecialToken) => TokenizeToIdsWithCache(sequence, null);
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);

/// <summary>
/// Map the token to tokenized Id.
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
Expand All @@ -219,7 +219,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
}

/// <summary>
/// Map the tokenized Id to the token.
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
Expand Down Expand Up @@ -264,7 +264,7 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(

internal static readonly int DefaultCacheCapacity = 10_000;

/// Reversed vocabulary, to rebuild sentences.
/// Reversed vocabulary, to rebuild the text.
internal SortedDictionary<int, string> VocabReverse { get; set; }

/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
Expand Down Expand Up @@ -415,22 +415,22 @@ internal Word MergeWord(string w)

internal List<Token> WordToTokens(ref Word word) => word.ToTokens(VocabReverse);

internal List<Token> TokenizeWithCache(string sequence)
internal List<Token> EncodeWithCache(string text)
{
Word word;
if (Cache is not null)
{
if (Cache.TryGet(sequence, out word))
if (Cache.TryGet(text, out word))
{
return WordToTokens(ref word);
}

word = MergeWord(sequence);
Cache.Set(sequence, word);
word = MergeWord(text);
Cache.Set(text, word);
}
else
{
word = MergeWord(sequence);
word = MergeWord(text);
}

return WordToTokens(ref word);
Expand All @@ -446,23 +446,23 @@ internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
return word.SymbolsCount;
}

internal int TokenizeToIdsWithCache(string sequence, IList<int>? accumulatedIds)
internal int EncodeToIdsWithCache(string text, IList<int>? accumulatedIds)
{
Word word;

if (Cache is not null)
{
if (Cache.TryGet(sequence, out Word hit))
if (Cache.TryGet(text, out Word hit))
{
return WordToIds(ref hit, accumulatedIds);
}

word = MergeWord(sequence);
Cache.Set(sequence, word);
word = MergeWord(text);
Cache.Set(text, word);
}
else
{
word = MergeWord(sequence);
word = MergeWord(text);
}

return WordToIds(ref word, accumulatedIds);
Expand Down
56 changes: 28 additions & 28 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public sealed class EnglishRoberta : Model
private readonly Cache<string, List<Token>> _cache;

/// <summary>
/// Construct tokenizer object to use with the English Robert model.
/// Construct tokenizer's model object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyPath">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergePath">The file path containing the tokens's pairs list.</param>
Expand Down Expand Up @@ -73,7 +73,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc
}

/// <summary>
/// Construct tokenizer object to use with the English Robert model.
/// Construct tokenizer's model object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
Expand Down Expand Up @@ -125,7 +125,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
public override int GetVocabSize() => _vocab.Count;

/// <summary>
/// Map the tokenized Id to the token.
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the string.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
Expand Down Expand Up @@ -157,25 +157,25 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
}

/// <summary>
/// Tokenize a sequence string to a list of tokens.
/// Encode a text string to a list of tokens.
/// </summary>
/// <param name="sequence">The sequence to tokenize.</param>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken = false)
/// <returns>The list of tokens generated from the text tokenization.</returns>
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false)
{
if (string.IsNullOrEmpty(sequence))
if (string.IsNullOrEmpty(text))
{
return Bpe.EmptyTokensList;
}

char[] token = ArrayPool<char>.Shared.Rent(sequence.Length);
int[] indexMapping = ArrayPool<int>.Shared.Rent(sequence.Length);
char[] token = ArrayPool<char>.Shared.Rent(text.Length);
int[] indexMapping = ArrayPool<int>.Shared.Rent(text.Length);

int newTokenIndex = 0;
for (int i = 0; i < sequence.Length; i++)
for (int i = 0; i < text.Length; i++)
{
if (_byteToUnicode.TryGetValue(sequence[i], out var value))
if (_byteToUnicode.TryGetValue(text[i], out var value))
{
token[newTokenIndex] = value;
indexMapping[newTokenIndex] = i;
Expand All @@ -190,44 +190,44 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
return Array.Empty<Token>();
}

if (_cache.TryGet(sequence, out List<Token>? hit))
if (_cache.TryGet(text, out List<Token>? hit))
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
return ModifyTokenListOffsets(hit, indexMapping);
}

List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
_cache.Set(sequence, result);
_cache.Set(text, result);
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
return result;
}

/// <summary>
/// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="sequence">The sequence to split.</param>
/// <param name="text">The text to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds) => TokenizeToIds(sequence, accumulatedIds);
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIds(text, accumulatedIds);

/// <summary>
/// Get the number of tokens that the input sequence will be encoded to.
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="sequence">The text to tokenize.</param>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input sequence will be encoded to.</returns>
public override int CountTokens(string sequence, bool isSpecialToken) => TokenizeToIds(sequence, null);
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIds(text, null);

private int TokenizeToIds(string sequence, IList<int>? accumulatedIds)
private int EncodeToIds(string text, IList<int>? accumulatedIds)
{
if (string.IsNullOrEmpty(sequence))
if (string.IsNullOrEmpty(text))
{
return 0;
}

if (_cache.TryGet(sequence, out List<Token>? hit))
if (_cache.TryGet(text, out List<Token>? hit))
{
if (accumulatedIds is not null)
{
Expand All @@ -240,8 +240,8 @@ private int TokenizeToIds(string sequence, IList<int>? accumulatedIds)
return hit.Count;
}

// If the cache doesn't have the sequence, then tokenize it and add it to the cache
IReadOnlyList<Token> tokens = Tokenize(sequence);
// If the cache doesn't have the text, then encode it and add it to the cache
IReadOnlyList<Token> tokens = Encode(text);
if (accumulatedIds is not null)
{
foreach (var t in tokens)
Expand All @@ -254,7 +254,7 @@ private int TokenizeToIds(string sequence, IList<int>? accumulatedIds)
}

/// <summary>
/// Map the token to tokenized Id.
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
Expand Down
36 changes: 18 additions & 18 deletions src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,64 +14,64 @@ namespace Microsoft.ML.Tokenizers
public abstract class Model
{
/// <summary>
/// Tokenize a split sequence string to a list of tokens.
/// Encode a split text string to a list of tokens.
/// </summary>
/// <param name="sequence">The text to tokenize.</param>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
public abstract IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken = false);
/// <returns>The list of tokens generated from the text tokenization.</returns>
public abstract IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false);

/// <summary>
/// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="sequence">The sequence to split.</param>
/// <param name="text">The text to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
/// <remarks>
/// This method does the default implementation that uses the Tokenize method to get the token's Ids.
/// This method does the default implementation that uses the Encode method to get the token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual void TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds)
public virtual void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds)
{
if (accumulatedIds is null)
{
throw new ArgumentNullException(nameof(accumulatedIds));
}

var tokens = Tokenize(sequence);
var tokens = Encode(text);
foreach (var token in tokens)
{
accumulatedIds.Add(token.Id);
}
}

/// <summary>
/// Get the number of tokens that the input sequence will be encoded to.
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="sequence">The text to tokenize.</param>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input sequence will be encoded to.</returns>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
/// <remarks>
/// This method does the default implementation that uses the TokenizeToIds method to get the number of token's Ids.
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual int CountTokens(string sequence, bool isSpecialToken)
public virtual int CountTokens(string text, bool isSpecialToken)
{
var ids = new List<int>();
TokenizeToIds(sequence, isSpecialToken, ids);
EncodeToIds(text, isSpecialToken, ids);
return ids.Count;
}

/// <summary>
/// Map the token to tokenized id with the option to skip the special tokens.
/// Map the token to encoded id with the option to skip the special tokens.
/// </summary>
/// <param name="token">The token to map to Id</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public abstract int? TokenToId(string token, bool skipSpecialTokens = false);

/// <summary>
/// Map the tokenized Id to the token.
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
Expand Down
Loading