Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -4966,7 +4966,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas

key := schemas.Key{}
var keys []schemas.Key
if providerRequiresKey(baseProvider, config.CustomProviderConfig) {
if providerRequiresKey(config.CustomProviderConfig) {
Comment thread
sammaji marked this conversation as resolved.
// ListModels needs all enabled/supported keys so providers can aggregate
// and report per-key statuses (KeyStatuses).
if req.RequestType == schemas.ListModelsRequest {
Expand Down
90 changes: 71 additions & 19 deletions core/providers/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package ollama

import (
"fmt"
"strings"
"time"

Expand Down Expand Up @@ -50,11 +49,7 @@ func NewOllamaProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")

// BaseURL is required for Ollama
if config.NetworkConfig.BaseURL == "" {
return nil, fmt.Errorf("base_url is required for ollama provider")
}

// BaseURL is optional when keys have ollama_key_config with per-key URLs
return &OllamaProvider{
logger: logger,
client: client,
Expand All @@ -69,30 +64,71 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider {
return schemas.Ollama
}

// ListModels performs a list models request to Ollama's API.
func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if provider.networkConfig.BaseURL == "" {
return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey())
// getBaseURL resolves the base URL for a request from the per-key ollama_key_config.
// Each Ollama key must have its own URL configured — there is no provider-level fallback.
func (provider *OllamaProvider) getBaseURL(key schemas.Key) string {
if key.OllamaKeyConfig != nil && key.OllamaKeyConfig.URL.GetValue() != "" {
return strings.TrimRight(key.OllamaKeyConfig.URL.GetValue(), "/")
}
return ""
}

// baseURLOrError returns the resolved base URL or a BifrostError when none is configured.
func (provider *OllamaProvider) baseURLOrError(key schemas.Key) (string, *schemas.BifrostError) {
u := provider.getBaseURL(key)
if u == "" {
return "", providerUtils.NewBifrostOperationError(
"no base URL configured: set ollama_key_config.url on the key",
nil,
provider.GetProviderKey(),
)
}
return u, nil
Comment thread
sammaji marked this conversation as resolved.
}

// listModelsByKey performs a list models request for a single Ollama key,
// resolving the per-key URL so each backend is queried individually.
func (provider *OllamaProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAIListModelsRequest(
url := baseURL + providerUtils.GetPathFromContext(ctx, "/v1/models")
return openai.ListModelsByKey(
ctx,
provider.client,
request,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
keys,
url,
key,
request.Unfiltered,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
)
}

// ListModels performs a list models request to Ollama's API.
// Requests are made concurrently per key so that each backend is queried
// with its own URL (from ollama_key_config).
func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return providerUtils.HandleMultipleListModelsRequests(
ctx,
keys,
request,
provider.listModelsByKey,
)
}

// TextCompletion performs a text completion request to the Ollama API.
func (provider *OllamaProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAITextCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
Expand All @@ -109,10 +145,14 @@ func (provider *OllamaProvider) TextCompletion(ctx *schemas.BifrostContext, key
// It formats the request, sends it to Ollama, and processes the response.
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAITextCompletionStreaming(
ctx,
provider.client,
provider.networkConfig.BaseURL+"/v1/completions",
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
request,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
nil,
provider.networkConfig.ExtraHeaders,
Expand All @@ -129,10 +169,14 @@ func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext

// ChatCompletion performs a chat completion request to the Ollama API.
func (provider *OllamaProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAIChatCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
Expand All @@ -150,11 +194,15 @@ func (provider *OllamaProvider) ChatCompletion(ctx *schemas.BifrostContext, key
// Uses Ollama's OpenAI-compatible streaming format.
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
func (provider *OllamaProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
// Use shared OpenAI-compatible streaming logic
return openai.HandleOpenAIChatCompletionStreaming(
ctx,
provider.client,
provider.networkConfig.BaseURL+"/v1/chat/completions",
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
request,
nil,
provider.networkConfig.ExtraHeaders,
Expand Down Expand Up @@ -199,10 +247,14 @@ func (provider *OllamaProvider) ResponsesStream(ctx *schemas.BifrostContext, pos

// Embedding performs an embedding request to the Ollama API.
func (provider *OllamaProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAIEmbeddingRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"),
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"),
request,
key,
provider.networkConfig.ExtraHeaders,
Expand Down
97 changes: 76 additions & 21 deletions core/providers/sgl/sgl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package sgl

import (
"fmt"
"strings"
"time"

Expand Down Expand Up @@ -50,11 +49,7 @@ func NewSGLProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*SGL
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")

// BaseURL is required for SGLang
if config.NetworkConfig.BaseURL == "" {
return nil, fmt.Errorf("base_url is required for sgl provider")
}

// BaseURL is optional when keys have sgl_key_config with per-key URLs
return &SGLProvider{
logger: logger,
client: client,
Expand All @@ -69,27 +64,71 @@ func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider {
return schemas.SGL
}

// ListModels performs a list models request to SGL's API.
func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return openai.HandleOpenAIListModelsRequest(
// getBaseURL resolves the base URL for a request from the per-key sgl_key_config.
// Each SGL key must have its own URL configured — there is no provider-level fallback.
func (provider *SGLProvider) getBaseURL(key schemas.Key) string {
if key.SGLKeyConfig != nil && key.SGLKeyConfig.URL.GetValue() != "" {
return strings.TrimRight(key.SGLKeyConfig.URL.GetValue(), "/")
}
return ""
}

// baseURLOrError returns the resolved base URL or a BifrostError when none is configured.
func (provider *SGLProvider) baseURLOrError(key schemas.Key) (string, *schemas.BifrostError) {
u := provider.getBaseURL(key)
if u == "" {
return "", providerUtils.NewBifrostOperationError(
"no base URL configured: set sgl_key_config.url on the key",
nil,
provider.GetProviderKey(),
)
}
return u, nil
}

// listModelsByKey performs a list models request for a single SGL key,
// resolving the per-key URL so each backend is queried individually.
func (provider *SGLProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
url := baseURL + providerUtils.GetPathFromContext(ctx, "/v1/models")
return openai.ListModelsByKey(
ctx,
provider.client,
request,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
keys,
url,
key,
request.Unfiltered,
provider.networkConfig.ExtraHeaders,
schemas.SGL,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
)
}

// TextCompletion is not supported by the SGL provider.
// ListModels performs a list models request to SGL's API.
// Requests are made concurrently per key so that each backend is queried
// with its own URL (from sgl_key_config).
func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return providerUtils.HandleMultipleListModelsRequests(
ctx,
keys,
request,
provider.listModelsByKey,
)
}

// TextCompletion performs a text completion request to the SGL API.
func (provider *SGLProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAITextCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
Expand All @@ -106,10 +145,14 @@ func (provider *SGLProvider) TextCompletion(ctx *schemas.BifrostContext, key sch
// It formats the request, sends it to SGL, and processes the response.
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAITextCompletionStreaming(
ctx,
provider.client,
provider.networkConfig.BaseURL+"/v1/completions",
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
request,
nil,
provider.networkConfig.ExtraHeaders,
Expand All @@ -126,10 +169,14 @@ func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, p

// ChatCompletion performs a chat completion request to the SGL API.
func (provider *SGLProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAIChatCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
Expand All @@ -147,11 +194,15 @@ func (provider *SGLProvider) ChatCompletion(ctx *schemas.BifrostContext, key sch
// Uses SGL's OpenAI-compatible streaming format.
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
func (provider *SGLProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
// Use shared OpenAI-compatible streaming logic
return openai.HandleOpenAIChatCompletionStreaming(
ctx,
provider.client,
provider.networkConfig.BaseURL+"/v1/chat/completions",
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
request,
nil,
provider.networkConfig.ExtraHeaders,
Expand Down Expand Up @@ -194,12 +245,16 @@ func (provider *SGLProvider) ResponsesStream(ctx *schemas.BifrostContext, postHo
)
}

// Embedding is not supported by the SGL provider.
// Embedding performs an embedding request to the SGL API.
func (provider *SGLProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
baseURL, bifrostErr := provider.baseURLOrError(key)
if bifrostErr != nil {
return nil, bifrostErr
}
return openai.HandleOpenAIEmbeddingRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"),
baseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"),
request,
key,
provider.networkConfig.ExtraHeaders,
Expand Down Expand Up @@ -403,4 +458,4 @@ func (provider *SGLProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Ke

func (provider *SGLProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
}
}
Loading
Loading