diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index af309b9c..ca4731fc 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -2,6 +2,7 @@ import asyncio import inspect +import os import time import uuid from collections.abc import Callable, Coroutine @@ -68,7 +69,7 @@ from collections.abc import AsyncIterator, MutableSequence from grpc.aio import ServicerContext - from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + from transformers import PreTrainedTokenizer from vllm import CompletionOutput, RequestOutput from vllm.config import ModelConfig from vllm.lora.request import LoRARequest @@ -94,6 +95,10 @@ logger = init_logger(__name__) service_metrics = ServiceMetrics() +ADD_SPECIAL_TOKENS = os.getenv("ADD_SPECIAL_TOKENS") +if ADD_SPECIAL_TOKENS is not None: + ADD_SPECIAL_TOKENS = ADD_SPECIAL_TOKENS.lower() not in (0, "false") + def with_default(value: _T, default: _T) -> _T: return value if value else default @@ -174,8 +179,7 @@ def __init__( ): self.engine: AsyncLLMEngine = engine - # These are set in post_init() - self.tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None + # This is set in post_init() self.config: ModelConfig | None = None self.max_max_new_tokens = args.max_new_tokens @@ -193,9 +197,6 @@ def __init__( async def post_init(self) -> None: self.config = await self.engine.get_model_config() - self.tokenizer_group = self.engine.engine.get_tokenizer_group() - self.tokenizer = await self.engine.get_tokenizer() - assert self.tokenizer is not None # Swap in the special TGIS stats logger tgis_stats_logger = TGISStatLogger( @@ -218,8 +219,15 @@ async def Generate( start_time = time.time() service_metrics.count_generate_request(len(request.requests)) request_id = self.request_id(context) + adapter_kwargs = ( + await self._validate_adapters(request, context) + if adapters_available + else {} + ) + tokenizer = await self._get_tokenizer(adapter_kwargs) + sampling_params, deadline = await self._validate_and_convert_params( - request.params, context + request.params, tokenizer, context ) truncate_input_tokens = with_default(request.params.truncate_input_tokens, None) request_count = len(request.requests) @@ -227,15 +235,9 @@ async def Generate( generators = [] max_is_token_limit = [False] * request_count - adapter_kwargs = ( - await self._validate_adapters(request, context) - if adapters_available - else {} - ) - for i, req in enumerate(request.requests): input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize( - sampling_params, truncate_input_tokens, req.text, context + sampling_params, truncate_input_tokens, req.text, tokenizer, context ) inputs = TextTokensPrompt( @@ -318,20 +320,26 @@ async def GenerateStream( start_time = time.time() service_metrics.count_generate_request() request_id = self.request_id(context) + adapter_kwargs = ( + await self._validate_adapters(request, context) + if adapters_available + else {} + ) + tokenizer = await self._get_tokenizer(adapter_kwargs) + sampling_params, deadline = await self._validate_and_convert_params( - request.params, context + request.params, tokenizer, context ) truncate_input_tokens = with_default(request.params.truncate_input_tokens, None) input_ids, max_is_tok_limit = await self._validate_prompt_and_tokenize( - sampling_params, truncate_input_tokens, request.request.text, context + sampling_params, + truncate_input_tokens, + request.request.text, + tokenizer, + context, ) - adapter_kwargs = ( - await self._validate_adapters(request, context) - if adapters_available - else {} - ) inputs = TextTokensPrompt( prompt=request.request.text, prompt_token_ids=input_ids ) @@ -478,7 +486,10 @@ def request_id(context: ServicerContext) -> str: # noqa: ARG004 return uuid.uuid4().hex async def _validate_and_convert_params( - self, params: Parameters, context: ServicerContext + self, + params: Parameters, + tokenizer: PreTrainedTokenizer, + context: ServicerContext, ) -> tuple[SamplingParams, float | None]: """Return (sampling_params, deadline).""" # First run TGIS validation to raise errors that match the TGIS api @@ -539,19 +550,16 @@ async def _validate_and_convert_params( decoding.length_penalty.start_index, decoding.length_penalty.decay_factor, ) - assert self.tokenizer is not None logits_processors.append( ExpDecayLengthPenaltyWarper( length_penalty=length_penalty_tuple, - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=tokenizer.eos_token_id, ) ) guided_decode_logit_processor = ( - await get_outlines_guided_decoding_logits_processor( - decoding, self.tokenizer - ) + await get_outlines_guided_decoding_logits_processor(decoding, tokenizer) ) if guided_decode_logit_processor is not None: logits_processors.append(guided_decode_logit_processor) @@ -597,7 +605,10 @@ async def _validate_and_convert_params( async def _validate_adapters( self, - request: SingleGenerationRequest | BatchedGenerationRequest, + request: SingleGenerationRequest + | BatchedGenerationRequest + | TokenizeResponse + | BatchedTokenizeRequest, context: ServicerContext, ) -> dict[str, LoRARequest | PromptAdapterRequest]: try: @@ -609,6 +620,23 @@ async def _validate_adapters( await context.abort(StatusCode.INVALID_ARGUMENT, str(e)) return adapters + async def _get_tokenizer( + self, adapter_kwargs: dict[str, Any] + ) -> PreTrainedTokenizer: + lora_request = adapter_kwargs.get("lora_request") + try: + return await self.engine.get_tokenizer(lora_request) + except TypeError as exc: + # vllm <= 0.5.2 + if "takes 1 positional argument but 2 were given" not in str(exc): + raise + + return ( + await self.engine.engine.get_tokenizer_group().get_lora_tokenizer_async( + lora_request + ) + ) + @staticmethod def _convert_reason( output: CompletionOutput, @@ -655,17 +683,15 @@ def _convert_tokens( # noqa: PLR0913 include_logprobs: bool, include_ranks: bool, top_n_tokens: int, + tokenizer: PreTrainedTokenizer, token_infos: MutableSequence[TokenInfo], # OUT token_start_offset: int = 0, ) -> None: - assert self.tokenizer - if token_start_offset: token_ids = token_ids[token_start_offset:] if logprobs_list is not None: logprobs_list = logprobs_list[token_start_offset:] - # TODO later use get_lora_tokenizer here - token_texts = self.tokenizer.convert_ids_to_tokens(token_ids) + token_texts = tokenizer.convert_ids_to_tokens(token_ids) for i, text in enumerate(token_texts): token_info = TokenInfo(text=text) if logprobs_list is None: @@ -692,10 +718,7 @@ def _convert_tokens( # noqa: PLR0913 key=lambda item: item[1].logprob, reverse=True, )[:top_n_tokens] - # TODO later use get_lora_tokenizer here - tt_texts = self.tokenizer.convert_ids_to_tokens( - [tid for tid, _ in items] - ) + tt_texts = tokenizer.convert_ids_to_tokens([tid for tid, _ in items]) token_info.top_tokens.extend( TokenInfo.TopToken( @@ -706,39 +729,36 @@ def _convert_tokens( # noqa: PLR0913 ) token_infos.append(token_info) - async def _validate_prompt_and_tokenize( + async def _validate_prompt_and_tokenize( # noqa: PLR0913 self, sampling_params: SamplingParams, truncate_input_tokens: int | None, prompt: str, + tokenizer: PreTrainedTokenizer, context: ServicerContext, ) -> tuple[list[int], bool]: assert self.config is not None max_model_len = self.config.max_model_len - # tokenize_kwargs = {"truncation": True, - # "max_length": truncate_input_tokens} \ - # if truncate_input_tokens is not None else { - # "truncation": True, "max_length": max_model_len + 1} - tokenize_kwargs: dict[str, Any] = {} - - input_ids = await self.tokenizer_group.encode_async( - prompt, - **tokenize_kwargs, - ) - # TODO this is temporary until truncation option is added - # to the TokenizerGroup encode methods - if truncate_input_tokens and truncate_input_tokens < len(input_ids): - input_ids = input_ids[-truncate_input_tokens:] - if not sampling_params.skip_special_tokens: - add_bos_token = getattr(self.tokenizer, "add_bos_token", False) - if add_bos_token: - assert self.tokenizer is not None + # Add special tokens based on env var or else only if the tokenizer + # does not have a chat template => this is not a chat model + add_special_tokens = ( + ADD_SPECIAL_TOKENS + if ADD_SPECIAL_TOKENS is not None + else not tokenizer.chat_template + ) - input_ids[0] = self.tokenizer.bos_token_id - # ----------------------------------------------- + tokenizer_kwargs: dict[str, Any] = {"add_special_tokens": add_special_tokens} + if truncate_input_tokens is not None: + tokenizer_kwargs.update( + { + "truncation": True, + "max_length": truncate_input_tokens, + } + ) + input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids token_num = len(input_ids) try: @@ -765,7 +785,7 @@ async def _validate_prompt_and_tokenize( async def Tokenize( self, request: BatchedTokenizeRequest, - context: ServicerContext, # noqa: ARG002 + context: ServicerContext, ) -> BatchedTokenizeResponse: """Handle tokenization requests by tokenizing input texts \ @@ -789,12 +809,16 @@ async def Tokenize( # Log the incoming tokenization request for metrics service_metrics.count_tokenization_request(request) + # TODO simplify to only check for lora adapter + adapter_kwargs = await self._validate_adapters(request, context) + tokenizer = await self._get_tokenizer(adapter_kwargs) + responses: list[TokenizeResponse] = [] # TODO: maybe parallelize, also move convert_ids_to_tokens into the # other threads for req in request.requests: - batch_encoding = self.tokenizer.encode_plus( + batch_encoding = tokenizer.encode_plus( text=req.text, return_offsets_mapping=request.return_offsets ) @@ -806,7 +830,7 @@ async def Tokenize( token_count = request.truncate_input_tokens # Initialize Tokens from ids - tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + tokens = tokenizer.convert_ids_to_tokens(token_ids) offsets = None if request.return_offsets: diff --git a/tests/conftest.py b/tests/conftest.py index a4794747..e2655226 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ import pytest import requests +import vllm from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -38,11 +39,51 @@ def monkeysession(): @pytest.fixture(scope="session") -def args( - monkeysession, grpc_server_thread_port, http_server_thread_port +def lora_enabled(): + # lora does not work on cpu + return not vllm.config.is_cpu() + + +@pytest.fixture(scope="session") +def requires_lora(lora_enabled): # noqa: PT004 + if not lora_enabled: + pytest.skip(reason="Lora is not enabled. (disabled on cpu)") + + +@pytest.fixture(scope="session") +def lora_adapter_name(requires_lora): + return "lora-test" + + +@pytest.fixture(scope="session") +def lora_adapter_path(requires_lora): + from huggingface_hub import snapshot_download + + path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + return path + + +@pytest.fixture(scope="session") +def args( # noqa: PLR0913 + monkeysession, + grpc_server_thread_port, + http_server_thread_port, + lora_enabled, + lora_adapter_name, + lora_adapter_path, ) -> argparse.Namespace: """Return parsed CLI arguments for the adapter/vLLM.""" # avoid parsing pytest arguments as vllm/vllm_tgis_adapter arguments + + extra_args: list[str] = [] + if lora_enabled: + extra_args.extend( + ( + "--enable-lora", + f"--lora-modules={lora_adapter_name}={lora_adapter_path}", + ) + ) + monkeysession.setattr( sys, "argv", @@ -50,6 +91,7 @@ def args( "__main__.py", f"--grpc-port={grpc_server_thread_port}", f"--port={http_server_thread_port}", + *extra_args, ], ) diff --git a/tests/test_grpc_server.py b/tests/test_grpc_server.py index 77116bc0..7a6d4018 100644 --- a/tests/test_grpc_server.py +++ b/tests/test_grpc_server.py @@ -48,3 +48,9 @@ def test_batched_generation_request(grpc_client, grpc_server_thread_port): assert len(responses) == 2 assert all(response.text for response in responses) + + +def test_lora_request(grpc_client, lora_adapter_name): + response = grpc_client.make_request("hello", adapter_id=lora_adapter_name) + + assert response.text diff --git a/tests/utils.py b/tests/utils.py index 578dd510..1ab56bff 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -117,6 +117,7 @@ def make_request( text: str | list[str], model_id: str | None = None, max_new_tokens: int = 10, + adapter_id: str | None = None, ) -> GenerationResponse | Sequence[GenerationResponse]: if single_request := isinstance(text, str): text = [text] @@ -127,6 +128,7 @@ def make_request( params=Parameters( stopping=StoppingCriteria(max_new_tokens=max_new_tokens), ), + adapter_id=adapter_id, ) response = self.generation_service_stub.Generate(