From 74d3c9ff841b1370d435b4341c737537129f8a8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 18 Jul 2024 12:41:55 +0200 Subject: [PATCH 1/7] tgis_utils: base EnvVarArgumentParser on FlexibleArgumentParser --- src/vllm_tgis_adapter/tgis_utils/args.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/vllm_tgis_adapter/tgis_utils/args.py b/src/vllm_tgis_adapter/tgis_utils/args.py index 6d0464e9..60ec4255 100644 --- a/src/vllm_tgis_adapter/tgis_utils/args.py +++ b/src/vllm_tgis_adapter/tgis_utils/args.py @@ -3,6 +3,8 @@ import argparse import os +from vllm.utils import FlexibleArgumentParser + from vllm_tgis_adapter.grpc.validation import MAX_TOP_N_TOKENS from vllm_tgis_adapter.logging import init_logger @@ -24,7 +26,7 @@ def _switch_action_default(action: argparse.Action) -> None: action.default = val -class EnvVarArgumentParser(argparse.ArgumentParser): +class EnvVarArgumentParser(FlexibleArgumentParser): """Allows env var fallback for all args.""" class _EnvVarHelpFormatter(argparse.ArgumentDefaultsHelpFormatter): From f3767e9652f4885b7149fff2f69346545c87576a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 18 Jul 2024 15:04:44 +0200 Subject: [PATCH 2/7] update __main__ for vllm>=0.5.2 --- pyproject.toml | 2 +- src/vllm_tgis_adapter/__main__.py | 105 +++++++++++++++++++++--------- tests/conftest.py | 4 +- 3 files changed, 79 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e06c59b4..1035c3ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ requires-python = ">=3.9" dynamic = ["version"] dependencies = [ - "vllm>=0.5.1", + "vllm>=0.5.2", "prometheus_client==0.20.0", "grpcio==1.62.2", "grpcio-health-checking==1.62.2", diff --git a/src/vllm_tgis_adapter/__main__.py b/src/vllm_tgis_adapter/__main__.py index 72fd7e3c..b43d924f 100644 --- a/src/vllm_tgis_adapter/__main__.py +++ b/src/vllm_tgis_adapter/__main__.py @@ -11,6 +11,7 @@ import fastapi import vllm +from fastapi import APIRouter from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -21,19 +22,23 @@ from vllm import envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.api_server import app from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import ( # noqa: TCH002 # pydantic needs to access these annotations ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, + DetokenizeRequest, + DetokenizeResponse, EmbeddingRequest, ErrorResponse, + TokenizeRequest, + TokenizeResponse, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser from .grpc import run_grpc_server from .logging import init_logger @@ -45,18 +50,30 @@ from vllm.config import ModelConfig + +try: + from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization, # noqa: TCH002 + ) +except ImportError: # vllm<=0.5.2 + has_tokenization = False +else: + has_tokenization = True + TIMEOUT_KEEP_ALIVE = 5 # seconds openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding +if has_tokenization: + openai_serving_tokenization: OpenAIServingTokenization logger = init_logger(__name__) _running_tasks: set[asyncio.Task] = set() -router = fastapi.APIRouter() +router = APIRouter() # Add prometheus asgi middleware to route /metrics requests route = Mount("/metrics", make_asgi_app()) @@ -72,15 +89,42 @@ async def health() -> Response: return Response(status_code=200) +if has_tokenization: + assert has_tokenization + + @router.post("/tokenize") + async def tokenize(request: TokenizeRequest) -> JSONResponse: + generator = await openai_serving_tokenization.create_tokenize(request) # noqa: F821 + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), + status_code=generator.code, + ) + assert isinstance(generator, TokenizeResponse) + return JSONResponse(content=generator.model_dump()) + + @router.post("/detokenize") + async def detokenize(request: DetokenizeRequest) -> JSONResponse: + generator = await openai_serving_tokenization.create_detokenize(request) # noqa: F821 + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), + status_code=generator.code, + ) + + assert isinstance(generator, DetokenizeResponse) + return JSONResponse(content=generator.model_dump()) + + @router.get("/v1/models") async def show_available_models() -> JSONResponse: - models = await openai_serving_chat.show_available_models() + models = await openai_serving_completion.show_available_models() return JSONResponse(content=models.model_dump()) @router.get("/version") async def show_version() -> JSONResponse: - ver = {"version": vllm.__version__} + ver = {"version": vllm.__version__, "commit": vllm.__commit__} return JSONResponse(content=ver) @@ -89,11 +133,15 @@ async def create_chat_completion( request: ChatCompletionRequest, raw_request: fastapi.Request, ) -> JSONResponse: - generator = await openai_serving_chat.create_chat_completion(request, raw_request) + generator = await openai_serving_chat.create_chat_completion( + request, + raw_request, + ) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) if request.stream: return StreamingResponse(content=generator, media_type="text/event-stream") + assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) @@ -102,7 +150,10 @@ async def create_chat_completion( async def create_completion(request: CompletionRequest, raw_request: fastapi.Request): # noqa: ANN201 generator = await openai_serving_completion.create_completion(request, raw_request) if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) + return JSONResponse( + content=generator.model_dump(), + status_code=generator.code, + ) if request.stream: return StreamingResponse(content=generator, media_type="text/event-stream") return JSONResponse(content=generator.model_dump()) @@ -116,7 +167,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: fastapi.Reque return JSONResponse(content=generator.model_dump()) -def build_app( # noqa: C901 # FIXME: waiting on https://github.com/vllm-project/vllm/pull/5090 to get rid of this +def build_app( # noqa: C901 engine: AsyncLLMEngine, args: argparse.Namespace ) -> fastapi.FastAPI: @asynccontextmanager @@ -148,7 +199,10 @@ async def _force_log(): # noqa: ANN202 @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): # noqa: ANN001, ANN202 err = openai_serving_chat.create_error_response(message=str(exc)) - return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) + return JSONResponse( + err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST, + ) if token := envs.VLLM_API_KEY or args.api_key: @@ -160,7 +214,10 @@ async def authentication(request: fastapi.Request, call_next): # noqa: ANN001, if not request.url.path.startswith(f"{root_path}/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: - return JSONResponse(content={"error": "Unauthorized"}, status_code=401) + return JSONResponse( + content={"error": "Unauthorized"}, + status_code=401, + ) return await call_next(request) for middleware in args.middleware: @@ -203,17 +260,12 @@ async def run_http_server( args.chat_template, ) - kwargs = {} - # prompt adapter arg required for vllm >0.5.1 - if hasattr(args, "prompt_adapters"): - kwargs = {"prompt_adapters": args.prompt_adapters} - openai_serving_completion = OpenAIServingCompletion( engine, model_config, served_model_names, args.lora_modules, - **kwargs, + prompt_adapters=args.prompt_adapters, ) openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, served_model_names @@ -240,13 +292,19 @@ async def run_http_server( if __name__ == "__main__": + parser = FlexibleArgumentParser("vLLM TGIS GRPC + OpenAI Rest api server") # convert to our custom env var arg parser - parser = EnvVarArgumentParser(parser=make_arg_parser()) + parser = EnvVarArgumentParser(parser=make_arg_parser(parser)) parser = add_tgis_args(parser) args = postprocess_tgis_args(parser.parse_args()) assert args is not None - logger.info("vLLM version %s", vllm.__version__) + version_info = ( + f"{vllm.__version__}" + vllm.__commit__ + if vllm.__commit__ != "COMMIT_HASH_PLACEHOLDER" + else "unknown" + ) + logger.info("vLLM version %s", version_info) logger.info("args: %s", args) engine_args = AsyncEngineArgs.from_cli_args(args) @@ -282,19 +340,6 @@ async def run_http_server( # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) - app.root_path = args.root_path - uvicorn_config = UvicornConfig( - app=app, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - ) - if event_loop is None: event_loop = asyncio.new_event_loop() diff --git a/tests/conftest.py b/tests/conftest.py index 3a6632fb..a4794747 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser from vllm_tgis_adapter.__main__ import run_http_server from vllm_tgis_adapter.grpc import run_grpc_server @@ -52,7 +53,8 @@ def args( ], ) - parser = EnvVarArgumentParser(parser=make_arg_parser()) + parser = FlexibleArgumentParser("testing parser") + parser = EnvVarArgumentParser(parser=make_arg_parser(parser)) parser = add_tgis_args(parser) args = postprocess_tgis_args(parser.parse_args()) From 8a087fef2e36ae5b70b59c27c6ae84f0c8d170d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 18 Jul 2024 15:05:28 +0200 Subject: [PATCH 3/7] gha: bump tested version to 0.5.2 --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 98978e82..a302525f 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -32,7 +32,7 @@ jobs: pyv: ["3.11"] vllm_version: # - "" # skip the pypi version as it will not work on CPU - - "git+https://github.com/vllm-project/vllm@v0.5.1" + - "git+https://github.com/vllm-project/vllm@v0.5.2" - "git+https://github.com/vllm-project/vllm@main" - "git+https://github.com/opendatahub-io/vllm@main" From 09f65fc7ae5b9af48b441ecc54eec975bf109be1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 18 Jul 2024 15:27:31 +0200 Subject: [PATCH 4/7] deps: pin transformers --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1035c3ec..4e3321c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "grpcio==1.62.2", "grpcio-health-checking==1.62.2", "grpcio-reflection==1.62.2", - "transformers", + "transformers==4.42.4", "accelerate==0.31.0", "hf-transfer==0.1.6", # additional dependencies for OpenTelemetry tracing From f77ed11672d74d564bc885951493a3ae602224ed Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 12 Jul 2024 14:49:08 -0700 Subject: [PATCH 5/7] Fix use of LoRA tokenizers --- src/vllm_tgis_adapter/grpc/grpc_server.py | 133 ++++++++++++---------- 1 file changed, 73 insertions(+), 60 deletions(-) diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index af309b9c..620fe309 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -2,6 +2,7 @@ import asyncio import inspect +import os import time import uuid from collections.abc import Callable, Coroutine @@ -68,7 +69,7 @@ from collections.abc import AsyncIterator, MutableSequence from grpc.aio import ServicerContext - from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + from transformers import PreTrainedTokenizer from vllm import CompletionOutput, RequestOutput from vllm.config import ModelConfig from vllm.lora.request import LoRARequest @@ -94,6 +95,10 @@ logger = init_logger(__name__) service_metrics = ServiceMetrics() +ADD_SPECIAL_TOKENS = os.getenv("ADD_SPECIAL_TOKENS") +if ADD_SPECIAL_TOKENS is not None: + ADD_SPECIAL_TOKENS = ADD_SPECIAL_TOKENS.lower() not in (0, "false") + def with_default(value: _T, default: _T) -> _T: return value if value else default @@ -174,8 +179,7 @@ def __init__( ): self.engine: AsyncLLMEngine = engine - # These are set in post_init() - self.tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None + # This is set in post_init() self.config: ModelConfig | None = None self.max_max_new_tokens = args.max_new_tokens @@ -193,9 +197,6 @@ def __init__( async def post_init(self) -> None: self.config = await self.engine.get_model_config() - self.tokenizer_group = self.engine.engine.get_tokenizer_group() - self.tokenizer = await self.engine.get_tokenizer() - assert self.tokenizer is not None # Swap in the special TGIS stats logger tgis_stats_logger = TGISStatLogger( @@ -218,8 +219,15 @@ async def Generate( start_time = time.time() service_metrics.count_generate_request(len(request.requests)) request_id = self.request_id(context) + adapter_kwargs = ( + await self._validate_adapters(request, context) + if adapters_available + else {} + ) + tokenizer = await self._get_tokenizer(adapter_kwargs) + sampling_params, deadline = await self._validate_and_convert_params( - request.params, context + request.params, tokenizer, context ) truncate_input_tokens = with_default(request.params.truncate_input_tokens, None) request_count = len(request.requests) @@ -227,15 +235,9 @@ async def Generate( generators = [] max_is_token_limit = [False] * request_count - adapter_kwargs = ( - await self._validate_adapters(request, context) - if adapters_available - else {} - ) - for i, req in enumerate(request.requests): input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize( - sampling_params, truncate_input_tokens, req.text, context + sampling_params, truncate_input_tokens, req.text, tokenizer, context ) inputs = TextTokensPrompt( @@ -318,20 +320,26 @@ async def GenerateStream( start_time = time.time() service_metrics.count_generate_request() request_id = self.request_id(context) + adapter_kwargs = ( + await self._validate_adapters(request, context) + if adapters_available + else {} + ) + tokenizer = await self._get_tokenizer(adapter_kwargs) + sampling_params, deadline = await self._validate_and_convert_params( - request.params, context + request.params, tokenizer, context ) truncate_input_tokens = with_default(request.params.truncate_input_tokens, None) input_ids, max_is_tok_limit = await self._validate_prompt_and_tokenize( - sampling_params, truncate_input_tokens, request.request.text, context + sampling_params, + truncate_input_tokens, + request.request.text, + tokenizer, + context, ) - adapter_kwargs = ( - await self._validate_adapters(request, context) - if adapters_available - else {} - ) inputs = TextTokensPrompt( prompt=request.request.text, prompt_token_ids=input_ids ) @@ -478,7 +486,10 @@ def request_id(context: ServicerContext) -> str: # noqa: ARG004 return uuid.uuid4().hex async def _validate_and_convert_params( - self, params: Parameters, context: ServicerContext + self, + params: Parameters, + tokenizer: PreTrainedTokenizer, + context: ServicerContext, ) -> tuple[SamplingParams, float | None]: """Return (sampling_params, deadline).""" # First run TGIS validation to raise errors that match the TGIS api @@ -539,19 +550,16 @@ async def _validate_and_convert_params( decoding.length_penalty.start_index, decoding.length_penalty.decay_factor, ) - assert self.tokenizer is not None logits_processors.append( ExpDecayLengthPenaltyWarper( length_penalty=length_penalty_tuple, - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=tokenizer.eos_token_id, ) ) guided_decode_logit_processor = ( - await get_outlines_guided_decoding_logits_processor( - decoding, self.tokenizer - ) + await get_outlines_guided_decoding_logits_processor(decoding, tokenizer) ) if guided_decode_logit_processor is not None: logits_processors.append(guided_decode_logit_processor) @@ -597,7 +605,10 @@ async def _validate_and_convert_params( async def _validate_adapters( self, - request: SingleGenerationRequest | BatchedGenerationRequest, + request: SingleGenerationRequest + | BatchedGenerationRequest + | TokenizeResponse + | BatchedTokenizeRequest, context: ServicerContext, ) -> dict[str, LoRARequest | PromptAdapterRequest]: try: @@ -609,6 +620,12 @@ async def _validate_adapters( await context.abort(StatusCode.INVALID_ARGUMENT, str(e)) return adapters + async def _get_tokenizer( + self, adapter_kwargs: dict[str, Any] + ) -> PreTrainedTokenizer: + lora_request = adapter_kwargs.get("lora_request") + return await self.engine.get_tokenizer(lora_request) + @staticmethod def _convert_reason( output: CompletionOutput, @@ -655,17 +672,15 @@ def _convert_tokens( # noqa: PLR0913 include_logprobs: bool, include_ranks: bool, top_n_tokens: int, + tokenizer: PreTrainedTokenizer, token_infos: MutableSequence[TokenInfo], # OUT token_start_offset: int = 0, ) -> None: - assert self.tokenizer - if token_start_offset: token_ids = token_ids[token_start_offset:] if logprobs_list is not None: logprobs_list = logprobs_list[token_start_offset:] - # TODO later use get_lora_tokenizer here - token_texts = self.tokenizer.convert_ids_to_tokens(token_ids) + token_texts = tokenizer.convert_ids_to_tokens(token_ids) for i, text in enumerate(token_texts): token_info = TokenInfo(text=text) if logprobs_list is None: @@ -692,10 +707,7 @@ def _convert_tokens( # noqa: PLR0913 key=lambda item: item[1].logprob, reverse=True, )[:top_n_tokens] - # TODO later use get_lora_tokenizer here - tt_texts = self.tokenizer.convert_ids_to_tokens( - [tid for tid, _ in items] - ) + tt_texts = tokenizer.convert_ids_to_tokens([tid for tid, _ in items]) token_info.top_tokens.extend( TokenInfo.TopToken( @@ -706,39 +718,36 @@ def _convert_tokens( # noqa: PLR0913 ) token_infos.append(token_info) - async def _validate_prompt_and_tokenize( + async def _validate_prompt_and_tokenize( # noqa: PLR0913 self, sampling_params: SamplingParams, truncate_input_tokens: int | None, prompt: str, + tokenizer: PreTrainedTokenizer, context: ServicerContext, ) -> tuple[list[int], bool]: assert self.config is not None max_model_len = self.config.max_model_len - # tokenize_kwargs = {"truncation": True, - # "max_length": truncate_input_tokens} \ - # if truncate_input_tokens is not None else { - # "truncation": True, "max_length": max_model_len + 1} - tokenize_kwargs: dict[str, Any] = {} - - input_ids = await self.tokenizer_group.encode_async( - prompt, - **tokenize_kwargs, - ) - # TODO this is temporary until truncation option is added - # to the TokenizerGroup encode methods - if truncate_input_tokens and truncate_input_tokens < len(input_ids): - input_ids = input_ids[-truncate_input_tokens:] - if not sampling_params.skip_special_tokens: - add_bos_token = getattr(self.tokenizer, "add_bos_token", False) - if add_bos_token: - assert self.tokenizer is not None + # Add special tokens based on env var or else only if the tokenizer + # does not have a chat template => this is not a chat model + add_special_tokens = ( + ADD_SPECIAL_TOKENS + if ADD_SPECIAL_TOKENS is not None + else not tokenizer.chat_template + ) - input_ids[0] = self.tokenizer.bos_token_id - # ----------------------------------------------- + tokenizer_kwargs: dict[str, Any] = {"add_special_tokens": add_special_tokens} + if truncate_input_tokens is not None: + tokenizer_kwargs.update( + { + "truncation": True, + "max_length": truncate_input_tokens, + } + ) + input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids token_num = len(input_ids) try: @@ -765,7 +774,7 @@ async def _validate_prompt_and_tokenize( async def Tokenize( self, request: BatchedTokenizeRequest, - context: ServicerContext, # noqa: ARG002 + context: ServicerContext, ) -> BatchedTokenizeResponse: """Handle tokenization requests by tokenizing input texts \ @@ -789,12 +798,16 @@ async def Tokenize( # Log the incoming tokenization request for metrics service_metrics.count_tokenization_request(request) + # TODO simplify to only check for lora adapter + adapter_kwargs = await self._validate_adapters(request, context) + tokenizer = await self._get_tokenizer(adapter_kwargs) + responses: list[TokenizeResponse] = [] # TODO: maybe parallelize, also move convert_ids_to_tokens into the # other threads for req in request.requests: - batch_encoding = self.tokenizer.encode_plus( + batch_encoding = tokenizer.encode_plus( text=req.text, return_offsets_mapping=request.return_offsets ) @@ -806,7 +819,7 @@ async def Tokenize( token_count = request.truncate_input_tokens # Initialize Tokens from ids - tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + tokens = tokenizer.convert_ids_to_tokens(token_ids) offsets = None if request.return_offsets: From 36e4c1eb8faf0cadf0f04a15592440f5dae4b042 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 18 Jul 2024 16:20:27 +0200 Subject: [PATCH 6/7] fix tokenizer group usage --- src/vllm_tgis_adapter/grpc/grpc_server.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index 620fe309..ca4731fc 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -624,7 +624,18 @@ async def _get_tokenizer( self, adapter_kwargs: dict[str, Any] ) -> PreTrainedTokenizer: lora_request = adapter_kwargs.get("lora_request") - return await self.engine.get_tokenizer(lora_request) + try: + return await self.engine.get_tokenizer(lora_request) + except TypeError as exc: + # vllm <= 0.5.2 + if "takes 1 positional argument but 2 were given" not in str(exc): + raise + + return ( + await self.engine.engine.get_tokenizer_group().get_lora_tokenizer_async( + lora_request + ) + ) @staticmethod def _convert_reason( From 503f7f249bd7796b4a8db8d0930b1d65ebb11769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 18 Jul 2024 16:47:24 +0200 Subject: [PATCH 7/7] tests: add lora adapter test --- tests/conftest.py | 46 +++++++++++++++++++++++++++++++++++++-- tests/test_grpc_server.py | 6 +++++ tests/utils.py | 2 ++ 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a4794747..e2655226 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ import pytest import requests +import vllm from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -38,11 +39,51 @@ def monkeysession(): @pytest.fixture(scope="session") -def args( - monkeysession, grpc_server_thread_port, http_server_thread_port +def lora_enabled(): + # lora does not work on cpu + return not vllm.config.is_cpu() + + +@pytest.fixture(scope="session") +def requires_lora(lora_enabled): # noqa: PT004 + if not lora_enabled: + pytest.skip(reason="Lora is not enabled. (disabled on cpu)") + + +@pytest.fixture(scope="session") +def lora_adapter_name(requires_lora): + return "lora-test" + + +@pytest.fixture(scope="session") +def lora_adapter_path(requires_lora): + from huggingface_hub import snapshot_download + + path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + return path + + +@pytest.fixture(scope="session") +def args( # noqa: PLR0913 + monkeysession, + grpc_server_thread_port, + http_server_thread_port, + lora_enabled, + lora_adapter_name, + lora_adapter_path, ) -> argparse.Namespace: """Return parsed CLI arguments for the adapter/vLLM.""" # avoid parsing pytest arguments as vllm/vllm_tgis_adapter arguments + + extra_args: list[str] = [] + if lora_enabled: + extra_args.extend( + ( + "--enable-lora", + f"--lora-modules={lora_adapter_name}={lora_adapter_path}", + ) + ) + monkeysession.setattr( sys, "argv", @@ -50,6 +91,7 @@ def args( "__main__.py", f"--grpc-port={grpc_server_thread_port}", f"--port={http_server_thread_port}", + *extra_args, ], ) diff --git a/tests/test_grpc_server.py b/tests/test_grpc_server.py index 77116bc0..7a6d4018 100644 --- a/tests/test_grpc_server.py +++ b/tests/test_grpc_server.py @@ -48,3 +48,9 @@ def test_batched_generation_request(grpc_client, grpc_server_thread_port): assert len(responses) == 2 assert all(response.text for response in responses) + + +def test_lora_request(grpc_client, lora_adapter_name): + response = grpc_client.make_request("hello", adapter_id=lora_adapter_name) + + assert response.text diff --git a/tests/utils.py b/tests/utils.py index 578dd510..1ab56bff 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -117,6 +117,7 @@ def make_request( text: str | list[str], model_id: str | None = None, max_new_tokens: int = 10, + adapter_id: str | None = None, ) -> GenerationResponse | Sequence[GenerationResponse]: if single_request := isinstance(text, str): text = [text] @@ -127,6 +128,7 @@ def make_request( params=Parameters( stopping=StoppingCriteria(max_new_tokens=max_new_tokens), ), + adapter_id=adapter_id, ) response = self.generation_service_stub.Generate(