Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/providers/inference/remote_fireworks.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire

| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |

Expand Down
3 changes: 1 addition & 2 deletions docs/source/providers/inference/remote_ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |

## Sample Configuration

Expand Down
1 change: 1 addition & 0 deletions docs/source/providers/inference/remote_together.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Together AI inference provider for open-source models and collaborative AI devel

| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |

Expand Down
1 change: 0 additions & 1 deletion docs/source/providers/inference/remote_vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |

## Sample Configuration

Expand Down
6 changes: 0 additions & 6 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,12 +819,6 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...

async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None: ...


class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
Expand Down
6 changes: 6 additions & 0 deletions llama_stack/distribution/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
RoutingKey = str | list[str]


class RegistryEntrySource(StrEnum):
via_register_api = "via_register_api"
listed_from_provider = "listed_from_provider"


class User(BaseModel):
principal: str
# further attributes that may be used for access control decisions
Expand All @@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
resource. This can be used to constrain access to the resource."""

owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.via_register_api


# Use the extended Resource for all routable objects
Expand Down
8 changes: 7 additions & 1 deletion llama_stack/distribution/library_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,13 @@ def initialize(self):
if not self.skip_logger_removal:
self._remove_root_logger_handlers()

return self.loop.run_until_complete(self.async_client.initialize())
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)

def _remove_root_logger_handlers(self):
"""
Expand Down
4 changes: 3 additions & 1 deletion llama_stack/distribution/routing_tables/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()

async def refresh(self) -> None:
pass

async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable
Expand Down Expand Up @@ -206,7 +209,6 @@ async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObje
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj

else:
await self.dist_registry.register(obj)
return obj
Expand Down
42 changes: 33 additions & 9 deletions llama_stack/distribution/routing_tables/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.log import get_logger

Expand All @@ -19,6 +20,26 @@


class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()

async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
if not (refresh or provider_id in self.listed_providers):
continue

try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue

self.listed_providers.add(provider_id)
if models is None:
continue

await self.update_registered_models(provider_id, models)

async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))

Expand Down Expand Up @@ -81,6 +102,7 @@ async def register_model(
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
source=RegistryEntrySource.via_register_api,
)
registered_model = await self.register_object(model)
return registered_model
Expand All @@ -91,7 +113,7 @@ async def unregister_model(self, model_id: str) -> None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)

async def update_registered_llm_models(
async def update_registered_models(
self,
provider_id: str,
models: list[Model],
Expand All @@ -102,18 +124,19 @@ async def update_registered_llm_models(
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
if model.provider_id != provider_id:
continue
if model.source == RegistryEntrySource.via_register_api:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
continue

logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)

for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]
# avoid overwriting a non-provider-registered model entry
continue

logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
Expand All @@ -123,5 +146,6 @@ async def update_registered_llm_models(
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
source=RegistryEntrySource.listed_from_provider,
)
)
29 changes: 29 additions & 0 deletions llama_stack/distribution/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import asyncio
import importlib.resources
import os
import re
Expand Down Expand Up @@ -38,6 +39,7 @@
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
Expand Down Expand Up @@ -90,6 +92,9 @@ class LlamaStack(
]


REGISTRY_REFRESH_INTERVAL_SECONDS = 300


async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
Expand Down Expand Up @@ -324,9 +329,33 @@ async def construct_stack(
add_internal_implementations(impls, run_config)

await register_resources(run_config, impls)

task = asyncio.create_task(refresh_registry(impls))

def cb(task):
import traceback

if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")

task.add_done_callback(cb)
return impls


async def refresh_registry(impls: dict[Api, Any]):
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True:
for routing_table in routing_tables:
await routing_table.refresh()

await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)


def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"

Expand Down
11 changes: 11 additions & 0 deletions llama_stack/providers/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ async def register_model(self, model: Model) -> Model: ...

async def unregister_model(self, model_id: str) -> None: ...

# the Stack router will query each provider for their list of models
# if a `refresh_interval_seconds` is provided, this method will be called
# periodically to refresh the list of models
#
# NOTE: each model returned will be registered with the model registry. this means
# a callback to the `register_model()` method will be made. this is duplicative and
# may be removed in the future.
async def list_models(self) -> list[Model] | None: ...

async def should_refresh_models(self) -> bool: ...


class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()

async def should_refresh_models(self) -> bool:
return False

async def list_models(self) -> list[Model] | None:
return None

async def unregister_model(self, model_id: str) -> None:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
Expand All @@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
InferenceProvider,
ModelsProtocolPrivate,
):
__provider_id__: str

def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config

Expand All @@ -50,6 +53,22 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass

async def should_refresh_models(self) -> bool:
return False

async def list_models(self) -> list[Model] | None:
return [
Model(
identifier="all-MiniLM-L6-v2",
provider_resource_id="all-MiniLM-L6-v2",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 384,
},
model_type=ModelType.embedding,
),
]

async def register_model(self, model: Model) -> Model:
return model

Expand Down
5 changes: 3 additions & 2 deletions llama_stack/providers/remote/inference/fireworks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

from typing import Any

from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr

from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type


@json_schema_type
class FireworksImplConfig(BaseModel):
class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config

async def initialize(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions llama_stack/providers/remote/inference/ollama/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)

@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
Expand Down
Loading
Loading