diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
index 0f8a76ddfc..13cde121e7 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
@@ -25,11 +25,11 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool
AddBeginningOfSentence = addBos;
AddEndOfSentence = addEos;
BeginningOfSentenceToken = modelProto.TrainerSpec.BosPiece ?? "";
- BeginningOfSentenceId = modelProto.TrainerSpec.BosId <= 0 ? 1 : modelProto.TrainerSpec.BosId;
+ BeginningOfSentenceId = Math.Max(0, modelProto.TrainerSpec.BosId);
EndOfSentenceToken = modelProto.TrainerSpec.EosPiece ?? "";
- EndOfSentenceId = modelProto.TrainerSpec.EosId <= 0 ? 1 : modelProto.TrainerSpec.EosId;
+ EndOfSentenceId = Math.Max(0, modelProto.TrainerSpec.EosId);
UnknownToken = modelProto.TrainerSpec.UnkPiece ?? "";
- UnknownId = modelProto.TrainerSpec.UnkId < 0 ? 0 : modelProto.TrainerSpec.UnkId;
+ UnknownId = Math.Max(0, modelProto.TrainerSpec.UnkId);
AddDummyPrefix = modelProto.NormalizerSpec.AddDummyPrefix;
EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces;
TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix;
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs
index dca346caea..057b34036b 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs
@@ -62,8 +62,27 @@ public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos
_trie = new DoubleArrayTrie(_vocab);
- _vocabReverse[BeginningOfSentenceId] = (BeginningOfSentenceToken, 0f, 0);
- _vocabReverse[EndOfSentenceId] = (EndOfSentenceToken, 0f, 0);
+ // Once the trie is built, we need to add the special tokens to the vocabulary.
+ // Including these special tokens ensures they are mapped like regular tokens.
+ // SentencePiece specifically handles the BOS, EOS, and UNK tokens, while the PAD token is optional.
+
+ Debug.Assert(modelProto.TrainerSpec.UnkId >= 0);
+ Debug.Assert(modelProto.TrainerSpec.BosId >= 0);
+ Debug.Assert(modelProto.TrainerSpec.EosId >= 0);
+
+ _vocab[modelProto.TrainerSpec.UnkPiece] = modelProto.TrainerSpec.UnkId;
+ _vocab[modelProto.TrainerSpec.BosPiece] = modelProto.TrainerSpec.BosId;
+ _vocab[modelProto.TrainerSpec.EosPiece] = modelProto.TrainerSpec.EosId;
+
+ _vocabReverse[modelProto.TrainerSpec.BosId] = (modelProto.TrainerSpec.BosPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control);
+ _vocabReverse[modelProto.TrainerSpec.EosId] = (modelProto.TrainerSpec.EosPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control);
+ _vocabReverse[modelProto.TrainerSpec.UnkId] = (modelProto.TrainerSpec.UnkPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Unknown);
+
+ if (modelProto.TrainerSpec.PadId >= 0)
+ {
+ _vocab[modelProto.TrainerSpec.PadPiece] = modelProto.TrainerSpec.PadId;
+ _vocabReverse[modelProto.TrainerSpec.PadId] = (modelProto.TrainerSpec.PadPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control);
+ }
}
public SentencePieceUnigramModel(SentencePieceOptions options) : base(options)
diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs
index 97b3f4a7a4..e939e7f35b 100644
--- a/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs
@@ -355,7 +355,7 @@ internal int Normalize(ReadOnlySpan input, ref Span normalized, ref
break;
}
- ReadOnlySpan normalizedByte = normalizedPrefix.Equals(default(Memory)) ? input.Slice(0, p) : normalizedPrefix.Span;
+ ReadOnlySpan normalizedByte = normalizedPrefix.Length == 0 ? input.Slice(0, p) : normalizedPrefix.Span;
if (normalizedByte[0] != (byte)' ')
{
break;
@@ -386,7 +386,7 @@ internal int Normalize(ReadOnlySpan input, ref Span normalized, ref
while (!input.IsEmpty)
{
int p = NormalizePrefix(input, out Memory normalizedPrefix);
- ReadOnlySpan sp = normalizedPrefix.Equals(default(Memory)) ? input.Slice(0, p) : normalizedPrefix.Span;
+ ReadOnlySpan sp = normalizedPrefix.Length == 0 ? input.Slice(0, p) : normalizedPrefix.Span;
// Removes heading spaces in sentence piece, if the previous sentence piece ends with whitespace.
while (isPrevSpace && sp.Length > 0 && sp[0] == (byte)' ')
diff --git a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs
index 2e948a36ea..32a7fc666f 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs
@@ -147,6 +147,26 @@ public static IEnumerable