Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 84 additions & 60 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import inspect
import os
import time
import uuid
from collections.abc import Callable, Coroutine
Expand Down Expand Up @@ -68,7 +69,7 @@
from collections.abc import AsyncIterator, MutableSequence

from grpc.aio import ServicerContext
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PreTrainedTokenizer
from vllm import CompletionOutput, RequestOutput
from vllm.config import ModelConfig
from vllm.lora.request import LoRARequest
Expand All @@ -94,6 +95,10 @@
logger = init_logger(__name__)
service_metrics = ServiceMetrics()

ADD_SPECIAL_TOKENS = os.getenv("ADD_SPECIAL_TOKENS")
if ADD_SPECIAL_TOKENS is not None:
ADD_SPECIAL_TOKENS = ADD_SPECIAL_TOKENS.lower() not in (0, "false")


def with_default(value: _T, default: _T) -> _T:
return value if value else default
Expand Down Expand Up @@ -174,8 +179,7 @@ def __init__(
):
self.engine: AsyncLLMEngine = engine

# These are set in post_init()
self.tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None
# This is set in post_init()
self.config: ModelConfig | None = None

self.max_max_new_tokens = args.max_new_tokens
Expand All @@ -193,9 +197,6 @@ def __init__(

async def post_init(self) -> None:
self.config = await self.engine.get_model_config()
self.tokenizer_group = self.engine.engine.get_tokenizer_group()
self.tokenizer = await self.engine.get_tokenizer()
assert self.tokenizer is not None

# Swap in the special TGIS stats logger
tgis_stats_logger = TGISStatLogger(
Expand All @@ -218,24 +219,25 @@ async def Generate(
start_time = time.time()
service_metrics.count_generate_request(len(request.requests))
request_id = self.request_id(context)
adapter_kwargs = (
await self._validate_adapters(request, context)
if adapters_available
else {}
)
tokenizer = await self._get_tokenizer(adapter_kwargs)

sampling_params, deadline = await self._validate_and_convert_params(
request.params, context
request.params, tokenizer, context
)
truncate_input_tokens = with_default(request.params.truncate_input_tokens, None)
request_count = len(request.requests)

generators = []
max_is_token_limit = [False] * request_count

adapter_kwargs = (
await self._validate_adapters(request, context)
if adapters_available
else {}
)

for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize(
sampling_params, truncate_input_tokens, req.text, context
sampling_params, truncate_input_tokens, req.text, tokenizer, context
)

inputs = TextTokensPrompt(
Expand Down Expand Up @@ -318,20 +320,26 @@ async def GenerateStream(
start_time = time.time()
service_metrics.count_generate_request()
request_id = self.request_id(context)
adapter_kwargs = (
await self._validate_adapters(request, context)
if adapters_available
else {}
)
tokenizer = await self._get_tokenizer(adapter_kwargs)

sampling_params, deadline = await self._validate_and_convert_params(
request.params, context
request.params, tokenizer, context
)
truncate_input_tokens = with_default(request.params.truncate_input_tokens, None)

input_ids, max_is_tok_limit = await self._validate_prompt_and_tokenize(
sampling_params, truncate_input_tokens, request.request.text, context
sampling_params,
truncate_input_tokens,
request.request.text,
tokenizer,
context,
)

adapter_kwargs = (
await self._validate_adapters(request, context)
if adapters_available
else {}
)
inputs = TextTokensPrompt(
prompt=request.request.text, prompt_token_ids=input_ids
)
Expand Down Expand Up @@ -478,7 +486,10 @@ def request_id(context: ServicerContext) -> str: # noqa: ARG004
return uuid.uuid4().hex

async def _validate_and_convert_params(
self, params: Parameters, context: ServicerContext
self,
params: Parameters,
tokenizer: PreTrainedTokenizer,
context: ServicerContext,
) -> tuple[SamplingParams, float | None]:
"""Return (sampling_params, deadline)."""
# First run TGIS validation to raise errors that match the TGIS api
Expand Down Expand Up @@ -539,19 +550,16 @@ async def _validate_and_convert_params(
decoding.length_penalty.start_index,
decoding.length_penalty.decay_factor,
)
assert self.tokenizer is not None

logits_processors.append(
ExpDecayLengthPenaltyWarper(
length_penalty=length_penalty_tuple,
eos_token_id=self.tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
)

guided_decode_logit_processor = (
await get_outlines_guided_decoding_logits_processor(
decoding, self.tokenizer
)
await get_outlines_guided_decoding_logits_processor(decoding, tokenizer)
)
if guided_decode_logit_processor is not None:
logits_processors.append(guided_decode_logit_processor)
Expand Down Expand Up @@ -597,7 +605,10 @@ async def _validate_and_convert_params(

async def _validate_adapters(
self,
request: SingleGenerationRequest | BatchedGenerationRequest,
request: SingleGenerationRequest
| BatchedGenerationRequest
| TokenizeResponse
| BatchedTokenizeRequest,
context: ServicerContext,
) -> dict[str, LoRARequest | PromptAdapterRequest]:
try:
Expand All @@ -609,6 +620,23 @@ async def _validate_adapters(
await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
return adapters

async def _get_tokenizer(
self, adapter_kwargs: dict[str, Any]
) -> PreTrainedTokenizer:
lora_request = adapter_kwargs.get("lora_request")
try:
return await self.engine.get_tokenizer(lora_request)
except TypeError as exc:
# vllm <= 0.5.2
if "takes 1 positional argument but 2 were given" not in str(exc):
raise

return (
await self.engine.engine.get_tokenizer_group().get_lora_tokenizer_async(
lora_request
)
)

@staticmethod
def _convert_reason(
output: CompletionOutput,
Expand Down Expand Up @@ -655,17 +683,15 @@ def _convert_tokens( # noqa: PLR0913
include_logprobs: bool,
include_ranks: bool,
top_n_tokens: int,
tokenizer: PreTrainedTokenizer,
token_infos: MutableSequence[TokenInfo], # OUT
token_start_offset: int = 0,
) -> None:
assert self.tokenizer

if token_start_offset:
token_ids = token_ids[token_start_offset:]
if logprobs_list is not None:
logprobs_list = logprobs_list[token_start_offset:]
# TODO later use get_lora_tokenizer here
token_texts = self.tokenizer.convert_ids_to_tokens(token_ids)
token_texts = tokenizer.convert_ids_to_tokens(token_ids)
for i, text in enumerate(token_texts):
token_info = TokenInfo(text=text)
if logprobs_list is None:
Expand All @@ -692,10 +718,7 @@ def _convert_tokens( # noqa: PLR0913
key=lambda item: item[1].logprob,
reverse=True,
)[:top_n_tokens]
# TODO later use get_lora_tokenizer here
tt_texts = self.tokenizer.convert_ids_to_tokens(
[tid for tid, _ in items]
)
tt_texts = tokenizer.convert_ids_to_tokens([tid for tid, _ in items])

token_info.top_tokens.extend(
TokenInfo.TopToken(
Expand All @@ -706,39 +729,36 @@ def _convert_tokens( # noqa: PLR0913
)
token_infos.append(token_info)

async def _validate_prompt_and_tokenize(
async def _validate_prompt_and_tokenize( # noqa: PLR0913
self,
sampling_params: SamplingParams,
truncate_input_tokens: int | None,
prompt: str,
tokenizer: PreTrainedTokenizer,
context: ServicerContext,
) -> tuple[list[int], bool]:
assert self.config is not None

max_model_len = self.config.max_model_len
# tokenize_kwargs = {"truncation": True,
# "max_length": truncate_input_tokens} \
# if truncate_input_tokens is not None else {
# "truncation": True, "max_length": max_model_len + 1}
tokenize_kwargs: dict[str, Any] = {}

input_ids = await self.tokenizer_group.encode_async(
prompt,
**tokenize_kwargs,
)

# TODO this is temporary until truncation option is added
# to the TokenizerGroup encode methods
if truncate_input_tokens and truncate_input_tokens < len(input_ids):
input_ids = input_ids[-truncate_input_tokens:]
if not sampling_params.skip_special_tokens:
Comment thread
prashantgupta24 marked this conversation as resolved.
add_bos_token = getattr(self.tokenizer, "add_bos_token", False)
if add_bos_token:
assert self.tokenizer is not None
# Add special tokens based on env var or else only if the tokenizer
# does not have a chat template => this is not a chat model
add_special_tokens = (
ADD_SPECIAL_TOKENS
if ADD_SPECIAL_TOKENS is not None
else not tokenizer.chat_template
)

input_ids[0] = self.tokenizer.bos_token_id
# -----------------------------------------------
tokenizer_kwargs: dict[str, Any] = {"add_special_tokens": add_special_tokens}
if truncate_input_tokens is not None:
tokenizer_kwargs.update(
{
"truncation": True,
"max_length": truncate_input_tokens,
}
)

input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
token_num = len(input_ids)

try:
Expand All @@ -765,7 +785,7 @@ async def _validate_prompt_and_tokenize(
async def Tokenize(
self,
request: BatchedTokenizeRequest,
context: ServicerContext, # noqa: ARG002
context: ServicerContext,
) -> BatchedTokenizeResponse:
"""Handle tokenization requests by tokenizing input texts \

Expand All @@ -789,12 +809,16 @@ async def Tokenize(
# Log the incoming tokenization request for metrics
service_metrics.count_tokenization_request(request)

# TODO simplify to only check for lora adapter
adapter_kwargs = await self._validate_adapters(request, context)
tokenizer = await self._get_tokenizer(adapter_kwargs)

responses: list[TokenizeResponse] = []

# TODO: maybe parallelize, also move convert_ids_to_tokens into the
# other threads
for req in request.requests:
batch_encoding = self.tokenizer.encode_plus(
batch_encoding = tokenizer.encode_plus(
text=req.text, return_offsets_mapping=request.return_offsets
)

Expand All @@ -806,7 +830,7 @@ async def Tokenize(
token_count = request.truncate_input_tokens

# Initialize Tokens from ids
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
offsets = None

if request.return_offsets:
Expand Down
46 changes: 44 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
import requests
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
Expand Down Expand Up @@ -38,18 +39,59 @@ def monkeysession():


@pytest.fixture(scope="session")
def args(
monkeysession, grpc_server_thread_port, http_server_thread_port
def lora_enabled():
# lora does not work on cpu
return not vllm.config.is_cpu()


@pytest.fixture(scope="session")
def requires_lora(lora_enabled): # noqa: PT004
if not lora_enabled:
pytest.skip(reason="Lora is not enabled. (disabled on cpu)")


@pytest.fixture(scope="session")
def lora_adapter_name(requires_lora):
return "lora-test"


@pytest.fixture(scope="session")
def lora_adapter_path(requires_lora):
from huggingface_hub import snapshot_download

path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
return path
Comment thread
prashantgupta24 marked this conversation as resolved.


@pytest.fixture(scope="session")
def args( # noqa: PLR0913
monkeysession,
grpc_server_thread_port,
http_server_thread_port,
lora_enabled,
lora_adapter_name,
lora_adapter_path,
) -> argparse.Namespace:
"""Return parsed CLI arguments for the adapter/vLLM."""
# avoid parsing pytest arguments as vllm/vllm_tgis_adapter arguments

extra_args: list[str] = []
if lora_enabled:
extra_args.extend(
(
"--enable-lora",
f"--lora-modules={lora_adapter_name}={lora_adapter_path}",
)
)

monkeysession.setattr(
sys,
"argv",
[
"__main__.py",
f"--grpc-port={grpc_server_thread_port}",
f"--port={http_server_thread_port}",
*extra_args,
],
)

Expand Down
6 changes: 6 additions & 0 deletions tests/test_grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ def test_batched_generation_request(grpc_client, grpc_server_thread_port):

assert len(responses) == 2
assert all(response.text for response in responses)


def test_lora_request(grpc_client, lora_adapter_name):
response = grpc_client.make_request("hello", adapter_id=lora_adapter_name)

assert response.text
Loading