diff --git a/src/Microsoft.ML.Tokenizers/TokenizerResult.cs b/src/Microsoft.ML.Tokenizers/EncodingResult.cs
similarity index 88%
rename from src/Microsoft.ML.Tokenizers/TokenizerResult.cs
rename to src/Microsoft.ML.Tokenizers/EncodingResult.cs
index 6b8e434878..9401cfa490 100644
--- a/src/Microsoft.ML.Tokenizers/TokenizerResult.cs
+++ b/src/Microsoft.ML.Tokenizers/EncodingResult.cs
@@ -11,16 +11,16 @@ namespace Microsoft.ML.Tokenizers
///
/// The Encoding represents the output of a Tokenizer.
///
- public sealed class TokenizerResult
+ public sealed class EncodingResult
{
///
- /// Create a new object of the TokenizerResult object.
+ /// Create a new object of the EncodingResult object.
///
/// The list of tokens to merge.
/// The list of tokens to merge.
/// The list of tokens to merge.
/// Indicate whether the offsets is mapped to the original string or the normalized string.
- public TokenizerResult(string originalString, string normalizedString, IEnumerable splits, bool offsetsMappedToOriginalString)
+ public EncodingResult(string originalString, string normalizedString, IEnumerable splits, bool offsetsMappedToOriginalString)
{
OriginalString = originalString;
NormalizedString = normalizedString;
@@ -47,7 +47,7 @@ public TokenizerResult(string originalString, string normalizedString, IEnumerab
private List? _tokens;
private List? _tokensWords;
private List? _ids;
- private List<(int Index, int End)>? _offsets;
+ private List<(int Index, int Length)>? _offsets;
internal void AddTokens(IReadOnlyList addedTokens)
{
@@ -121,10 +121,10 @@ public IReadOnlyList Tokens
}
///
- /// Gets The list of offsets. These offsets let’s you slice the input string, and thus retrieve
+ /// Gets The list of offsets. These offsets let's you slice the input string, and thus retrieve
/// the original part that led to producing the corresponding token.
///
- public IReadOnlyList<(int Index, int End)> Offsets
+ public IReadOnlyList<(int Index, int Length)> Offsets
{
get
{
@@ -138,7 +138,7 @@ public IReadOnlyList Tokens
return Array.Empty<(int, int)>();
}
- _offsets = new List<(int Index, int End)>(_tokens.Count);
+ _offsets = new List<(int Index, int Length)>(_tokens.Count);
foreach (var token in _tokens)
{
diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs
index 008dacb573..d799d45a39 100644
--- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs
@@ -30,7 +30,7 @@ public string? UnknownToken
return _unknownToken;
}
- set
+ private set
{
_unknownToken = value;
@@ -39,102 +39,134 @@ public string? UnknownToken
if (VocabReverse.TryGetValue(0, out string? v))
{
VocabReverse.Remove(0);
- if (Vocab.TryGetValue(v, out int id))
+ if (_vocab.TryGetValue(v, out int id))
{
- Vocab.Remove(v);
+ _vocab.Remove(v);
}
}
}
else
{
- Vocab[value] = 0;
+ _vocab[value] = 0;
VocabReverse[0] = value;
}
}
}
///
- /// An optional prefix to use on any sub-word that exist only behind another one
+ /// A prefix to be used for every subword that is not a beginning-of-word
///
- public string? ContinuingSubwordPrefix { get; set; }
+ public string? ContinuingSubwordPrefix { get; }
///
/// An optional suffix to characterize and end-of-word sub-word
///
- public string? EndOfWordSuffix { get; set; }
+ public string? EndOfWordSuffix { get; }
///
/// Gets or sets whether allowing multiple unknown tokens get fused
///
- public bool FuseUnknownTokens { get; set; }
+ public bool FuseUnknownTokens { get; }
+
///
- /// Construct a new Bpe model object with no tokenization vocabulary. This constructor is useful only in the training scenario.
+ /// Construct a new Bpe model object to use for text encoding.
///
- public Bpe()
+ /// The JSON file path containing the dictionary of string keys and their ids.
+ /// The file path containing the tokens's pairs list.
+ /// The unknown token to be used by the model.
+ /// The prefix to attach to sub-word units that don’t represent a beginning of word.
+ /// The suffix to attach to sub-word units that represent an end of word.
+ /// Indicate whether allowing multiple unknown tokens get fused.
+ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
+ this(vocabFile is null ? throw new ArgumentNullException(nameof(vocabFile)) : File.Open(vocabFile, FileMode.Open, FileAccess.Read),
+ mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true)
{
- Vocab = new();
- VocabReverse = new();
- Merges = new();
-
- UnknownToken = "[Unk]";
}
///
- /// Construct a new Bpe model object to use for sentence tokenization and tokenizer training.
+ /// Construct a new Bpe model object to use for text encoding.
///
- /// The JSON file path containing the dictionary of string keys and their ids.
- /// The file path containing the tokens's pairs list.
+ /// The JSON stream containing the dictionary of string keys and their ids.
+ /// The stream containing the tokens's pairs list.
/// The unknown token to be used by the model.
/// The prefix to attach to sub-word units that don’t represent a beginning of word.
/// The suffix to attach to sub-word units that represent an end of word.
- public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null)
+ /// Indicate whether allowing multiple unknown tokens get fused.
+ public Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
+ this(vocabStream, mergesStream, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false)
{
- ContinuingSubwordPrefix = continuingSubwordPrefix;
- EndOfWordSuffix = endOfWordSuffix;
-
- (Dictionary? vocab1, Vec<(string, string)> merges) = ReadFile(vocabFile, mergesFile);
- Vocab = vocab1 ?? new Dictionary();
- Cache = new Cache();
-
- VocabReverse = new();
-
- foreach (KeyValuePair kvp in Vocab)
- {
- VocabReverse.Add(kvp.Value, kvp.Key);
- }
+ }
- if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
+ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams)
+ {
+ try
{
- unknownToken = unkToken;
- }
+ if (vocabStream is null)
+ {
+ throw new ArgumentNullException(nameof(vocabStream));
+ }
- UnknownToken = unknownToken;
+ FuseUnknownTokens = fuseUnknownTokens;
+ ContinuingSubwordPrefix = continuingSubwordPrefix;
+ EndOfWordSuffix = endOfWordSuffix;
- int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
+ (Dictionary? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
+ _vocab = vocab1 ?? new Dictionary();
+ Cache = new Cache();
- Merges = new();
- for (int i = 0; i < merges.Count; i++)
- {
- (string a, string b) mergeValues = merges[i];
+ VocabReverse = new();
- if (!Vocab.TryGetValue(mergeValues.a, out int aId))
+ foreach (KeyValuePair kvp in Vocab)
{
- throw new InvalidOperationException($"Trying to merge a token {mergeValues.a} which not exist in the vocabulary.");
+ VocabReverse.Add(kvp.Value, kvp.Key);
}
- if (!Vocab.TryGetValue(mergeValues.b, out int bId))
+ if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
{
- throw new InvalidOperationException($"Trying to merge a token {mergeValues.b} which not exist in the vocabulary.");
+ unknownToken = unkToken;
}
- string newToken = $"{mergeValues.a}{mergeValues.b.Substring(prefixLen)}";
- if (!Vocab.TryGetValue(newToken, out int newId))
+ UnknownToken = unknownToken;
+
+ int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
+
+ Merges = new();
+ for (int i = 0; i < merges.Count; i++)
{
- throw new InvalidOperationException($"Trying to merge a token {newToken} which not exist in the vocabulary.");
- }
+ (string a, string b) mergeValues = merges[i];
+
+ if (!_vocab.TryGetValue(mergeValues.a, out int aId))
+ {
+ throw new InvalidOperationException($"Trying to merge a token '{mergeValues.a}' which not exist in the vocabulary.");
+ }
+
+ if (!_vocab.TryGetValue(mergeValues.b, out int bId))
+ {
+ throw new InvalidOperationException($"Trying to merge a token '{mergeValues.b}' which not exist in the vocabulary.");
+ }
+
+ if (mergeValues.b.Length <= prefixLen)
+ {
+ throw new InvalidOperationException($"The merge value '{mergeValues.b}' is too short to be merged with a prefix of length {prefixLen}. This implies that the merge file is either damaged or missing the prefix in its entries.");
+ }
- Merges.Add(new Pair(aId, bId), (i, newId));
+ string newToken = $"{mergeValues.a}{mergeValues.b.Substring(prefixLen)}";
+ if (!_vocab.TryGetValue(newToken, out int newId))
+ {
+ throw new InvalidOperationException($"Trying to merge a token '{newToken}' which not exist in the vocabulary.");
+ }
+
+ Merges.Add(new Pair(aId, bId), (i, newId));
+ }
+ }
+ finally
+ {
+ if (disposeStreams)
+ {
+ vocabStream.Dispose();
+ mergesStream?.Dispose();
+ }
}
}
@@ -144,45 +176,46 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
public static TokenizerDecoder Decoder { get; } = new BpeDecoder();
///
- /// Tokenize a sequence string to a list of tokens.
+ /// Encode a text string to a list of tokens.
///
- /// The sequence to tokenize.
+ /// The text to encode.
/// Indicate if the token is a special token.
- /// The list of tokens generated from the sequence tokenization.
- public override IReadOnlyList Tokenize(string sequence, bool isSpecialToken = false)
+ /// The list of tokens generated from the text tokenization.
+ public override IReadOnlyList Encode(string text, bool isSpecialToken = false)
{
- if (sequence.Length == 0)
+ if (text.Length == 0)
{
return EmptyTokensList;
}
- return TokenizeWithCache(sequence);
+ return EncodeWithCache(text);
}
///
- /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
+ /// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
///
- /// The sequence to split.
+ /// The text to split.
/// Indicate if the token is a special token.
- /// The list of accumulated tokenized Ids.
- public override void TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) => TokenizeToIdsWithCache(sequence, accumulatedIds);
+ /// The list of accumulated encoded Ids.
+ public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
///
- /// Get the number of tokens that the input sequence will be encoded to.
+ /// Get the number of tokens that the input text will be encoded to.
///
- /// The text to tokenize.
+ /// The text to encode.
/// Indicate if the token is special token.
- /// The number of tokens that the input sequence will be encoded to.
- public override int CountTokens(string sequence, bool isSpecialToken) => TokenizeToIdsWithCache(sequence, null);
+ /// The number of tokens that the input text will be encoded to.
+ public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
///
- /// Map the token to tokenized Id.
+ /// Map the token to encoded Id.
///
/// The token to map to the Id.
+ /// Indicate if want to consider the special tokens during the encoding.
/// The mapped Id of the token.
- public override int? TokenToId(string token)
+ public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
{
- if (Vocab.TryGetValue(token, out int value))
+ if (_vocab.TryGetValue(token, out int value))
{
return value;
}
@@ -191,12 +224,12 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok
}
///
- /// Map the tokenized Id to the token.
+ /// Map the encoded Id to the token.
///
/// The Id to map to the token.
- /// Indicate if want to skip the special tokens during the decoding.
+ /// Indicate if want to consider the special tokens during the decoding.
/// The mapped token of the Id.
- public override string? IdToToken(int id, bool skipSpecialTokens = false)
+ public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
{
if (VocabReverse.TryGetValue(id, out string? value))
{
@@ -209,101 +242,62 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok
///
/// Gets the dictionary mapping tokens to Ids.
///
- public override IReadOnlyDictionary GetVocab() => Vocab;
-
- ///
- /// Gets the dictionary size that map tokens to Ids.
- ///
- public override int GetVocabSize() => Vocab.Count;
-
- ///
- /// Gets a trainer object to use in training the model and generate the vocabulary and merges data.
- ///
- public override Trainer? GetTrainer() => new BpeTrainer();
-
- ///
- /// Save the model data into the vocabulary and merges files.
- ///
- /// The file system path to store the generated files at.
- /// Optional prefix for the generated file names.
- /// The list of all saved files.
- public override string[] Save(string path, string? prefix = null)
- {
- // Write vocab.json
- string vocabFileNname = prefix is null ? "vocab.json" : $"{prefix}-vocab.json";
- string vocabPath = Path.Combine(path, vocabFileNname);
- string serialized = JsonSerializer.Serialize(VocabReverse, new JsonSerializerOptions { Converters = { new DictReversingConverter() } });
- File.WriteAllText(vocabPath, serialized, System.Text.Encoding.UTF8);
-
- // Write merges.txt
- string mergeFileName = prefix is null ? "merges.txt" : $"{prefix}-merges.txt";
- string mergePath = Path.Combine(path, mergeFileName);
- (Pair pair, int rank)[] pairsArray = new (Pair, int)[Merges.Count];
- int i = 0;
- foreach (var p in Merges)
- {
- pairsArray[i++] = (p.Key, p.Value.Item1 /* rank */);
- }
- Array.Sort(pairsArray, (x, y) => x.rank.CompareTo(y.rank));
- using StreamWriter file = new(mergePath, append: false, System.Text.Encoding.UTF8);
- file.WriteLine("#version: 0.2 - Trained by `huggingface/tokenizers`");
- foreach (var p in pairsArray)
- {
- file.WriteLine($"{VocabReverse[p.pair.First]} {VocabReverse[p.pair.Second]}");
- }
-
- return new string[] { vocabPath, mergePath };
- }
+ public IReadOnlyDictionary Vocab => _vocab;
/// Read the given files to extract the vocab and merges
- internal static (Dictionary?, Vec<(string, string)>) ReadFile(string vocab, string? merges)
+ internal static (Dictionary?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
{
- Dictionary? dic;
- using (Stream stream = File.OpenRead(vocab))
- {
- dic = JsonSerializer.Deserialize>(stream) as Dictionary;
- }
+ Dictionary? dic = JsonSerializer.Deserialize>(vocab) as Dictionary;
return (dic, ConvertMergesToHashmap(merges));
}
/// The vocabulary assigns a number to each token.
- internal Dictionary Vocab { get; set; }
+ private readonly Dictionary _vocab;
/// Contains the mapping between Pairs and their (rank, newId).
- internal Dictionary, (int, int)> Merges { get; set; }
+ internal Dictionary, (int, int)> Merges { get; }
/// Contains the cache for optimizing the encoding step.
- internal Cache? Cache { get; set; }
+ internal Cache? Cache { get; }
internal static readonly int DefaultCacheCapacity = 10_000;
- /// Reversed vocabulary, to rebuild sentences.
- internal SortedDictionary VocabReverse { get; set; }
+ /// Reversed vocabulary, to rebuild the text.
+ internal SortedDictionary VocabReverse { get; }
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
/// perform no merges, so the result will just be characters.
- internal float? Dropout { get; set; }
+ internal float? Dropout { get; }
/// Converts the merges strings (for example from `merges.txt` file) with the format
/// "{pair_a} {pair_b}" into the format expected by the BPE struct
- internal static Vec<(string, string)> ConvertMergesToHashmap(string? mergesFile)
+ internal static Vec<(string, string)> ConvertMergesToHashmap(Stream? mergesStream)
{
- if (mergesFile is null)
+ if (mergesStream is null)
{
return new Vec<(string, string)>();
}
+ using StreamReader reader = new StreamReader(mergesStream);
+
Vec<(string, string)> merges = new(1000);
int lineNumber = 0;
- foreach (string line in System.IO.File.ReadLines(mergesFile))
+ while (true)
{
+ string? line = reader.ReadLine();
+ if (line is null)
+ {
+ break;
+ }
+
lineNumber++;
if (line.StartsWith("#version", StringComparison.Ordinal) || line.Length == 0)
{
continue;
}
+
int index = line.IndexOf(' ');
if (index < 0 || index == line.Length - 1 || line.IndexOf(' ', index + 1) >= 0)
{
@@ -367,7 +361,7 @@ internal Word MergeWord(string w)
s = $"{s}{EndOfWordSuffix}";
}
- if (Vocab.TryGetValue(s, out int id))
+ if (_vocab.TryGetValue(s, out int id))
{
if (unk.HasValue)
{
@@ -389,7 +383,7 @@ internal Word MergeWord(string w)
{
// Do not fuse unk, add the previous one
word.Add(unk.Value.Id, unk.Value.Len);
- if (!Vocab.TryGetValue(UnknownToken, out int value))
+ if (!_vocab.TryGetValue(UnknownToken, out int value))
{
throw new InvalidOperationException($"Unknown Token Out Of Vocabulary.");
}
@@ -398,7 +392,7 @@ internal Word MergeWord(string w)
}
else
{
- if (!Vocab.TryGetValue(UnknownToken, out int value))
+ if (!_vocab.TryGetValue(UnknownToken, out int value))
{
throw new InvalidOperationException($"Unknown Token Out Of Vocabulary.");
}
@@ -420,22 +414,22 @@ internal Word MergeWord(string w)
internal List WordToTokens(ref Word word) => word.ToTokens(VocabReverse);
- internal List TokenizeWithCache(string sequence)
+ internal List EncodeWithCache(string text)
{
Word word;
if (Cache is not null)
{
- if (Cache.TryGet(sequence, out word))
+ if (Cache.TryGet(text, out word))
{
return WordToTokens(ref word);
}
- word = MergeWord(sequence);
- Cache.Set(sequence, word);
+ word = MergeWord(text);
+ Cache.Set(text, word);
}
else
{
- word = MergeWord(sequence);
+ word = MergeWord(text);
}
return WordToTokens(ref word);
@@ -451,23 +445,23 @@ internal int WordToIds(ref Word word, IList? accumulatedIds)
return word.SymbolsCount;
}
- internal int TokenizeToIdsWithCache(string sequence, IList? accumulatedIds)
+ internal int EncodeToIdsWithCache(string text, IList? accumulatedIds)
{
Word word;
if (Cache is not null)
{
- if (Cache.TryGet(sequence, out Word hit))
+ if (Cache.TryGet(text, out Word hit))
{
return WordToIds(ref hit, accumulatedIds);
}
- word = MergeWord(sequence);
- Cache.Set(sequence, word);
+ word = MergeWord(text);
+ Cache.Set(text, word);
}
else
{
- word = MergeWord(sequence);
+ word = MergeWord(text);
}
return WordToIds(ref word, accumulatedIds);
diff --git a/src/Microsoft.ML.Tokenizers/Model/BpeTrainer.cs b/src/Microsoft.ML.Tokenizers/Model/BpeTrainer.cs
deleted file mode 100644
index fa65a8d540..0000000000
--- a/src/Microsoft.ML.Tokenizers/Model/BpeTrainer.cs
+++ /dev/null
@@ -1,534 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using System.Collections.Generic;
-using System.Runtime.CompilerServices;
-using System.Text;
-
-namespace Microsoft.ML.Tokenizers
-{
- ///
- /// The Bpe trainer responsible to train the Bpe model.
- ///
- public sealed class BpeTrainer : Trainer
- {
- ///
- /// Gets the size of the final vocabulary, including all tokens and alphabet.
- ///
- public int VocabSize { get; }
-
- ///
- /// Gets the minimum frequency a pair should have in order to be merged.
- ///
- public int MinFrequency { get; }
-
- ///
- /// Gets the list of special tokens the model should know of.
- ///
- public IReadOnlyList? SpecialTokens { get; }
-
- ///
- /// Gets the maximum different characters to keep in the alphabet.
- ///
- public int? LimitAlphabet { get; }
-
- ///
- /// Gets the list of characters to include in the initial alphabet, even if not seen in the training dataset.
- /// If the strings contain more than one character, only the first one is kept.
- ///
- public HashSet? InitialAlphabet { get; }
-
- ///
- /// Gets the prefix to be used for every sub-word that is not a beginning-of-word.
- ///
- public string? ContinuingSubwordPrefix { get; }
-
- ///
- /// Gets the suffix to be used for every sub-word that is a end-of-word.
- ///
- public string? EndOfWordSuffix { get; }
-
- private Dictionary Words { get; set; }
-
- ///
- /// Construct a new BpeTrainer object using the default values.
- ///
- public BpeTrainer() : this(null)
- {
- }
-
- ///
- /// Construct a new BpeTrainer object.
- ///
- /// The list of special tokens the model should know of.
- /// The minimum frequency a pair should have in order to be merged.
- /// the size of the final vocabulary, including all tokens and alphabet.
- /// Callback for the training progress updates.
- /// The list of characters to include in the initial alphabet.
- /// The JSON file path containing the dictionary of string keys and their ids
- /// the prefix to be used for every sub-word that is not a beginning-of-word.
- /// the suffix to be used for every sub-word that is a end-of-word.
- public BpeTrainer(
- IEnumerable? specialTokens,
- int minFrequency = 0,
- int vocabSize = 30000,
- ReportProgress? progress = null,
- int? limitAlphabet = null,
- HashSet? initialAlphabet = null,
- string? continuingSubwordPrefix = null,
- string? endOfWordSuffix = null)
- {
- MinFrequency = minFrequency;
- VocabSize = vocabSize;
- Progress = progress;
-
- if (specialTokens is not null)
- {
- SpecialTokens = new List(specialTokens);
- }
-
- LimitAlphabet = limitAlphabet;
- InitialAlphabet = initialAlphabet;
- ContinuingSubwordPrefix = continuingSubwordPrefix;
- EndOfWordSuffix = endOfWordSuffix;
- Words = new();
- }
-
- /// Add the provided special tokens to the initial vocabulary
- private void AddSpecialTokens(Dictionary w2Id, ref Vec id2W)
- {
- if (SpecialTokens is not null)
- {
- foreach (var token in SpecialTokens)
- {
- if (!w2Id.ContainsKey(token.Content))
- {
- id2W.Push(token.Content);
- w2Id.Add(token.Content, id2W.Count - 1);
- }
- }
- }
- }
-
- private void ComputeAlphabet(Dictionary wc, Dictionary w2Id, ref Vec id2W)
- {
- // Compute the alphabet from seen words
- Dictionary alphabet = new();
- foreach (KeyValuePair kvp in wc)
- {
- foreach (char c in kvp.Key)
- {
- if (alphabet.ContainsKey(c))
- {
- alphabet[c] = alphabet[c] + kvp.Value;
- }
- else
- {
- alphabet[c] = kvp.Value;
- }
- }
- }
-
- // Also include anything from the provided initial alphabet
- if (InitialAlphabet is not null)
- {
- foreach (char c in InitialAlphabet)
- {
- alphabet[c] = int.MaxValue;
- }
- }
-
- List> kept = new List>(alphabet.Count);
- foreach (KeyValuePair kvp in alphabet)
- {
- kept.Add(kvp);
- }
-
- // Compute the number of chars to remove from the alphabet
- // If `limit_alphabet < initial_alphabet.len()`, some of these initial characters
- // will be removed
- int toRemove = LimitAlphabet.HasValue && alphabet.Count > LimitAlphabet.Value ? (int)(alphabet.Count - LimitAlphabet.Value) : 0;
-
- // Remove the unwanted chars
- if (toRemove > 0)
- {
- kept.Sort((x, y) => (int)x.Value - (int)y.Value);
- kept.RemoveRange(0, toRemove);
- }
-
- // Keep the initial alphabet (sorted for determinism)
- kept.Sort((x, y) => (int)x.Key - (int)y.Key);
-
- foreach (KeyValuePair kvp in kept)
- {
- string s = kvp.Key.ToString();
- if (!w2Id.ContainsKey(s))
- {
- id2W.Push(s);
- w2Id[s] = (int)id2W.Count - 1;
- }
- }
- }
-
- private readonly Dictionary _charToString = new Dictionary();
-
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- internal string CharToString(char c)
- {
- if (_charToString.TryGetValue(c, out string? v))
- {
- return v;
- }
-
- string s = c.ToString();
- _charToString[c] = s;
- return s;
- }
-
- /// Tokenize words and add sub-words to the vocabulary when relevant
- private (Vec, Vec) TokenizeWords(Dictionary wc, Dictionary w2Id, ref Vec id2w)
- {
- Vec words = new Vec(wc.Count);
- Vec counts = new Vec(wc.Count);
-
- foreach (KeyValuePair kvp in wc)
- {
- Word currentWord = new Word();
- counts.Push(kvp.Value);
-
- for (int i = 0; i < kvp.Key.Length; i++)
- {
- char c = kvp.Key[i]; ;
- string s = CharToString(c);
- if (w2Id.ContainsKey(s))
- {
- // Found the initial char in the authorized alphabet
- // Add the `continuing_subword_prefix` if relevant
- if (i != 0 && ContinuingSubwordPrefix is not null)
- {
- s = $"{ContinuingSubwordPrefix}{s}";
- }
-
- // Add the `end_of_word_suffix` if relevant
- if (i == kvp.Key.Length - 1 && EndOfWordSuffix is not null)
- {
- s = $"{s}{EndOfWordSuffix}";
- }
-
- // Insert the new formed string if necessary
- if (!w2Id.ContainsKey(s))
- {
- id2w.Push(s);
- w2Id[s] = (int)(id2w.Count - 1);
- }
-
- currentWord.Add(w2Id[s], 1); // We do not care about the len here
- }
- }
-
- words.Push(currentWord);
- Progress?.Invoke(new Progress(ProgressState.Increment, null, 1));
- }
-
- return (words, counts);
- }
-
- private (Dictionary, int>, Dictionary, HashSet>) CountPairs(ref Vec words, ref Vec counts)
- {
- if (words.Count <= 0)
- {
- return (new(), new());
- }
-
- Dictionary, int> pairCounts = new Dictionary, int>();
- Dictionary, HashSet> whereToUpdate = new();
-
- for (int i = 0; i < words.Count; i++)
- {
- ref Word word = ref words[i];
-
- int j = 0;
- Vec chars = word.GetChars();
- while (j < chars.Count - 1)
- {
- Pair curPair = new Pair(chars[j], chars[j + 1]);
-
- // Initialize pair_counts and where_to_update for this pair if we just saw it
- if (!pairCounts.ContainsKey(curPair))
- {
- pairCounts[curPair] = 0;
- }
-
- // Then update counts
- int count = counts[i];
-
- if (!whereToUpdate.TryGetValue(curPair, out HashSet? h))
- {
- h = new HashSet();
- whereToUpdate[curPair] = h;
- }
-
- h.Add(i);
-
- pairCounts[curPair] = pairCounts[curPair] + (int)count;
-
- j++;
- }
-
- Progress?.Invoke(new Progress(ProgressState.Increment, null, 1));
- }
-
- return (pairCounts, whereToUpdate);
- }
-
- private IReadOnlyList? DoTrain(Dictionary wordCounts, Bpe model)
- {
- Dictionary wordToId = new(VocabSize);
- Vec idToWord = new(VocabSize);
-
- //
- // 1. Add all special tokens to the vocabulary
- //
- AddSpecialTokens(wordToId, ref idToWord);
-
- //
- // 2. Compute the initial alphabet
- //
- ComputeAlphabet(wordCounts, wordToId, ref idToWord);
-
- //
- // 3. Tokenize words
- //
- Progress?.Invoke(new Progress(ProgressState.Start, "Tokenize words", wordCounts.Count));
- (Vec words, Vec counts) = TokenizeWords(wordCounts, wordToId, ref idToWord);
- Progress?.Invoke(new Progress(ProgressState.End, null, wordCounts.Count));
-
- //
- // 4. Count pairs in words
- //
- Progress?.Invoke(new Progress(ProgressState.Start, "Count pairs", wordCounts.Count));
- (Dictionary, int> pairCounts, Dictionary, HashSet> whereToUpdate) = CountPairs(ref words, ref counts);
-
- // Insert them in the queue
- PriorityQueue queue = new(pairCounts.Count);
-
- foreach (KeyValuePair, HashSet> kvp in whereToUpdate)
- {
- int count = pairCounts[kvp.Key];
- if (count > 0)
- {
- queue.Enqueue(new BpeTrainerMerge(kvp.Key, count, kvp.Value));
- }
- }
-
- whereToUpdate.Clear();
- Progress?.Invoke(new Progress(ProgressState.End, null, words.Count));
-
- //
- // 5. Do merges
- //
- Progress?.Invoke(new Progress(ProgressState.End, "Compute merges", VocabSize));
- Vec<(Pair, int)> merges = new();
-
- while (true)
- {
- // Stop as soon as we have a big enough vocabulary
- if (wordToId.Count >= VocabSize || queue.Count == 0)
- {
- break;
- }
-
- BpeTrainerMerge top = queue.Dequeue();
-
- if (top.Count != pairCounts[top.Pair])
- {
- top.Count = pairCounts[top.Pair];
- queue.Enqueue(top);
- continue;
- }
-
- if (top.Count < 1 || MinFrequency > top.Count)
- {
- break;
- }
-
- string partA = idToWord[(int)top.Pair.First];
- string partB = idToWord[(int)top.Pair.Second];
-
- // Build new token
- if (ContinuingSubwordPrefix is not null)
- {
- if (partB.StartsWith(ContinuingSubwordPrefix, StringComparison.Ordinal))
- {
- partB = partB.Substring(ContinuingSubwordPrefix.Length);
- }
- }
-
- string newToken = $"{partA}{partB}";
-
- // Insert new token if it does not already exist
- if (!wordToId.TryGetValue(newToken, out int newTokenId))
- {
- newTokenId = (int)idToWord.Count;
- idToWord.Push(newToken);
- wordToId[newToken] = newTokenId;
- }
-
- merges.Push((top.Pair, newTokenId));
-
- Vec<((Pair, int), int)> changes = new();
-
- // Merge the new pair in every words
- foreach (int i in top.Pos)
- {
- ref Word w = ref words[(int)i];
-
- Vec<(Pair, int)> m = w.Merge(top.Pair.First, top.Pair.Second, newTokenId);
-
- for (int j = 0; j < m.Count; j++)
- {
- changes.Push((m[j], i));
- }
- }
-
- // Introduce new formed pairs
- for (int j = 0; j < changes.Count; j++)
- {
- ((Pair p, int change), int iw) = changes[j];
- int count = (int)(change * counts[(int)iw]);
-
- pairCounts[p] = pairCounts.TryGetValue(p, out int c) ? c + count : count;
-
- if (change > 0)
- {
- if (!whereToUpdate.TryGetValue(p, out HashSet? h))
- {
- h = new();
- whereToUpdate[p] = h;
- }
- h.Add(iw);
- }
- }
-
- foreach (KeyValuePair, HashSet> kvp in whereToUpdate)
- {
- int count = pairCounts[kvp.Key];
- if (count > 0)
- {
- queue.Enqueue(new BpeTrainerMerge(kvp.Key, count, kvp.Value));
- }
- }
- whereToUpdate.Clear();
-
- Progress?.Invoke(new Progress(ProgressState.Increment, null, 1));
- }
-
- Progress?.Invoke(new Progress(ProgressState.End, null, merges.Count));
-
- // Transfer new vocab & options to model
- model.Vocab = wordToId;
-
- if (SpecialTokens is not null)
- {
- for (int i = 0; i < SpecialTokens.Count; i++)
- {
- model.Vocab[SpecialTokens[(int)i].Content] = i;
- }
-
- if (SpecialTokens.Count > 0)
- {
- model.UnknownToken = SpecialTokens[0].Content;
- }
- }
-
- model.VocabReverse = new();
-
- foreach (KeyValuePair kvp in model.Vocab)
- {
- model.VocabReverse[kvp.Value] = kvp.Key;
- }
-
- model.Merges = new();
-
- for (int i = 0; i < merges.Count; i++)
- {
- (Pair p, int v) = merges[i];
- model.Merges[p] = ((int)i, v);
- }
-
- model.ContinuingSubwordPrefix = ContinuingSubwordPrefix;
- model.EndOfWordSuffix = EndOfWordSuffix;
-
- return SpecialTokens;
- }
-
- ///
- /// Process the input sequences and feed the result to the model.
- ///
- /// The list of sequences to feed the trainer.
- /// Optional process callback for reporting the training progress update.
- public override void Feed(IEnumerable sequences, Func> process)
- {
- foreach (string s in sequences)
- {
- foreach (string word in process(s))
- {
- Words[word] = Words.TryGetValue(word, out int value) ? value + 1 : 1;
- }
- }
- }
-
- ///
- /// Perform the actual training and update the input model with the new vocabularies and merges data.
- ///
- /// The model to train. This has to be BpeModel to work with BpeTrainer.
- /// The list of the added tokens.
- public override IReadOnlyList? Train(Model model)
- {
- if (model is Bpe bpeModel)
- {
- return DoTrain(Words, bpeModel);
- }
-
- throw new Exception($"BpeTrainer work only with Bpe Model.");
- }
- }
-
- internal struct BpeTrainerMerge : IEquatable, IComparable
- {
- public BpeTrainerMerge(Pair pair, int count, HashSet pos)
- {
- Pair = pair;
- Count = count;
- Pos = pos;
- }
-
- public Pair Pair { get; set; }
- public int Count { get; set; }
- public HashSet Pos { get; set; }
-
- public int CompareTo(BpeTrainerMerge other)
- {
- if (Count != other.Count)
- {
- // return Count.CompareTo(other.Count);
- return other.Count.CompareTo(Count);
- }
-
- return Pair.CompareTo(other.Pair);
- }
-
- public override int GetHashCode()
- {
- int hashcode = 23;
- hashcode = (hashcode * 37) + Count.GetHashCode();
- hashcode = (hashcode * 37) + Pair.First.GetHashCode();
- hashcode = (hashcode * 37) + Pair.Second.GetHashCode();
- return hashcode;
- }
-
- public bool Equals(BpeTrainerMerge other) => Count == other.Count && Pair.Equals(other.Pair);
- }
-}
diff --git a/src/Microsoft.ML.Tokenizers/Model/Cache.cs b/src/Microsoft.ML.Tokenizers/Model/Cache.cs
index 1fcfa849ec..b10d211ea6 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Cache.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Cache.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using System.Threading;
@@ -19,7 +20,7 @@ internal Cache(int capacity)
Map = new Dictionary(Capacity);
}
- private readonly ReaderWriterLockSlim _cacheLock = new ReaderWriterLockSlim();
+ private readonly object _lock = new();
internal Dictionary Map { get; set; }
@@ -29,19 +30,16 @@ internal Cache(int capacity)
internal void Clear()
{
- _cacheLock.EnterWriteLock();
- try
+ lock (_lock)
{
Map.Clear();
}
- finally { _cacheLock.ExitWriteLock(); }
}
internal List GetValues(IEnumerable keys)
{
List values = new();
- _cacheLock.EnterReadLock();
- try
+ lock (_lock)
{
foreach (TKey key in keys)
{
@@ -51,25 +49,21 @@ internal List GetValues(IEnumerable keys)
}
}
}
- finally { _cacheLock.ExitReadLock(); }
return values;
}
internal bool TryGet(TKey key, out TValue value)
{
- _cacheLock.EnterReadLock();
- try
+ lock (_lock)
{
return Map.TryGetValue(key, out value!);
}
- finally { _cacheLock.ExitReadLock(); }
}
internal void SetValues(IEnumerable<(TKey, TValue)> entries)
{
- _cacheLock.EnterWriteLock();
- try
+ lock (_lock)
{
foreach ((TKey, TValue) entry in entries)
{
@@ -80,20 +74,43 @@ internal void SetValues(IEnumerable<(TKey, TValue)> entries)
Map[entry.Item1] = entry.Item2;
}
}
- finally { _cacheLock.ExitWriteLock(); }
}
internal void Set(TKey k, TValue v)
{
- _cacheLock.EnterWriteLock();
- try
+ lock (_lock)
{
if (Capacity > Map.Count)
{
Map[k] = v;
}
}
- finally { _cacheLock.ExitWriteLock(); }
+ }
+
+ internal KeyValuePair[] ToArray()
+ {
+ lock (_lock)
+ {
+ return Map.ToArray();
+ }
+ }
+
+ internal TValue GetOrAdd(TKey key, TValue value)
+ {
+ lock (_lock)
+ {
+ if (Map.TryGetValue(key, out TValue? v))
+ {
+ return v;
+ }
+
+ if (Capacity > Map.Count)
+ {
+ Map[key] = value;
+ }
+
+ return value;
+ }
}
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
index ad98ed917c..ea9fa884a8 100644
--- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
@@ -20,19 +20,25 @@ public sealed class EnglishRoberta : Model
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
private readonly IReadOnlyDictionary _vocab;
private readonly SortedDictionary _vocabReverse;
- private readonly Dictionary<(string, string), int> _mergeRanks;
+ private readonly Cache<(string, string), int> _mergeRanks;
private readonly IReadOnlyDictionary _byteToUnicode;
private readonly IReadOnlyDictionary _unicodeToByte;
private readonly string[] _charToString;
private readonly Cache> _cache;
///
- /// Construct tokenizer object to use with the English Robert model.
+ /// Indicate if want to filter the unsupported characters during the decoding.
+ ///
+ public bool FilterUnsupportedChars { get; }
+
+ ///
+ /// Construct tokenizer's model object to use with the English Robert model.
///
/// The JSON file path containing the dictionary of string keys and their ids.
/// The file path containing the tokens's pairs list.
/// Remap the original GPT-2 model Ids to high occurrence ranks and values.
- public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath)
+ /// Indicate if want to filter the unsupported characters during the decoding.
+ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, bool filterUnsupportedChars = true)
{
if (vocabularyPath is null)
{
@@ -49,6 +55,8 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc
throw new ArgumentNullException(nameof(highestOccurrenceMappingPath));
}
+ FilterUnsupportedChars = filterUnsupportedChars;
+
using Stream vocabularyStream = File.OpenRead(vocabularyPath);
using Stream mergeStream = File.OpenRead(mergePath);
using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath);
@@ -73,12 +81,13 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc
}
///
- /// Construct tokenizer object to use with the English Robert model.
+ /// Construct tokenizer's model object to use with the English Robert model.
///
/// The stream of a JSON file containing the dictionary of string keys and their ids.
/// The stream of a file containing the tokens's pairs list.
/// Remap the original GPT-2 model Ids to high occurrence ranks and values.
- public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream)
+ /// Indicate if want to filter the unsupported characters during the decoding.
+ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, bool filterUnsupportedChars = true)
{
if (vocabularyStream is null)
{
@@ -95,6 +104,8 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
}
+ FilterUnsupportedChars = filterUnsupportedChars;
+
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
@@ -110,108 +121,76 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
_cache = new Cache>();
}
- //
- // Public Model interfaces implementation
- //
-
///
/// Gets the dictionary mapping tokens to Ids.
///
- public override IReadOnlyDictionary GetVocab() => _vocab;
-
- ///
- /// Gets the dictionary size that map tokens to Ids.
- ///
- public override int GetVocabSize() => _vocab.Count;
+ public IReadOnlyDictionary Vocab => _vocab;
- ///
- /// Map the tokenized Id to the token.
- ///
- /// The Id to map to the token.
- /// Indicate if want to skip the special tokens during the decoding.
- /// The mapped token of the Id.
- public override string? IdToToken(int id, bool skipSpecialTokens = false) =>
- skipSpecialTokens && id < 0 ? null : _vocabReverse.TryGetValue(id, out var value) ? value : null;
+ //
+ // Public Model interfaces implementation
+ //
///
- /// Map the tokenized Id to the original string while filtering out unsupported characters.
+ /// Map the encoded Id to the token.
///
/// The Id to map to the string.
- /// Indicate if want to skip the special tokens during the decoding.
+ /// Indicate if want to consider the special tokens during the decoding.
/// The mapped token of the Id.
- public string? IdToFilteredToken(int id, bool skipSpecialTokens = false)
+ public override string? MapIdToToken(int id, bool considerSpecialTokens = true)
{
- if (skipSpecialTokens && id < 0)
- return null;
- if (_vocabReverse.TryGetValue(id, out var value))
+ if (!considerSpecialTokens && id < 0)
{
- var textChars = string.Join("", value)
- .Where(c => _unicodeToByte.ContainsKey(c))
- .Select(c => _unicodeToByte[c]);
- var text = new string(textChars.ToArray());
- return text;
+ return null;
}
- return null;
- }
-
- ///
- /// Save the model data into the vocabulary, merges, and occurrence mapping files.
- ///
- /// The file system path to store the generated files at.
- /// Optional prefix for the generated file names.
- /// The list of all saved files.
- public override string[] Save(string path, string? prefix = null)
- {
- // Write vocab.json
- string vocabFileName = prefix is null ? "vocab.json" : $"{prefix}-vocab.json";
- string vocabPath = Path.Combine(path, vocabFileName);
- string serialized = JsonSerializer.Serialize(_vocabReverse, new JsonSerializerOptions { Converters = { new DictReversingConverter() } });
- File.WriteAllText(vocabPath, serialized, System.Text.Encoding.UTF8);
-
- // Write merges.txt
- string mergeFileName = prefix is null ? "merges.txt" : $"{prefix}-merges.txt";
- string mergePath = Path.Combine(path, mergeFileName);
+ if (_vocabReverse.TryGetValue(id, out var value))
+ {
+ if (FilterUnsupportedChars)
+ {
+ char[] buffer = ArrayPool.Shared.Rent(value.Length);
+ int i = 0;
- KeyValuePair<(string, string), int>[] mergeArray = _mergeRanks.ToArray();
- Array.Sort(mergeArray, (x, y) => x.Value.CompareTo(y.Value));
+ for (int j = 0; j < value.Length; j++)
+ {
+ if (_unicodeToByte.TryGetValue(value[j], out var c))
+ {
+ buffer[i++] = c;
+ }
+ }
- using StreamWriter file = new(mergePath, append: false, System.Text.Encoding.UTF8);
- file.WriteLine("#version: 0.2");
- foreach (var p in mergeArray)
- {
- if (p.Value == int.MaxValue)
+ string result = new string(buffer, 0, i);
+ ArrayPool.Shared.Return(buffer);
+ return result;
+ }
+ else
{
- // Skip the entries which we added during the runs.
- continue;
+ return value;
}
- file.WriteLine($"{p.Key.Item1} {p.Key.Item2}");
}
- // Write high occurrence mapping file
- string highOccurrenceFileName = prefix is null ? "dict.txt" : $"{prefix}-dict.txt";
- string highOccurrencePath = Path.Combine(path, highOccurrenceFileName);
- using StreamWriter file1 = new(highOccurrencePath, append: false, System.Text.Encoding.UTF8);
- _vocabIdToHighestOccurrence.Save(file1);
-
- return new string[] { vocabPath, mergePath, highOccurrencePath };
+ return null;
}
///
- /// Tokenize a sequence string to a list of tokens.
+ /// Encode a text string to a list of tokens.
///
- /// The sequence to tokenize.
+ /// The text to encode.
/// Indicate if the token is a special token.
- /// The list of tokens generated from the sequence tokenization.
- public override IReadOnlyList Tokenize(string sequence, bool isSpecialToken = false)
+ /// The list of tokens generated from the text tokenization.
+ public override IReadOnlyList Encode(string text, bool isSpecialToken = false)
{
- char[] token = ArrayPool.Shared.Rent(sequence.Length);
- int[] indexMapping = ArrayPool.Shared.Rent(sequence.Length);
+ if (string.IsNullOrEmpty(text))
+ {
+ return Bpe.EmptyTokensList;
+ }
+
+ char[] token = ArrayPool.Shared.Rent(text.Length);
+ int[] indexMapping = ArrayPool.Shared.Rent(text.Length);
int newTokenIndex = 0;
- for (int i = 0; i < sequence.Length; i++)
+ for (int i = 0; i < text.Length; i++)
{
- if (_byteToUnicode.TryGetValue(sequence[i], out var value))
+ if (_byteToUnicode.TryGetValue(text[i], out var value))
{
token[newTokenIndex] = value;
indexMapping[newTokenIndex] = i;
@@ -226,7 +205,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok
return Array.Empty();
}
- if (_cache.TryGet(sequence, out List? hit))
+ if (_cache.TryGet(text, out List? hit))
{
ArrayPool.Shared.Return(token);
ArrayPool.Shared.Return(indexMapping);
@@ -234,31 +213,36 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok
}
List result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
- _cache.Set(sequence, result);
+ _cache.Set(text, result);
ArrayPool.Shared.Return(token);
ArrayPool.Shared.Return(indexMapping);
return result;
}
///
- /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
+ /// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
///
- /// The sequence to split.
+ /// The text to split.
/// Indicate if the token is a special token.
- /// The list of accumulated tokenized Ids.
- public override void TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) => TokenizeToIds(sequence, accumulatedIds);
+ /// The list of accumulated encoded Ids.
+ public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds);
///
- /// Get the number of tokens that the input sequence will be encoded to.
+ /// Get the number of tokens that the input text will be encoded to.
///
- /// The text to tokenize.
+ /// The text to encode.
/// Indicate if the token is special token.
- /// The number of tokens that the input sequence will be encoded to.
- public override int CountTokens(string sequence, bool isSpecialToken) => TokenizeToIds(sequence, null);
+ /// The number of tokens that the input text will be encoded to.
+ public override int CountTokens(string text, bool isSpecialToken) => EncodeToIds(text, null);
- private int TokenizeToIds(string sequence, IList? accumulatedIds)
+ private int EncodeToIds(string text, IList? accumulatedIds)
{
- if (_cache.TryGet(sequence, out List? hit))
+ if (string.IsNullOrEmpty(text))
+ {
+ return 0;
+ }
+
+ if (_cache.TryGet(text, out List? hit))
{
if (accumulatedIds is not null)
{
@@ -271,57 +255,33 @@ private int TokenizeToIds(string sequence, IList? accumulatedIds)
return hit.Count;
}
- Span token = stackalloc char[100];
- Span indexMapping = stackalloc int[100];
-
- if (sequence.Length > 100)
+ // If the cache doesn't have the text, then encode it and add it to the cache
+ IReadOnlyList tokens = Encode(text);
+ if (accumulatedIds is not null)
{
- token = new char[sequence.Length].AsSpan();
- indexMapping = new int[sequence.Length].AsSpan();
- }
-
- int newTokenIndex = 0;
- for (int i = 0; i < sequence.Length; i++)
- {
- if (_byteToUnicode.TryGetValue(sequence[i], out var value))
+ foreach (var t in tokens)
{
- token[newTokenIndex] = value;
- indexMapping[newTokenIndex] = i;
- newTokenIndex++;
+ accumulatedIds.Add(t.Id);
}
}
- if (newTokenIndex == 0)
- {
- return 0;
- }
-
- List result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping);
- _cache.Set(sequence, result);
- return result.Count;
+ return tokens.Count;
}
///
- /// Map the token to tokenized Id.
+ /// Map the token to encoded Id.
///
/// The token to map to the Id.
+ /// Indicate if want to consider the special tokens during the encoding.
/// The mapped Id of the token.
- public override int? TokenToId(string token) => _vocab.TryGetValue(token, out var value) ? value : null;
-
- ///
- /// Gets a trainer object to use in training the model and generate the vocabulary and merges data.
- ///
- ///
- /// This tokenizer doesn't support training so this method will return null. Consider using Bpe.GetTrainer() for training.
- ///
- public override Trainer? GetTrainer() => null;
+ public override int? MapTokenToId(string token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out var value) ? value : null;
///
/// Convert a list of tokens Ids to highest occurrence rankings.
///
/// The Ids list to map to the high occurrence rank.
/// The list of ranks mapped from the list of Ids.
- public IReadOnlyList IdsToOccurrenceRanks(IReadOnlyList ids)
+ public IReadOnlyList ConvertIdsToOccurrenceRanks(IReadOnlyList ids)
{
if (ids is null)
{
@@ -343,7 +303,7 @@ public IReadOnlyList IdsToOccurrenceRanks(IReadOnlyList ids)
///
/// The Ids list to map to the high occurrence values.
/// The list of occurrence values mapped from the list of Ids.
- public IReadOnlyList IdsToOccurrenceValues(IReadOnlyList ids)
+ public IReadOnlyList ConvertIdsToOccurrenceValues(IReadOnlyList ids)
{
if (ids is null)
{
@@ -365,7 +325,7 @@ public IReadOnlyList IdsToOccurrenceValues(IReadOnlyList ids)
///
/// The high occurrence ranks list to map to the Ids list.
/// The list of Ids mapped from the list of ranks.
- public IReadOnlyList OccurrenceRanksIds(IReadOnlyList ranks)
+ public IReadOnlyList ConvertOccurrenceRanksToIds(IReadOnlyList ranks)
{
if (ranks is null)
{
@@ -376,7 +336,7 @@ public IReadOnlyList OccurrenceRanksIds(IReadOnlyList ranks)
foreach (int rank in ranks)
{
- list.Add(_vocabIdToHighestOccurrence.OccurrenceRankToId(rank));
+ list.Add(_vocabIdToHighestOccurrence.ConvertOccurrenceRankToId(rank));
}
return list;
@@ -411,7 +371,7 @@ private IReadOnlyList ModifyTokenListOffsets(IReadOnlyList tokens,
{
Debug.Assert(index + tokens[i].Value.Length <= indexMapping.Length);
- if (tokens[i].Offset != (indexMapping[index], indexMapping[index + tokens[i].Value.Length - 1] + 1))
+ if (tokens[i].Offset != (indexMapping[index], tokens[i].Value.Length))
{
List list = new List(tokens.Count);
for (int j = 0; j < i; j++)
@@ -421,7 +381,7 @@ private IReadOnlyList ModifyTokenListOffsets(IReadOnlyList tokens,
for (int j = i; j < tokens.Count; j++)
{
- list.Add(new Token(tokens[j].Id, tokens[j].Value, (indexMapping[index], indexMapping[index + tokens[j].Value.Length - 1] + 1)));
+ list.Add(new Token(tokens[j].Id, tokens[j].Value, (indexMapping[index], tokens[j].Value.Length)));
index += tokens[j].Value.Length;
}
@@ -477,9 +437,9 @@ private Dictionary GetVocabulary(Stream vocabularyStream)
return vocab;
}
- private Dictionary<(string, string), int> GetMergeRanks(Stream mergeStream)
+ private Cache<(string, string), int> GetMergeRanks(Stream mergeStream)
{
- var mergeRanks = new Dictionary<(string, string), int>();
+ var mergeRanks = new Cache<(string, string), int>(60_000);
try
{
using StreamReader reader = new StreamReader(mergeStream);
@@ -500,7 +460,7 @@ private Dictionary GetVocabulary(Stream vocabularyStream)
throw new Exception($"Invalid format of merge file: \"{line}\"");
}
- mergeRanks.Add((line.Substring(0, index), line.Substring(index + 1)), rank++);
+ mergeRanks.Set((line.Substring(0, index), line.Substring(index + 1)), rank++);
}
}
catch (Exception e)
@@ -538,26 +498,19 @@ private static int GetByteToUnicode(out IReadOnlyDictionary byteToUn
}
///
- /// Encode a token into BPE-ed Ids. E.g., "playing" into ["play", "ing"].
+ /// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"].
///
- /// The token to encode.
- /// The list of Ids to encode the token into.
- /// The number of encoded ids.
- private int EncodeToIds(Span token, IList? ids)
+ private List EncodeToTokens(Span token, Span indexMapping)
{
if (token.Length == 0)
{
- return 0;
+ return Bpe.EmptyTokensList;
}
if (token.Length == 1)
{
- if (ids is not null)
- {
- ids.Add(_vocab[_charToString[token[0]]]);
- }
-
- return 1;
+ string tokenValue = _charToString[token[0]];
+ return new List { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], 1)) };
}
List word = new(token.Length);
@@ -586,7 +539,7 @@ private int EncodeToIds(Span token, IList? ids)
// get the most frequent bi-gram pair
var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue));
- if (!_mergeRanks.ContainsKey((first, second)))
+ if (!_mergeRanks.TryGet((first, second), out int _))
{
break;
}
@@ -605,6 +558,7 @@ private int EncodeToIds(Span token, IList? ids)
{
newWord.Add(word[k]);
}
+
break;
}
else
@@ -614,104 +568,7 @@ private int EncodeToIds(Span token, IList? ids)
{
newWord.Add(word[k]);
}
- i = j;
- }
-
- // check the next element is {second} or not
- if (i < word.Count - 1 && word[i + 1] == second)
- {
- newWord.Add(first + second);
- i += 2;
- }
- else
- {
- newWord.Add(word[i]);
- i += 1;
- }
- }
- List temp = word;
- word = newWord;
- newWord = temp;
- newWord.Clear();
-
- // otherwise, continue merging
- WordToPairs(word, pairs);
- }
-
- if (ids is not null)
- {
- foreach (string w in word)
- {
- ids.Add(_vocab[w]);
- }
- }
-
- return word.Count;
- }
-
- ///
- /// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"].
- ///
- private List EncodeToTokens(Span token, Span indexMapping)
- {
- if (token.Length == 0)
- {
- return Bpe.EmptyTokensList;
- }
-
- if (token.Length == 1)
- {
- string tokenValue = _charToString[token[0]];
- return new List { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], indexMapping[0] + 1)) };
- }
-
- List word = new(token.Length);
- foreach (char c in token)
- {
- Debug.Assert(c < _charToString.Length);
- word.Add(_charToString[c]);
- }
-
- HashSet<(string, string)> pairs = new();
-
- WordToPairs(word, pairs);
-
- var newWord = new List();
-
- Debug.Assert(pairs.Count != 0, "Pairs should not be empty.");
-
- while (true)
- {
- /* while conditions */
- // if only one element left, merge is finished (with the whole word merged)
- if (word.Count == 1)
- {
- break;
- }
-
- // get the most frequent bi-gram pair
- var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue));
- if (!_mergeRanks.ContainsKey((first, second)))
- {
- break;
- }
- /* end while conditions */
-
- // search and merge all (first, second) pairs in {word}
- var i = 0;
- while (i < word.Count)
- {
- // find the next occurrence of {first} and add the elements before into {newWord}
- var j = word.IndexOf(first, i);
- if (j == -1)
- {
- newWord.AddRange(word.Skip(i));
- break;
- }
- else
- {
- newWord.AddRange(word.Skip(i).Take(j - i));
i = j;
}
@@ -742,7 +599,7 @@ private List EncodeToTokens(Span token, Span indexMapping)
foreach (string w in word)
{
- tokens.Add(new Token(_vocab[w], w, (indexMapping[index], indexMapping[index + w.Length - 1] + 1)));
+ tokens.Add(new Token(_vocab[w], w, (indexMapping[index], w.Length)));
index += w.Length;
}
@@ -770,7 +627,12 @@ private static void WordToPairs(IReadOnlyList word, HashSet<(string, str
}
}
- public bool CharInSupportedRange(char ch)
+ ///
+ /// Check if the character is supported by the tokenizer's model.
+ ///
+ /// The character to check.
+ /// True if the character is supported, otherwise false.
+ public bool IsSupportedChar(char ch)
{
return _byteToUnicode.ContainsKey(ch);
}
@@ -847,7 +709,7 @@ public int IdToOccurrenceValue(int id)
return 0;
}
- public int OccurrenceRankToId(int rank)
+ public int ConvertOccurrenceRankToId(int rank)
{
if ((uint)rank >= _symbols.Count)
{
@@ -874,12 +736,12 @@ private int ReserveStringSymbolSlot(string symbol, int defaultOccurrence = -1)
return idx;
}
- public int AddSymbol(int id, int highOccuranceScore)
+ public int AddSymbol(int id, int highOccurrenceScore)
{
if (!_idToIndex.TryGetValue(id, out int idx))
{
idx = _symbols.Count;
- _symbols.Add((id, highOccuranceScore));
+ _symbols.Add((id, highOccurrenceScore));
_idToIndex[id] = idx;
}
@@ -970,25 +832,5 @@ public void AddFromStream(Stream stream)
}
}
}
-
- public void Save(StreamWriter file)
- {
- for (int i = NumSpecialSymbols; i < _symbols.Count; i++)
- {
- (int id, int occurrenceScore) symbol = _symbols[i];
- if (symbol.id >= 0 && symbol.occurrenceScore >= 0)
- {
- file.WriteLine($"{symbol.id} {symbol.occurrenceScore}");
- }
- }
-
- foreach (KeyValuePair kvp in _stringSymbolToIndexMapping)
- {
- if (_symbols[kvp.Value].OccurrenceScore >= 0)
- {
- file.WriteLine($"{kvp.Key} {_symbols[kvp.Value].OccurrenceScore}");
- }
- }
- }
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs
index c8bd01cc06..16eecc4aa4 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Model.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs
@@ -14,31 +14,31 @@ namespace Microsoft.ML.Tokenizers
public abstract class Model
{
///
- /// Tokenize a split sequence string to a list of tokens.
+ /// Encode a text to a list of tokens.
///
- /// The text to tokenize.
+ /// The text to encode.
/// Indicate if the token is a special token.
- /// The list of tokens generated from the sequence tokenization.
- public abstract IReadOnlyList Tokenize(string sequence, bool isSpecialToken = false);
+ /// The list of tokens generated from the text tokenization.
+ public abstract IReadOnlyList Encode(string text, bool isSpecialToken = false);
///
- /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
+ /// Encode a text to a list of Ids and add them to the accumulatedIds list.
///
- /// The sequence to split.
+ /// The text to encode.
/// Indicate if the token is a special token.
- /// The list of accumulated tokenized Ids.
+ /// The list of accumulated encoded Ids.
///
- /// This method does the default implementation that uses the Tokenize method to get the token's Ids.
+ /// This method does the default implementation that uses the Encode method to get the token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
///
- public virtual void TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds)
+ public virtual void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds)
{
if (accumulatedIds is null)
{
throw new ArgumentNullException(nameof(accumulatedIds));
}
- var tokens = Tokenize(sequence);
+ var tokens = Encode(text);
foreach (var token in tokens)
{
accumulatedIds.Add(token.Id);
@@ -46,66 +46,58 @@ public virtual void TokenizeToIds(string sequence, bool isSpecialToken, IList
- /// Get the number of tokens that the input sequence will be encoded to.
+ /// Get the number of tokens that the input text will be encoded to.
///
- /// The text to tokenize.
+ /// The text to encode.
/// Indicate if the token is special token.
- /// The number of tokens that the input sequence will be encoded to.
+ /// The number of tokens that the input text will be encoded to.
///
- /// This method does the default implementation that uses the TokenizeToIds method to get the number of token's Ids.
+ /// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
///
- public virtual int CountTokens(string sequence, bool isSpecialToken)
+ public virtual int CountTokens(string text, bool isSpecialToken)
{
var ids = new List();
- TokenizeToIds(sequence, isSpecialToken, ids);
+ EncodeToIds(text, isSpecialToken, ids);
return ids.Count;
}
///
- /// Map the token to tokenized Id.
- ///
- /// The token to map to the Id.
- /// The mapped Id of the token.
- public abstract int? TokenToId(string token);
-
- ///
- /// Map the token to tokenized id with the option to skip the special tokens.
+ /// Map the token to encoded id with the option to skip the special tokens.
///
/// The token to map to Id
- /// Indicate if want to skip the special tokens during the encoding.
+ /// Indicate if want to consider the special tokens during the encoding.
/// The mapped Id of the token.
- public virtual int? TokenToId(string token, bool skipSpecialTokens) => TokenToId(token);
+ public abstract int? MapTokenToId(string token, bool considerSpecialTokens = true);
///
- /// Map the tokenized Id to the token.
+ /// Map the encoded Id to the token.
///
/// The Id to map to the token.
- /// Indicate if want to skip the special tokens during the decoding.
+ /// Indicate if want to consider the special tokens during the decoding.
/// The mapped token of the Id.
- public abstract string? IdToToken(int id, bool skipSpecialTokens = false);
+ public abstract string? MapIdToToken(int id, bool considerSpecialTokens = true);
///
- /// Gets the dictionary mapping tokens to Ids.
+ /// Decode the given ids, back to a String.
///
- public abstract IReadOnlyDictionary GetVocab();
-
- ///
- /// Gets the dictionary size that map tokens to Ids.
- ///
- public abstract int GetVocabSize();
+ /// The list of ids that we want to decode.
+ /// Whether the special tokens should be kept in the decoded string.
+ /// The optional Decoder to merge the given list of tokens in a string.
+ /// The decoded string.
+ public virtual string? Decode(IEnumerable ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true)
+ {
+ List tokens = new List();
- ///
- /// Save the model data into the vocabulary and merges files.
- ///
- /// The file system path to store the generated files at.
- /// Optional prefix for the generated file names.
- /// The list of all saved files.
- public abstract string[] Save(string path, string? prefix = null);
+ foreach (int id in ids)
+ {
+ if (MapIdToToken(id, considerSpecialTokens) is string s)
+ {
+ tokens.Add(s);
+ }
+ }
- ///
- /// Gets a trainer object to use in training the model.
- ///
- public abstract Trainer? GetTrainer();
+ return decoder?.Decode(tokens) ?? string.Concat(tokens);
+ }
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/Progress.cs b/src/Microsoft.ML.Tokenizers/Model/Progress.cs
deleted file mode 100644
index 50aef06e57..0000000000
--- a/src/Microsoft.ML.Tokenizers/Model/Progress.cs
+++ /dev/null
@@ -1,64 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-
-namespace Microsoft.ML.Tokenizers
-{
- public delegate void ReportProgress(Progress progress);
-
- ///
- /// Represent the state of the reported progress.
- ///
- public enum ProgressState
- {
- ///
- /// The progress is started. The reported value in the Progress structure will have the max number progressing toward.
- ///
- Start,
-
- ///
- /// The progress is ended. The reported value in the Progress structure will have the final max processed number.
- ///
- End,
-
- ///
- /// The progress is incremented. The reported value in increment value in the progress.
- ///
- Increment
- }
-
- public readonly struct Progress
- {
- ///
- /// Construct the Progress object using the progress state, message and the value.
- ///
- public Progress(ProgressState state, string? message, int value)
- {
- State = state;
- Message = message;
- Value = value;
- }
-
- ///
- /// The progress state.
- ///
- public ProgressState State { get; }
-
- ///
- /// The message of the progress.
- ///
- public string? Message { get; }
-
- ///
- /// The Value of the progress.
- ///
- ///
- /// The value is the max number the progress can reach if the progress state is `Start`.
- /// The value is the max number the progress reached if the progress state is `End`.
- /// The value is the incremented value in the progress if the progress state is `Increment`.
- ///
- public int Value { get; }
- }
-}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
index f7eeed8b7e..0696efd9b0 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
@@ -20,67 +20,73 @@ namespace Microsoft.ML.Tokenizers
public sealed class Tiktoken : Model
{
private readonly Dictionary, int> _encoder = null!;
- private readonly IReadOnlyDictionary _decoder = null!;
+ private readonly Dictionary> _decoder = null!;
private readonly LruCache? _cache;
private readonly IReadOnlyDictionary? _specialTokensEncoder;
private readonly Dictionary? _specialTokensDecoder;
private readonly Dictionary _vocab = null!;
///
- /// Create a new Tiktoken tokenizer object.
+ /// Create a new Tiktoken tokenizer's model object.
///
- /// The path to the BPE rank file.
- /// The dictionary mapping special tokens to Ids.
+ /// The path to the BPE vocab file.
+ /// The dictionary mapping special tokens to Ids.
/// The size of the cache to use.
- /// Thrown when is null or empty.
- /// Thrown when failed to load the BPE rank file.
- public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) :
- this(string.IsNullOrEmpty(tikTokenBpeFile) ? throw new ArgumentNullException(nameof(tikTokenBpeFile)) : File.OpenRead(tikTokenBpeFile), specialTokensEncoder, cacheSize, disposeStream: true)
+ /// Thrown when is null or empty.
+ /// Thrown when failed to load the BPE vocab file.
+ public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) :
+ this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true)
{
}
///
- /// Create a new Tiktoken tokenizer object.
+ /// Create a new Tiktoken tokenizer's model object.
///
- /// The stream to the BPE rank file.
- /// The dictionary mapping special tokens to Ids.
+ /// The stream to the BPE vocab file.
+ /// The dictionary mapping special tokens to Ids.
/// The size of the cache to use.
- /// Thrown when is null or empty.
- /// Thrown when failed to load the BPE rank file.
- public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) :
- this(tikTokenBpeFileStream ?? throw new ArgumentNullException(nameof(tikTokenBpeFileStream)), specialTokensEncoder, cacheSize, disposeStream: false)
+ /// Thrown when is null or empty.
+ /// Thrown when failed to load the BPE vocab file.
+ public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) :
+ this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false)
{
}
+ ///
+ /// Create a new Tiktoken tokenizer's model object.
+ ///
+ /// The dictionary mapping token utf-8 bytes to Ids.
+ /// The dictionary mapping Ids to token utf-8 bytes.
+ /// The dictionary mapping string tokens to Ids.
+ /// The dictionary mapping special tokens to Ids.
+ /// The max size of the cache to use.
internal Tiktoken(
Dictionary, int> encoder,
- IReadOnlyDictionary decoder,
+ Dictionary> decoder,
Dictionary vocab,
- IReadOnlyDictionary? specialTokensEncoder = null,
+ IReadOnlyDictionary? specialTokens,
int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize)
{
- Debug.Assert(encoder is not null);
- Debug.Assert(decoder is not null);
- Debug.Assert(vocab is not null);
+ _encoder = encoder ?? throw new ArgumentNullException(nameof(encoder));
+ _decoder = decoder ?? throw new ArgumentNullException(nameof(decoder));
+ _vocab = vocab ?? throw new ArgumentNullException(nameof(vocab));
- _encoder = encoder!;
- _vocab = vocab!;
- _decoder = decoder!;
+ Debug.Assert(encoder.Count == decoder.Count);
- _specialTokensEncoder = specialTokensEncoder;
+ _specialTokensEncoder = specialTokens;
if (_specialTokensEncoder is not null)
{
_specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
}
}
- private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder, int cacheSize, bool disposeStream) : this(cacheSize)
+ private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens, int cacheSize, bool disposeStream) : this(cacheSize)
{
try
{
- (_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(tikTokenBpeFileStream, useAsync: false).GetAwaiter().GetResult();
+ (_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
- _specialTokensEncoder = specialTokensEncoder;
+ _specialTokensEncoder = specialTokens;
if (_specialTokensEncoder is not null)
{
_specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
@@ -90,7 +96,7 @@ private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary