diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs index 0f8a76ddfc..13cde121e7 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs @@ -25,11 +25,11 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool AddBeginningOfSentence = addBos; AddEndOfSentence = addEos; BeginningOfSentenceToken = modelProto.TrainerSpec.BosPiece ?? ""; - BeginningOfSentenceId = modelProto.TrainerSpec.BosId <= 0 ? 1 : modelProto.TrainerSpec.BosId; + BeginningOfSentenceId = Math.Max(0, modelProto.TrainerSpec.BosId); EndOfSentenceToken = modelProto.TrainerSpec.EosPiece ?? ""; - EndOfSentenceId = modelProto.TrainerSpec.EosId <= 0 ? 1 : modelProto.TrainerSpec.EosId; + EndOfSentenceId = Math.Max(0, modelProto.TrainerSpec.EosId); UnknownToken = modelProto.TrainerSpec.UnkPiece ?? ""; - UnknownId = modelProto.TrainerSpec.UnkId < 0 ? 0 : modelProto.TrainerSpec.UnkId; + UnknownId = Math.Max(0, modelProto.TrainerSpec.UnkId); AddDummyPrefix = modelProto.NormalizerSpec.AddDummyPrefix; EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces; TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix; diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs index dca346caea..057b34036b 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs @@ -62,8 +62,27 @@ public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos _trie = new DoubleArrayTrie(_vocab); - _vocabReverse[BeginningOfSentenceId] = (BeginningOfSentenceToken, 0f, 0); - _vocabReverse[EndOfSentenceId] = (EndOfSentenceToken, 0f, 0); + // Once the trie is built, we need to add the special tokens to the vocabulary. + // Including these special tokens ensures they are mapped like regular tokens. + // SentencePiece specifically handles the BOS, EOS, and UNK tokens, while the PAD token is optional. + + Debug.Assert(modelProto.TrainerSpec.UnkId >= 0); + Debug.Assert(modelProto.TrainerSpec.BosId >= 0); + Debug.Assert(modelProto.TrainerSpec.EosId >= 0); + + _vocab[modelProto.TrainerSpec.UnkPiece] = modelProto.TrainerSpec.UnkId; + _vocab[modelProto.TrainerSpec.BosPiece] = modelProto.TrainerSpec.BosId; + _vocab[modelProto.TrainerSpec.EosPiece] = modelProto.TrainerSpec.EosId; + + _vocabReverse[modelProto.TrainerSpec.BosId] = (modelProto.TrainerSpec.BosPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control); + _vocabReverse[modelProto.TrainerSpec.EosId] = (modelProto.TrainerSpec.EosPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control); + _vocabReverse[modelProto.TrainerSpec.UnkId] = (modelProto.TrainerSpec.UnkPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Unknown); + + if (modelProto.TrainerSpec.PadId >= 0) + { + _vocab[modelProto.TrainerSpec.PadPiece] = modelProto.TrainerSpec.PadId; + _vocabReverse[modelProto.TrainerSpec.PadId] = (modelProto.TrainerSpec.PadPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control); + } } public SentencePieceUnigramModel(SentencePieceOptions options) : base(options) diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs index 97b3f4a7a4..e939e7f35b 100644 --- a/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs +++ b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs @@ -355,7 +355,7 @@ internal int Normalize(ReadOnlySpan input, ref Span normalized, ref break; } - ReadOnlySpan normalizedByte = normalizedPrefix.Equals(default(Memory)) ? input.Slice(0, p) : normalizedPrefix.Span; + ReadOnlySpan normalizedByte = normalizedPrefix.Length == 0 ? input.Slice(0, p) : normalizedPrefix.Span; if (normalizedByte[0] != (byte)' ') { break; @@ -386,7 +386,7 @@ internal int Normalize(ReadOnlySpan input, ref Span normalized, ref while (!input.IsEmpty) { int p = NormalizePrefix(input, out Memory normalizedPrefix); - ReadOnlySpan sp = normalizedPrefix.Equals(default(Memory)) ? input.Slice(0, p) : normalizedPrefix.Span; + ReadOnlySpan sp = normalizedPrefix.Length == 0 ? input.Slice(0, p) : normalizedPrefix.Span; // Removes heading spaces in sentence piece, if the previous sentence piece ends with whitespace. while (isPrevSpace && sp.Length > 0 && sp[0] == (byte)' ') diff --git a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs index 2e948a36ea..32a7fc666f 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs @@ -147,6 +147,26 @@ public static IEnumerable UnigramTestData() new Range[0] }; + yield return new object[] + { + "\u001f", // string start with control character + "▁\u001f", + "", + new int[] { 5, 0 }, + new string[] { "▁", "\u001f"}, + new Range[] { new Range(0, 1), new Range(1, 2) } + }; + + yield return new object[] + { + "\ufe7b", // Decompose to two letters + "▁\u0640\u0650", + "\u0640\u0650", + new int[] { 17637, 487 }, + new string[] { "▁\u0640", "\u0650" }, + new Range[] { new Range(0, 2), new Range(2, 3) } + }; + yield return new object[] { @"The sun dipped below the horizon, casting a warm golden hue across the tranquil meadow. Birds fluttered from " +