Skip to content
Merged
8 changes: 8 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
return null;
}

/// <summary>
/// Map the tokenized Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToString(int id, bool skipSpecialTokens = false) => throw new NotImplementedException();

/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
Expand Down
22 changes: 22 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,28 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
public override string? IdToToken(int id, bool skipSpecialTokens = false) =>
skipSpecialTokens && id < 0 ? null : _vocabReverse.TryGetValue(id, out var value) ? value : null;

/// <summary>
/// Map the tokenized Id to the original string.
/// </summary>
/// <param name="id">The Id to map to the string.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToString(int id, bool skipSpecialTokens = false)
{
if (skipSpecialTokens && id < 0)
return null;
if (_vocabReverse.TryGetValue(id, out var value))
{
var textChars = string.Join("", value)
.Where(c => _unicodeToByte.ContainsKey(c))
.Select(c => _unicodeToByte[c]);
var text = new string(textChars.ToArray());
return text;
}

return null;
}

/// <summary>
/// Save the model data into the vocabulary, merges, and occurrence mapping files.
/// </summary>
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public abstract class Model
/// <returns>The mapped token of the Id.</returns>
public abstract string? IdToToken(int id, bool skipSpecialTokens = false);

public abstract string? IdToString(int id, bool skipSpecialTokens = false);

/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
Expand Down
5 changes: 4 additions & 1 deletion src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ public TokenizerResult Encode(string sequence)

foreach (int id in ids)
{
tokens.Add(Model.IdToToken(id) ?? "");
if (Model.GetType() == typeof(EnglishRoberta))
tokens.Add(Model.IdToString(id) ?? "");
else
tokens.Add(Model.IdToToken(id) ?? "");
}

return Decoder?.Decode(tokens) ?? string.Join("", tokens);
Expand Down
16 changes: 16 additions & 0 deletions src/Microsoft.ML.TorchSharp/NasBert/BertModelType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.ML.TorchSharp.NasBert
{
internal enum BertModelType
{
NasBert,
Roberta
}
}
4 changes: 3 additions & 1 deletion src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public enum BertTaskType
None = 0,
MaskedLM = 1,
TextClassification = 2,
SentenceRegression = 3
SentenceRegression = 3,
NameEntityRecognition = 4,
QuestionAnswering = 5
}
}
5 changes: 5 additions & 0 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/BaseModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@ namespace Microsoft.ML.TorchSharp.NasBert.Models
internal abstract class BaseModel : torch.nn.Module<torch.Tensor, torch.Tensor, torch.Tensor>
{
protected readonly NasBertTrainer.NasBertOptions Options;
public BertModelType EncoderType => Options.ModelType;

public BertTaskType HeadType => Options.TaskType;

//public ModelType EncoderType => Options.ModelType;

#pragma warning disable CA1024 // Use properties where appropriate: Modules should be fields in TorchSharp
public abstract TransformerEncoder GetEncoder();

public abstract BaseHead GetHead();

#pragma warning restore CA1024 // Use properties where appropriate

protected BaseModel(NasBertTrainer.NasBertOptions options)
Expand Down
39 changes: 39 additions & 0 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/ModelForNer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;
using TorchSharp;

namespace Microsoft.ML.TorchSharp.NasBert.Models
{
internal class ModelForNer : NasBertModel
{
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "Has to match TorchSharp model.")]
private readonly SequenceLabelHead NerHead;

public override BaseHead GetHead() => NerHead;

public ModelForNer(NasBertTrainer.NasBertOptions options, int padIndex, int symbolsCount, int numLabels)
: base(options, padIndex, symbolsCount)
{
NerHead = new SequenceLabelHead(
inputDim: Options.EncoderOutputDim,
numLabels: numLabels,
dropoutRate: Options.PoolerDropout);
Initialize();
RegisterComponents();
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
public override torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor tokenMask = null)
{
using var disposeScope = torch.NewDisposeScope();
var x = ExtractFeatures(srcTokens);
x = NerHead.call(x);
return x.MoveToOuterDisposeScope();
}
}
}
36 changes: 36 additions & 0 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/ModelPrediction.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using TorchSharp;

namespace Microsoft.ML.TorchSharp.NasBert.Models
{
internal sealed class ModelForPrediction : NasBertModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NERInferenceModel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't for NER. Its for SentenceSimilarity and TextClassification. How about TextModel? TextModelForPrediction? Thoughts?

{
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "Has to match TorchSharp model.")]
private readonly PredictionHead PredictionHead;

public override BaseHead GetHead() => PredictionHead;

public ModelForPrediction(NasBertTrainer.NasBertOptions options, int padIndex, int symbolsCount, int numClasses)
: base(options, padIndex, symbolsCount)
{
PredictionHead = new PredictionHead(
inputDim: Options.EncoderOutputDim,
numClasses: numClasses,
dropoutRate: Options.PoolerDropout);
Initialize();
RegisterComponents();
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
public override torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor tokenMask = null)
{
using var disposeScope = torch.NewDisposeScope();
var x = ExtractFeatures(srcTokens);
x = PredictionHead.call(x);
return x.MoveToOuterDisposeScope();
}
}
}
Loading