Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -6096,10 +6096,14 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p
// Model filtering logic:
// - If model is nil or empty → include all keys (no model filter)
// - If model is specified:
// - If key.Models is empty → include key (supports all models)
// - If model is in key.BlacklistedModels → exclude (wins over Models allow list)
// - If key.Models is empty → include key (supports all non-blacklisted models)
// - If key.Models is non-empty → only include if model is in list
if model != nil && *model != "" && len(k.Models) > 0 {
if !slices.Contains(k.Models, *model) {
if model != nil && *model != "" {
if len(k.BlacklistedModels) > 0 && slices.Contains(k.BlacklistedModels, *model) {
continue
}
if len(k.Models) > 0 && !slices.Contains(k.Models, *model) {
continue
Comment thread
TejasGhatte marked this conversation as resolved.
}
}
Expand Down Expand Up @@ -6167,7 +6171,8 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex
keys = batchEnabledKeys
}

// filter out keys which don't support the model, if the key has no models, it is supported for all models
// Filter out keys that don't support the model: blacklisted_models wins over models allow list;
// if the key has no models list, it supports all models except those blacklisted.
var supportedKeys []schemas.Key

// Skip model check conditions
Expand All @@ -6192,7 +6197,12 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex
continue
}
hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType)
modelSupported := (len(key.Models) == 0 && hasValue) || (slices.Contains(key.Models, model) && hasValue)
var modelSupported bool
if len(key.BlacklistedModels) > 0 && slices.Contains(key.BlacklistedModels, model) {
modelSupported = false
} else {
modelSupported = (len(key.Models) == 0 && hasValue) || (slices.Contains(key.Models, model) && hasValue)
}
// Additional deployment checks for Azure, Bedrock and Vertex
deploymentSupported := true
if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil {
Expand Down
56 changes: 56 additions & 0 deletions core/bifrost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,62 @@ func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T
}
}

func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) {
account := NewMockAccount()
account.AddProvider(schemas.OpenAI, 5, 1000)

ctx := context.Background()
bifrost, err := Init(ctx, schemas.BifrostConfig{
Account: account,
Logger: NewDefaultLogger(schemas.LogLevelError),
})
if err != nil {
t.Fatalf("Init failed: %v", err)
}
bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

t.Run("all keys blacklist model", func(t *testing.T) {
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
{ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}},
})
_, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
if err == nil {
t.Fatal("expected error when model is only blacklisted")
}
if !strings.Contains(err.Error(), "no keys found that support model") {
t.Fatalf("unexpected error: %v", err)
}
})

t.Run("blacklist wins over models allow list", func(t *testing.T) {
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
{
ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1,
Models: []string{"gpt-4"},
BlacklistedModels: []string{"gpt-4"},
},
})
_, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
if err == nil {
t.Fatal("expected error when model is both allowed and blacklisted")
}
})

t.Run("second key used when first blacklists", func(t *testing.T) {
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
{ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}},
{ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1},
})
key, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if key.ID != "k2" {
t.Fatalf("expected k2, got %s", key.ID)
}
})
}

// Test UpdateProvider functionality
func TestUpdateProvider(t *testing.T) {
t.Run("SuccessfulUpdate", func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions core/changelog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- feat: added `blacklisted_models` on provider keys to exclude models from routing and filtered list-models
2 changes: 1 addition & 1 deletion core/providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext,
}

// Create final response
response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, request.Unfiltered)
response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, request.Unfiltered)
response.ExtraFields.Latency = latency.Milliseconds()

// Set raw request if enabled
Expand Down
11 changes: 9 additions & 2 deletions core/providers/anthropic/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package anthropic
import (
"time"

providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)

func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand Down Expand Up @@ -40,6 +41,9 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide
continue
}
}
if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, modelID) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + modelID,
Name: schemas.Ptr(model.DisplayName),
Expand All @@ -48,9 +52,12 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide
includedModels[modelID] = true
}

// Backfill allowed models that were not in the response
// Backfill allowed models that were not in the response (skip blacklisted; blacklist wins over allow list)
if !unfiltered && len(allowedModels) > 0 {
for _, allowedModel := range allowedModels {
if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) {
continue
}
if !includedModels[allowedModel] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key
}

// Convert to Bifrost response
response := azureResponse.ToBifrostListModelsResponse(key.Models, key.AzureKeyConfig.Deployments, request.Unfiltered)
response := azureResponse.ToBifrostListModelsResponse(key.Models, key.AzureKeyConfig.Deployments, key.BlacklistedModels, request.Unfiltered)
if response == nil {
return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure)
}
Expand Down
13 changes: 12 additions & 1 deletion core/providers/azure/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package azure
import (
"slices"

providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)

Expand Down Expand Up @@ -58,7 +59,7 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym
return "", ""
}

