Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,19 @@ async def init_render_app_state(
directly from the :class:`~vllm.config.VllmConfig`.
"""
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.plugins.io_processors import get_io_processor
from vllm.renderers import renderer_from_config

served_model_names = args.served_model_name or [args.model]
model_registry = OpenAIModelRegistry(
model_config=vllm_config.model_config,
base_model_paths=[
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
],
)

if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
Expand All @@ -435,7 +443,7 @@ async def init_render_app_state(
model_config=vllm_config.model_config,
renderer=renderer,
io_processor=io_processor,
served_model_names=served_model_names,
model_registry=model_registry,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
Expand All @@ -447,8 +455,7 @@ async def init_render_app_state(
log_error_stack=args.log_error_stack,
)

# Expose models endpoint via the render handler.
state.openai_serving_models = state.openai_serving_render
state.openai_serving_models = model_registry

state.vllm_config = vllm_config
# Disable stats logging — there is no engine to poll.
Expand Down
4 changes: 1 addition & 3 deletions vllm/entrypoints/openai/generate/api_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ async def init_generate_state(
model_config=engine_client.model_config,
renderer=engine_client.renderer,
io_processor=engine_client.io_processor,
served_model_names=[
mp.name for mp in state.openai_serving_models.base_model_paths
],
model_registry=state.openai_serving_models.registry,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
Expand Down
81 changes: 58 additions & 23 deletions vllm/entrypoints/openai/models/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from http import HTTPStatus

from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
Expand All @@ -27,6 +28,51 @@
logger = init_logger(__name__)


class OpenAIModelRegistry:
"""Read-only view of the loaded base models with no engine dependency.

Suitable for CPU-only / render-only contexts that have no engine client
and no LoRA support.
"""

def __init__(
self,
model_config: ModelConfig,
base_model_paths: list[BaseModelPath],
) -> None:
self.model_config = model_config
self.base_model_paths = base_model_paths

def is_base_model(self, model_name: str) -> bool:
return any(model.name == model_name for model in self.base_model_paths)

async def check_model(self, model_name: str | None) -> ErrorResponse | None:
"""Return an ErrorResponse if model_name is not served, else None."""
if not model_name or self.is_base_model(model_name):
return None
return create_error_response(
message=f"The model `{model_name}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
param="model",
)

async def show_available_models(self) -> ModelList:
"""Show available models (base models only)."""
max_model_len = self.model_config.max_model_len
return ModelList(
data=[
ModelCard(
id=base_model.name,
max_model_len=max_model_len,
root=base_model.model_path,
permission=[ModelPermission()],
)
for base_model in self.base_model_paths
]
)


class OpenAIServingModels:
"""Shared instance to hold data about the loaded base model(s) and adapters.

Expand All @@ -45,6 +91,11 @@ def __init__(
):
super().__init__()

self.registry = OpenAIModelRegistry(
model_config=engine_client.model_config,
base_model_paths=base_model_paths,
)

self.engine_client = engine_client
self.base_model_paths = base_model_paths

Expand Down Expand Up @@ -79,34 +130,18 @@ async def init_static_loras(self):
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.error.message)

def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths)
def is_base_model(self, model_name: str) -> bool:
return self.registry.is_base_model(model_name)

def model_name(self, lora_request: LoRARequest | None = None) -> str:
"""Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora_request is not None:
return lora_request.lora_name
return self.base_model_paths[0].name

async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all adapters."""
max_model_len = self.model_config.max_model_len

model_cards = [
ModelCard(
id=base_model.name,
max_model_len=max_model_len,
root=base_model.model_path,
permission=[ModelPermission()],
)
for base_model in self.base_model_paths
]
"""Show available models. This includes the base model and all
adapters."""
model_list = await self.registry.show_available_models()
lora_cards = [
ModelCard(
id=lora.lora_name,
Expand All @@ -118,8 +153,8 @@ async def show_available_models(self) -> ModelList:
)
for lora in self.lora_requests.values()
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
model_list.data.extend(lora_cards)
return model_list

async def load_lora_adapter(
self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None
Expand Down
37 changes: 4 additions & 33 deletions vllm/entrypoints/serve/render/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
ModelCard,
ModelList,
ModelPermission,
)
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
from vllm.entrypoints.openai.parser.harmony_utils import (
get_developer_message,
get_system_message,
Expand All @@ -46,7 +44,7 @@ def __init__(
model_config: ModelConfig,
renderer: BaseRenderer,
io_processor: Any,
served_model_names: list[str],
model_registry: OpenAIModelRegistry,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
Expand All @@ -61,7 +59,7 @@ def __init__(
self.model_config = model_config
self.renderer = renderer
self.io_processor = io_processor
self.served_model_names = served_model_names
self.model_registry = model_registry
self.request_logger = request_logger
self.chat_template = chat_template
self.chat_template_content_format: ChatTemplateContentFormatOption = (
Expand Down Expand Up @@ -252,21 +250,6 @@ def _make_request_with_harmony(

return messages, [engine_prompt]

async def show_available_models(self) -> ModelList:
"""Returns the models served by this render server."""
max_model_len = self.model_config.max_model_len
return ModelList(
data=[
ModelCard(
id=name,
max_model_len=max_model_len,
root=self.model_config.model,
permission=[ModelPermission()],
)
for name in self.served_model_names
]
)

def create_error_response(
self,
message: str | Exception,
Expand All @@ -276,23 +259,11 @@ def create_error_response(
) -> ErrorResponse:
return create_error_response(message, err_type, status_code, param)

def _is_model_supported(self, model_name: str) -> bool:
"""Simplified from OpenAIServing._is_model_supported (no LoRA support)."""
return model_name in self.served_model_names

async def _check_model(
self,
request: Any,
) -> ErrorResponse | None:
"""Simplified from OpenAIServing._check_model (no LoRA support)."""
if self._is_model_supported(request.model):
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
param="model",
)
return await self.model_registry.check_model(request.model)

def _validate_chat_template(
self,
Expand Down
Loading