From a193c9fc3f30c99c0897a49f38a2ba9c2e8f2f2c Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 7 Apr 2025 21:27:06 -0400 Subject: [PATCH 01/19] Add OpenAI-Compatible models, completions, chat/completions endpoints This stubs in some OpenAI server-side compatibility with three new endpoints: /v1/openai/v1/models /v1/openai/v1/completions /v1/openai/v1/chat/completions This gives common inference apps using OpenAI clients the ability to talk to Llama Stack using an endpoint like http://localhost:8321/v1/openai/v1 . The two "v1" instances in there isn't awesome, but the thinking is that Llama Stack's API is v1 and then our OpenAI compatibility layer is compatible with OpenAI V1. And, some OpenAI clients implicitly assume the URL ends with "v1", so this gives maximum compatibility. The openai models endpoint is implemented in the routing layer, and just returns all the models Llama Stack knows about. The chat endpoints are only actually implemented for the remote-vllm provider right now, and it just proxies the completion and chat completion requests to the backend vLLM. The goal to support this for every inference provider - proxying directly to the provider's OpenAI endpoint for OpenAI-compatible providers. For providers that don't have an OpenAI-compatible API, we'll add a mixin to translate incoming OpenAI requests to Llama Stack inference requests and translate the Llama Stack inference responses to OpenAI responses. --- llama_stack/apis/inference/inference.py | 57 +++++++++ llama_stack/apis/models/models.py | 8 ++ llama_stack/distribution/routers/routers.py | 120 ++++++++++++++++++ .../distribution/routers/routing_tables.py | 17 ++- .../inference/meta_reference/inference.py | 6 + .../sentence_transformers.py | 56 +++++++- .../remote/inference/ollama/ollama.py | 9 +- .../providers/remote/inference/vllm/vllm.py | 109 +++++++++++++++- .../utils/inference/openai_compat.py | 58 ++++++++- pyproject.toml | 1 + requirements.txt | 2 + uv.lock | 8 +- 12 files changed, 443 insertions(+), 8 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index e59132e330..864bef2d5d 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -17,6 +17,9 @@ runtime_checkable, ) +from openai.types.chat import ChatCompletion as OpenAIChatCompletion +from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam +from openai.types.completion import Completion as OpenAICompletion from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated @@ -564,3 +567,57 @@ async def embeddings( :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} """ ... + + @webmethod(route="/openai/v1/completions", method="POST") + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + """Generate an OpenAI-compatible completion for the given prompt using the specified model.""" + ... + + @webmethod(route="/openai/v1/chat/completions", method="POST") + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIChatCompletionMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + """Generate an OpenAI-compatible chat completion for the given messages using the specified model.""" + ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 893ebc179e..e48add8823 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable +from openai.types.model import Model as OpenAIModel from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.resource import Resource, ResourceType @@ -56,12 +57,19 @@ class ListModelsResponse(BaseModel): data: List[Model] +class OpenAIListModelsResponse(BaseModel): + data: List[OpenAIModel] + + @runtime_checkable @trace_protocol class Models(Protocol): @webmethod(route="/models", method="GET") async def list_models(self) -> ListModelsResponse: ... + @webmethod(route="/openai/v1/models", method="GET") + async def openai_list_models(self) -> OpenAIListModelsResponse: ... + @webmethod(route="/models/{model_id:path}", method="GET") async def get_model( self, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index eed96a40a7..146ac50212 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -7,6 +7,10 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from openai.types.chat import ChatCompletion as OpenAIChatCompletion +from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam +from openai.types.completion import Completion as OpenAICompletion + from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -419,6 +423,122 @@ async def embeddings( task_type=task_type, ) + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + logger.debug( + f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", + ) + model_obj = await self.routing_table.get_model(model) + if model_obj is None: + raise ValueError(f"Model '{model}' not found") + if model_obj.model_type == ModelType.embedding: + raise ValueError(f"Model '{model}' is an embedding model and does not support completions") + + params = dict( + model=model_obj.identifier, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + + provider = self.routing_table.get_provider_impl(model_obj.identifier) + return await provider.openai_completion(**params) + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIChatCompletionMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + logger.debug( + f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", + ) + model_obj = await self.routing_table.get_model(model) + if model_obj is None: + raise ValueError(f"Model '{model}' not found") + if model_obj.model_type == ModelType.embedding: + raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") + + params = dict( + model=model_obj.identifier, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + + provider = self.routing_table.get_provider_impl(model_obj.identifier) + return await provider.openai_chat_completion(**params) + class SafetyRouter(Safety): def __init__( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index f6adae49d3..5ec90864ea 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -5,9 +5,11 @@ # the root directory of this source tree. import logging +import time import uuid from typing import Any, Dict, List, Optional +from openai.types.model import Model as OpenAIModel from pydantic import TypeAdapter from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse @@ -23,7 +25,7 @@ RowsDataSource, URIDataSource, ) -from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType +from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ( ListScoringFunctionsResponse, @@ -254,6 +256,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> ListModelsResponse: return ListModelsResponse(data=await self.get_all_with_type("model")) + async def openai_list_models(self) -> OpenAIListModelsResponse: + models = await self.get_all_with_type("model") + openai_models = [ + OpenAIModel( + id=model.identifier, + object="model", + created=int(time.time()), + owned_by="llama_stack", + ) + for model in models + ] + return OpenAIListModelsResponse(data=openai_models) + async def get_model(self, model_id: str) -> Model: model = await self.get_object_by_identifier("model", model_id) if model is None: diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 5f81d64215..3a7632065b 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -54,6 +54,10 @@ ModelRegistryHelper, build_hf_repo_model_entry, ) +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, +) from llama_stack.providers.utils.inference.prompt_adapter import ( augment_content_with_response_format_prompt, chat_completion_request_to_messages, @@ -79,6 +83,8 @@ def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama class MetaReferenceInferenceImpl( + OpenAICompletionUnsupportedMixin, + OpenAIChatCompletionUnsupportedMixin, SentenceTransformerEmbeddingMixin, Inference, ModelsProtocolPrivate, diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 39847e085f..26a34064d2 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -5,7 +5,11 @@ # the root directory of this source tree. import logging -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Union + +from openai.types.chat import ChatCompletion as OpenAIChatCompletion +from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam +from openai.types.completion import Completion as OpenAICompletion from llama_stack.apis.inference import ( CompletionResponse, @@ -74,3 +78,53 @@ async def chat_completion( tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") + + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + raise ValueError("Sentence transformers don't support openai completion") + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIChatCompletionMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + raise ValueError("Sentence transformers don't support openai chat completion") diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 12902996b1..944493b6d3 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -45,8 +45,10 @@ ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + OpenAICompletionUnsupportedMixin, get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, @@ -67,7 +69,12 @@ logger = get_logger(name=__name__, category="inference") -class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): +class OllamaInferenceAdapter( + OpenAICompletionUnsupportedMixin, + OpenAIChatCompletionUnsupportedMixin, + Inference, + ModelsProtocolPrivate, +): def __init__(self, url: str) -> None: self.register_helper = ModelRegistryHelper(model_entries) self.url = url diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 6a828322f5..18e6a19722 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -5,13 +5,16 @@ # the root directory of this source tree. import json import logging -from typing import Any, AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Union import httpx from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion as OpenAIChatCompletion +from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk as OpenAIChatCompletionChunk, ) +from openai.types.completion import Completion as OpenAICompletion from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -418,3 +421,107 @@ async def embeddings( embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + model_obj = await self._get_model(model) + params = { + k: v + for k, v in { + "model": model_obj.provider_resource_id, + "prompt": prompt, + "best_of": best_of, + "echo": echo, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_tokens": max_tokens, + "n": n, + "presence_penalty": presence_penalty, + "seed": seed, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "top_p": top_p, + "user": user, + }.items() + if v is not None + } + return await self.client.completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIChatCompletionMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + model_obj = await self._get_model(model) + params = { + k: v + for k, v in { + "model": model_obj.provider_resource_id, + "messages": messages, + "frequency_penalty": frequency_penalty, + "function_call": function_call, + "functions": functions, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_completion_tokens": max_completion_tokens, + "max_tokens": max_tokens, + "n": n, + "parallel_tool_calls": parallel_tool_calls, + "presence_penalty": presence_penalty, + "response_format": response_format, + "seed": seed, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_logprobs": top_logprobs, + "top_p": top_p, + "user": user, + }.items() + if v is not None + } + return await self.client.chat.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 0f3945b34c..3f1846b768 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -6,9 +6,10 @@ import json import logging import warnings -from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union from openai import AsyncStream +from openai.types.chat import ChatCompletion as OpenAIChatCompletion from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ) @@ -54,6 +55,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import ( Function as OpenAIFunction, ) +from openai.types.completion import Completion as OpenAICompletion from pydantic import BaseModel from llama_stack.apis.common.content_types import ( @@ -1049,3 +1051,57 @@ async def convert_openai_chat_completion_stream( stop_reason=stop_reason, ) ) + + +class OpenAICompletionUnsupportedMixin: + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + raise ValueError(f"{self.__class__.__name__} doesn't support openai completion") + + +class OpenAIChatCompletionUnsupportedMixin: + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIChatCompletionMessage], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion") diff --git a/pyproject.toml b/pyproject.toml index 83260b681f..9ef3abe68f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "jinja2>=3.1.6", "jsonschema", "llama-stack-client>=0.2.1", + "openai>=1.66", "prompt-toolkit", "python-dotenv", "pydantic>=2", diff --git a/requirements.txt b/requirements.txt index 6645e4e368..ef5782905f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ httpx==0.28.1 huggingface-hub==0.29.0 idna==3.10 jinja2==3.1.6 +jiter==0.8.2 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 llama-stack-client==0.2.1 @@ -27,6 +28,7 @@ markdown-it-py==3.0.0 markupsafe==3.0.2 mdurl==0.1.2 numpy==2.2.3 +openai==1.71.0 packaging==24.2 pandas==2.2.3 pillow==11.1.0 diff --git a/uv.lock b/uv.lock index 1f7adea828..c6c9b10049 100644 --- a/uv.lock +++ b/uv.lock @@ -1384,6 +1384,7 @@ dependencies = [ { name = "jinja2" }, { name = "jsonschema" }, { name = "llama-stack-client" }, + { name = "openai" }, { name = "pillow" }, { name = "prompt-toolkit" }, { name = "pydantic" }, @@ -1485,6 +1486,7 @@ requires-dist = [ { name = "mcp", marker = "extra == 'test'" }, { name = "myst-parser", marker = "extra == 'docs'" }, { name = "nbval", marker = "extra == 'dev'" }, + { name = "openai", specifier = ">=1.66" }, { name = "openai", marker = "extra == 'test'" }, { name = "openai", marker = "extra == 'unit'" }, { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" }, @@ -2016,7 +2018,7 @@ wheels = [ [[package]] name = "openai" -version = "1.63.2" +version = "1.71.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -2028,9 +2030,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e6/1c/11b520deb71f9ea54ced3c52cd6a5f7131215deba63ad07f23982e328141/openai-1.63.2.tar.gz", hash = "sha256:aeabeec984a7d2957b4928ceaa339e2ead19c61cfcf35ae62b7c363368d26360", size = 356902 } +sdist = { url = "https://files.pythonhosted.org/packages/d9/19/b8f0347090a649dce55a008ec54ac6abb50553a06508cdb5e7abb2813e99/openai-1.71.0.tar.gz", hash = "sha256:52b20bb990a1780f9b0b8ccebac93416343ebd3e4e714e3eff730336833ca207", size = 409926 } wheels = [ - { url = "https://files.pythonhosted.org/packages/15/64/db3462b358072387b8e93e6e6a38d3c741a17b4a84171ef01d6c85c63f25/openai-1.63.2-py3-none-any.whl", hash = "sha256:1f38b27b5a40814c2b7d8759ec78110df58c4a614c25f182809ca52b080ff4d4", size = 472282 }, + { url = "https://files.pythonhosted.org/packages/c4/f7/049e85faf6a000890e5ca0edca8e9183f8a43c9e7bba869cad871da0caba/openai-1.71.0-py3-none-any.whl", hash = "sha256:e1c643738f1fff1af52bce6ef06a7716c95d089281e7011777179614f32937aa", size = 598975 }, ] [[package]] From 92fdf6d0c965b684619e7481c6ffd65c6d9e486f Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 8 Apr 2025 09:01:35 -0400 Subject: [PATCH 02/19] Use our own pydantic models for OpenAI Server APIs Importing the models from the OpenAI client library required a top-level dependency on the openai python package, and also was incompatible with our API generation code due to some quirks in how the OpenAI pydantic models are defined. So, this creates our own stubs of those pydantic models so that we're in more direct control of our API surface for this OpenAI-compatible API, so that it works with our code generation, and so that the openai python client isn't a hard requirement of Llama Stack's API. --- docs/_static/llama-stack-spec.html | 898 ++++++++++++++++++ docs/_static/llama-stack-spec.yaml | 647 +++++++++++++ llama_stack/apis/inference/inference.py | 264 ++++- llama_stack/apis/models/models.py | 17 +- llama_stack/distribution/routers/routers.py | 4 +- .../distribution/routers/routing_tables.py | 3 +- .../sentence_transformers.py | 4 +- .../providers/remote/inference/vllm/vllm.py | 4 +- 8 files changed, 1826 insertions(+), 15 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 567110829a..e92deaa41f 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3092,6 +3092,125 @@ } } }, + "/v1/openai/v1/chat/completions": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAIChatCompletion" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "Generate an OpenAI-compatible chat completion for the given messages using the specified model.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenaiChatCompletionRequest" + } + } + }, + "required": true + } + } + }, + "/v1/openai/v1/completions": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAICompletion" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "Generate an OpenAI-compatible completion for the given prompt using the specified model.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenaiCompletionRequest" + } + } + }, + "required": true + } + } + }, + "/v1/openai/v1/models": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAIListModelsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Models" + ], + "description": "", + "parameters": [] + } + }, "/v1/post-training/preference-optimize": { "post": { "responses": { @@ -8713,6 +8832,785 @@ ], "title": "LogEventRequest" }, + "OpenAIAssistantMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "assistant", + "default": "assistant", + "description": "Must be \"assistant\" to identify this as the model's response" + }, + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "The content of the model's response" + }, + "name": { + "type": "string", + "description": "(Optional) The name of the assistant message participant." + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + }, + "description": "List of tool calls. Each tool call is a ToolCall object." + } + }, + "additionalProperties": false, + "required": [ + "role", + "content" + ], + "title": "OpenAIAssistantMessageParam", + "description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request." + }, + "OpenAIDeveloperMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "developer", + "default": "developer", + "description": "Must be \"developer\" to identify this as a developer message" + }, + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "The content of the developer message" + }, + "name": { + "type": "string", + "description": "(Optional) The name of the developer message participant." + } + }, + "additionalProperties": false, + "required": [ + "role", + "content" + ], + "title": "OpenAIDeveloperMessageParam", + "description": "A message from the developer in an OpenAI-compatible chat completion request." + }, + "OpenAIMessageParam": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIUserMessageParam" + }, + { + "$ref": "#/components/schemas/OpenAISystemMessageParam" + }, + { + "$ref": "#/components/schemas/OpenAIAssistantMessageParam" + }, + { + "$ref": "#/components/schemas/OpenAIToolMessageParam" + }, + { + "$ref": "#/components/schemas/OpenAIDeveloperMessageParam" + } + ], + "discriminator": { + "propertyName": "role", + "mapping": { + "user": "#/components/schemas/OpenAIUserMessageParam", + "system": "#/components/schemas/OpenAISystemMessageParam", + "assistant": "#/components/schemas/OpenAIAssistantMessageParam", + "tool": "#/components/schemas/OpenAIToolMessageParam", + "developer": "#/components/schemas/OpenAIDeveloperMessageParam" + } + } + }, + "OpenAISystemMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "system", + "default": "system", + "description": "Must be \"system\" to identify this as a system message" + }, + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "The content of the \"system prompt\". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions)." + }, + "name": { + "type": "string", + "description": "(Optional) The name of the system message participant." + } + }, + "additionalProperties": false, + "required": [ + "role", + "content" + ], + "title": "OpenAISystemMessageParam", + "description": "A system message providing instructions or context to the model." + }, + "OpenAIToolMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "tool", + "default": "tool", + "description": "Must be \"tool\" to identify this as a tool response" + }, + "tool_call_id": { + "type": "string", + "description": "Unique identifier for the tool call this response is for" + }, + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "The response content from the tool" + } + }, + "additionalProperties": false, + "required": [ + "role", + "tool_call_id", + "content" + ], + "title": "OpenAIToolMessageParam", + "description": "A message representing the result of a tool invocation in an OpenAI-compatible chat completion request." + }, + "OpenAIUserMessageParam": { + "type": "object", + "properties": { + "role": { + "type": "string", + "const": "user", + "default": "user", + "description": "Must be \"user\" to identify this as a user message" + }, + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "The content of the message, which can include text and other media" + }, + "name": { + "type": "string", + "description": "(Optional) The name of the user message participant." + } + }, + "additionalProperties": false, + "required": [ + "role", + "content" + ], + "title": "OpenAIUserMessageParam", + "description": "A message from the user in an OpenAI-compatible chat completion request." + }, + "OpenaiChatCompletionRequest": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIMessageParam" + }, + "description": "List of messages in the conversation" + }, + "frequency_penalty": { + "type": "number", + "description": "(Optional) The penalty for repeated tokens" + }, + "function_call": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + ], + "description": "(Optional) The function call to use" + }, + "functions": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "description": "(Optional) List of functions to use" + }, + "logit_bias": { + "type": "object", + "additionalProperties": { + "type": "number" + }, + "description": "(Optional) The logit bias to use" + }, + "logprobs": { + "type": "boolean", + "description": "(Optional) The log probabilities to use" + }, + "max_completion_tokens": { + "type": "integer", + "description": "(Optional) The maximum number of tokens to generate" + }, + "max_tokens": { + "type": "integer", + "description": "(Optional) The maximum number of tokens to generate" + }, + "n": { + "type": "integer", + "description": "(Optional) The number of completions to generate" + }, + "parallel_tool_calls": { + "type": "boolean", + "description": "(Optional) Whether to parallelize tool calls" + }, + "presence_penalty": { + "type": "number", + "description": "(Optional) The penalty for repeated tokens" + }, + "response_format": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "description": "(Optional) The response format to use" + }, + "seed": { + "type": "integer", + "description": "(Optional) The seed to use" + }, + "stop": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ], + "description": "(Optional) The stop tokens to use" + }, + "stream": { + "type": "boolean", + "description": "(Optional) Whether to stream the response" + }, + "stream_options": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + }, + "description": "(Optional) The stream options to use" + }, + "temperature": { + "type": "number", + "description": "(Optional) The temperature to use" + }, + "tool_choice": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + ], + "description": "(Optional) The tool choice to use" + }, + "tools": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "description": "(Optional) The tools to use" + }, + "top_logprobs": { + "type": "integer", + "description": "(Optional) The top log probabilities to use" + }, + "top_p": { + "type": "number", + "description": "(Optional) The top p to use" + }, + "user": { + "type": "string", + "description": "(Optional) The user to use" + } + }, + "additionalProperties": false, + "required": [ + "model", + "messages" + ], + "title": "OpenaiChatCompletionRequest" + }, + "OpenAIChatCompletion": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The ID of the chat completion" + }, + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIChoice" + }, + "description": "List of choices" + }, + "object": { + "type": "string", + "const": "chat.completion", + "default": "chat.completion", + "description": "The object type, which will be \"chat.completion\"" + }, + "created": { + "type": "integer", + "description": "The Unix timestamp in seconds when the chat completion was created" + }, + "model": { + "type": "string", + "description": "The model that was used to generate the chat completion" + } + }, + "additionalProperties": false, + "required": [ + "id", + "choices", + "object", + "created", + "model" + ], + "title": "OpenAIChatCompletion", + "description": "Response from an OpenAI-compatible chat completion request." + }, + "OpenAIChoice": { + "type": "object", + "properties": { + "message": { + "$ref": "#/components/schemas/OpenAIMessageParam", + "description": "The message from the model" + }, + "finish_reason": { + "type": "string", + "description": "The reason the model stopped generating" + }, + "index": { + "type": "integer" + }, + "logprobs": { + "$ref": "#/components/schemas/OpenAIChoiceLogprobs" + } + }, + "additionalProperties": false, + "required": [ + "message", + "finish_reason", + "index" + ], + "title": "OpenAIChoice", + "description": "A choice from an OpenAI-compatible chat completion response." + }, + "OpenAIChoiceLogprobs": { + "type": "object", + "properties": { + "content": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAITokenLogProb" + } + }, + "refusal": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAITokenLogProb" + } + } + }, + "additionalProperties": false, + "title": "OpenAIChoiceLogprobs", + "description": "The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response." + }, + "OpenAITokenLogProb": { + "type": "object", + "properties": { + "token": { + "type": "string" + }, + "bytes": { + "type": "array", + "items": { + "type": "integer" + } + }, + "logprob": { + "type": "number" + }, + "top_logprobs": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAITopLogProb" + } + } + }, + "additionalProperties": false, + "required": [ + "token", + "logprob", + "top_logprobs" + ], + "title": "OpenAITokenLogProb", + "description": "The log probability for a token from an OpenAI-compatible chat completion response." + }, + "OpenAITopLogProb": { + "type": "object", + "properties": { + "token": { + "type": "string" + }, + "bytes": { + "type": "array", + "items": { + "type": "integer" + } + }, + "logprob": { + "type": "number" + } + }, + "additionalProperties": false, + "required": [ + "token", + "logprob" + ], + "title": "OpenAITopLogProb", + "description": "The top log probability for a token from an OpenAI-compatible chat completion response." + }, + "OpenaiCompletionRequest": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." + }, + "prompt": { + "type": "string", + "description": "The prompt to generate a completion for" + }, + "best_of": { + "type": "integer", + "description": "(Optional) The number of completions to generate" + }, + "echo": { + "type": "boolean", + "description": "(Optional) Whether to echo the prompt" + }, + "frequency_penalty": { + "type": "number", + "description": "(Optional) The penalty for repeated tokens" + }, + "logit_bias": { + "type": "object", + "additionalProperties": { + "type": "number" + }, + "description": "(Optional) The logit bias to use" + }, + "logprobs": { + "type": "boolean", + "description": "(Optional) The log probabilities to use" + }, + "max_tokens": { + "type": "integer", + "description": "(Optional) The maximum number of tokens to generate" + }, + "n": { + "type": "integer", + "description": "(Optional) The number of completions to generate" + }, + "presence_penalty": { + "type": "number", + "description": "(Optional) The penalty for repeated tokens" + }, + "seed": { + "type": "integer", + "description": "(Optional) The seed to use" + }, + "stop": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ], + "description": "(Optional) The stop tokens to use" + }, + "stream": { + "type": "boolean", + "description": "(Optional) Whether to stream the response" + }, + "stream_options": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + }, + "description": "(Optional) The stream options to use" + }, + "temperature": { + "type": "number", + "description": "(Optional) The temperature to use" + }, + "top_p": { + "type": "number", + "description": "(Optional) The top p to use" + }, + "user": { + "type": "string", + "description": "(Optional) The user to use" + } + }, + "additionalProperties": false, + "required": [ + "model", + "prompt" + ], + "title": "OpenaiCompletionRequest" + }, + "OpenAICompletion": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAICompletionChoice" + } + }, + "created": { + "type": "integer" + }, + "model": { + "type": "string" + }, + "object": { + "type": "string", + "const": "text_completion", + "default": "text_completion" + } + }, + "additionalProperties": false, + "required": [ + "id", + "choices", + "created", + "model", + "object" + ], + "title": "OpenAICompletion", + "description": "Response from an OpenAI-compatible completion request." + }, + "OpenAICompletionChoice": { + "type": "object", + "properties": { + "finish_reason": { + "type": "string" + }, + "text": { + "type": "string" + }, + "index": { + "type": "integer" + }, + "logprobs": { + "$ref": "#/components/schemas/OpenAIChoiceLogprobs" + } + }, + "additionalProperties": false, + "required": [ + "finish_reason", + "text", + "index" + ], + "title": "OpenAICompletionChoice", + "description": "A choice from an OpenAI-compatible completion response." + }, + "OpenAIModel": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "object": { + "type": "string", + "const": "model", + "default": "model" + }, + "created": { + "type": "integer" + }, + "owned_by": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "id", + "object", + "created", + "owned_by" + ], + "title": "OpenAIModel", + "description": "A model from OpenAI." + }, + "OpenAIListModelsResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIModel" + } + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "OpenAIListModelsResponse" + }, "DPOAlignmentConfig": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 1dfd17f55c..f0c5d1a799 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2131,6 +2131,91 @@ paths: schema: $ref: '#/components/schemas/LogEventRequest' required: true + /v1/openai/v1/chat/completions: + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/OpenAIChatCompletion' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: >- + Generate an OpenAI-compatible chat completion for the given messages using + the specified model. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/OpenaiChatCompletionRequest' + required: true + /v1/openai/v1/completions: + post: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/OpenAICompletion' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: >- + Generate an OpenAI-compatible completion for the given prompt using the specified + model. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/OpenaiCompletionRequest' + required: true + /v1/openai/v1/models: + get: + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/OpenAIListModelsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Models + description: '' + parameters: [] /v1/post-training/preference-optimize: post: responses: @@ -5980,6 +6065,568 @@ components: - event - ttl_seconds title: LogEventRequest + OpenAIAssistantMessageParam: + type: object + properties: + role: + type: string + const: assistant + default: assistant + description: >- + Must be "assistant" to identify this as the model's response + content: + $ref: '#/components/schemas/InterleavedContent' + description: The content of the model's response + name: + type: string + description: >- + (Optional) The name of the assistant message participant. + tool_calls: + type: array + items: + $ref: '#/components/schemas/ToolCall' + description: >- + List of tool calls. Each tool call is a ToolCall object. + additionalProperties: false + required: + - role + - content + title: OpenAIAssistantMessageParam + description: >- + A message containing the model's (assistant) response in an OpenAI-compatible + chat completion request. + OpenAIDeveloperMessageParam: + type: object + properties: + role: + type: string + const: developer + default: developer + description: >- + Must be "developer" to identify this as a developer message + content: + $ref: '#/components/schemas/InterleavedContent' + description: The content of the developer message + name: + type: string + description: >- + (Optional) The name of the developer message participant. + additionalProperties: false + required: + - role + - content + title: OpenAIDeveloperMessageParam + description: >- + A message from the developer in an OpenAI-compatible chat completion request. + OpenAIMessageParam: + oneOf: + - $ref: '#/components/schemas/OpenAIUserMessageParam' + - $ref: '#/components/schemas/OpenAISystemMessageParam' + - $ref: '#/components/schemas/OpenAIAssistantMessageParam' + - $ref: '#/components/schemas/OpenAIToolMessageParam' + - $ref: '#/components/schemas/OpenAIDeveloperMessageParam' + discriminator: + propertyName: role + mapping: + user: '#/components/schemas/OpenAIUserMessageParam' + system: '#/components/schemas/OpenAISystemMessageParam' + assistant: '#/components/schemas/OpenAIAssistantMessageParam' + tool: '#/components/schemas/OpenAIToolMessageParam' + developer: '#/components/schemas/OpenAIDeveloperMessageParam' + OpenAISystemMessageParam: + type: object + properties: + role: + type: string + const: system + default: system + description: >- + Must be "system" to identify this as a system message + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + The content of the "system prompt". If multiple system messages are provided, + they are concatenated. The underlying Llama Stack code may also add other + system messages (for example, for formatting tool definitions). + name: + type: string + description: >- + (Optional) The name of the system message participant. + additionalProperties: false + required: + - role + - content + title: OpenAISystemMessageParam + description: >- + A system message providing instructions or context to the model. + OpenAIToolMessageParam: + type: object + properties: + role: + type: string + const: tool + default: tool + description: >- + Must be "tool" to identify this as a tool response + tool_call_id: + type: string + description: >- + Unique identifier for the tool call this response is for + content: + $ref: '#/components/schemas/InterleavedContent' + description: The response content from the tool + additionalProperties: false + required: + - role + - tool_call_id + - content + title: OpenAIToolMessageParam + description: >- + A message representing the result of a tool invocation in an OpenAI-compatible + chat completion request. + OpenAIUserMessageParam: + type: object + properties: + role: + type: string + const: user + default: user + description: >- + Must be "user" to identify this as a user message + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + The content of the message, which can include text and other media + name: + type: string + description: >- + (Optional) The name of the user message participant. + additionalProperties: false + required: + - role + - content + title: OpenAIUserMessageParam + description: >- + A message from the user in an OpenAI-compatible chat completion request. + OpenaiChatCompletionRequest: + type: object + properties: + model: + type: string + description: >- + The identifier of the model to use. The model must be registered with + Llama Stack and available via the /models endpoint. + messages: + type: array + items: + $ref: '#/components/schemas/OpenAIMessageParam' + description: List of messages in the conversation + frequency_penalty: + type: number + description: >- + (Optional) The penalty for repeated tokens + function_call: + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: (Optional) The function call to use + functions: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: (Optional) List of functions to use + logit_bias: + type: object + additionalProperties: + type: number + description: (Optional) The logit bias to use + logprobs: + type: boolean + description: (Optional) The log probabilities to use + max_completion_tokens: + type: integer + description: >- + (Optional) The maximum number of tokens to generate + max_tokens: + type: integer + description: >- + (Optional) The maximum number of tokens to generate + n: + type: integer + description: >- + (Optional) The number of completions to generate + parallel_tool_calls: + type: boolean + description: >- + (Optional) Whether to parallelize tool calls + presence_penalty: + type: number + description: >- + (Optional) The penalty for repeated tokens + response_format: + type: object + additionalProperties: + type: string + description: (Optional) The response format to use + seed: + type: integer + description: (Optional) The seed to use + stop: + oneOf: + - type: string + - type: array + items: + type: string + description: (Optional) The stop tokens to use + stream: + type: boolean + description: >- + (Optional) Whether to stream the response + stream_options: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: (Optional) The stream options to use + temperature: + type: number + description: (Optional) The temperature to use + tool_choice: + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: (Optional) The tool choice to use + tools: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: (Optional) The tools to use + top_logprobs: + type: integer + description: >- + (Optional) The top log probabilities to use + top_p: + type: number + description: (Optional) The top p to use + user: + type: string + description: (Optional) The user to use + additionalProperties: false + required: + - model + - messages + title: OpenaiChatCompletionRequest + OpenAIChatCompletion: + type: object + properties: + id: + type: string + description: The ID of the chat completion + choices: + type: array + items: + $ref: '#/components/schemas/OpenAIChoice' + description: List of choices + object: + type: string + const: chat.completion + default: chat.completion + description: >- + The object type, which will be "chat.completion" + created: + type: integer + description: >- + The Unix timestamp in seconds when the chat completion was created + model: + type: string + description: >- + The model that was used to generate the chat completion + additionalProperties: false + required: + - id + - choices + - object + - created + - model + title: OpenAIChatCompletion + description: >- + Response from an OpenAI-compatible chat completion request. + OpenAIChoice: + type: object + properties: + message: + $ref: '#/components/schemas/OpenAIMessageParam' + description: The message from the model + finish_reason: + type: string + description: The reason the model stopped generating + index: + type: integer + logprobs: + $ref: '#/components/schemas/OpenAIChoiceLogprobs' + additionalProperties: false + required: + - message + - finish_reason + - index + title: OpenAIChoice + description: >- + A choice from an OpenAI-compatible chat completion response. + OpenAIChoiceLogprobs: + type: object + properties: + content: + type: array + items: + $ref: '#/components/schemas/OpenAITokenLogProb' + refusal: + type: array + items: + $ref: '#/components/schemas/OpenAITokenLogProb' + additionalProperties: false + title: OpenAIChoiceLogprobs + description: >- + The log probabilities for the tokens in the message from an OpenAI-compatible + chat completion response. + OpenAITokenLogProb: + type: object + properties: + token: + type: string + bytes: + type: array + items: + type: integer + logprob: + type: number + top_logprobs: + type: array + items: + $ref: '#/components/schemas/OpenAITopLogProb' + additionalProperties: false + required: + - token + - logprob + - top_logprobs + title: OpenAITokenLogProb + description: >- + The log probability for a token from an OpenAI-compatible chat completion + response. + OpenAITopLogProb: + type: object + properties: + token: + type: string + bytes: + type: array + items: + type: integer + logprob: + type: number + additionalProperties: false + required: + - token + - logprob + title: OpenAITopLogProb + description: >- + The top log probability for a token from an OpenAI-compatible chat completion + response. + OpenaiCompletionRequest: + type: object + properties: + model: + type: string + description: >- + The identifier of the model to use. The model must be registered with + Llama Stack and available via the /models endpoint. + prompt: + type: string + description: The prompt to generate a completion for + best_of: + type: integer + description: >- + (Optional) The number of completions to generate + echo: + type: boolean + description: (Optional) Whether to echo the prompt + frequency_penalty: + type: number + description: >- + (Optional) The penalty for repeated tokens + logit_bias: + type: object + additionalProperties: + type: number + description: (Optional) The logit bias to use + logprobs: + type: boolean + description: (Optional) The log probabilities to use + max_tokens: + type: integer + description: >- + (Optional) The maximum number of tokens to generate + n: + type: integer + description: >- + (Optional) The number of completions to generate + presence_penalty: + type: number + description: >- + (Optional) The penalty for repeated tokens + seed: + type: integer + description: (Optional) The seed to use + stop: + oneOf: + - type: string + - type: array + items: + type: string + description: (Optional) The stop tokens to use + stream: + type: boolean + description: >- + (Optional) Whether to stream the response + stream_options: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: (Optional) The stream options to use + temperature: + type: number + description: (Optional) The temperature to use + top_p: + type: number + description: (Optional) The top p to use + user: + type: string + description: (Optional) The user to use + additionalProperties: false + required: + - model + - prompt + title: OpenaiCompletionRequest + OpenAICompletion: + type: object + properties: + id: + type: string + choices: + type: array + items: + $ref: '#/components/schemas/OpenAICompletionChoice' + created: + type: integer + model: + type: string + object: + type: string + const: text_completion + default: text_completion + additionalProperties: false + required: + - id + - choices + - created + - model + - object + title: OpenAICompletion + description: >- + Response from an OpenAI-compatible completion request. + OpenAICompletionChoice: + type: object + properties: + finish_reason: + type: string + text: + type: string + index: + type: integer + logprobs: + $ref: '#/components/schemas/OpenAIChoiceLogprobs' + additionalProperties: false + required: + - finish_reason + - text + - index + title: OpenAICompletionChoice + description: >- + A choice from an OpenAI-compatible completion response. + OpenAIModel: + type: object + properties: + id: + type: string + object: + type: string + const: model + default: model + created: + type: integer + owned_by: + type: string + additionalProperties: false + required: + - id + - object + - created + - owned_by + title: OpenAIModel + description: A model from OpenAI. + OpenAIListModelsResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/OpenAIModel' + additionalProperties: false + required: + - data + title: OpenAIListModelsResponse DPOAlignmentConfig: type: object properties: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 864bef2d5d..6271466d49 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -17,9 +17,6 @@ runtime_checkable, ) -from openai.types.chat import ChatCompletion as OpenAIChatCompletion -from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam -from openai.types.completion import Completion as OpenAICompletion from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated @@ -445,6 +442,217 @@ class EmbeddingsResponse(BaseModel): embeddings: List[List[float]] +@json_schema_type +class OpenAIUserMessageParam(BaseModel): + """A message from the user in an OpenAI-compatible chat completion request. + + :param role: Must be "user" to identify this as a user message + :param content: The content of the message, which can include text and other media + :param name: (Optional) The name of the user message participant. + """ + + role: Literal["user"] = "user" + content: InterleavedContent + name: Optional[str] = None + + +@json_schema_type +class OpenAISystemMessageParam(BaseModel): + """A system message providing instructions or context to the model. + + :param role: Must be "system" to identify this as a system message + :param content: The content of the "system prompt". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions). + :param name: (Optional) The name of the system message participant. + """ + + role: Literal["system"] = "system" + content: InterleavedContent + name: Optional[str] = None + + +@json_schema_type +class OpenAIAssistantMessageParam(BaseModel): + """A message containing the model's (assistant) response in an OpenAI-compatible chat completion request. + + :param role: Must be "assistant" to identify this as the model's response + :param content: The content of the model's response + :param name: (Optional) The name of the assistant message participant. + :param tool_calls: List of tool calls. Each tool call is a ToolCall object. + """ + + role: Literal["assistant"] = "assistant" + content: InterleavedContent + name: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) + + +@json_schema_type +class OpenAIToolMessageParam(BaseModel): + """A message representing the result of a tool invocation in an OpenAI-compatible chat completion request. + + :param role: Must be "tool" to identify this as a tool response + :param tool_call_id: Unique identifier for the tool call this response is for + :param content: The response content from the tool + """ + + role: Literal["tool"] = "tool" + tool_call_id: str + content: InterleavedContent + + +@json_schema_type +class OpenAIDeveloperMessageParam(BaseModel): + """A message from the developer in an OpenAI-compatible chat completion request. + + :param role: Must be "developer" to identify this as a developer message + :param content: The content of the developer message + :param name: (Optional) The name of the developer message participant. + """ + + role: Literal["developer"] = "developer" + content: InterleavedContent + name: Optional[str] = None + + +OpenAIMessageParam = Annotated[ + Union[ + OpenAIUserMessageParam, + OpenAISystemMessageParam, + OpenAIAssistantMessageParam, + OpenAIToolMessageParam, + OpenAIDeveloperMessageParam, + ], + Field(discriminator="role"), +] +register_schema(OpenAIMessageParam, name="OpenAIMessageParam") + + +@json_schema_type +class OpenAITopLogProb(BaseModel): + """The top log probability for a token from an OpenAI-compatible chat completion response. + + :token: The token + :bytes: (Optional) The bytes for the token + :logprob: The log probability of the token + """ + + token: str + bytes: Optional[List[int]] = None + logprob: float + + +@json_schema_type +class OpenAITokenLogProb(BaseModel): + """The log probability for a token from an OpenAI-compatible chat completion response. + + :token: The token + :bytes: (Optional) The bytes for the token + :logprob: The log probability of the token + :top_logprobs: The top log probabilities for the token + """ + + token: str + bytes: Optional[List[int]] = None + logprob: float + top_logprobs: List[OpenAITopLogProb] + + +@json_schema_type +class OpenAIChoiceLogprobs(BaseModel): + """The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response. + + :content: (Optional) The log probabilities for the tokens in the message + :refusal: (Optional) The log probabilities for the tokens in the message + """ + + content: Optional[List[OpenAITokenLogProb]] = None + refusal: Optional[List[OpenAITokenLogProb]] = None + + +@json_schema_type +class OpenAIChoice(BaseModel): + """A choice from an OpenAI-compatible chat completion response. + + :param message: The message from the model + :param finish_reason: The reason the model stopped generating + :index: The index of the choice + :logprobs: (Optional) The log probabilities for the tokens in the message + """ + + message: OpenAIMessageParam + finish_reason: str + index: int + logprobs: Optional[OpenAIChoiceLogprobs] = None + + +@json_schema_type +class OpenAIChatCompletion(BaseModel): + """Response from an OpenAI-compatible chat completion request. + + :param id: The ID of the chat completion + :param choices: List of choices + :param object: The object type, which will be "chat.completion" + :param created: The Unix timestamp in seconds when the chat completion was created + :param model: The model that was used to generate the chat completion + """ + + id: str + choices: List[OpenAIChoice] + object: Literal["chat.completion"] = "chat.completion" + created: int + model: str + + +@json_schema_type +class OpenAICompletionLogprobs(BaseModel): + """The log probabilities for the tokens in the message from an OpenAI-compatible completion response. + + :text_offset: (Optional) The offset of the token in the text + :token_logprobs: (Optional) The log probabilities for the tokens + :tokens: (Optional) The tokens + :top_logprobs: (Optional) The top log probabilities for the tokens + """ + + text_offset: Optional[List[int]] = None + token_logprobs: Optional[List[float]] = None + tokens: Optional[List[str]] = None + top_logprobs: Optional[List[Dict[str, float]]] = None + + +@json_schema_type +class OpenAICompletionChoice(BaseModel): + """A choice from an OpenAI-compatible completion response. + + :finish_reason: The reason the model stopped generating + :text: The text of the choice + :index: The index of the choice + :logprobs: (Optional) The log probabilities for the tokens in the choice + """ + + finish_reason: str + text: str + index: int + logprobs: Optional[OpenAIChoiceLogprobs] = None + + +@json_schema_type +class OpenAICompletion(BaseModel): + """Response from an OpenAI-compatible completion request. + + :id: The ID of the completion + :choices: List of choices + :created: The Unix timestamp in seconds when the completion was created + :model: The model that was used to generate the completion + :object: The object type, which will be "text_completion" + """ + + id: str + choices: List[OpenAICompletionChoice] + created: int + model: str + object: Literal["text_completion"] = "text_completion" + + class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... @@ -589,14 +797,33 @@ async def openai_completion( top_p: Optional[float] = None, user: Optional[str] = None, ) -> OpenAICompletion: - """Generate an OpenAI-compatible completion for the given prompt using the specified model.""" + """Generate an OpenAI-compatible completion for the given prompt using the specified model. + + :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. + :param prompt: The prompt to generate a completion for + :param best_of: (Optional) The number of completions to generate + :param echo: (Optional) Whether to echo the prompt + :param frequency_penalty: (Optional) The penalty for repeated tokens + :param logit_bias: (Optional) The logit bias to use + :param logprobs: (Optional) The log probabilities to use + :param max_tokens: (Optional) The maximum number of tokens to generate + :param n: (Optional) The number of completions to generate + :param presence_penalty: (Optional) The penalty for repeated tokens + :param seed: (Optional) The seed to use + :param stop: (Optional) The stop tokens to use + :param stream: (Optional) Whether to stream the response + :param stream_options: (Optional) The stream options to use + :param temperature: (Optional) The temperature to use + :param top_p: (Optional) The top p to use + :param user: (Optional) The user to use + """ ... @webmethod(route="/openai/v1/chat/completions", method="POST") async def openai_chat_completion( self, model: str, - messages: List[OpenAIChatCompletionMessageParam], + messages: List[OpenAIMessageParam], frequency_penalty: Optional[float] = None, function_call: Optional[Union[str, Dict[str, Any]]] = None, functions: Optional[List[Dict[str, Any]]] = None, @@ -619,5 +846,30 @@ async def openai_chat_completion( top_p: Optional[float] = None, user: Optional[str] = None, ) -> OpenAIChatCompletion: - """Generate an OpenAI-compatible chat completion for the given messages using the specified model.""" + """Generate an OpenAI-compatible chat completion for the given messages using the specified model. + + :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. + :param messages: List of messages in the conversation + :param frequency_penalty: (Optional) The penalty for repeated tokens + :param function_call: (Optional) The function call to use + :param functions: (Optional) List of functions to use + :param logit_bias: (Optional) The logit bias to use + :param logprobs: (Optional) The log probabilities to use + :param max_completion_tokens: (Optional) The maximum number of tokens to generate + :param max_tokens: (Optional) The maximum number of tokens to generate + :param n: (Optional) The number of completions to generate + :param parallel_tool_calls: (Optional) Whether to parallelize tool calls + :param presence_penalty: (Optional) The penalty for repeated tokens + :param response_format: (Optional) The response format to use + :param seed: (Optional) The seed to use + :param stop: (Optional) The stop tokens to use + :param stream: (Optional) Whether to stream the response + :param stream_options: (Optional) The stream options to use + :param temperature: (Optional) The temperature to use + :param tool_choice: (Optional) The tool choice to use + :param tools: (Optional) The tools to use + :param top_logprobs: (Optional) The top log probabilities to use + :param top_p: (Optional) The top p to use + :param user: (Optional) The user to use + """ ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index e48add8823..97398ce75e 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -7,7 +7,6 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable -from openai.types.model import Model as OpenAIModel from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.resource import Resource, ResourceType @@ -57,6 +56,22 @@ class ListModelsResponse(BaseModel): data: List[Model] +@json_schema_type +class OpenAIModel(BaseModel): + """A model from OpenAI. + + :id: The ID of the model + :object: The object type, which will be "model" + :created: The Unix timestamp in seconds when the model was created + :owned_by: The owner of the model + """ + + id: str + object: Literal["model"] = "model" + created: int + owned_by: str + + class OpenAIListModelsResponse(BaseModel): data: List[OpenAIModel] diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 146ac50212..4f3e977788 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,7 +8,6 @@ from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from openai.types.chat import ChatCompletion as OpenAIChatCompletion -from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam from openai.types.completion import Completion as OpenAICompletion from llama_stack.apis.common.content_types import ( @@ -39,6 +38,7 @@ ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIMessageParam from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( @@ -478,7 +478,7 @@ async def openai_completion( async def openai_chat_completion( self, model: str, - messages: List[OpenAIChatCompletionMessageParam], + messages: List[OpenAIMessageParam], frequency_penalty: Optional[float] = None, function_call: Optional[Union[str, Dict[str, Any]]] = None, functions: Optional[List[Dict[str, Any]]] = None, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 5ec90864ea..18b0c891fe 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -9,7 +9,6 @@ import uuid from typing import Any, Dict, List, Optional -from openai.types.model import Model as OpenAIModel from pydantic import TypeAdapter from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse @@ -25,7 +24,7 @@ RowsDataSource, URIDataSource, ) -from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse +from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ( ListScoringFunctionsResponse, diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 26a34064d2..7cce2fb92f 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -8,7 +8,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union from openai.types.chat import ChatCompletion as OpenAIChatCompletion -from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam from openai.types.completion import Completion as OpenAICompletion from llama_stack.apis.inference import ( @@ -23,6 +22,7 @@ ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIMessageParam from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -104,7 +104,7 @@ async def openai_completion( async def openai_chat_completion( self, model: str, - messages: List[OpenAIChatCompletionMessageParam], + messages: List[OpenAIMessageParam], frequency_penalty: Optional[float] = None, function_call: Optional[Union[str, Dict[str, Any]]] = None, functions: Optional[List[Dict[str, Any]]] = None, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 18e6a19722..696e72a323 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -10,7 +10,6 @@ import httpx from openai import AsyncOpenAI from openai.types.chat import ChatCompletion as OpenAIChatCompletion -from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk as OpenAIChatCompletionChunk, ) @@ -48,6 +47,7 @@ ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIMessageParam from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models @@ -471,7 +471,7 @@ async def openai_completion( async def openai_chat_completion( self, model: str, - messages: List[OpenAIChatCompletionMessageParam], + messages: List[OpenAIMessageParam], frequency_penalty: Optional[float] = None, function_call: Optional[Union[str, Dict[str, Any]]] = None, functions: Optional[List[Dict[str, Any]]] = None, From 5bc5fed6df7e77c41d0c9e4725fed044fa3f12b7 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 8 Apr 2025 09:10:52 -0400 Subject: [PATCH 03/19] Clean up some more usage of direct OpenAI types --- llama_stack/distribution/routers/routers.py | 5 +- .../sentence_transformers.py | 62 +++---------------- .../providers/remote/inference/vllm/vllm.py | 4 +- .../utils/inference/openai_compat.py | 3 +- 4 files changed, 10 insertions(+), 64 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 4f3e977788..19cc8ac095 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -7,9 +7,6 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union -from openai.types.chat import ChatCompletion as OpenAIChatCompletion -from openai.types.completion import Completion as OpenAICompletion - from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -38,7 +35,7 @@ ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIMessageParam +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 7cce2fb92f..9c370b6c56 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -5,10 +5,7 @@ # the root directory of this source tree. import logging -from typing import Any, AsyncGenerator, Dict, List, Optional, Union - -from openai.types.chat import ChatCompletion as OpenAIChatCompletion -from openai.types.completion import Completion as OpenAICompletion +from typing import AsyncGenerator, List, Optional, Union from llama_stack.apis.inference import ( CompletionResponse, @@ -22,11 +19,14 @@ ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIMessageParam from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, +) from .config import SentenceTransformersInferenceConfig @@ -34,6 +34,8 @@ class SentenceTransformersInferenceImpl( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, SentenceTransformerEmbeddingMixin, Inference, ModelsProtocolPrivate, @@ -78,53 +80,3 @@ async def chat_completion( tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") - - async def openai_completion( - self, - model: str, - prompt: str, - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> OpenAICompletion: - raise ValueError("Sentence transformers don't support openai completion") - - async def openai_chat_completion( - self, - model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> OpenAIChatCompletion: - raise ValueError("Sentence transformers don't support openai chat completion") diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 696e72a323..d7555c39f7 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -9,11 +9,9 @@ import httpx from openai import AsyncOpenAI -from openai.types.chat import ChatCompletion as OpenAIChatCompletion from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk as OpenAIChatCompletionChunk, ) -from openai.types.completion import Completion as OpenAICompletion from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -47,7 +45,7 @@ ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIMessageParam +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 3f1846b768..d9091d5c8c 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -9,7 +9,6 @@ from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union from openai import AsyncStream -from openai.types.chat import ChatCompletion as OpenAIChatCompletion from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ) @@ -55,7 +54,6 @@ from openai.types.chat.chat_completion_message_tool_call_param import ( Function as OpenAIFunction, ) -from openai.types.completion import Completion as OpenAICompletion from pydantic import BaseModel from llama_stack.apis.common.content_types import ( @@ -85,6 +83,7 @@ TopPSamplingStrategy, UserMessage, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, From 1dbdff14966fd78db026a8c0e07f847ae01853c4 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 8 Apr 2025 09:29:49 -0400 Subject: [PATCH 04/19] ollama OpenAI-compatible completions and chat completions --- .../remote/inference/ollama/ollama.py | 116 +++++++++++++++++- 1 file changed, 111 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 944493b6d3..dc2c8b3f5d 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,10 +5,11 @@ # the root directory of this source tree. -from typing import Any, AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Union import httpx from ollama import AsyncClient +from openai import AsyncOpenAI from llama_stack.apis.common.content_types import ( ImageContentItem, @@ -38,6 +39,7 @@ ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -45,10 +47,8 @@ ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionUnsupportedMixin, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, - OpenAICompletionUnsupportedMixin, get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, @@ -70,8 +70,6 @@ class OllamaInferenceAdapter( - OpenAICompletionUnsupportedMixin, - OpenAIChatCompletionUnsupportedMixin, Inference, ModelsProtocolPrivate, ): @@ -83,6 +81,10 @@ def __init__(self, url: str) -> None: def client(self) -> AsyncClient: return AsyncClient(host=self.url) + @property + def openai_client(self) -> AsyncOpenAI: + return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama") + async def initialize(self) -> None: logger.info(f"checking connectivity to Ollama at `{self.url}`...") try: @@ -326,6 +328,110 @@ async def register_model(self, model: Model) -> Model: return model + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + model_obj = await self._get_model(model) + params = { + k: v + for k, v in { + "model": model_obj.provider_resource_id, + "prompt": prompt, + "best_of": best_of, + "echo": echo, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_tokens": max_tokens, + "n": n, + "presence_penalty": presence_penalty, + "seed": seed, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "top_p": top_p, + "user": user, + }.items() + if v is not None + } + return await self.openai_client.completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + model_obj = await self._get_model(model) + params = { + k: v + for k, v in { + "model": model_obj.provider_resource_id, + "messages": messages, + "frequency_penalty": frequency_penalty, + "function_call": function_call, + "functions": functions, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_completion_tokens": max_completion_tokens, + "max_tokens": max_tokens, + "n": n, + "parallel_tool_calls": parallel_tool_calls, + "presence_penalty": presence_penalty, + "response_format": response_format, + "seed": seed, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_logprobs": top_logprobs, + "top_p": top_p, + "user": user, + }.items() + if v is not None + } + return await self.openai_client.chat.completions.create(**params) # type: ignore + async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def _convert_content(content) -> dict: From 00c4493bda19a65f3936e3392abbcb100a7eb2b1 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 8 Apr 2025 12:35:16 -0400 Subject: [PATCH 05/19] OpenAI-compatible completions and chats for litellm and together This adds OpenAI-compatible completions and chat completions support for the native Together provider as well as all providers implemented with litellm. --- .../remote/inference/together/together.py | 111 +++++++++++++++++- .../providers/remote/inference/vllm/vllm.py | 97 +++++++-------- .../utils/inference/litellm_openai_mixin.py | 100 +++++++++++++++- .../utils/inference/openai_compat.py | 5 + 4 files changed, 259 insertions(+), 54 deletions(-) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index df76109355..3e43a844c4 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -4,8 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from openai import AsyncOpenAI from together import AsyncTogether from llama_stack.apis.common.content_types import ( @@ -30,12 +31,14 @@ ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, + prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -60,6 +63,7 @@ def __init__(self, config: TogetherImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self.config = config self._client = None + self._openai_client = None async def initialize(self) -> None: pass @@ -110,6 +114,15 @@ def _get_client(self) -> AsyncTogether: self._client = AsyncTogether(api_key=together_api_key) return self._client + def _get_openai_client(self) -> AsyncOpenAI: + if not self._openai_client: + together_client = self._get_client().client + self._openai_client = AsyncOpenAI( + base_url=together_client.base_url, + api_key=together_client.api_key, + ) + return self._openai_client + async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) client = self._get_client() @@ -243,3 +256,99 @@ async def embeddings( ) embeddings = [item.embedding for item in r.data] return EmbeddingsResponse(embeddings=embeddings) + + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + model_obj = await self._get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + model_obj = await self._get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().chat.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d7555c39f7..66fb986f94 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -59,6 +59,7 @@ convert_message_to_openai_dict, convert_tool_call, get_sampling_options, + prepare_openai_completion_params, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, @@ -441,29 +442,25 @@ async def openai_completion( user: Optional[str] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) - params = { - k: v - for k, v in { - "model": model_obj.provider_resource_id, - "prompt": prompt, - "best_of": best_of, - "echo": echo, - "frequency_penalty": frequency_penalty, - "logit_bias": logit_bias, - "logprobs": logprobs, - "max_tokens": max_tokens, - "n": n, - "presence_penalty": presence_penalty, - "seed": seed, - "stop": stop, - "stream": stream, - "stream_options": stream_options, - "temperature": temperature, - "top_p": top_p, - "user": user, - }.items() - if v is not None - } + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) return await self.client.completions.create(**params) # type: ignore async def openai_chat_completion( @@ -493,33 +490,29 @@ async def openai_chat_completion( user: Optional[str] = None, ) -> OpenAIChatCompletion: model_obj = await self._get_model(model) - params = { - k: v - for k, v in { - "model": model_obj.provider_resource_id, - "messages": messages, - "frequency_penalty": frequency_penalty, - "function_call": function_call, - "functions": functions, - "logit_bias": logit_bias, - "logprobs": logprobs, - "max_completion_tokens": max_completion_tokens, - "max_tokens": max_tokens, - "n": n, - "parallel_tool_calls": parallel_tool_calls, - "presence_penalty": presence_penalty, - "response_format": response_format, - "seed": seed, - "stop": stop, - "stream": stream, - "stream_options": stream_options, - "temperature": temperature, - "tool_choice": tool_choice, - "tools": tools, - "top_logprobs": top_logprobs, - "top_p": top_p, - "user": user, - }.items() - if v is not None - } + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) return await self.client.chat.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index bd1eb39782..8111e44639 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, AsyncIterator, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union import litellm @@ -30,6 +30,7 @@ ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger @@ -40,6 +41,7 @@ convert_openai_chat_completion_stream, convert_tooldef_to_openai_tool, get_sampling_options, + prepare_openai_completion_params, ) from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -245,3 +247,99 @@ async def embeddings( embeddings = [data["embedding"] for data in response["data"]] return EmbeddingsResponse(embeddings=embeddings) + + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + model_obj = await self._get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return litellm.text_completion(**params) + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + model_obj = await self._get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + return litellm.completion(**params) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index d9091d5c8c..ea02de573a 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1052,6 +1052,11 @@ async def convert_openai_chat_completion_stream( ) +async def prepare_openai_completion_params(**params): + completion_params = {k: v for k, v in params.items() if v is not None} + return completion_params + + class OpenAICompletionUnsupportedMixin: async def openai_completion( self, From 15d37fde195eadb64d304211fa2a53c4c1049394 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 8 Apr 2025 12:50:23 -0400 Subject: [PATCH 06/19] Add unsupported OpenAI mixin to all remaining inference providers --- .../providers/remote/inference/bedrock/bedrock.py | 9 ++++++++- .../providers/remote/inference/cerebras/cerebras.py | 9 ++++++++- .../providers/remote/inference/databricks/databricks.py | 9 ++++++++- llama_stack/providers/remote/inference/nvidia/nvidia.py | 9 ++++++++- llama_stack/providers/remote/inference/runpod/runpod.py | 9 ++++++++- .../providers/remote/inference/sambanova/sambanova.py | 9 ++++++++- llama_stack/providers/remote/inference/tgi/tgi.py | 9 ++++++++- 7 files changed, 56 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 120da5bd4d..0a485da8fe 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -36,8 +36,10 @@ ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + OpenAICompletionUnsupportedMixin, get_sampling_strategy_options, process_chat_completion_response, process_chat_completion_stream_response, @@ -51,7 +53,12 @@ from .models import MODEL_ENTRIES -class BedrockInferenceAdapter(ModelRegistryHelper, Inference): +class BedrockInferenceAdapter( + ModelRegistryHelper, + Inference, + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, +): def __init__(self, config: BedrockConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self._config = config diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 43d986b86e..5e0a5b484f 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -34,6 +34,8 @@ ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, @@ -49,7 +51,12 @@ from .models import MODEL_ENTRIES -class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): +class CerebrasInferenceAdapter( + ModelRegistryHelper, + Inference, + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, +): def __init__(self, config: CerebrasImplConfig) -> None: ModelRegistryHelper.__init__( self, diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 0eaf0135bb..a10878b27a 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -34,6 +34,8 @@ build_hf_repo_model_entry, ) from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, @@ -56,7 +58,12 @@ ] -class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): +class DatabricksInferenceAdapter( + ModelRegistryHelper, + Inference, + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, +): def __init__(self, config: DatabricksImplConfig) -> None: ModelRegistryHelper.__init__(self, model_entries=model_entries) self.config = config diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index e1f5d7a6a3..3ed458058c 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -40,6 +40,8 @@ ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, ) @@ -58,7 +60,12 @@ logger = logging.getLogger(__name__) -class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): +class NVIDIAInferenceAdapter( + Inference, + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, + ModelRegistryHelper, +): def __init__(self, config: NVIDIAConfig) -> None: # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 72f858cd81..878460122f 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -12,6 +12,8 @@ # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, @@ -38,7 +40,12 @@ } -class RunpodInferenceAdapter(ModelRegistryHelper, Inference): +class RunpodInferenceAdapter( + ModelRegistryHelper, + Inference, + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, +): def __init__(self, config: RunpodImplConfig) -> None: ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) self.config = config diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index a3badd468b..c503657ebe 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -42,6 +42,8 @@ ) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -52,7 +54,12 @@ from .models import MODEL_ENTRIES -class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): +class SambaNovaInferenceAdapter( + ModelRegistryHelper, + Inference, + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, +): def __init__(self, config: SambaNovaImplConfig) -> None: ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) self.config = config diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index fe99fafe18..8f5b5e3ccd 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -40,8 +40,10 @@ build_hf_repo_model_entry, ) from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + OpenAICompletionUnsupportedMixin, get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, @@ -69,7 +71,12 @@ def build_hf_repo_model_entries(): ] -class _HfAdapter(Inference, ModelsProtocolPrivate): +class _HfAdapter( + Inference, + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, + ModelsProtocolPrivate, +): client: AsyncInferenceClient max_tokens: int model_id: str From de01b1455b7b8fc2ca962fd144ed65297d84ec5e Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 08:35:36 -0400 Subject: [PATCH 07/19] Passthrough inference support for OpenAI-compatible APIs Signed-off-by: Ben Browning --- .../inference/passthrough/passthrough.py | 106 +++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 96b2d73d84..cbe6e6cae2 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional, Union from llama_stack_client import AsyncLlamaStackClient @@ -26,9 +26,11 @@ ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models import Model from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from .config import PassthroughImplConfig @@ -201,6 +203,108 @@ async def embeddings( task_type=task_type, ) + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + client = self._get_client() + model_obj = await self.model_store.get_model(model) + + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + + return await client.inference.openai_completion(**params) + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + client = self._get_client() + model_obj = await self.model_store.get_model(model) + + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + + return await client.inference.openai_chat_completion(**params) + def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]: json_params = {} for key, value in request_params.items(): From 24cfa1ef1aaab15d355f57c4baa9d591f01afcda Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 08:36:01 -0400 Subject: [PATCH 08/19] Mark inline vllm as OpenAI unsupported inference Signed-off-by: Ben Browning --- llama_stack/providers/inline/inference/vllm/vllm.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index ea2643b7a9..085c79d6ba 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -66,8 +66,10 @@ ModelsProtocolPrivate, ) from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + OpenAICompletionUnsupportedMixin, get_stop_reason, process_chat_completion_stream_response, ) @@ -172,7 +174,12 @@ def _convert_sampling_params( return vllm_sampling_params -class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): +class VLLMInferenceImpl( + Inference, + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, + ModelsProtocolPrivate, +): """ vLLM-based inference model adapter for Llama Stack with support for multiple models. From a6cf8fa12b179f54af3c68ae874cb7af79a095e3 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 09:28:50 -0400 Subject: [PATCH 09/19] OpenAI completion prompt can also be an array The OpenAI completion prompt field can be a string or an array, so update things to use and pass that properly. This also stubs in a basic conversion of OpenAI non-streaming completion requests to Llama Stack completion calls, for those providers that don't actually have an OpenAI backend to allow them to still accept requests via the OpenAI APIs. Signed-off-by: Ben Browning --- docs/_static/llama-stack-spec.html | 12 ++- docs/_static/llama-stack-spec.yaml | 6 +- llama_stack/apis/inference/inference.py | 2 +- llama_stack/distribution/routers/routers.py | 2 +- .../remote/inference/ollama/ollama.py | 2 +- .../inference/passthrough/passthrough.py | 2 +- .../remote/inference/together/together.py | 2 +- .../providers/remote/inference/vllm/vllm.py | 2 +- .../utils/inference/litellm_openai_mixin.py | 2 +- .../utils/inference/openai_compat.py | 75 ++++++++++++++++++- 10 files changed, 95 insertions(+), 12 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index e92deaa41f..34f060386f 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9401,7 +9401,17 @@ "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." }, "prompt": { - "type": "string", + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ], "description": "The prompt to generate a completion for" }, "best_of": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index f0c5d1a799..85a2876435 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6477,7 +6477,11 @@ components: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. prompt: - type: string + oneOf: + - type: string + - type: array + items: + type: string description: The prompt to generate a completion for best_of: type: integer diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 6271466d49..13eacd217e 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -780,7 +780,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 19cc8ac095..89f1744513 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -423,7 +423,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index dc2c8b3f5d..fc1cf22655 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -331,7 +331,7 @@ async def register_model(self, model: Model) -> Model: async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index cbe6e6cae2..09bd22b4c2 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -206,7 +206,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 3e43a844c4..bde32593c5 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -260,7 +260,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 66fb986f94..daeb95b27e 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -424,7 +424,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 8111e44639..cdb4b21aa3 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -251,7 +251,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index ea02de573a..bc6eed1044 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -5,6 +5,8 @@ # the root directory of this source tree. import json import logging +import time +import uuid import warnings from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union @@ -83,7 +85,7 @@ TopPSamplingStrategy, UserMessage, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, @@ -844,6 +846,31 @@ def _convert_openai_logprobs( ] +def _convert_openai_sampling_params( + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, +) -> SamplingParams: + sampling_params = SamplingParams() + + if max_tokens: + sampling_params.max_tokens = max_tokens + + # Map an explicit temperature of 0 to greedy sampling + if temperature == 0: + strategy = GreedySamplingStrategy() + else: + # OpenAI defaults to 1.0 for temperature and top_p if unset + if temperature is None: + temperature = 1.0 + if top_p is None: + top_p = 1.0 + strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) + + sampling_params.strategy = strategy + return sampling_params + + def convert_openai_chat_completion_choice( choice: OpenAIChoice, ) -> ChatCompletionResponse: @@ -1061,7 +1088,7 @@ class OpenAICompletionUnsupportedMixin: async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, @@ -1078,7 +1105,49 @@ async def openai_completion( top_p: Optional[float] = None, user: Optional[str] = None, ) -> OpenAICompletion: - raise ValueError(f"{self.__class__.__name__} doesn't support openai completion") + if stream: + raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions") + + # This is a pretty hacky way to do emulate completions - + # basically just de-batches them... + prompts = [prompt] if not isinstance(prompt, list) else prompt + + sampling_params = _convert_openai_sampling_params( + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + ) + + choices = [] + # "n" is the number of completions to generate per prompt + for _i in range(0, n): + # and we may have multiple prompts, if batching was used + + for prompt in prompts: + result = self.completion( + model_id=model, + content=prompt, + sampling_params=sampling_params, + ) + + index = len(choices) + text = result.content + finish_reason = _convert_openai_finish_reason(result.stop_reason) + + choice = OpenAICompletionChoice( + index=index, + text=text, + finish_reason=finish_reason, + ) + choices.append(choice) + + return OpenAICompletion( + id=f"cmpl-{uuid.uuid4()}", + choices=choices, + created=int(time.time()), + model=model, + object="text_completion", + ) class OpenAIChatCompletionUnsupportedMixin: From fcdeb3d7bfe6cb5ea4bc0b48e030b0c898ae6fb1 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 10:05:50 -0400 Subject: [PATCH 10/19] OpenAI completion prompt can also include tokens The OpenAI completion API supports strings, array of strings, array of tokens, or array of token arrays. So, expand our type hinting to support all of these types. Signed-off-by: Ben Browning --- llama_stack/apis/inference/inference.py | 2 +- llama_stack/distribution/routers/routers.py | 2 +- llama_stack/providers/remote/inference/ollama/ollama.py | 2 +- .../providers/remote/inference/passthrough/passthrough.py | 2 +- llama_stack/providers/remote/inference/together/together.py | 2 +- llama_stack/providers/remote/inference/vllm/vllm.py | 2 +- llama_stack/providers/utils/inference/litellm_openai_mixin.py | 2 +- llama_stack/providers/utils/inference/openai_compat.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 13eacd217e..b29e165f76 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -780,7 +780,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 89f1744513..2d0c956882 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -423,7 +423,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index fc1cf22655..1fbc9e7476 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -331,7 +331,7 @@ async def register_model(self, model: Model) -> Model: async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 09bd22b4c2..7d19c78134 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -206,7 +206,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index bde32593c5..be984167a0 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -260,7 +260,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index daeb95b27e..7425d68bd0 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -424,7 +424,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index cdb4b21aa3..3119c8b408 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -251,7 +251,7 @@ async def embeddings( async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index bc6eed1044..74587c7f5b 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1088,7 +1088,7 @@ class OpenAICompletionUnsupportedMixin: async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, From a1e9cff37c53589997ff3091232e02ab9c8c315a Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 10:08:10 -0400 Subject: [PATCH 11/19] Update spec with latest changes as well Signed-off-by: Ben Browning --- docs/_static/llama-stack-spec.html | 15 +++++++++++++++ docs/_static/llama-stack-spec.yaml | 8 ++++++++ 2 files changed, 23 insertions(+) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 34f060386f..a749321470 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9410,6 +9410,21 @@ "items": { "type": "string" } + }, + { + "type": "array", + "items": { + "type": "integer" + } + }, + { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "integer" + } + } } ], "description": "The prompt to generate a completion for" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 85a2876435..b475dc142b 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6482,6 +6482,14 @@ components: - type: array items: type: string + - type: array + items: + type: integer + - type: array + items: + type: array + items: + type: integer description: The prompt to generate a completion for best_of: type: integer From 52b4766949040a0eda617290a0db3b908aeaa1ad Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 13:55:34 -0400 Subject: [PATCH 12/19] Start some integration tests with an OpenAI client This starts to stub in some integration tests for the OpenAI-compatible server APIs using an OpenAI client. Signed-off-by: Ben Browning --- .../inference/test_openai_completion.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 tests/integration/inference/test_openai_completion.py diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py new file mode 100644 index 0000000000..fe368b20f4 --- /dev/null +++ b/tests/integration/inference/test_openai_completion.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest +from openai import OpenAI + +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient + +from ..test_cases.test_case import TestCase + + +def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI completions are not supported when testing with library client yet.") + + models = {m.identifier: m for m in client_with_models.models.list()} + models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) + provider_id = models[model_id].provider_id + providers = {p.provider_id: p for p in client_with_models.providers.list()} + provider = providers[provider_id] + if provider.provider_type in ( + "inline::meta-reference", + "inline::sentence-transformers", + "inline::vllm", + "remote::bedrock", + "remote::cerebras", + "remote::databricks", + "remote::nvidia", + "remote::runpod", + "remote::sambanova", + "remote::tgi", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") + + +@pytest.fixture +def openai_client(client_with_models, text_model_id): + skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) + base_url = f"{client_with_models.base_url}/v1/openai/v1" + return OpenAI(base_url=base_url, api_key="bar") + + +@pytest.mark.parametrize( + "test_case", + [ + "inference:completion:sanity", + ], +) +def test_openai_completion_non_streaming(openai_client, text_model_id, test_case): + tc = TestCase(test_case) + + response = openai_client.completions.create( + model=text_model_id, + prompt=tc["content"], + stream=False, + ) + assert len(response.choices) > 0 + choice = response.choices[0] + assert len(choice.text) > 10 + + +@pytest.mark.parametrize( + "test_case", + [ + "inference:completion:sanity", + ], +) +def test_openai_completion_streaming(openai_client, text_model_id, test_case): + tc = TestCase(test_case) + + response = openai_client.completions.create( + model=text_model_id, + prompt=tc["content"], + stream=True, + max_tokens=50, + ) + streamed_content = [chunk.choices[0].text for chunk in response] + content_str = "".join(streamed_content).lower().strip() + assert len(content_str) > 10 From ef684ff178f6b22c1fea9e50fbea65f58e2a1172 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 15:22:52 -0400 Subject: [PATCH 13/19] Fix openai_completion tests for ollama When called via the OpenAI API, ollama is responding with more brief responses than when called via its native API. This adjusts the prompting for its OpenAI calls to ask it to be more verbose. --- llama_stack/providers/remote/inference/ollama/ollama.py | 3 +++ tests/integration/inference/test_openai_completion.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 1fbc9e7476..cdd41e372e 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -348,6 +348,9 @@ async def openai_completion( top_p: Optional[float] = None, user: Optional[str] = None, ) -> OpenAICompletion: + if not isinstance(prompt, str): + raise ValueError("Ollama does not support non-string prompts for completion") + model_obj = await self._get_model(model) params = { k: v diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index fe368b20f4..78df64af0e 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -53,9 +53,11 @@ def openai_client(client_with_models, text_model_id): def test_openai_completion_non_streaming(openai_client, text_model_id, test_case): tc = TestCase(test_case) + # ollama needs more verbose prompting for some reason here... + prompt = "Respond to this question and explain your answer. " + tc["content"] response = openai_client.completions.create( model=text_model_id, - prompt=tc["content"], + prompt=prompt, stream=False, ) assert len(response.choices) > 0 @@ -72,9 +74,11 @@ def test_openai_completion_non_streaming(openai_client, text_model_id, test_case def test_openai_completion_streaming(openai_client, text_model_id, test_case): tc = TestCase(test_case) + # ollama needs more verbose prompting for some reason here... + prompt = "Respond to this question and explain your answer. " + tc["content"] response = openai_client.completions.create( model=text_model_id, - prompt=tc["content"], + prompt=prompt, stream=True, max_tokens=50, ) From ac5dc8fae29d433a394363bd001fb150f686a0d3 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 15:43:53 -0400 Subject: [PATCH 14/19] Add prompt_logprobs and guided_choice to OpenAI completions This adds the vLLM-specific extra_body parameters of prompt_logprobs and guided_choice to our openai_completion inference endpoint. The plan here would be to expand this to support all common optional parameters of any of the OpenAI providers, allowing each provider to use or ignore these parameters based on whether their server supports them. Signed-off-by: Ben Browning --- docs/_static/llama-stack-spec.html | 9 ++++ docs/_static/llama-stack-spec.yaml | 6 +++ llama_stack/apis/inference/inference.py | 4 ++ llama_stack/distribution/routers/routers.py | 4 ++ .../remote/inference/ollama/ollama.py | 2 + .../inference/passthrough/passthrough.py | 4 ++ .../remote/inference/together/together.py | 4 ++ .../providers/remote/inference/vllm/vllm.py | 10 ++++ .../utils/inference/litellm_openai_mixin.py | 4 ++ .../utils/inference/openai_compat.py | 2 + .../inference/test_openai_completion.py | 54 +++++++++++++++++-- 11 files changed, 98 insertions(+), 5 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index a749321470..36bfad49e2 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9523,6 +9523,15 @@ "user": { "type": "string", "description": "(Optional) The user to use" + }, + "guided_choice": { + "type": "array", + "items": { + "type": "string" + } + }, + "prompt_logprobs": { + "type": "integer" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index b475dc142b..82faf450a0 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6556,6 +6556,12 @@ components: user: type: string description: (Optional) The user to use + guided_choice: + type: array + items: + type: string + prompt_logprobs: + type: integer additionalProperties: false required: - model diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index b29e165f76..3390a3fef6 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -779,6 +779,7 @@ async def embeddings( @webmethod(route="/openai/v1/completions", method="POST") async def openai_completion( self, + # Standard OpenAI completion parameters model: str, prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, @@ -796,6 +797,9 @@ async def openai_completion( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + # vLLM-specific parameters + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: """Generate an OpenAI-compatible completion for the given prompt using the specified model. diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2d0c956882..bc313036f6 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -439,6 +439,8 @@ async def openai_completion( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: logger.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", @@ -467,6 +469,8 @@ async def openai_completion( temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) provider = self.routing_table.get_provider_impl(model_obj.identifier) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index cdd41e372e..b8671197ef 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -347,6 +347,8 @@ async def openai_completion( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: if not isinstance(prompt, str): raise ValueError("Ollama does not support non-string prompts for completion") diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 7d19c78134..0eb38c3954 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -222,6 +222,8 @@ async def openai_completion( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: client = self._get_client() model_obj = await self.model_store.get_model(model) @@ -244,6 +246,8 @@ async def openai_completion( temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) return await client.inference.openai_completion(**params) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index be984167a0..2c9a7ec034 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -276,6 +276,8 @@ async def openai_completion( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) params = await prepare_openai_completion_params( @@ -296,6 +298,8 @@ async def openai_completion( temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) return await self._get_openai_client().completions.create(**params) # type: ignore diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 7425d68bd0..cac3106136 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -440,8 +440,17 @@ async def openai_completion( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) + + extra_body: Dict[str, Any] = {} + if prompt_logprobs: + extra_body["prompt_logprobs"] = prompt_logprobs + if guided_choice: + extra_body["guided_choice"] = guided_choice + params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, prompt=prompt, @@ -460,6 +469,7 @@ async def openai_completion( temperature=temperature, top_p=top_p, user=user, + extra_body=extra_body, ) return await self.client.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 3119c8b408..2d2f0400ac 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -267,6 +267,8 @@ async def openai_completion( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) params = await prepare_openai_completion_params( @@ -287,6 +289,8 @@ async def openai_completion( temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) return litellm.text_completion(**params) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 74587c7f5b..f33cb44432 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1104,6 +1104,8 @@ async def openai_completion( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: if stream: raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions") diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 78df64af0e..410c1fe22f 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -13,15 +13,19 @@ from ..test_cases.test_case import TestCase -def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id): - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI completions are not supported when testing with library client yet.") - +def provider_from_model(client_with_models, model_id): models = {m.identifier: m for m in client_with_models.models.list()} models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} - provider = providers[provider_id] + return providers[provider_id] + + +def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI completions are not supported when testing with library client yet.") + + provider = provider_from_model(client_with_models, model_id) if provider.provider_type in ( "inline::meta-reference", "inline::sentence-transformers", @@ -37,6 +41,12 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") +def skip_if_provider_isnt_vllm(client_with_models, model_id): + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type != "remote::vllm": + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.") + + @pytest.fixture def openai_client(client_with_models, text_model_id): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) @@ -85,3 +95,37 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case): streamed_content = [chunk.choices[0].text for chunk in response] content_str = "".join(streamed_content).lower().strip() assert len(content_str) > 10 + + +def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id): + skip_if_provider_isnt_vllm(client_with_models, text_model_id) + + prompt = "Hello, world!" + response = openai_client.completions.create( + model=text_model_id, + prompt=prompt, + stream=False, + extra_body={ + "prompt_logprobs": 1, + }, + ) + assert len(response.choices) > 0 + choice = response.choices[0] + assert len(choice.prompt_logprobs) > 0 + + +def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id): + skip_if_provider_isnt_vllm(client_with_models, text_model_id) + + prompt = "I am feeling really sad today." + response = openai_client.completions.create( + model=text_model_id, + prompt=prompt, + stream=False, + extra_body={ + "guided_choice": ["joy", "sadness"], + }, + ) + assert len(response.choices) > 0 + choice = response.choices[0] + assert choice.text in ["joy", "sadness"] From 8d10556ce397e6b47f4269a180b25dd8f52065a9 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 16:18:13 -0400 Subject: [PATCH 15/19] Add basic tests for OpenAI Chat Completions API Signed-off-by: Ben Browning --- .../inference/test_openai_completion.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 410c1fe22f..48c8282609 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -129,3 +129,53 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text assert len(response.choices) > 0 choice = response.choices[0] assert choice.text in ["joy", "sadness"] + + +@pytest.mark.parametrize( + "test_case", + [ + "inference:chat_completion:non_streaming_01", + "inference:chat_completion:non_streaming_02", + ], +) +def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test_case): + tc = TestCase(test_case) + question = tc["question"] + expected = tc["expected"] + + response = openai_client.chat.completions.create( + model=text_model_id, + messages=[ + { + "role": "user", + "content": question, + } + ], + stream=False, + ) + message_content = response.choices[0].message.content.lower().strip() + assert len(message_content) > 0 + assert expected.lower() in message_content + + +@pytest.mark.parametrize( + "test_case", + [ + "inference:chat_completion:streaming_01", + "inference:chat_completion:streaming_02", + ], +) +def test_openai_chat_completion_streaming(openai_client, text_model_id, test_case): + tc = TestCase(test_case) + question = tc["question"] + expected = tc["expected"] + + response = openai_client.chat.completions.create( + model=text_model_id, + messages=[{"role": "user", "content": question}], + stream=True, + timeout=120, # Increase timeout to 2 minutes for large conversation history + ) + streamed_content = [str(chunk.choices[0].delta.content.lower().strip()) for chunk in response] + assert len(streamed_content) > 0 + assert expected.lower() in "".join(streamed_content) From 8f5cd491590177532b1d3c231f946a5790656955 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 17:32:03 -0400 Subject: [PATCH 16/19] vllm prompt_logprobs can also be 0 This adjusts the vllm openai_completion endpoint to also pass a value of 0 for prompt_logprobs, instead of only passing values greater than zero to the backend. The existing test_openai_completion_prompt_logprobs was parameterized to test this case as well. Signed-off-by: Ben Browning --- llama_stack/providers/remote/inference/vllm/vllm.py | 2 +- tests/integration/inference/test_openai_completion.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index cac3106136..79f92adce7 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -446,7 +446,7 @@ async def openai_completion( model_obj = await self._get_model(model) extra_body: Dict[str, Any] = {} - if prompt_logprobs: + if prompt_logprobs is not None and prompt_logprobs >= 0: extra_body["prompt_logprobs"] = prompt_logprobs if guided_choice: extra_body["guided_choice"] = guided_choice diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 48c8282609..d94390b8f9 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -97,7 +97,14 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case): assert len(content_str) > 10 -def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id): +@pytest.mark.parametrize( + "prompt_logprobs", + [ + 1, + 0, + ], +) +def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs): skip_if_provider_isnt_vllm(client_with_models, text_model_id) prompt = "Hello, world!" @@ -106,7 +113,7 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te prompt=prompt, stream=False, extra_body={ - "prompt_logprobs": 1, + "prompt_logprobs": prompt_logprobs, }, ) assert len(response.choices) > 0 From a5827f7cb33916ba69b088e531ee27ae3960c9da Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 10 Apr 2025 13:43:28 -0400 Subject: [PATCH 17/19] Nvidia provider support for OpenAI API endpoints This wires up the openai_completion and openai_chat_completion API methods for the remote Nvidia inference provider, and adds it to the chat completions part of the OpenAI test suite. The hosted Nvidia service doesn't actually host any Llama models with functioning completions and chat completions endpoints, so for now the test suite only activates the nvidia provider for chat completions. Signed-off-by: Ben Browning --- .../remote/inference/nvidia/nvidia.py | 121 ++++++++++++++++-- .../inference/test_openai_completion.py | 37 +++++- 2 files changed, 143 insertions(+), 15 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 3ed458058c..d6f717719d 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,7 @@ import logging import warnings from functools import lru_cache -from typing import AsyncIterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Union from openai import APIConnectionError, AsyncOpenAI, BadRequestError @@ -35,15 +35,15 @@ ToolConfig, ToolDefinition, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.models.llama.datatypes import ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionUnsupportedMixin, - OpenAICompletionUnsupportedMixin, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, + prepare_openai_completion_params, ) from llama_stack.providers.utils.inference.prompt_adapter import content_has_media @@ -60,12 +60,7 @@ logger = logging.getLogger(__name__) -class NVIDIAInferenceAdapter( - Inference, - OpenAIChatCompletionUnsupportedMixin, - OpenAICompletionUnsupportedMixin, - ModelRegistryHelper, -): +class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): def __init__(self, config: NVIDIAConfig) -> None: # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) @@ -270,3 +265,111 @@ async def chat_completion( else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0]) + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + provider_model_id = self.get_provider_model_id(model) + + params = await prepare_openai_completion_params( + model=provider_model_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + + try: + return await self._get_client(provider_model_id).completions.create(**params) + except APIConnectionError as e: + raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + provider_model_id = self.get_provider_model_id(model) + + params = await prepare_openai_completion_params( + model=provider_model_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + + try: + return await self._get_client(provider_model_id).chat.completions.create(**params) + except APIConnectionError as e: + raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index d94390b8f9..e6e584727f 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -33,6 +33,9 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) "remote::bedrock", "remote::cerebras", "remote::databricks", + # Technically Nvidia does support OpenAI completions, but none of their hosted models + # support both completions and chat completions endpoint and all the Llama models are + # just chat completions "remote::nvidia", "remote::runpod", "remote::sambanova", @@ -41,6 +44,25 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") +def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI chat completions are not supported when testing with library client yet.") + + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type in ( + "inline::meta-reference", + "inline::sentence-transformers", + "inline::vllm", + "remote::bedrock", + "remote::cerebras", + "remote::databricks", + "remote::runpod", + "remote::sambanova", + "remote::tgi", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") + + def skip_if_provider_isnt_vllm(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) if provider.provider_type != "remote::vllm": @@ -48,8 +70,7 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id): @pytest.fixture -def openai_client(client_with_models, text_model_id): - skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) +def openai_client(client_with_models): base_url = f"{client_with_models.base_url}/v1/openai/v1" return OpenAI(base_url=base_url, api_key="bar") @@ -60,7 +81,8 @@ def openai_client(client_with_models, text_model_id): "inference:completion:sanity", ], ) -def test_openai_completion_non_streaming(openai_client, text_model_id, test_case): +def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... @@ -81,7 +103,8 @@ def test_openai_completion_non_streaming(openai_client, text_model_id, test_case "inference:completion:sanity", ], ) -def test_openai_completion_streaming(openai_client, text_model_id, test_case): +def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... @@ -145,7 +168,8 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text "inference:chat_completion:non_streaming_02", ], ) -def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test_case): +def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] @@ -172,7 +196,8 @@ def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test "inference:chat_completion:streaming_02", ], ) -def test_openai_chat_completion_streaming(openai_client, text_model_id, test_case): +def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] From ffae192540d5bfbeccfaf43735ca95e5016b2ca3 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 10 Apr 2025 14:19:48 -0400 Subject: [PATCH 18/19] Bug fixes for together.ai OpenAI endpoints After actually running the test_openai_completion.py tests against together.ai, turns out there were a couple of bugs in the initial implementation. This fixes those. Signed-off-by: Ben Browning --- llama_stack/providers/remote/inference/together/together.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 2c9a7ec034..1615b8cd15 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -279,7 +279,7 @@ async def openai_completion( guided_choice: Optional[List[str]] = None, prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: - model_obj = await self._get_model(model) + model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, prompt=prompt, @@ -298,8 +298,6 @@ async def openai_completion( temperature=temperature, top_p=top_p, user=user, - guided_choice=guided_choice, - prompt_logprobs=prompt_logprobs, ) return await self._get_openai_client().completions.create(**params) # type: ignore @@ -329,7 +327,7 @@ async def openai_chat_completion( top_p: Optional[float] = None, user: Optional[str] = None, ) -> OpenAIChatCompletion: - model_obj = await self._get_model(model) + model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, messages=messages, From 31181c070bb1d10c770bf9fc4bec9395a4a3864f Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 10 Apr 2025 15:29:32 -0400 Subject: [PATCH 19/19] Fireworks provider support for OpenAI API endpoints This wires up the openai_completion and openai_chat_completion API methods for the remote Fireworks inference provider. Signed-off-by: Ben Browning --- .../remote/inference/fireworks/fireworks.py | 109 +++++++++++++++++- .../inference/test_openai_completion.py | 5 +- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 4acbe43f80..b59e9f2cb0 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Union from fireworks.client import Fireworks +from openai import AsyncOpenAI from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -31,6 +32,7 @@ ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( @@ -39,6 +41,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, + prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -81,10 +84,16 @@ def _get_api_key(self) -> str: ) return provider_data.fireworks_api_key + def _get_base_url(self) -> str: + return "https://api.fireworks.ai/inference/v1" + def _get_client(self) -> Fireworks: fireworks_api_key = self._get_api_key() return Fireworks(api_key=fireworks_api_key) + def _get_openai_client(self) -> AsyncOpenAI: + return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key()) + async def completion( self, model_id: str, @@ -268,3 +277,101 @@ async def embeddings( embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().completions.create(**params) + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().chat.completions.create(**params) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index e6e584727f..0905d58176 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -208,6 +208,9 @@ def test_openai_chat_completion_streaming(openai_client, client_with_models, tex stream=True, timeout=120, # Increase timeout to 2 minutes for large conversation history ) - streamed_content = [str(chunk.choices[0].delta.content.lower().strip()) for chunk in response] + streamed_content = [] + for chunk in response: + if chunk.choices[0].delta.content: + streamed_content.append(chunk.choices[0].delta.content.lower().strip()) assert len(streamed_content) > 0 assert expected.lower() in "".join(streamed_content)