Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/providers/inference/remote_vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |

## Sample Configuration

```yaml
url: ${env.VLLM_URL}
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...

async def update_registered_models(
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
Expand Down
8 changes: 6 additions & 2 deletions llama_stack/distribution/routing_tables/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def unregister_model(self, model_id: str) -> None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)

async def update_registered_models(
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
Expand All @@ -92,12 +92,16 @@ async def update_registered_models(
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
if model.provider_id == provider_id:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)

for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]

Expand Down
8 changes: 4 additions & 4 deletions llama_stack/providers/remote/inference/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,18 @@ async def _refresh_models(self) -> None:
models = []
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
# unfortunately, ollama does not provide embedding dimension in the model list :(
# we should likely add a hard-coded mapping of model name to embedding dimension
if model_type == ModelType.embedding:
continue
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {},
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_models(provider_id, models)
await self.model_store.update_registered_llm_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)")

await asyncio.sleep(self.config.refresh_models_interval)
Expand Down
10 changes: 9 additions & 1 deletion llama_stack/providers/remote/inference/vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=True,
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
)
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)

@field_validator("tls_verify")
@classmethod
Expand All @@ -46,7 +54,7 @@ def validate_tls_verify(cls, v):
@classmethod
def sample_run_config(
cls,
url: str = "${env.VLLM_URL}",
url: str = "${env.VLLM_URL:=}",
**kwargs,
):
return {
Expand Down
78 changes: 74 additions & 4 deletions llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any

Expand Down Expand Up @@ -38,6 +38,7 @@
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ModelStore,
OpenAIChatCompletion,
OpenAICompletion,
OpenAIEmbeddingData,
Expand All @@ -54,6 +55,7 @@
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import (
Expand Down Expand Up @@ -84,7 +86,7 @@

from .config import VLLMInferenceAdapterConfig

log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")


def build_hf_repo_model_entries():
Expand Down Expand Up @@ -288,16 +290,76 @@ async def _process_vllm_chat_completion_stream_response(


class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None

def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
self.client = None

async def initialize(self) -> None:
pass
if not self.config.url:
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return

if self.config.refresh_models:
self._refresh_task = asyncio.create_task(self._refresh_models())

def cb(task):
import traceback

if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")

self._refresh_task.add_done_callback(cb)

async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1

if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")

self._lazy_initialize_client()
assert self.client is not None # mypy
while True:
try:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)

async def shutdown(self) -> None:
pass
if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None

async def unregister_model(self, model_id: str) -> None:
pass
Expand All @@ -312,6 +374,9 @@ async def health(self) -> HealthResponse:
HealthResponse: A dictionary containing the health status.
"""
try:
if not self.config.url:
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")

client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
Expand All @@ -327,6 +392,11 @@ def _lazy_initialize_client(self):
if self.client is not None:
return

if not self.config.url:
raise ValueError(
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or set the VLLM_URL environment variable

is this correct in general? (non-starter templates)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ehhuang yes I think it is correct, because any template will have the same config struct derived from sample_run_config()

)

log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()

Expand Down
2 changes: 1 addition & 1 deletion llama_stack/templates/starter/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ providers:
- provider_id: ${env.ENABLE_VLLM:=__disabled__}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL}
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
Expand Down
Loading