Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/providers/inference/remote_ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |

## Sample Configuration

Expand Down
6 changes: 6 additions & 0 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,12 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...

async def update_registered_models(
self,
provider_id: str,
models: list[Model],
) -> None: ...


class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
Expand Down
11 changes: 4 additions & 7 deletions llama_stack/distribution/library_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __init__(
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data

self.loop = asyncio.new_event_loop()

def initialize(self):
if in_notebook():
import nest_asyncio
Expand All @@ -136,7 +138,7 @@ def initialize(self):
if not self.skip_logger_removal:
self._remove_root_logger_handlers()

return asyncio.run(self.async_client.initialize())
return self.loop.run_until_complete(self.async_client.initialize())

def _remove_root_logger_handlers(self):
"""
Expand All @@ -149,10 +151,7 @@ def _remove_root_logger_handlers(self):
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")

def request(self, *args, **kwargs):
# NOTE: We are using AsyncLlamaStackClient under the hood
# A new event loop is needed to convert the AsyncStream
# from async client into SyncStream return type for streaming
loop = asyncio.new_event_loop()
loop = self.loop
asyncio.set_event_loop(loop)

if kwargs.get("stream"):
Expand All @@ -169,7 +168,6 @@ def sync_generator():
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()

return sync_generator()
else:
Expand All @@ -179,7 +177,6 @@ def sync_generator():
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return result


Expand Down
31 changes: 31 additions & 0 deletions llama_stack/distribution/routing_tables/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,34 @@ async def unregister_model(self, model_id: str) -> None:
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)

async def update_registered_models(
self,
provider_id: str,
models: list[Model],
) -> None:
existing_models = await self.get_all_with_type("model")

# we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
if model.provider_id == provider_id:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)

for model in models:
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]

logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,
provider_resource_id=model.provider_resource_id,
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
)
)
4 changes: 3 additions & 1 deletion llama_stack/providers/remote/inference/ollama/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

from typing import Any

from pydantic import BaseModel
from pydantic import BaseModel, Field

DEFAULT_OLLAMA_URL = "http://localhost:11434"


class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")

@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
Expand Down
85 changes: 77 additions & 8 deletions llama_stack/providers/remote/inference/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the root directory of this source tree.


import asyncio
import base64
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
Expand Down Expand Up @@ -91,23 +92,88 @@ class OllamaInferenceAdapter(
InferenceProvider,
ModelsProtocolPrivate,
):
# automatically set by the resolver when instantiating the provider
__provider_id__: str

def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.url = config.url
self.config = config
self._client = None
self._openai_client = None

@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
if self._client is None:
self._client = AsyncClient(host=self.config.url)
return self._client

@property
def openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
if self._openai_client is None:
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
return self._openai_client

async def initialize(self) -> None:
logger.debug(f"checking connectivity to Ollama at `{self.url}`...")
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
health_response = await self.health()
if health_response["status"] == HealthStatus.ERROR:
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
logger.warning(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
)

if self.config.refresh_models:
logger.debug("ollama starting background model refresh task")
self._refresh_task = asyncio.create_task(self._refresh_models())

def cb(task):
if task.cancelled():
import traceback

logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")

self._refresh_task.add_done_callback(cb)

async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1

if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")

provider_id = self.__provider_id__
while True:
try:
response = await self.client.list()
except Exception as e:
logger.warning(f"Failed to list models: {str(e)}")
await asyncio.sleep(self.config.refresh_models_interval)
continue

models = []
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
# unfortunately, ollama does not provide embedding dimension in the model list :(
# we should likely add a hard-coded mapping of model name to embedding dimension
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {},
model_type=model_type,
)
)
await self.model_store.update_registered_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)")

await asyncio.sleep(self.config.refresh_models_interval)

async def health(self) -> HealthResponse:
"""
Expand All @@ -124,7 +190,12 @@ async def health(self) -> HealthResponse:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")

async def shutdown(self) -> None:
pass
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()

self._client = None
self._openai_client = None

async def unregister_model(self, model_id: str) -> None:
pass
Expand Down Expand Up @@ -354,8 +425,6 @@ async def register_model(self, model: Model) -> Model:
raise ValueError("Model provider_resource_id cannot be None")

if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
# TODO: you should pull here only if the model is not found in a list
response = await self.client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id)
Expand Down
Loading