diff --git a/.gitignore b/.gitignore index 7cda86478664..864542128c05 100644 --- a/.gitignore +++ b/.gitignore @@ -227,3 +227,8 @@ ep_kernels_workspace/ # Allow tracked library source folders under submodules (e.g., benchmarks/lib) !vllm/benchmarks/lib/ + +# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto) +vllm/grpc/vllm_engine_pb2.py +vllm/grpc/vllm_engine_pb2_grpc.py +vllm/grpc/vllm_engine_pb2.pyi diff --git a/mkdocs.yaml b/mkdocs.yaml index 8fb8f0568c6e..c5501e7db0f0 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -80,6 +80,7 @@ plugins: - "re:vllm\\._.*" # Internal modules - "vllm.third_party" - "vllm.vllm_flash_attn" + - "re:vllm\\.grpc\\..*_pb2.*" # Auto-generated protobuf files - !ENV [API_AUTONAV_EXCLUDE, "re:^$"] # Match nothing by default - mkdocstrings: handlers: @@ -87,7 +88,8 @@ plugins: options: show_symbol_type_heading: true show_symbol_type_toc: true - filters: [] + filters: + - "!.*_pb2_grpc" # Exclude auto-generated gRPC stubs summary: modules: true show_if_no_docstring: true diff --git a/pyproject.toml b/pyproject.toml index 773f832d650c..97651afeec82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ requires = [ "torch == 2.9.1", "wheel", "jinja2", + "grpcio-tools>=1.76.0", ] build-backend = "setuptools.build_meta" @@ -55,6 +56,10 @@ include = ["vllm*"] "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] +# Exclude generated protobuf files +"vllm/grpc/*_pb2.py" = ["ALL"] +"vllm/grpc/*_pb2_grpc.py" = ["ALL"] +"vllm/grpc/*_pb2.pyi" = ["ALL"] [tool.ruff.lint] select = [ diff --git a/requirements/build.txt b/requirements/build.txt index 3756371638ba..b3ef0a71038f 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -9,3 +9,5 @@ wheel jinja2>=3.1.6 regex build +protobuf>=6.33.2 +grpcio-tools>=1.76.0 diff --git a/requirements/common.txt b/requirements/common.txt index 43f4a8676d79..29d59c8db8df 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -9,7 +9,7 @@ blake3 py-cpuinfo transformers >= 4.56.0, < 5 tokenizers >= 0.21.1 # Required for fast incremental detokenization. -protobuf # Required by LlamaTokenizer. +protobuf >= 6.30.0 # Required by LlamaTokenizer, gRPC. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.99.1 # For Responses API with reasoning content @@ -52,3 +52,5 @@ openai-harmony >= 0.0.3 # Required for gpt-oss anthropic == 0.71.0 model-hosting-container-standards >= 0.1.10, < 1.0.0 mcp +grpcio>=1.76.0 +grpcio-reflection>=1.76.0 diff --git a/requirements/test.txt b/requirements/test.txt index 41882da9d31f..e78431ab39a4 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -297,7 +297,7 @@ graphql-relay==3.2.0 # via graphene greenlet==3.2.3 # via sqlalchemy -grpcio==1.71.0 +grpcio==1.76.0 # via ray gunicorn==23.0.0 # via mlflow @@ -758,7 +758,7 @@ propcache==0.2.0 # yarl proto-plus==1.26.1 # via google-api-core -protobuf==5.28.3 +protobuf==6.33.2 # via # google-api-core # googleapis-common-protos @@ -1249,6 +1249,7 @@ typing-extensions==4.15.0 # chz # fastapi # graphene + # grpcio # huggingface-hub # librosa # lightning diff --git a/setup.py b/setup.py index 595397264283..34bce769359a 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ from packaging.version import Version, parse from setuptools import Extension, setup from setuptools.command.build_ext import build_ext +from setuptools.command.build_py import build_py from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME @@ -79,6 +80,73 @@ def is_freethreaded(): return bool(sysconfig.get_config_var("Py_GIL_DISABLED")) +def compile_grpc_protos(): + """Compile gRPC protobuf definitions during build. + + This generates *_pb2.py, *_pb2_grpc.py, and *_pb2.pyi files from + the vllm_engine.proto definition. + """ + try: + from grpc_tools import protoc + except ImportError: + logger.warning( + "grpcio-tools not installed, skipping gRPC proto compilation. " + "gRPC server functionality will not be available." + ) + return False + + proto_file = ROOT_DIR / "vllm" / "grpc" / "vllm_engine.proto" + if not proto_file.exists(): + logger.warning("Proto file not found at %s, skipping compilation", proto_file) + return False + + logger.info("Compiling gRPC protobuf: %s", proto_file) + + result = protoc.main( + [ + "grpc_tools.protoc", + f"--proto_path={ROOT_DIR}", + f"--python_out={ROOT_DIR}", + f"--grpc_python_out={ROOT_DIR}", + f"--pyi_out={ROOT_DIR}", + str(proto_file), + ] + ) + + if result != 0: + logger.error("protoc failed with exit code %s", result) + return False + + # Add SPDX headers and mypy ignore to generated files + spdx_header = ( + "# SPDX-License-Identifier: Apache-2.0\n" + "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n" + "# mypy: ignore-errors\n" + ) + + grpc_dir = ROOT_DIR / "vllm" / "grpc" + for generated_file in [ + grpc_dir / "vllm_engine_pb2.py", + grpc_dir / "vllm_engine_pb2_grpc.py", + grpc_dir / "vllm_engine_pb2.pyi", + ]: + if generated_file.exists(): + content = generated_file.read_text() + if not content.startswith("# SPDX-License-Identifier"): + generated_file.write_text(spdx_header + content) + + logger.info("gRPC protobuf compilation successful") + return True + + +class BuildPyAndGenerateGrpc(build_py): + """Build Python modules and generate gRPC stubs from proto files.""" + + def run(self): + compile_grpc_protos() + super().run() + + class CMakeExtension(Extension): def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None: super().__init__(name, sources=[], py_limited_api=not is_freethreaded(), **kwa) @@ -882,12 +950,13 @@ def _read_requirements(filename: str) -> list[str]: ext_modules = [] if not ext_modules: - cmdclass = {} + cmdclass = {"build_py": BuildPyAndGenerateGrpc} else: cmdclass = { "build_ext": precompiled_build_ext if envs.VLLM_USE_PRECOMPILED - else cmake_build_ext + else cmake_build_ext, + "build_py": BuildPyAndGenerateGrpc, } setup( diff --git a/tests/entrypoints/test_grpc_server.py b/tests/entrypoints/test_grpc_server.py new file mode 100644 index 000000000000..5fb55843f750 --- /dev/null +++ b/tests/entrypoints/test_grpc_server.py @@ -0,0 +1,428 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +End-to-end tests for the vLLM gRPC server. +""" + +import asyncio +import socket +import subprocess +import sys +import time + +import grpc +import pytest +import pytest_asyncio + +from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc + +# Use a small model for fast testing +MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" + + +def find_free_port() -> int: + """Find a free port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + +async def wait_for_server(port: int, timeout: float = 30.0) -> bool: + """Wait for the gRPC server to be ready by trying health checks.""" + start_time = time.time() + print("waiting for server to start...") + while time.time() - start_time < timeout: + try: + channel = grpc.aio.insecure_channel(f"localhost:{port}") + stub = vllm_engine_pb2_grpc.VllmEngineStub(channel) + request = vllm_engine_pb2.HealthCheckRequest() + response = await stub.HealthCheck(request, timeout=5.0) + await channel.close() + if response.healthy: + print("server returned healthy=True") + return True + except Exception: + await asyncio.sleep(0.5) + return False + + +class GrpcServerProcess: + """Manages a gRPC server running in a subprocess.""" + + def __init__(self): + self.process: subprocess.Popen | None = None + self.port: int | None = None + + async def start(self): + """Start the gRPC server process.""" + self.port = find_free_port() + + # Start the server as a subprocess + self.process = subprocess.Popen( + [ + sys.executable, + "-m", + "vllm.entrypoints.grpc_server", + "--model", + MODEL_NAME, + "--host", + "localhost", + "--port", + str(self.port), + "--max-num-batched-tokens", + "512", + "--disable-log-stats-server", + ], + ) + + # Wait for server to be ready + if not await wait_for_server(self.port): + self.stop() + raise RuntimeError("gRPC server failed to start within timeout") + + def stop(self): + """Stop the gRPC server process.""" + if self.process: + self.process.terminate() + try: + self.process.wait(timeout=10) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait() + + +@pytest_asyncio.fixture(scope="module") +async def grpc_server(): + """Fixture providing a running gRPC server in a subprocess.""" + server = GrpcServerProcess() + await server.start() + + yield server + + server.stop() + + +@pytest_asyncio.fixture +async def grpc_client(grpc_server): + """Fixture providing a gRPC client connected to the server.""" + channel = grpc.aio.insecure_channel(f"localhost:{grpc_server.port}") + stub = vllm_engine_pb2_grpc.VllmEngineStub(channel) + + yield stub + + await channel.close() + + +@pytest.mark.asyncio +async def test_health_check(grpc_client): + """Test the HealthCheck RPC.""" + request = vllm_engine_pb2.HealthCheckRequest() + response = await grpc_client.HealthCheck(request) + + assert response.healthy is True + assert response.message == "Health" + + +@pytest.mark.asyncio +async def test_get_model_info(grpc_client): + """Test the GetModelInfo RPC.""" + request = vllm_engine_pb2.GetModelInfoRequest() + response = await grpc_client.GetModelInfo(request) + + assert response.model_path == MODEL_NAME + assert response.is_generation is True + assert response.max_context_length > 0 + assert response.vocab_size > 0 + assert response.supports_vision is False + + +@pytest.mark.asyncio +async def test_get_server_info(grpc_client): + """Test the GetServerInfo RPC.""" + request = vllm_engine_pb2.GetServerInfoRequest() + response = await grpc_client.GetServerInfo(request) + + assert response.active_requests >= 0 + assert response.is_paused is False + assert response.uptime_seconds >= 0 + assert response.server_type == "vllm-grpc" + assert response.last_receive_timestamp > 0 + + +@pytest.mark.asyncio +async def test_generate_non_streaming(grpc_client): + """Test the Generate RPC in non-streaming mode.""" + # Create a simple request + request = vllm_engine_pb2.GenerateRequest( + request_id="test-non-streaming-1", + tokenized=vllm_engine_pb2.TokenizedInput( + original_text="Hello, my name is", + input_ids=[15496, 11, 616, 1438, 318], # GPT-2 tokens for the prompt + ), + sampling_params=vllm_engine_pb2.SamplingParams( + temperature=0.0, + max_tokens=10, + n=1, + ), + stream=False, + ) + + # Collect all responses + responses = [] + async for response in grpc_client.Generate(request): + responses.append(response) + + # Should have exactly one response (complete) + assert len(responses) == 1 + + # Check the response + final_response = responses[0] + assert final_response.HasField("complete") + + complete = final_response.complete + assert len(complete.output_ids) > 0 + assert complete.finish_reason in ["stop", "length"] + assert complete.prompt_tokens > 0 + assert complete.completion_tokens > 0 + + +@pytest.mark.asyncio +async def test_generate_streaming(grpc_client): + """Test the Generate RPC in streaming mode.""" + request = vllm_engine_pb2.GenerateRequest( + request_id="test-streaming-1", + tokenized=vllm_engine_pb2.TokenizedInput( + original_text="The capital of France is", + input_ids=[464, 3139, 286, 4881, 318], # GPT-2 tokens + ), + sampling_params=vllm_engine_pb2.SamplingParams( + temperature=0.0, max_tokens=10, n=1 + ), + stream=True, + ) + + # Collect all responses + chunks = [] + complete_response = None + + async for response in grpc_client.Generate(request): + if response.HasField("chunk"): + chunks.append(response.chunk) + elif response.HasField("complete"): + complete_response = response.complete + + # Should have received some chunks + assert len(chunks) >= 0 # May have 0 chunks if generation is very fast + + # Should have a final complete response + assert complete_response is not None + assert complete_response.finish_reason in ["stop", "length"] + assert complete_response.prompt_tokens > 0 + + # Verify chunk structure + for chunk in chunks: + assert chunk.prompt_tokens > 0 + assert chunk.completion_tokens >= 0 + + +@pytest.mark.asyncio +async def test_generate_with_different_sampling_params(grpc_client): + """Test Generate with various sampling parameters.""" + # Test with temperature + request = vllm_engine_pb2.GenerateRequest( + request_id="test-sampling-temp", + tokenized=vllm_engine_pb2.TokenizedInput( + original_text="Hello", + input_ids=[15496], + ), + sampling_params=vllm_engine_pb2.SamplingParams( + temperature=0.8, top_p=0.95, max_tokens=5 + ), + stream=False, + ) + + responses = [r async for r in grpc_client.Generate(request)] + assert len(responses) == 1 + assert responses[0].HasField("complete") + + # Test with top_k + request = vllm_engine_pb2.GenerateRequest( + request_id="test-sampling-topk", + tokenized=vllm_engine_pb2.TokenizedInput( + original_text="Hello", + input_ids=[15496], + ), + sampling_params=vllm_engine_pb2.SamplingParams( + temperature=1.0, top_k=50, max_tokens=5 + ), + stream=False, + ) + + responses = [r async for r in grpc_client.Generate(request)] + assert len(responses) == 1 + assert responses[0].HasField("complete") + + +@pytest.mark.asyncio +async def test_generate_with_stop_strings(grpc_client): + """Test Generate with stop strings.""" + request = vllm_engine_pb2.GenerateRequest( + request_id="test-stop-strings", + tokenized=vllm_engine_pb2.TokenizedInput( + original_text="Hello", + input_ids=[15496], + ), + sampling_params=vllm_engine_pb2.SamplingParams( + temperature=0.0, + max_tokens=20, + stop=["\n", "END"], + ), + stream=False, + ) + + responses = [r async for r in grpc_client.Generate(request)] + assert len(responses) == 1 + assert responses[0].HasField("complete") + + complete = responses[0].complete + assert complete.finish_reason in ["stop", "length"] + + +@pytest.mark.asyncio +async def test_generate_multiple_requests(grpc_client): + """Test handling multiple concurrent Generate requests.""" + + async def make_request(request_id: str): + request = vllm_engine_pb2.GenerateRequest( + request_id=request_id, + tokenized=vllm_engine_pb2.TokenizedInput( + original_text="Hello", + input_ids=[15496], + ), + sampling_params=vllm_engine_pb2.SamplingParams( + temperature=0.0, max_tokens=5 + ), + stream=False, + ) + + responses = [r async for r in grpc_client.Generate(request)] + return responses[0] + + # Send multiple requests concurrently + tasks = [make_request(f"test-concurrent-{i}") for i in range(3)] + responses = await asyncio.gather(*tasks) + + # Verify all requests completed successfully + assert len(responses) == 3 + for i, response in enumerate(responses): + assert response.HasField("complete") + + +@pytest.mark.asyncio +async def test_generate_with_seed(grpc_client): + """Test Generate with a fixed seed for reproducibility.""" + + def make_request(request_id: str, seed: int): + return vllm_engine_pb2.GenerateRequest( + request_id=request_id, + tokenized=vllm_engine_pb2.TokenizedInput( + original_text="The future of AI is", + input_ids=[464, 2003, 286, 9552, 318], + ), + sampling_params=vllm_engine_pb2.SamplingParams( + temperature=1.0, max_tokens=10, seed=seed + ), + stream=False, + ) + + # Make two requests with the same seed + request1 = make_request("test-seed-1", 42) + request2 = make_request("test-seed-2", 42) + + response_list1 = [r async for r in grpc_client.Generate(request1)] + response_list2 = [r async for r in grpc_client.Generate(request2)] + + # Both should complete successfully + assert len(response_list1) == 1 + assert len(response_list2) == 1 + assert response_list1[0].HasField("complete") + assert response_list2[0].HasField("complete") + + # With the same seed, outputs should be identical + output_ids1 = list(response_list1[0].complete.output_ids) + output_ids2 = list(response_list2[0].complete.output_ids) + assert output_ids1 == output_ids2 + + +@pytest.mark.asyncio +async def test_generate_error_handling(grpc_client): + """Test error handling in Generate RPC.""" + # Request with invalid top_p value (-33) + request = vllm_engine_pb2.GenerateRequest( + request_id="test-error-invalid-topp", + sampling_params=vllm_engine_pb2.SamplingParams( + temperature=0.0, max_tokens=10, top_p=-33 + ), + stream=False, + ) + + # Should raise an error response + with pytest.raises(grpc.RpcError) as exc_info: + _ = [r async for r in grpc_client.Generate(request)] + + assert exc_info.value.code() == grpc.StatusCode.INVALID_ARGUMENT + assert "top_p must be in (0, 1], got -33.0" in exc_info.value.details() + + +@pytest.mark.asyncio +async def test_abort_request(grpc_client): + """Test the out-of-band Abort RPC.""" + request_id = "test-abort-1" + + # Start a long-running streaming generate request + generate_request = vllm_engine_pb2.GenerateRequest( + request_id=request_id, + tokenized=vllm_engine_pb2.TokenizedInput( + original_text="Hello", + input_ids=[15496], + ), + sampling_params=vllm_engine_pb2.SamplingParams( + temperature=0.0, + min_tokens=500, + max_tokens=500, # Request many tokens to ensure it runs long enough + ), + stream=True, + ) + + # Track whether we were aborted + was_aborted = False + received_chunks = 0 + + async def run_generate(): + nonlocal was_aborted, received_chunks + async for response in grpc_client.Generate(generate_request): + if response.HasField("chunk"): + received_chunks += 1 + + if response.HasField("complete"): + complete = response.complete + was_aborted = complete.finish_reason == "abort" + else: + was_aborted = False + + async def abort_after_delay(): + # Small delay to ensure generate has started + await asyncio.sleep(0.1) + abort_request = vllm_engine_pb2.AbortRequest(request_ids=[request_id]) + await grpc_client.Abort(abort_request) + + # Run generate and abort concurrently + await asyncio.gather(run_generate(), abort_after_delay()) + + # The request should have been aborted (received final chunk with + # "abort" finish reason) and finished early due to the abort. + assert was_aborted and received_chunks < 500, ( + "Request should have been aborted before generating all 500 tokens" + ) diff --git a/vllm/entrypoints/grpc_server.py b/vllm/entrypoints/grpc_server.py new file mode 100755 index 000000000000..2778385c9998 --- /dev/null +++ b/vllm/entrypoints/grpc_server.py @@ -0,0 +1,531 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# mypy: ignore-errors +""" +vLLM gRPC Server + +Starts a gRPC server for vLLM using the VllmEngine protocol. + +Usage: + python -m vllm.entrypoints.grpc_server --model + +Example: + python -m vllm.entrypoints.grpc_server \ + --model meta-llama/Llama-2-7b-hf \ + --host 0.0.0.0 \ + --port 50051 +""" + +import argparse +import asyncio +import signal +import sys +import time +from collections.abc import AsyncGenerator + +import grpc +import uvloop +from grpc_reflection.v1alpha import reflection + +from vllm import SamplingParams, TextPrompt, TokensPrompt +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import RequestOutputKind, StructuredOutputsParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger(__name__) + + +class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer): + """ + gRPC servicer implementing the VllmEngine service. + + Handles 6 RPCs: + - Generate: Streaming text generation + - Embed: Embeddings (TODO) + - HealthCheck: Health probe + - Abort: Cancel requests out-of-band + - GetModelInfo: Model metadata + - GetServerInfo: Server state + """ + + def __init__(self, async_llm: AsyncLLM, start_time: float): + """ + Initialize the servicer. + + Args: + async_llm: The AsyncLLM instance + start_time: The server start time, in seconds since epoch + """ + self.async_llm = async_llm + self.start_time = start_time + logger.info("VllmEngineServicer initialized") + + async def Generate( + self, + request: vllm_engine_pb2.GenerateRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncGenerator[vllm_engine_pb2.GenerateResponse, None]: + """ + Handle streaming generation requests. + + Args: + request: The GenerateRequest protobuf + context: gRPC context + + Yields: + GenerateResponse protobuf messages (streaming) + """ + request_id = request.request_id + logger.debug("Generate request %s received.", request_id) + + try: + # Extract tokenized input + if request.WhichOneof("input") == "tokenized": + prompt: TokensPrompt = { + "prompt_token_ids": list(request.tokenized.input_ids) + } + if request.tokenized.original_text: + prompt["prompt"] = request.tokenized.original_text + else: + prompt: TextPrompt = {"prompt": request.text} + + # Build sampling params with detokenize=False + sampling_params = self._sampling_params_from_proto( + request.sampling_params, stream=request.stream + ) + + async for output in self.async_llm.generate( + prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + ): + # Convert vLLM output to protobuf + # For streaming, always send chunks + if request.stream: + yield self._chunk_response(output) + + # Send complete response when finished + if output.finished: + yield self._complete_response(output) + + except ValueError as e: + # Invalid request error (equiv to 400). + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except Exception as e: + logger.exception("Error in Generate for request %s", request_id) + await context.abort(grpc.StatusCode.INTERNAL, str(e)) + + async def Embed( + self, + request: vllm_engine_pb2.EmbedRequest, + context: grpc.aio.ServicerContext, + ) -> vllm_engine_pb2.EmbedResponse: + """ + Handle embedding requests. + + TODO: Implement in Phase 4 + + Args: + request: The EmbedRequest protobuf + context: gRPC context + + Returns: + EmbedResponse protobuf + """ + logger.warning("Embed RPC not yet implemented") + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, "Embed RPC not yet implemented" + ) + + async def HealthCheck( + self, + request: vllm_engine_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> vllm_engine_pb2.HealthCheckResponse: + """ + Handle health check requests. + + Args: + request: The HealthCheckRequest protobuf + context: gRPC context + + Returns: + HealthCheckResponse protobuf + """ + is_healthy = not self.async_llm.errored + message = "Health" if is_healthy else "Engine is not alive" + + logger.debug("HealthCheck request: healthy=%s, message=%s", is_healthy, message) + + return vllm_engine_pb2.HealthCheckResponse(healthy=is_healthy, message=message) + + async def Abort( + self, + request: vllm_engine_pb2.AbortRequest, + context: grpc.aio.ServicerContext, + ) -> vllm_engine_pb2.AbortResponse: + """ + Out-of-band abort requests. + + Args: + request: The AbortRequest protobuf + context: gRPC context + + Returns: + AbortResponse protobuf + """ + request_ids = request.request_ids + logger.debug("Abort requests: %s", request_ids) + + await self.async_llm.abort(request_ids) + return vllm_engine_pb2.AbortResponse() + + async def GetModelInfo( + self, + request: vllm_engine_pb2.GetModelInfoRequest, + context: grpc.aio.ServicerContext, + ) -> vllm_engine_pb2.GetModelInfoResponse: + """ + Handle model info requests. + + Args: + request: The GetModelInfoRequest protobuf + context: gRPC context + + Returns: + GetModelInfoResponse protobuf + """ + model_config = self.async_llm.model_config + + return vllm_engine_pb2.GetModelInfoResponse( + model_path=model_config.model, + is_generation=model_config.runner_type == "generate", + max_context_length=model_config.max_model_len, + vocab_size=model_config.get_vocab_size(), + supports_vision=model_config.is_multimodal_model, + ) + + async def GetServerInfo( + self, + request: vllm_engine_pb2.GetServerInfoRequest, + context: grpc.aio.ServicerContext, + ) -> vllm_engine_pb2.GetServerInfoResponse: + """ + Handle server info requests. + + Args: + request: The GetServerInfoRequest protobuf + context: gRPC context + + Returns: + GetServerInfoResponse protobuf + """ + num_requests = self.async_llm.output_processor.get_num_unfinished_requests() + + return vllm_engine_pb2.GetServerInfoResponse( + active_requests=num_requests, + is_paused=False, # TODO + last_receive_timestamp=time.time(), # TODO looks wrong? + uptime_seconds=time.time() - self.start_time, + server_type="vllm-grpc", + ) + + # ========== Helper methods ========== + + @staticmethod + def _sampling_params_from_proto( + params: vllm_engine_pb2.SamplingParams, stream: bool = True + ) -> SamplingParams: + """ + Convert protobuf SamplingParams to vLLM SamplingParams. + + Args: + params: Protobuf SamplingParams message + stream: Whether streaming is enabled + + Returns: + vLLM SamplingParams with detokenize=False and structured_outputs + """ + # Build stop sequences + stop = list(params.stop) if params.stop else None + stop_token_ids = list(params.stop_token_ids) if params.stop_token_ids else None + + # Handle structured outputs constraints + structured_outputs = None + constraint_field = params.WhichOneof("constraint") + if constraint_field: + if constraint_field == "json_schema": + structured_outputs = StructuredOutputsParams(json=params.json_schema) + elif constraint_field == "regex": + structured_outputs = StructuredOutputsParams(regex=params.regex) + elif constraint_field == "grammar": + structured_outputs = StructuredOutputsParams(grammar=params.grammar) + elif constraint_field == "structural_tag": + structured_outputs = StructuredOutputsParams( + structural_tag=params.structural_tag + ) + elif constraint_field == "json_object": + structured_outputs = StructuredOutputsParams( + json_object=params.json_object + ) + elif constraint_field == "choice": + structured_outputs = StructuredOutputsParams( + choice=list(params.choice.choices) + ) + + # Create SamplingParams + # output_kind=DELTA: Return only new tokens in each chunk (for streaming) + return SamplingParams( + temperature=params.temperature if params.HasField("temperature") else 1.0, + top_p=params.top_p if params.top_p != 0.0 else 1.0, + top_k=params.top_k, + min_p=params.min_p, + frequency_penalty=params.frequency_penalty, + presence_penalty=params.presence_penalty, + repetition_penalty=params.repetition_penalty + if params.repetition_penalty != 0.0 + else 1.0, + max_tokens=params.max_tokens if params.HasField("max_tokens") else None, + min_tokens=params.min_tokens, + stop=stop, + stop_token_ids=stop_token_ids, + skip_special_tokens=params.skip_special_tokens, + spaces_between_special_tokens=params.spaces_between_special_tokens, + ignore_eos=params.ignore_eos, + n=params.n if params.n > 0 else 1, + logprobs=params.logprobs if params.HasField("logprobs") else None, + prompt_logprobs=params.prompt_logprobs + if params.HasField("prompt_logprobs") + else None, + seed=params.seed if params.HasField("seed") else None, + include_stop_str_in_output=params.include_stop_str_in_output, + logit_bias=dict(params.logit_bias) if params.logit_bias else None, + truncate_prompt_tokens=params.truncate_prompt_tokens + if params.HasField("truncate_prompt_tokens") + else None, + structured_outputs=structured_outputs, + # detokenize must be True if stop strings are used + detokenize=bool(stop), + output_kind=RequestOutputKind.DELTA + if stream + else RequestOutputKind.FINAL_ONLY, + ) + + @staticmethod + def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse: + """ + Build a streaming chunk response from vLLM output. + When output_kind=DELTA, vLLM returns only new tokens automatically. + + Args: + output: vLLM RequestOutput (with delta tokens when output_kind=DELTA) + + Returns: + GenerateResponse with chunk field set + """ + # Get the completion output (first one if n > 1) + completion = output.outputs[0] if output.outputs else None + + if completion is None: + # Empty chunk + return vllm_engine_pb2.GenerateResponse( + chunk=vllm_engine_pb2.GenerateStreamChunk( + token_ids=[], + prompt_tokens=0, + completion_tokens=0, + cached_tokens=0, + ), + ) + + # When output_kind=DELTA, completion.token_ids contains only new tokens + # vLLM handles the delta logic internally + # completion_tokens = delta count (client will accumulate) + return vllm_engine_pb2.GenerateResponse( + chunk=vllm_engine_pb2.GenerateStreamChunk( + token_ids=completion.token_ids, + prompt_tokens=len(output.prompt_token_ids) + if output.prompt_token_ids + else 0, + completion_tokens=len(completion.token_ids), # Delta count + cached_tokens=output.num_cached_tokens, + ), + ) + + @staticmethod + def _complete_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse: + """ + Build a final completion response from vLLM output. + + Args: + output: vLLM RequestOutput (finished=True) + + Returns: + GenerateResponse with complete field set + """ + # Get the completion output (first one if n > 1) + completion = output.outputs[0] if output.outputs else None + + if completion is None: + # Empty completion + return vllm_engine_pb2.GenerateResponse( + complete=vllm_engine_pb2.GenerateComplete( + output_ids=[], + finish_reason="error", + prompt_tokens=0, + completion_tokens=0, + cached_tokens=0, + ), + ) + + # Build complete response + # When streaming (DELTA mode): completion.token_ids will be empty/last delta + # When non-streaming (FINAL_ONLY mode): completion.token_ids has all tokens + # Client will accumulate token counts for streaming + return vllm_engine_pb2.GenerateResponse( + complete=vllm_engine_pb2.GenerateComplete( + output_ids=completion.token_ids, + finish_reason=completion.finish_reason or "stop", + prompt_tokens=len(output.prompt_token_ids) + if output.prompt_token_ids + else 0, + completion_tokens=len(completion.token_ids), + cached_tokens=output.num_cached_tokens, + ), + ) + + +async def serve_grpc(args: argparse.Namespace): + """ + Main serving function. + + Args: + args: Parsed command line arguments + """ + logger.info("vLLM gRPC server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + start_time = time.time() + + # Create engine args + engine_args = AsyncEngineArgs.from_cli_args(args) + + # Build vLLM config + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.OPENAI_API_SERVER + ) + + # Create AsyncLLM + async_llm = AsyncLLM.from_vllm_config( + vllm_config=vllm_config, + usage_context=UsageContext.OPENAI_API_SERVER, + enable_log_requests=args.enable_log_requests, + disable_log_stats=args.disable_log_stats_server, + ) + + # Create servicer + servicer = VllmEngineServicer(async_llm, start_time) + + # Create gRPC server + server = grpc.aio.server( + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Add servicer to server + vllm_engine_pb2_grpc.add_VllmEngineServicer_to_server(servicer, server) + + # Enable reflection for grpcurl and other tools + service_names = ( + vllm_engine_pb2.DESCRIPTOR.services_by_name["VllmEngine"].full_name, + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(service_names, server) + + # Bind to address + address = f"{args.host}:{args.port}" + server.add_insecure_port(address) + + # Start server + await server.start() + logger.info("vLLM gRPC server started on %s", address) + logger.info("Server is ready to accept requests") + + # Handle shutdown signals + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def signal_handler(): + logger.info("Received shutdown signal") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + # Serve until shutdown signal + try: + await stop_event.wait() + except KeyboardInterrupt: + logger.info("Interrupted by user") + finally: + logger.info("Shutting down vLLM gRPC server...") + + # Stop gRPC server + await server.stop(grace=5.0) + logger.info("gRPC server stopped") + + # Shutdown AsyncLLM + async_llm.shutdown() + logger.info("AsyncLLM engine stopped") + + logger.info("Shutdown complete") + + +def main(): + """Main entry point.""" + parser = FlexibleArgumentParser( + description="vLLM gRPC Server", + ) + + # Server args + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to bind gRPC server to", + ) + parser.add_argument( + "--port", + type=int, + default=50051, + help="Port to bind gRPC server to", + ) + parser.add_argument( + "--disable-log-stats-server", + action="store_true", + help="Disable stats logging on server side", + ) + + # Add vLLM engine args + parser = AsyncEngineArgs.add_cli_args(parser) + + args = parser.parse_args() + + # Run server + try: + uvloop.run(serve_grpc(args)) + except Exception as e: + logger.exception("Server failed: %s", e) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/vllm/grpc/__init__.py b/vllm/grpc/__init__.py new file mode 100644 index 000000000000..b59ee96fb986 --- /dev/null +++ b/vllm/grpc/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +vLLM gRPC protocol definitions. + +This module contains the protocol buffer definitions for vLLM's gRPC API. +The protobuf files are compiled into Python code using grpcio-tools. +""" + +# These imports will be available after protobuf compilation +# from vllm.grpc import vllm_engine_pb2 +# from vllm.grpc import vllm_engine_pb2_grpc + +__all__ = [ + "vllm_engine_pb2", + "vllm_engine_pb2_grpc", +] diff --git a/vllm/grpc/compile_protos.py b/vllm/grpc/compile_protos.py new file mode 100755 index 000000000000..92ad46e160a5 --- /dev/null +++ b/vllm/grpc/compile_protos.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Compile vLLM protobuf definitions into Python code. + +This script uses grpcio-tools to generate *_pb2.py, *_pb2_grpc.py, and +*_pb2.pyi (type stubs) files from the vllm_engine.proto definition. + +NOTE: Proto compilation happens automatically during package build (via setup.py). +This script is provided for developers who want to regenerate protos manually, +e.g., after modifying vllm_engine.proto. + +Usage: + python vllm/grpc/compile_protos.py + +Requirements: + pip install grpcio-tools +""" + +import sys +from pathlib import Path + + +def compile_protos(): + """Compile protobuf definitions.""" + # Get the vllm package root directory + script_dir = Path(__file__).parent + vllm_package_root = script_dir.parent.parent # vllm/vllm/grpc -> vllm/ + + proto_file = script_dir / "vllm_engine.proto" + + if not proto_file.exists(): + print(f"Error: Proto file not found at {proto_file}") + return 1 + + print(f"Compiling protobuf: {proto_file}") + print(f"Output directory: {script_dir}") + + # Compile the proto file + # We use vllm/vllm as the proto_path so that the package is vllm.grpc.engine + try: + from grpc_tools import protoc + + result = protoc.main( + [ + "grpc_tools.protoc", + f"--proto_path={vllm_package_root}", + f"--python_out={vllm_package_root}", + f"--grpc_python_out={vllm_package_root}", + f"--pyi_out={vllm_package_root}", # Generate type stubs + str(script_dir / "vllm_engine.proto"), + ] + ) + + if result == 0: + # Add SPDX headers to generated files + spdx_header = ( + "# SPDX-License-Identifier: Apache-2.0\n" + "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n" + ) + + for generated_file in [ + script_dir / "vllm_engine_pb2.py", + script_dir / "vllm_engine_pb2_grpc.py", + script_dir / "vllm_engine_pb2.pyi", + ]: + if generated_file.exists(): + content = generated_file.read_text() + if not content.startswith("# SPDX-License-Identifier"): + # Add mypy ignore-errors comment for all generated files + header = spdx_header + "# mypy: ignore-errors\n" + generated_file.write_text(header + content) + + print("✓ Protobuf compilation successful!") + print(f" Generated: {script_dir / 'vllm_engine_pb2.py'}") + print(f" Generated: {script_dir / 'vllm_engine_pb2_grpc.py'}") + print(f" Generated: {script_dir / 'vllm_engine_pb2.pyi'} (type stubs)") + return 0 + else: + print(f"Error: protoc returned {result}") + return result + + except ImportError: + print("Error: grpcio-tools not installed") + print("Install with: pip install grpcio-tools") + return 1 + except Exception as e: + print(f"Error during compilation: {e}") + return 1 + + +if __name__ == "__main__": + sys.exit(compile_protos()) diff --git a/vllm/grpc/vllm_engine.proto b/vllm/grpc/vllm_engine.proto new file mode 100644 index 000000000000..bbb1b9b00370 --- /dev/null +++ b/vllm/grpc/vllm_engine.proto @@ -0,0 +1,195 @@ +syntax = "proto3"; + +package vllm.grpc.engine; + +// Service definition for vLLM engine communication +// This protocol is designed for efficient binary communication between +// the Rust router and vLLM Python engine (AsyncLLM). +service VllmEngine { + // Submit a generation request (supports streaming) + rpc Generate(GenerateRequest) returns (stream GenerateResponse); + + // Submit an embedding request + rpc Embed(EmbedRequest) returns (EmbedResponse); + + // Health check + rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); + + // Abort a running request + rpc Abort(AbortRequest) returns (AbortResponse); + + // Get model information + rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse); + + // Get server information + rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse); +} + +// ===================== +// Common Types +// ===================== + +// Sampling parameters for text generation +message SamplingParams { + optional float temperature = 1; + float top_p = 2; + uint32 top_k = 3; + float min_p = 4; + float frequency_penalty = 5; + float presence_penalty = 6; + float repetition_penalty = 7; + + optional uint32 max_tokens = 8; + uint32 min_tokens = 9; + + repeated string stop = 10; + repeated uint32 stop_token_ids = 11; + + bool skip_special_tokens = 12; + bool spaces_between_special_tokens = 13; + bool ignore_eos = 14; + + uint32 n = 15; // Number of parallel samples + + // Logprobs configuration + optional int32 logprobs = 22; // Number of log probabilities per output token (-1 for all) + optional int32 prompt_logprobs = 23; // Number of log probabilities per prompt token (-1 for all) + + // Additional vLLM fields + optional int32 seed = 24; // Random seed for reproducibility + bool include_stop_str_in_output = 25; // Whether to include stop strings in output + map logit_bias = 26; // Token ID to bias mapping (-100 to 100) + optional int32 truncate_prompt_tokens = 27; // Prompt truncation (-1 for model max) + + // Structured outputs (one of) - matches vLLM's StructuredOutputsParams + oneof constraint { + string json_schema = 16; // JSON schema for structured output + string regex = 17; // Regex pattern + string grammar = 18; // Grammar/EBNF for structured output + string structural_tag = 19; // Structural tag (e.g., Harmony models) + bool json_object = 20; // Force JSON object output + ChoiceConstraint choice = 21; // List of allowed choices + } +} + +// Choice constraint for structured outputs +message ChoiceConstraint { + repeated string choices = 1; +} + +// Pre-tokenized input from Rust router +message TokenizedInput { + string original_text = 1; // For reference/debugging + repeated uint32 input_ids = 2; // Actual token IDs to process +} + +// ===================== +// Generate Request +// ===================== + +message GenerateRequest { + string request_id = 1; + + // Prompt input + oneof input { + TokenizedInput tokenized = 2; + string text = 3; + } + + // Generation parameters (includes logprobs config) + SamplingParams sampling_params = 4; + + // Streaming + bool stream = 5; +} + +// ===================== +// Generate Response +// ===================== + +message GenerateResponse { + oneof response { + GenerateStreamChunk chunk = 1; // For streaming + GenerateComplete complete = 2; // For final/non-streaming + } +} + +message GenerateStreamChunk { + repeated uint32 token_ids = 1; // Incremental tokens + uint32 prompt_tokens = 2; + uint32 completion_tokens = 3; + uint32 cached_tokens = 4; + + // Logprobs support (TODO: implement in Phase 4) + // OutputLogProbs output_logprobs = 5; + // InputLogProbs input_logprobs = 6; // Only in first chunk +} + +message GenerateComplete { + repeated uint32 output_ids = 1; // All output tokens + string finish_reason = 2; // "stop", "length", "abort" + uint32 prompt_tokens = 3; + uint32 completion_tokens = 4; + uint32 cached_tokens = 5; + + // Logprobs support (TODO: implement in Phase 4) + // OutputLogProbs output_logprobs = 6; + // InputLogProbs input_logprobs = 7; +} + +// ===================== +// Embedding Request +// ===================== + +message EmbedRequest { + string request_id = 1; + TokenizedInput tokenized = 2; +} + +message EmbedResponse { + repeated float embedding = 1; + uint32 prompt_tokens = 2; + uint32 embedding_dim = 3; +} + +// ===================== +// Management Operations +// ===================== + +message HealthCheckRequest {} + +message HealthCheckResponse { + bool healthy = 1; + string message = 2; +} + +message AbortRequest { + repeated string request_ids = 1; +} + +message AbortResponse { +} + +// ===================== +// Model and Server Info +// ===================== + +message GetModelInfoRequest {} + +message GetModelInfoResponse { + string model_path = 1; + bool is_generation = 2; + uint32 max_context_length = 3; + uint32 vocab_size = 4; + bool supports_vision = 5; +} + +message GetServerInfoRequest {} + +message GetServerInfoResponse { + uint32 active_requests = 1; + bool is_paused = 2; + double last_receive_timestamp = 3; + double uptime_seconds = 4; + string server_type = 5; // "vllm-grpc" +}