-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Tokenizer's Interfaces Cleanup #7001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,6 +95,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st | |
|
|
||
| (Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadFile(vocabFile, mergesFile); | ||
| Vocab = vocab1 ?? new Dictionary<string, int>(); | ||
| Cache = new Cache<string, Word>(); | ||
|
|
||
| VocabReverse = new(); | ||
|
|
||
|
|
@@ -146,23 +147,33 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st | |
| /// Tokenize a sequence string to a list of tokens. | ||
| /// </summary> | ||
| /// <param name="sequence">The sequence to tokenize.</param> | ||
| /// <param name="isSpecialToken">Indicate if the token is a special token.</param> | ||
| /// <returns>The list of tokens generated from the sequence tokenization.</returns> | ||
| public override IReadOnlyList<Token> Tokenize(string sequence) | ||
| public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken = false) | ||
| { | ||
| if (sequence.Length == 0) | ||
| { | ||
| return EmptyTokensList; | ||
| } | ||
|
|
||
| if (!Dropout.HasValue) | ||
| { | ||
| return TokenizeWithCache(sequence); | ||
| } | ||
| return TokenizeWithCache(sequence); | ||
| } | ||
|
|
||
| Word word = MergeWord(sequence); | ||
| /// <summary> | ||
| /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list. | ||
| /// </summary> | ||
| /// <param name="sequence">The sequence to split.</param> | ||
| /// <param name="isSpecialToken">Indicate if the token is a special token.</param> | ||
| /// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param> | ||
| public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds) => TokenizeToIdsWithCache(sequence, accumulatedIds); | ||
|
|
||
| return WordToTokens(ref word); | ||
| } | ||
| /// <summary> | ||
| /// Get the number of token's Ids that the input sequence will be encoded to. | ||
| /// </summary> | ||
| /// <param name="sequence">The text to tokenize.</param> | ||
| /// <param name="isSpecialToken">Indicate if the token is special token.</param> | ||
| /// <returns>The number of token's Ids that the input sequence will be encoded to.</returns> | ||
| public override int GetTokenizedIdsCount(string sequence, bool isSpecialToken) => TokenizeToIdsWithCache(sequence, null); | ||
|
|
||
| /// <summary> | ||
| /// Map the token to tokenized Id. | ||
|
|
@@ -195,14 +206,6 @@ public override IReadOnlyList<Token> Tokenize(string sequence) | |
| return null; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Map the tokenized Id to the token. | ||
| /// </summary> | ||
| /// <param name="id">The Id to map to the token.</param> | ||
| /// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param> | ||
| /// <returns>The mapped token of the Id.</returns> | ||
| public override string? IdToString(int id, bool skipSpecialTokens = false) => throw new NotImplementedException(); | ||
|
|
||
| /// <summary> | ||
| /// Gets the dictionary mapping tokens to Ids. | ||
| /// </summary> | ||
|
|
@@ -332,7 +335,7 @@ internal string CharToString(char c) | |
|
|
||
| internal Word MergeWord(string w) | ||
| { | ||
| Word word = Word.WithCapacity((int)w.Length); | ||
| Word word = Word.WithCapacity(w.Length); | ||
| (int Id, int Len)? unk = null; | ||
| int i = 0; | ||
|
|
||
|
|
@@ -344,7 +347,7 @@ internal Word MergeWord(string w) | |
| if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1])) | ||
| { | ||
| length = 2; | ||
| s = w.Substring(i, (int)length); | ||
| s = w.Substring(i, length); | ||
| } | ||
| else | ||
| { | ||
|
|
@@ -403,7 +406,7 @@ internal Word MergeWord(string w) | |
| } | ||
| } | ||
|
|
||
| i += (int)length; | ||
| i += length; | ||
| } | ||
|
|
||
| if (unk.HasValue) | ||
|
|
@@ -415,45 +418,59 @@ internal Word MergeWord(string w) | |
| return word; | ||
| } | ||
|
|
||
| // internal Word.Enumerator WordToTokens(Word word) => word.GetIterator(VocabReverse); | ||
| internal List<Token> WordToTokens(ref Word word) | ||
| internal List<Token> WordToTokens(ref Word word) => word.ToTokens(VocabReverse); | ||
|
|
||
| internal List<Token> TokenizeWithCache(string sequence) | ||
| { | ||
| List<Token> tokens = new(word.SymbolsCount); | ||
| Word word; | ||
| if (Cache is not null) | ||
| { | ||
| if (Cache.TryGet(sequence, out word)) | ||
| { | ||
| return WordToTokens(ref word); | ||
| } | ||
|
|
||
| foreach (Token token in word.GetIterator(VocabReverse)) | ||
| word = MergeWord(sequence); | ||
| Cache.Set(sequence, word); | ||
| } | ||
| else | ||
| { | ||
| tokens.Add(token); | ||
| word = MergeWord(sequence); | ||
| } | ||
|
|
||
| return tokens; | ||
| return WordToTokens(ref word); | ||
|
Comment on lines
+423
to
+441
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this whole method just be: List<Token> result = new();
TokenizeToIdsWithCache(sequence, result);
return result;?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, |
||
| } | ||
|
|
||
| internal List<Token> TokenizeWithCache(string sequence) | ||
| internal int WordToIds(ref Word word, IList<int>? accumulatedIds) | ||
| { | ||
| if (Cache is not null) | ||
| if (accumulatedIds is not null) | ||
| { | ||
| Word? hit = Cache.Get(sequence); | ||
| if (hit.HasValue) | ||
| { | ||
| Word w = hit.Value; | ||
| return WordToTokens(ref w); | ||
| } | ||
| word.PopulateIds(accumulatedIds); | ||
| } | ||
|
|
||
| Word word = MergeWord(sequence); | ||
| List<Token> tokens = WordToTokens(ref word); | ||
| return word.SymbolsCount; | ||
| } | ||
|
|
||
| internal int TokenizeToIdsWithCache(string sequence, IList<int>? accumulatedIds) | ||
| { | ||
| Word word; | ||
|
|
||
| if (Cache is not null) | ||
| { | ||
| if (Cache.TryGet(sequence, out Word hit)) | ||
| { | ||
| return WordToIds(ref hit, accumulatedIds); | ||
| } | ||
|
|
||
| word = MergeWord(sequence); | ||
| Cache.Set(sequence, word); | ||
| } | ||
| else | ||
| { | ||
| word = MergeWord(sequence); | ||
| } | ||
|
|
||
| return tokens; | ||
| } | ||
|
|
||
| public override bool IsValidChar(char ch) | ||
| { | ||
| throw new NotImplementedException(); | ||
| return WordToIds(ref word, accumulatedIds); | ||
| } | ||
|
|
||
| internal static readonly List<Token> EmptyTokensList = new(); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.