Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool
AddBeginningOfSentence = addBos;
AddEndOfSentence = addEos;
BeginningOfSentenceToken = modelProto.TrainerSpec.BosPiece ?? "<s>";
BeginningOfSentenceId = modelProto.TrainerSpec.BosId <= 0 ? 1 : modelProto.TrainerSpec.BosId;
BeginningOfSentenceId = modelProto.TrainerSpec.BosId;
EndOfSentenceToken = modelProto.TrainerSpec.EosPiece ?? "</s>";
EndOfSentenceId = modelProto.TrainerSpec.EosId <= 0 ? 1 : modelProto.TrainerSpec.EosId;
EndOfSentenceId = modelProto.TrainerSpec.EosId;
UnknownToken = modelProto.TrainerSpec.UnkPiece ?? "<unk>";
UnknownId = modelProto.TrainerSpec.UnkId < 0 ? 0 : modelProto.TrainerSpec.UnkId;
UnknownId = modelProto.TrainerSpec.UnkId;
AddDummyPrefix = modelProto.NormalizerSpec.AddDummyPrefix;
EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces;
TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix;
Expand Down
23 changes: 21 additions & 2 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,27 @@ public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos

_trie = new DoubleArrayTrie(_vocab);

_vocabReverse[BeginningOfSentenceId] = (BeginningOfSentenceToken, 0f, 0);
_vocabReverse[EndOfSentenceId] = (EndOfSentenceToken, 0f, 0);
// Once the trie is built, we need to add the special tokens to the vocabulary.
// Including these special tokens ensures they are mapped like regular tokens.
// SentencePiece specifically handles the BOS, EOS, and UNK tokens, while the PAD token is optional.

Debug.Assert(modelProto.TrainerSpec.UnkId >= 0);
Debug.Assert(modelProto.TrainerSpec.BosId >= 0);
Debug.Assert(modelProto.TrainerSpec.EosId >= 0);

_vocab[modelProto.TrainerSpec.UnkPiece] = modelProto.TrainerSpec.UnkId;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this present in the original tokenizer or is it special to our port? Asking because I don't understand why this isn't handled via modelProto.Pieces. Are we certain we only need these 3 special cases? Comment might help.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is special to our port to ensure adding these special tokens to the vocabs for easier lookup. Adding these here are not changing any behavior more than allowing the vocabulary to map these tokens which help us in some operations like decoding for example.

I'll add some detailed comment here. Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you know for sure that modelProto.Pieces.Count used when allocating _vocabReverse is greater than these IDs? I guess so since these IDs come from the modelProto too. Just seems odd to me that we step through all the modelProto.Pieces.Count above, but then we come back and overwrite some of the IDs like this.

Copy link
Member Author

@tarekgh tarekgh Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just seems odd to me that we step through all the modelProto.Pieces.Count above, but then we come back and overwrite some of the IDs like this.

This is done this way to avoid adding these special tokens to the trie as these shouldn't be part of it. The native code is also not adding such tokens when enumerating the modelProto.Pieces. It is only our addition is we add these to the vocab after we are done building the trie for easier mapping internally.

Can you know for sure that modelProto.Pieces.Count used when allocating _vocabReverse is greater than these IDs?

I can check but I want to know what you suggest doing when for any reason have wrong data, just throw exception?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicit exception might be better than index out of range, but your call.

This has made me wonder twice - in initial PR and here - so this logic warrants a comment in source to explain what's going on.

_vocab[modelProto.TrainerSpec.BosPiece] = modelProto.TrainerSpec.BosId;
_vocab[modelProto.TrainerSpec.EosPiece] = modelProto.TrainerSpec.EosId;

_vocabReverse[modelProto.TrainerSpec.BosId] = (modelProto.TrainerSpec.BosPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control);
_vocabReverse[modelProto.TrainerSpec.EosId] = (modelProto.TrainerSpec.EosPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control);
_vocabReverse[modelProto.TrainerSpec.UnkId] = (modelProto.TrainerSpec.UnkPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Unknown);

if (modelProto.TrainerSpec.PadId >= 0)
{
_vocab[modelProto.TrainerSpec.PadPiece] = modelProto.TrainerSpec.PadId;
_vocabReverse[modelProto.TrainerSpec.PadId] = (modelProto.TrainerSpec.PadPiece, 0f, ModelProto.Types.SentencePiece.Types.Type.Control);
}
}

public SentencePieceUnigramModel(SentencePieceOptions options) : base(options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ internal int Normalize(ReadOnlySpan<byte> input, ref Span<byte> normalized, ref
break;
}

ReadOnlySpan<byte> normalizedByte = normalizedPrefix.Equals(default(Memory<byte>)) ? input.Slice(0, p) : normalizedPrefix.Span;
ReadOnlySpan<byte> normalizedByte = normalizedPrefix.Length == 0 ? input.Slice(0, p) : normalizedPrefix.Span;
if (normalizedByte[0] != (byte)' ')
{
break;
Expand Down Expand Up @@ -386,7 +386,7 @@ internal int Normalize(ReadOnlySpan<byte> input, ref Span<byte> normalized, ref
while (!input.IsEmpty)
{
int p = NormalizePrefix(input, out Memory<byte> normalizedPrefix);
ReadOnlySpan<byte> sp = normalizedPrefix.Equals(default(Memory<byte>)) ? input.Slice(0, p) : normalizedPrefix.Span;
ReadOnlySpan<byte> sp = normalizedPrefix.Length == 0 ? input.Slice(0, p) : normalizedPrefix.Span;

// Removes heading spaces in sentence piece, if the previous sentence piece ends with whitespace.
while (isPrevSpace && sp.Length > 0 && sp[0] == (byte)' ')
Expand Down
20 changes: 20 additions & 0 deletions test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,26 @@ public static IEnumerable<object[]> UnigramTestData()
new Range[0]
};

yield return new object[]
{
"\u001f", // string start with control character
"▁\u001f",
"",
new int[] { 5, 0 },
new string[] { "▁", "\u001f"},
new Range[] { new Range(0, 1), new Range(1, 2) }
};

yield return new object[]
{
"\ufe7b", // Decompose to two letters
"▁\u0640\u0650",
"\u0640\u0650",
new int[] { 17637, 487 },
new string[] { "▁\u0640", "\u0650" },
new Range[] { new Range(0, 2), new Range(2, 3) }
};

yield return new object[]
{
@"The sun dipped below the horizon, casting a warm golden hue across the tranquil meadow. Birds fluttered from " +
Expand Down
Loading