Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@
<Import Project="$(RepoRoot)eng/pkg/Pack.props" />

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<TargetFrameworks>netstandard2.0;net8.0</TargetFrameworks>
<Nullable>enable</Nullable>
<PackageDescription>Microsoft.ML.Tokenizers contains the implmentation of the tokenization used in the NLP transforms.</PackageDescription>
</PropertyGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<Compile Remove="Utils/Helpers.netcoreapp.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' != 'netstandard2.0'">
<Compile Remove="Utils/Helpers.netstandard.cs" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
</ItemGroup>
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public string? UnknownToken

if (value is null)
{
if (VocabReverse.TryGetValue(0, out string v))
if (VocabReverse.TryGetValue(0, out string? v))
{
VocabReverse.Remove(0);
if (Vocab.TryGetValue(v, out int id))
Expand Down Expand Up @@ -103,7 +103,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
VocabReverse.Add(kvp.Value, kvp.Key);
}

if (unknownToken is null && VocabReverse.TryGetValue(0, out string unkToken))
if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
{
unknownToken = unkToken;
}
Expand Down Expand Up @@ -187,7 +187,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
/// <returns>The mapped token of the Id.</returns>
public override string? IdToToken(int id, bool skipSpecialTokens = false)
{
if (VocabReverse.TryGetValue(id, out string value))
if (VocabReverse.TryGetValue(id, out string? value))
{
return value;
}
Expand Down Expand Up @@ -253,7 +253,7 @@ public override string[] Save(string path, string? prefix = null)
}

/// Read the given files to extract the vocab and merges
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(string? vocab, string? merges)
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(string vocab, string? merges)
{
Dictionary<string, int>? dic;
using (Stream stream = File.OpenRead(vocab))
Expand Down Expand Up @@ -320,7 +320,7 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(strin
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal string CharToString(char c)
{
if (_charToString.TryGetValue(c, out string v))
if (_charToString.TryGetValue(c, out string? v))
{
return v;
}
Expand Down
13 changes: 9 additions & 4 deletions src/Microsoft.ML.Tokenizers/Model/BpeTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ public BpeTrainer(
MinFrequency = minFrequency;
VocabSize = vocabSize;
Progress = progress;
SpecialTokens = new List<AddedToken>(specialTokens);

if (specialTokens is not null)
{
SpecialTokens = new List<AddedToken>(specialTokens);
}

LimitAlphabet = limitAlphabet;
InitialAlphabet = initialAlphabet;
ContinuingSubwordPrefix = continuingSubwordPrefix;
Expand Down Expand Up @@ -172,7 +177,7 @@ private void ComputeAlphabet(Dictionary<string, int> wc, Dictionary<string, int>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal string CharToString(char c)
{
if (_charToString.TryGetValue(c, out string v))
if (_charToString.TryGetValue(c, out string? v))
{
return v;
}
Expand Down Expand Up @@ -259,7 +264,7 @@ internal string CharToString(char c)
// Then update counts
int count = counts[i];

if (!whereToUpdate.TryGetValue(curPair, out HashSet<int> h))
if (!whereToUpdate.TryGetValue(curPair, out HashSet<int>? h))
{
h = new HashSet<int>();
whereToUpdate[curPair] = h;
Expand Down Expand Up @@ -398,7 +403,7 @@ internal string CharToString(char c)

if (change > 0)
{
if (!whereToUpdate.TryGetValue(p, out HashSet<int> h))
if (!whereToUpdate.TryGetValue(p, out HashSet<int>? h))
{
h = new();
whereToUpdate[p] = h;
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Tokenizers/Model/Cache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Microsoft.ML.Tokenizers
{
internal sealed class Cache<TKey, TValue>
internal sealed class Cache<TKey, TValue> where TKey : notnull
{
internal Cache() : this(Bpe.DefaultCacheCapacity) { }

Expand Down Expand Up @@ -39,13 +39,13 @@ internal void Clear()

internal List<TValue> GetValues(IEnumerable<TKey> keys)
{
List<TValue>? values = new();
List<TValue> values = new();
_cacheLock.EnterReadLock();
try
{
foreach (TKey key in keys)
{
if (Map.TryGetValue(key, out TValue value))
if (Map.TryGetValue(key, out TValue? value))
{
values.Add(value);
}
Expand All @@ -61,7 +61,7 @@ internal List<TValue> GetValues(IEnumerable<TKey> keys)
_cacheLock.EnterReadLock();
try
{
if (Map.TryGetValue(key, out TValue value))
if (Map.TryGetValue(key, out TValue? value))
{
return value;
}
Expand Down
8 changes: 6 additions & 2 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
using StreamReader reader = new StreamReader(mergeStream);
while (reader.Peek() >= 0)
{
splitContents.Add(reader.ReadLine());
splitContents.Add(reader.ReadLine()!);
}
}
catch (Exception e)
Expand Down Expand Up @@ -761,7 +761,11 @@ public void AddFromStream(Stream stream)

while (reader.Peek() >= 0)
{
string line = reader.ReadLine();
string? line = reader.ReadLine();
if (line is null)
{
continue;
}

var splitLine = line.Trim().Split(' ');
if (splitLine.Length != 2)
Expand Down
38 changes: 38 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,51 @@ public abstract class Model
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
public abstract IReadOnlyList<Token> Tokenize(string sequence);

/// <summary>
/// Tokenize a split sequence string to a list of tokens.
/// </summary>
/// <param name="sequence">The text to tokenize.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
public virtual IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken) => Tokenize(sequence);

/// <summary>
/// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="sequence">The sequence to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
/// <returns>True if the operation succeeded, false otherwise.</returns>
public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds)
{
if (accumulatedIds is null)
{
throw new ArgumentNullException(nameof(accumulatedIds));
}

var tokens = Tokenize(sequence);
foreach (var token in tokens)
{
accumulatedIds.Add(token.Id);
}
return true;
}

/// <summary>
/// Map the token to tokenized Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <returns>The mapped Id of the token.</returns>
public abstract int? TokenToId(string token);

/// <summary>
/// Map the token to tokenized id with the option to skip the special tokens.
/// </summary>
/// <param name="token">The token to map to Id</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public virtual int? TokenToId(string token, bool skipSpecialTokens) => TokenToId(token);

/// <summary>
/// Map the tokenized Id to the token.
/// </summary>
Expand Down
Loading