Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,11 @@ public override bool IsValidChar(char ch)
throw new NotImplementedException();
}

public override bool IsFirstTokenInWord(string token)
{
throw new NotImplementedException();
}

internal static readonly List<Token> EmptyTokensList = new();
}
}
2 changes: 0 additions & 2 deletions src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,5 @@ public abstract class Model
/// <param name="ch"></param>
/// <returns></returns>
public abstract bool IsValidChar(char ch);

}

}
13 changes: 8 additions & 5 deletions src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ private protected override torch.Tensor PrepareBatchTensor(ref List<Tensor> inpu
return DataUtils.CollateTokens(inputTensors, Tokenizer.RobertaModel().PadIndex, device: Device);
}

private protected override torch.Tensor PrepareRowTensor()
private protected override torch.Tensor PrepareRowTensor(ref TLabelCol target)
{
ReadOnlyMemory<char> sentence1 = default;
Sentence1Getter(ref sentence1);
Expand Down Expand Up @@ -494,7 +494,8 @@ private protected abstract class NasBertMapper : TorchSharpBaseMapper

private static readonly FuncInstanceMethodInfo1<NasBertMapper, DataViewSchema.DetachedColumn, Delegate> _makeLabelAnnotationGetter
= FuncInstanceMethodInfo1<NasBertMapper, DataViewSchema.DetachedColumn, Delegate>.Create(target => target.GetLabelAnnotations<int>);

internal static readonly int[] InitTokenArray = new[] { 0 /* InitToken */ };
internal static readonly int[] SeperatorTokenArray = new[] { 2 /* SeperatorToken */ };

public NasBertMapper(TorchSharpBaseTransformer<TLabelCol, TTargetsCol> parent, DataViewSchema inputSchema) :
base(parent, inputSchema)
Expand Down Expand Up @@ -583,13 +584,15 @@ private IList<int> PrepInputTokens(ref ReadOnlyMemory<char> sentence1, ref ReadO
getSentence1(ref sentence1);
if (getSentence2 == default)
{
return new[] { 0 /* InitToken */ }.Concat(tokenizer.EncodeToConverted(sentence1.ToString())).ToList();
List<int> newList = new List<int>(tokenizer.EncodeToConverted(sentence1.ToString()));
newList.Insert(0, 0);
return newList;
}
else
{
getSentence2(ref sentence2);
return new[] { 0 /* InitToken */ }.Concat(tokenizer.EncodeToConverted(sentence1.ToString()))
.Concat(new[] { 2 /* SeperatorToken */ }).Concat(tokenizer.EncodeToConverted(sentence2.ToString())).ToList();
return InitTokenArray.Concat(tokenizer.EncodeToConverted(sentence1.ToString()))
.Concat(SeperatorTokenArray).Concat(tokenizer.EncodeToConverted(sentence2.ToString())).ToList();
}
}

Expand Down
85 changes: 77 additions & 8 deletions src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using Microsoft.ML.TorchSharp.NasBert.Models;
using TorchSharp;
using static Microsoft.ML.TorchSharp.NasBert.NasBertTrainer;
using static TorchSharp.torch;

[assembly: LoadableClass(typeof(NerTransformer), null, typeof(SignatureLoadModel),
NerTransformer.UserName, NerTransformer.LoaderSignature)]
Expand Down Expand Up @@ -61,6 +62,8 @@ namespace Microsoft.ML.TorchSharp.NasBert
///
public class NerTrainer : NasBertTrainer<VBuffer<uint>, TargetType>
{
private const char StartChar = (char)(' ' + 256);

public class NerOptions : NasBertOptions
{
public NerOptions()
Expand All @@ -69,6 +72,7 @@ public NerOptions()
EncoderOutputDim = 384;
EmbeddingDim = 128;
Arches = new int[] { 15, 16, 14, 0, 0, 0, 15, 16, 14, 0, 0, 0, 17, 14, 15, 0, 0, 0, 17, 14, 15, 0, 0, 0 };
TaskType = BertTaskType.NamedEntityRecognition;
}
}
internal NerTrainer(IHostEnvironment env, NerOptions options) : base(env, options)
Expand All @@ -93,7 +97,6 @@ internal NerTrainer(IHostEnvironment env,
BatchSize = batchSize,
MaxEpoch = maxEpochs,
ValidationSet = validationSet,
TaskType = BertTaskType.NamedEntityRecognition
})
{
}
Expand All @@ -108,9 +111,12 @@ private protected override TorchSharpBaseTransformer<VBuffer<uint>, TargetType>
return new NerTransformer(host, options as NasBertOptions, model as NasBertModel, labelColumn);
}

internal static bool TokenStartsWithSpace(string token) => token is null || (token.Length != 0 && token[0] == StartChar);

