Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Text;

namespace Microsoft.ML.TorchSharp.NasBert
Expand All @@ -17,7 +18,10 @@ public enum BertTaskType
MaskedLM = 1,
TextClassification = 2,
SentenceRegression = 3,
NameEntityRecognition = 4,
NamedEntityRecognition = 4,
[Obsolete("Please use NamedEntityRecognition instead", false)]
[EditorBrowsable(EditorBrowsableState.Never)]
NameEntityRecognition = NamedEntityRecognition,
QuestionAnswering = 5
}
}
12 changes: 6 additions & 6 deletions src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ private protected override Module CreateModule(IChannel ch, IDataView input)
EnglishRoberta tokenizerModel = Tokenizer.RobertaModel();

NasBertModel model;
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
model = new NerModel(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
else
model = new ModelForPrediction(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
Expand Down Expand Up @@ -268,7 +268,7 @@ private protected override torch.Tensor PrepareRowTensor()
private protected override void RunModelAndBackPropagate(ref List<Tensor> inputTensors, ref Tensor targetsTensor)
{
Tensor logits = default;
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
int[,] lengthArray = new int[inputTensors.Count, 1];
for (int i = 0; i < inputTensors.Count; i++)
Expand All @@ -293,7 +293,7 @@ private protected override void RunModelAndBackPropagate(ref List<Tensor> inputT
torch.Tensor loss;
if (Parent.BertOptions.TaskType == BertTaskType.TextClassification)
loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor);
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
targetsTensor = targetsTensor.@long().view(-1);
logits = logits.view(-1, logits.size(-1));
Expand Down Expand Up @@ -338,7 +338,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.Vector,
NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol)));
}
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
var metadata = new List<SchemaShape.Column>();
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
Expand Down Expand Up @@ -387,7 +387,7 @@ private protected override void CheckInputSchema(SchemaShape inputSchema)
TextDataViewType.Instance.ToString(), sentenceCol2.GetTypeString());
}
}
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
if (labelCol.ItemType != NumberDataViewType.UInt32)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
Expand Down Expand Up @@ -535,7 +535,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
info[1] = new DataViewSchema.DetachedColumn(Parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single, Parent.Options.NumberOfClasses), meta.ToAnnotations());
return info;
}
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
var info = new DataViewSchema.DetachedColumn[1];
var keyType = Parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NameEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NamedEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
///
/// ### Input and Output Columns
/// The input label column data must be a Vector of [string](xref:Microsoft.ML.Data.TextDataViewType) type and the sentence columns must be of type<xref:Microsoft.ML.Data.TextDataViewType>.
Expand All @@ -54,7 +54,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
/// | Exportable to ONNX | No |
///
/// ### Training Algorithm Details
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of name entity recognition.
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of named entity recognition.
/// ]]>
/// </format>
/// </remarks>
Expand Down Expand Up @@ -93,7 +93,7 @@ internal NerTrainer(IHostEnvironment env,
BatchSize = batchSize,
MaxEpoch = maxEpochs,
ValidationSet = validationSet,
TaskType = BertTaskType.NameEntityRecognition
TaskType = BertTaskType.NamedEntityRecognition
})
{
}
Expand Down Expand Up @@ -295,7 +295,7 @@ private static NerTransformer Create(IHostEnvironment env, ModelLoadContext ctx)

options.Sentence1ColumnName = ctx.LoadString();
options.Sentence2ColumnName = ctx.LoadStringOrNull();
options.TaskType = BertTaskType.NameEntityRecognition;
options.TaskType = BertTaskType.NamedEntityRecognition;

BinarySaver saver = new BinarySaver(env, new BinarySaver.Arguments());
DataViewType type;
Expand Down
47 changes: 43 additions & 4 deletions src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Text;
using Microsoft.ML.Data;
using Microsoft.ML.TorchSharp.AutoFormerV2;
Expand Down Expand Up @@ -161,7 +162,45 @@ public static ObjectDetectionMetrics EvaluateObjectDetection(
}

/// <summary>
/// Fine tune a NAS-BERT model for Name Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, int, int, BertArchitecture, IDataView)"/> method instead
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="labelColumnName">Name of the label column. Column should be a key type.</param>
/// <param name="outputColumnName">Name of the output column. It will be a key type. It is the predicted label.</param>
/// <param name="sentence1ColumnName">Name of the column for the first sentence.</param>
/// <param name="batchSize">Number of rows in the batch.</param>
/// <param name="maxEpochs">Maximum number of times to loop through your training set.</param>
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
/// <returns></returns>
[Obsolete("Please use NamedEntityRecognition method instead", false)]
[EditorBrowsable(EditorBrowsableState.Never)]
public static NerTrainer NameEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
string labelColumnName = DefaultColumnNames.Label,
string outputColumnName = DefaultColumnNames.PredictedLabel,
string sentence1ColumnName = "Sentence",
int batchSize = 32,
int maxEpochs = 10,
BertArchitecture architecture = BertArchitecture.Roberta,
IDataView validationSet = null)
=> NamedEntityRecognition(catalog, labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, architecture, validationSet);

/// <summary>
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, NerTrainer.NerOptions)"/> method instead
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The full set of advanced options.</param>
/// <returns></returns>
[Obsolete("Please use NamedEntityRecognition method instead", false)]
[EditorBrowsable(EditorBrowsableState.Never)]
public static NerTrainer NameEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
NerTrainer.NerOptions options)
=> NamedEntityRecognition(catalog, options);

/// <summary>
/// Fine tune a NAS-BERT model for Named Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
/// will map to a single token, and we automatically add 2 specical tokens (a start token and a separator token)
/// so in general this limit will be 510 words for all sentences.
/// </summary>
Expand All @@ -174,7 +213,7 @@ public static ObjectDetectionMetrics EvaluateObjectDetection(
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
/// <returns></returns>
public static NerTrainer NameEntityRecognition(
public static NerTrainer NamedEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
string labelColumnName = DefaultColumnNames.Label,
string outputColumnName = DefaultColumnNames.PredictedLabel,
Expand All @@ -186,12 +225,12 @@ public static NerTrainer NameEntityRecognition(
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, validationSet, architecture);

/// <summary>
/// Fine tune a Name Entity Recognition model.
/// Fine tune a Named Entity Recognition model.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The full set of advanced options.</param>
/// <returns></returns>
public static NerTrainer NameEntityRecognition(
public static NerTrainer NamedEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
NerTrainer.NerOptions options)
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), options);
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.Tests/NerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void TestSimpleNer()
}));
var chain = new EstimatorChain<ITransformer>();
var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
.Append(ML.MulticlassClassification.Trainers.NameEntityRecognition(outputColumnName: "outputColumn"))
.Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "outputColumn"))
.Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));

var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
Expand Down