Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add Mistral Tokenization to improve robustness and chat encoding #7739

Merged
merged 66 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
f64f664
WIP
patrickvonplaten Aug 21, 2024
99abde0
WIP
patrickvonplaten Aug 21, 2024
441628b
WIP
patrickvonplaten Aug 21, 2024
3d13b50
WIP
patrickvonplaten Aug 21, 2024
b2f2b2f
WIP
patrickvonplaten Aug 21, 2024
e2b4f29
Reformat
patrickvonplaten Aug 21, 2024
d6bb4d8
Up
patrickvonplaten Aug 21, 2024
3d046b6
Update vllm/entrypoints/chat_utils.py
patrickvonplaten Aug 21, 2024
cfbe3cd
Apply suggestions from code review
patrickvonplaten Aug 21, 2024
8c7d0cf
Apply suggestions from code review
patrickvonplaten Aug 21, 2024
fb092dc
Apply suggestions from code review
patrickvonplaten Aug 21, 2024
9e2df7b
Update vllm/entrypoints/openai/serving_chat.py
patrickvonplaten Aug 21, 2024
662e1ce
Apply suggestions from code review
patrickvonplaten Aug 21, 2024
a56425e
Update vllm/transformers_utils/detokenizer.py
patrickvonplaten Aug 21, 2024
a3744fe
Update vllm/transformers_utils/detokenizer.py
patrickvonplaten Aug 21, 2024
513b4f5
Update vllm/transformers_utils/detokenizer.py
patrickvonplaten Aug 21, 2024
6981440
Update vllm/transformers_utils/detokenizer.py
patrickvonplaten Aug 21, 2024
373bbe7
Apply suggestions from code review
patrickvonplaten Aug 21, 2024
bd00d7d
more format
patrickvonplaten Aug 21, 2024
1ae7b5b
More format
patrickvonplaten Aug 21, 2024
ba4d770
isort
patrickvonplaten Aug 21, 2024
59dd456
isort
patrickvonplaten Aug 21, 2024
73a8341
up
patrickvonplaten Aug 21, 2024
0221276
up
patrickvonplaten Aug 21, 2024
2dee65b
up
patrickvonplaten Aug 21, 2024
f32c3db
up
patrickvonplaten Aug 21, 2024
a7a282d
finish
patrickvonplaten Aug 21, 2024
5460ab6
finish
patrickvonplaten Aug 21, 2024
ba290a4
finish
patrickvonplaten Aug 21, 2024
1780de7
finish
patrickvonplaten Aug 21, 2024
e0e9de1
Merge branch 'main' into add_mistral_common
patrickvonplaten Aug 21, 2024
7034e3a
Update vllm/entrypoints/llm.py
patrickvonplaten Aug 21, 2024
4e543f8
yapf again
patrickvonplaten Aug 21, 2024
0a3ccb4
WIP
patrickvonplaten Aug 22, 2024
dd943da
Merge branch 'add_mistral_common' of https://github.com/patrickvonpla…
patrickvonplaten Aug 22, 2024
5dd4458
Apply suggestions from code review
patrickvonplaten Aug 22, 2024
5c469d6
WIP
patrickvonplaten Aug 22, 2024
9e6e239
WIP
patrickvonplaten Aug 22, 2024
18cff00
yapf again
patrickvonplaten Aug 22, 2024
f711795
yapf again
patrickvonplaten Aug 22, 2024
31a04a9
Up
patrickvonplaten Aug 22, 2024
7a02d3b
Up
patrickvonplaten Aug 22, 2024
b921956
WIP
patrickvonplaten Aug 22, 2024
ec3f035
WIP
patrickvonplaten Aug 22, 2024
d18dc1a
Up
patrickvonplaten Aug 22, 2024
908beed
finish
patrickvonplaten Aug 22, 2024
39a9234
WIP
patrickvonplaten Aug 22, 2024
911ec07
Merge branch 'add_mistral_common' of https://github.com/patrickvonpla…
patrickvonplaten Aug 22, 2024
002d0cf
WIP
patrickvonplaten Aug 22, 2024
ba25ce7
finish
patrickvonplaten Aug 22, 2024
4dbf270
WIP
patrickvonplaten Aug 22, 2024
e1a190c
Merge branch 'add_mistral_common' of https://github.com/patrickvonpla…
patrickvonplaten Aug 22, 2024
50296c5
WIP
patrickvonplaten Aug 22, 2024
cb48bfc
WIP
patrickvonplaten Aug 22, 2024
3d43014
WIP
patrickvonplaten Aug 22, 2024
1481229
WIP
patrickvonplaten Aug 22, 2024
4963015
finish
patrickvonplaten Aug 22, 2024
25ceeb1
finish
patrickvonplaten Aug 22, 2024
4e99ef4
WIP
patrickvonplaten Aug 24, 2024
91e5027
isort
patrickvonplaten Aug 24, 2024
61c1817
Trigger CI build
patrickvonplaten Aug 25, 2024
daf9a25
Trigger CI build
patrickvonplaten Aug 25, 2024
e3c8046
Trigger CI build
patrickvonplaten Aug 25, 2024
eba5658
Trigger CI build
patrickvonplaten Aug 27, 2024
ba08ee6
Trigger CI build
patrickvonplaten Aug 27, 2024
b905f1e
Trigger CI build
patrickvonplaten Aug 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.3.4
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class ModelConfig:
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
dtype: Data type for model weights and activations. The "auto" option
Expand Down Expand Up @@ -238,10 +239,10 @@ def _init_multimodal_config(

def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
if tokenizer_mode not in ["auto", "slow", "mistral"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'.")
"either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode

def _verify_embedding_mode(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
choices=['auto', 'slow', 'mistral'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer.')
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyHFTokenizer, AnyTokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -113,7 +113,7 @@ def load_chat_template(


@lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyHFTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
Expand Down Expand Up @@ -259,7 +259,7 @@ def parse_chat_messages(


def apply_chat_template(
tokenizer: AnyTokenizer,
tokenizer: AnyHFTokenizer,
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
Expand Down
29 changes: 19 additions & 10 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -383,20 +383,29 @@ def chat(
"""

tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()

conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)
if isinstance(tokenizer, MistralTokenizer):
prompt_token_ids = tokenizer.encode_messages(messages)
prompts = None
else:
model_config = self.llm_engine.get_model_config()

conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)

prompts = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_template=add_generation_template,
)

prompts = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_template=add_generation_template)
prompt_token_ids = None

return self.generate(
prompts,
prompts, # type: ignore[arg-type]
sampling_params,
prompt_token_ids=prompt_token_ids,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
Expand Down
90 changes: 54 additions & 36 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import time
from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Final,
List, Optional)
from typing import Sequence as GenericSequence
from typing import Union

Expand All @@ -22,15 +23,17 @@
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
PromptAdapterPath,
TextTokensPrompt)
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyHFTokenizer, AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid

logger = init_logger(__name__)
Expand Down Expand Up @@ -83,35 +86,49 @@ async def create_chat_completion(
if error_check_ret is not None:
return error_check_ret

try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)

model_config = self.model_config
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)

conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)

tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
if isinstance(tokenizer, MistralTokenizer):
encoded = tokenizer.encode_chat_completion(request)
conversation = [
ConversationMessage(role=m.role, content=m.content)
for m in encoded.pop("messages")
]

prompt = apply_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
prompt_inputs = TextTokensPrompt(
prompt=encoded["prompt"],
prompt_token_ids=encoded["prompt_token_ids"],
)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))

# multi-modal is not yet supported
mm_futures: List[Awaitable[MultiModalDataDict]] = []
else:
try:
model_config = self.model_config
conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)

tool_dicts = (None if request.tools is None else
[tool.model_dump() for tool in request.tools])

prompt = apply_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
prompt_inputs = None
except Exception as e:
logger.error(
"Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))

mm_data: Optional[MultiModalDataDict] = None
try:
Expand All @@ -130,13 +147,14 @@ async def create_chat_completion(
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))

prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
if prompt_inputs is None:
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)

sampling_params = request.to_sampling_params(
tokenizer,
Expand Down Expand Up @@ -530,7 +548,7 @@ def _create_chat_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: AnyTokenizer,
tokenizer: AnyHFTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyHFTokenizer, AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid

logger = init_logger(__name__)
Expand Down Expand Up @@ -434,7 +434,7 @@ def _create_completion_logprobs(
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
tokenizer: AnyTokenizer,
tokenizer: AnyHFTokenizer,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyHFTokenizer, AnyTokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -219,7 +219,7 @@ def _normalize_prompt_text_to_input(
def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
tokenizer: AnyHFTokenizer,
prompt_ids: List[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
Expand Down Expand Up @@ -395,7 +395,7 @@ def _log_inputs(
@staticmethod
def _get_decoded_token(logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
tokenizer: AnyHFTokenizer,
return_as_token_id: bool = False) -> str:
if return_as_token_id:
return f"token_id:{token_id}"
Expand Down
4 changes: 2 additions & 2 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import AnyHFTokenizer, get_tokenizer
from vllm.utils import is_list_of

from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
Expand Down Expand Up @@ -39,7 +39,7 @@ def repeat_and_pad_token(


def repeat_and_pad_image_tokens(
tokenizer: AnyTokenizer,
tokenizer: AnyHFTokenizer,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
Expand Down
6 changes: 3 additions & 3 deletions vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup

from .tokenizer import AnyTokenizer
from .tokenizer import AnyHFTokenizer, AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup

# Used eg. for marking rejected tokens in spec decoding.
Expand Down Expand Up @@ -172,7 +172,7 @@ def _replace_none_with_empty(tokens: List[Optional[str]]):


def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
tokenizer: AnyHFTokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
Expand Down Expand Up @@ -230,7 +230,7 @@ def convert_prompt_ids_to_tokens(
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens)
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
return new_tokens, prefix_offset, read_offset


Expand Down
Loading
Loading