private protected class Trainer : NasBertTrainerBase
{
private const string ModelUrlString = "models/pretrained_NasBert_14M_encoder.tsm";
internal static readonly int[] ZeroArray = new int[] { 0 /* InitToken */};

public Trainer(TorchSharpBaseTrainer<VBuffer<uint>, TargetType> parent, IChannel ch, IDataView input) : base(parent, ch, input, ModelUrlString)
{
Expand Down Expand Up @@ -155,6 +161,40 @@ private protected override torch.Tensor CreateTargetsTensor(ref List<TargetType>
return torch.tensor(targetArray, device: Device);
}

private protected override torch.Tensor PrepareRowTensor(ref VBuffer<uint> target)
{
ReadOnlyMemory<char> sentenceRom = default;
Sentence1Getter(ref sentenceRom);
var sentence = sentenceRom.ToString();
Tensor t;
var encoding = Tokenizer.Encode(sentence);

if (target.Length != encoding.Tokens.Count)
{
var targetIndex = 0;
var targetEditor = VBufferEditor.Create(ref target, encoding.Tokens.Count);
var newValues = targetEditor.Values;
for (var i = 0; i < encoding.Tokens.Count; i++)
{
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
{
newValues[i] = target.GetItemOrDefault(++targetIndex);
}
else
{
newValues[i] = target.GetItemOrDefault(targetIndex);
}
}
target = targetEditor.Commit();
}
t = torch.tensor((ZeroArray).Concat(Tokenizer.RobertaModel().IdsToOccurrenceRanks(encoding.Ids)).ToList(), device: Device);

if (t.NumberOfElements > 512)
t = t.slice(0, 0, 512, 1);

return t;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private protected override int GetNumCorrect(torch.Tensor predictions, torch.Tensor targets)
{
Expand Down Expand Up @@ -334,6 +374,41 @@ private protected override Delegate CreateGetter(DataViewRow input, int iinfo, T

}

private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer tokenizer, TensorCacher outputCacher)
{
var pre = tokenizer.PreTokenizer.PreTokenize(sentence);
TokenizerResult encoding = tokenizer.Encode(sentence);

var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
var prediction = argmax.ToArray<long>();

var targetIndex = 0;
// Figure out actual count of output tokens
for (var i = 0; i < encoding.Tokens.Count; i++)
{
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
{
targetIndex++;
}
}

var editor = VBufferEditor.Create(ref dst, targetIndex + 1);
var newValues = editor.Values;
targetIndex = 0;

newValues[targetIndex++] = (uint)prediction[0];

for (var i = 1; i < encoding.Tokens.Count; i++)
{
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
{
newValues[targetIndex++] = (uint)prediction[i];
}
}

dst = editor.Commit();
}

private Delegate MakePredictedLabelGetter(DataViewRow input, IChannel ch, TensorCacher outputCacher)
{
ValueGetter<ReadOnlyMemory<char>> getSentence1 = default;
Expand All @@ -353,13 +428,7 @@ private Delegate MakePredictedLabelGetter(DataViewRow input, IChannel ch, Tensor
var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
var prediction = argmax.ToArray<long>();

var editor = VBufferEditor.Create(ref dst, prediction.Length - 1);
for (int i = 1; i < prediction.Length; i++)
{
editor.Values[i - 1] = (uint)prediction[i];
}

dst = editor.Commit();
CondenseOutput(ref dst, sentence1.ToString(), tokenizer, outputCacher);
};

return classification;
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.TorchSharp/TorchSharpBaseTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ private bool ValidateStep(DataViewRowCursor cursor,
cursorValid = cursor.MoveNext();
if (cursorValid)
{
inputTensors.Add(PrepareRowTensor());
TLabelCol target = default;
labelGetter(ref target);
inputTensors.Add(PrepareRowTensor(ref target));
targets.Add(AddToTargets(target));
}
else
Expand Down Expand Up @@ -312,9 +312,9 @@ private bool TrainStep(IHost host,
cursorValid = cursor.MoveNext();
if (cursorValid)
{
inputTensors.Add(PrepareRowTensor());
TLabelCol target = default;
labelGetter(ref target);
inputTensors.Add(PrepareRowTensor(ref target));
targets.Add(AddToTargets(target));
}
else
Expand Down Expand Up @@ -343,7 +343,7 @@ private bool TrainStep(IHost host,

private protected abstract void RunModelAndBackPropagate(ref List<Tensor> inputTensorm, ref Tensor targetsTensor);

private protected abstract torch.Tensor PrepareRowTensor();
private protected abstract torch.Tensor PrepareRowTensor(ref TLabelCol target);
private protected abstract torch.Tensor PrepareBatchTensor(ref List<Tensor> inputTensors, Device device);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down
Loading