func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand Down Expand Up @@ -116,6 +117,10 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
modelID = matchedAllowedModel
}

if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ID, modelID, deploymentAlias, matchedAllowedModel) {
continue
}

modelEntry := schemas.Model{
ID: string(schemas.Azure) + "/" + modelID,
Created: schemas.Ptr(model.CreatedAt),
Expand All @@ -142,6 +147,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) {
continue
}
if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) {
continue
}
Comment thread
TejasGhatte marked this conversation as resolved.
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(schemas.Azure) + "/" + alias,
Name: schemas.Ptr(alias),
Expand All @@ -154,6 +162,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
// Backfill allowed models that were not in the response
if !unfiltered && len(allowedModels) > 0 {
for _, allowedModel := range allowedModels {
if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) {
continue
}
if !includedModels[allowedModel] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(schemas.Azure) + "/" + allowedModel,
Expand Down
4 changes: 2 additions & 2 deletions core/providers/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) (
MaxIdleConns: schemas.DefaultMaxIdleConnsPerHost,
MaxIdleConnsPerHost: schemas.DefaultMaxIdleConnsPerHost,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: requestTimeout,
ExpectContinueTimeout: 1 * time.Second,
ForceAttemptHTTP2: config.NetworkConfig.EnforceHTTP2,
Expand Down Expand Up @@ -732,7 +732,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke
}

// Convert to Bifrost response
response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, config.Deployments, request.Unfiltered)
response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, config.Deployments, key.BlacklistedModels, request.Unfiltered)
if response == nil {
return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil, providerName)
}
Expand Down
13 changes: 12 additions & 1 deletion core/providers/bedrock/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"slices"
"strings"

providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)

Expand Down Expand Up @@ -221,7 +222,7 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym
return "", ""
}

func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand Down Expand Up @@ -284,6 +285,10 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK
modelID = matchedAllowedModel
}

if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ModelID, modelID, deploymentAlias, matchedAllowedModel) {
continue
}

modelEntry := schemas.Model{
ID: string(providerKey) + "/" + modelID,
Name: schemas.Ptr(model.ModelName),
Expand Down Expand Up @@ -315,6 +320,9 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK
if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) {
continue
}
if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) {
continue
}
Comment thread
TejasGhatte marked this conversation as resolved.
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + alias,
Name: schemas.Ptr(alias),
Expand All @@ -327,6 +335,9 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK
// Backfill allowed models that were not in the response
if !unfiltered && len(allowedModels) > 0 {
for _, allowedModel := range allowedModels {
if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) {
continue
}
if !includedModels[allowedModel] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key
}

// Convert Cohere v2 response to Bifrost response
response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered)
response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered)
Comment thread
TejasGhatte marked this conversation as resolved.

response.ExtraFields.Latency = latency.Milliseconds()

Expand Down
12 changes: 9 additions & 3 deletions core/providers/cohere/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ func (r *CohereRerankRequest) GetExtraParams() map[string]interface{} {

// CohereRerankResult represents a single result from Cohere rerank.
type CohereRerankResult struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Document json.RawMessage `json:"document,omitempty"`
}

Expand All @@ -44,7 +44,7 @@ type CohereRerankMeta struct {
Tokens *CohereTokenUsage `json:"tokens,omitempty"`
}

func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -58,6 +58,9 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe
if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.Name) {
continue
}
if !unfiltered && slices.Contains(blacklistedModels, model.Name) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + model.Name,
Name: schemas.Ptr(model.Name),
Expand All @@ -70,6 +73,9 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe
// Backfill allowed models that were not in the response
if !unfiltered && len(allowedModels) > 0 {
for _, allowedModel := range allowedModels {
if slices.Contains(blacklistedModels, allowedModel) {
continue
}
if !includedModels[allowedModel] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/elevenlabs/elevenlabs.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext,
return nil, bifrostErr
}

response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered)
response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered)

response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
Expand Down
8 changes: 7 additions & 1 deletion core/providers/elevenlabs/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/maximhq/bifrost/core/schemas"
)

func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -20,6 +20,9 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid
if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) {
continue
}
if !unfiltered && slices.Contains(blacklistedModels, model.ModelID) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + model.ModelID,
Name: schemas.Ptr(model.Name),
Expand All @@ -30,6 +33,9 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid
// Backfill allowed models that were not in the response
if !unfiltered && len(allowedModels) > 0 {
for _, allowedModel := range allowedModels {
if slices.Contains(blacklistedModels, allowedModel) {
continue
}
if !includedModels[allowedModel] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key
}
}

response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered)
response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered)
Comment thread
TejasGhatte marked this conversation as resolved.

response.ExtraFields.Latency = latency.Milliseconds()

Expand Down
Loading
Loading