Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0
msgspec
gguf >= 0.17.0
mistral_common[image] >= 1.8.5
mistral_common[image] >= 1.8.8
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
Expand Down
2 changes: 1 addition & 1 deletion requirements/nightly_torch_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
mistral_common[image,audio] >= 1.8.8 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ torchaudio==2.9.1
torchvision==0.24.1
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
mistral_common[image,audio] >= 1.8.8 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test
opencv-python-headless >= 4.11.0 # required for video test
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.8.5
mistral-common==1.8.8
# via -r requirements/test.in
mlflow==2.22.0
# via terratorch
Expand Down
73 changes: 20 additions & 53 deletions vllm/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,28 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest,
)
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.logger import init_logger

from .protocol import TokenizerLike

if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import (
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mistral_common is very light weight (has no heavy dependencies) and is a necessary requirement in common.txt => so I think it's cleaner to directly import at the top

ChatCompletionRequest as MistralChatCompletionRequest,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from transformers import BatchEncoding

try:
Expand Down Expand Up @@ -101,8 +112,6 @@ def _prepare_apply_chat_template_tools_and_messages(
continue_final_message: bool = False,
add_generation_prompt: bool = False,
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
from mistral_common.protocol.instruct.tool_calls import Function, Tool

if add_generation_prompt and continue_final_message:
raise ValueError(
"Cannot set both `add_generation_prompt` and "
Expand Down Expand Up @@ -181,8 +190,6 @@ def validate_request_params(request: "ChatCompletionRequest"):


def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
from mistral_common.tokens.tokenizers.tekken import Tekkenizer

assert isinstance(tokenizer, Tekkenizer), type(tokenizer)

t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
Expand Down Expand Up @@ -210,8 +217,6 @@ def from_pretrained(
download_dir: str | None = None,
**kwargs,
) -> "MistralTokenizer":
from mistral_common.protocol.instruct.validator import ValidationMode

try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
Expand All @@ -235,12 +240,6 @@ def from_pretrained(
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
super().__init__()

from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer

self.transformers_tokenizer = tokenizer
self.mistral = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer
Expand Down Expand Up @@ -270,37 +269,20 @@ def __init__(self, tokenizer: "MistralCommonBackend") -> None:
# Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))

# Vocab sorted by token id.
self._vocab = self.tokenizer.vocab()
self._max_token_id = self.vocab_size - 1

# Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids()
self._special_token_ids_set = set(self._special_token_ids)
self._special_tokens = self._get_special_tokens(self._special_token_ids)
self._special_tokens_set = set(self._special_tokens)

# Vocab sorted by token id.
self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1

def _get_special_token_ids(self) -> list[int]:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer

if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
elif self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
special_ids = self.tokenizer._control_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids)
return [i for i in range(len(self._vocab)) if self.tokenizer.is_special(i)]

def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy

return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in all_special_ids
Expand Down Expand Up @@ -460,15 +442,6 @@ def batch_decode(
)

def convert_tokens_to_string(self, tokens: list[str]) -> str:
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer

to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
Expand Down Expand Up @@ -523,12 +496,6 @@ def convert_ids_to_tokens(
ids: list[int],
skip_special_tokens: bool = False,
) -> list[str]:
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13

if not skip_special_tokens:
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]

Expand Down