diff --git a/core/bifrost.go b/core/bifrost.go index a872935bfd..fae37aac76 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -6191,10 +6191,9 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p // - If key.Models is ["*"] → include key (supports all non-blacklisted models) // - If key.Models is empty → exclude key (deny-by-default) // - If key.Models is non-empty → only include if model is in list + // Blacklist wins over allowlist if model != nil && *model != "" { - if k.Models.IsUnrestricted() { - // wildcard: allow all models - } else if !k.Models.IsAllowed(*model) { + if k.BlacklistedModels.IsBlocked(*model) || !k.Models.IsAllowed(*model) { continue } } @@ -6289,7 +6288,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex } hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) // ["*"] = allow all models; [] = deny all; specific list = allow only listed - modelSupported := hasValue && key.Models.IsAllowed(model) + modelSupported := hasValue && key.Models.IsAllowed(model) && !key.BlacklistedModels.IsBlocked(model) // Additional deployment checks for Azure, Bedrock and Vertex deploymentSupported := true if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { diff --git a/core/bifrost_test.go b/core/bifrost_test.go index 806fd4d844..6f91d24a86 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -891,7 +891,7 @@ func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { t.Run("second key used when first blacklists", func(t *testing.T) { account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ {ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}}, - {ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1}, + {ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1, Models: []string{"*"}}, }) key, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err != nil { @@ -1226,4 +1226,3 @@ func TestUpdateProvider_ProviderSliceIntegrity(t *testing.T) { } }) } - diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index 4afba1c96d..0857214d7c 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -298,7 +298,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, } // Create final response - response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, request.Unfiltered) + response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled diff --git a/core/providers/anthropic/models.go b/core/providers/anthropic/models.go index baf0ad098e..9ff3e1ca14 100644 --- a/core/providers/anthropic/models.go +++ b/core/providers/anthropic/models.go @@ -6,7 +6,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -24,7 +24,7 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide bifrostResponse.NextPageToken = *response.LastID } - if !unfiltered && allowedModels.IsEmpty() { + if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -44,6 +44,9 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide continue } } + if !unfiltered && blacklistedModels.IsBlocked(modelID) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + modelID, Name: schemas.Ptr(model.DisplayName), @@ -55,6 +58,9 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide // Backfill allowed models that were not in the response if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { + if blacklistedModels.IsBlocked(allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index f25fee81f4..21253f8b14 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -366,7 +366,7 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert to Bifrost response - response := azureResponse.ToBifrostListModelsResponse(key.Models, key.AzureKeyConfig.Deployments, request.Unfiltered) + response := azureResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.AzureKeyConfig.Deployments, request.Unfiltered) if response == nil { return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure) } diff --git a/core/providers/azure/models.go b/core/providers/azure/models.go index e6d6bda14c..0875cf781b 100644 --- a/core/providers/azure/models.go +++ b/core/providers/azure/models.go @@ -58,7 +58,24 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym return "", "" } -func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +// matchesBlacklist reports whether modelID matches any entry in the blacklist, +// using the same matching logic as findMatchingAllowedModel (exact and base-model). +func matchesBlacklist(bl schemas.BlackList, modelID string) bool { + if bl.IsEmpty() { + return false + } + if bl.Contains(modelID) { + return true + } + for _, item := range bl { + if schemas.SameBaseModel(item, modelID) { + return true + } + } + return false +} + +func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -67,7 +84,7 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode Data: make([]schemas.Model, 0, len(response.Data)), } - if !unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 { + if !unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -113,6 +130,10 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode if shouldFilter { continue } + if !unfiltered && (matchesBlacklist(blacklistedModels, model.ID) || + (deploymentAlias != "" && matchesBlacklist(blacklistedModels, deploymentAlias))) { + continue + } // Use the matched name from allowedModels or deployments (like Anthropic) // Priority: deployment value > matched allowedModel > original model.ID @@ -148,6 +169,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode if restrictAllowed && !allowedModels.Contains(alias) { continue } + if !unfiltered && matchesBlacklist(blacklistedModels, alias) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Azure) + "/" + alias, Name: schemas.Ptr(alias), @@ -160,6 +184,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode // Backfill allowed models that were not in the response if restrictAllowed { for _, allowedModel := range allowedModels { + if matchesBlacklist(blacklistedModels, allowedModel) { + continue + } if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Azure) + "/" + allowedModel, diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index b5a9b49c50..f140c10db4 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -732,7 +732,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Convert to Bifrost response - response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, config.Deployments, request.Unfiltered) + response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, config.Deployments, request.Unfiltered) if response == nil { return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil, providerName) } diff --git a/core/providers/bedrock/models.go b/core/providers/bedrock/models.go index cfffd360f7..005998aa4a 100644 --- a/core/providers/bedrock/models.go +++ b/core/providers/bedrock/models.go @@ -220,7 +220,35 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym return "", "" } -func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +// matchesBlacklist reports whether modelID matches any entry in the blacklist, +// using the same matching logic as findMatchingAllowedModel (exact, prefix-normalized, base-model). +func matchesBlacklist(bl schemas.BlackList, modelID string) bool { + if bl.IsEmpty() { + return false + } + if bl.Contains(modelID) { + return true + } + if extractPrefix(modelID) != "" { + if bl.Contains(removePrefix(modelID)) { + return true + } + } + for _, item := range bl { + if extractPrefix(item) != "" && removePrefix(item) == modelID { + return true + } + } + valueNormalized := removePrefix(modelID) + for _, item := range bl { + if schemas.SameBaseModel(removePrefix(item), valueNormalized) { + return true + } + } + return false +} + +func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -229,7 +257,7 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK Data: make([]schemas.Model, 0, len(response.ModelSummaries)), } - if !unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 { + if !unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -280,6 +308,10 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK if shouldFilter { continue } + if !unfiltered && (matchesBlacklist(blacklistedModels, model.ModelID) || + (deploymentAlias != "" && matchesBlacklist(blacklistedModels, deploymentAlias))) { + continue + } // Use the matched name from allowedModels or deployments (like Anthropic) // Priority: deployment value > matched allowedModel > original model.ModelID @@ -320,6 +352,9 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK if restrictAllowed && !allowedModels.Contains(alias) { continue } + if !unfiltered && matchesBlacklist(blacklistedModels, alias) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + alias, Name: schemas.Ptr(alias), @@ -332,6 +367,9 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK // Backfill allowed models that were not in the response if restrictAllowed { for _, allowedModel := range allowedModels { + if matchesBlacklist(blacklistedModels, allowedModel) { + continue + } if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index d9636eee40..8d562364ba 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -288,7 +288,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert Cohere v2 response to Bifrost response - response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered) + response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/cohere/models.go b/core/providers/cohere/models.go index 2ef4a8b0fd..e66aeedb1b 100644 --- a/core/providers/cohere/models.go +++ b/core/providers/cohere/models.go @@ -44,7 +44,7 @@ type CohereRerankMeta struct { Tokens *CohereTokenUsage `json:"tokens,omitempty"` } -func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -53,7 +53,7 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - if !unfiltered && allowedModels.IsEmpty() { + if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -62,6 +62,9 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.Name) { continue } + if !unfiltered && blacklistedModels.IsBlocked(model.Name) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + model.Name, Name: schemas.Ptr(model.Name), @@ -74,6 +77,9 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe // Backfill allowed models that were not in the response if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { + if blacklistedModels.IsBlocked(allowedModel) { + continue + } if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index bddf520be3..6ca0172d6c 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -115,7 +115,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, return nil, bifrostErr } - response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered) + response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) diff --git a/core/providers/elevenlabs/models.go b/core/providers/elevenlabs/models.go index dd9136b27e..3c4e939fca 100644 --- a/core/providers/elevenlabs/models.go +++ b/core/providers/elevenlabs/models.go @@ -6,7 +6,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,7 +15,7 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid Data: make([]schemas.Model, 0, len(*response)), } - if !unfiltered && allowedModels.IsEmpty() { + if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -24,6 +24,9 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ModelID) { continue } + if !unfiltered && blacklistedModels.IsBlocked(model.ModelID) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + model.ModelID, Name: schemas.Ptr(model.Name), @@ -34,6 +37,9 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid // Backfill allowed models that were not in the response if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { + if blacklistedModels.IsBlocked(allowedModel) { + continue + } if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index 57146bd737..7a4d299bd8 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -227,7 +227,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key } } - response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered) + response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/gemini/models.go b/core/providers/gemini/models.go index 576201bcc4..34f021a5be 100644 --- a/core/providers/gemini/models.go +++ b/core/providers/gemini/models.go @@ -16,7 +16,7 @@ func toGeminiModelResourceName(modelID string) string { return "models/" + modelID } -func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -25,7 +25,7 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - if !unfiltered && allowedModels.IsEmpty() { + if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -38,6 +38,9 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(modelName) { continue } + if !unfiltered && blacklistedModels.IsBlocked(modelName) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + modelName, Name: schemas.Ptr(model.DisplayName), @@ -53,6 +56,9 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe // Backfill allowed models that were not in the response if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { + if blacklistedModels.IsBlocked(allowedModel) { + continue + } if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/huggingface/huggingface.go b/core/providers/huggingface/huggingface.go index d1b7f145f3..62f17c8931 100644 --- a/core/providers/huggingface/huggingface.go +++ b/core/providers/huggingface/huggingface.go @@ -384,7 +384,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext } if result.response != nil { - providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, request.Unfiltered) + providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, request.Unfiltered) if providerResponse != nil { aggregatedResponse.Data = append(aggregatedResponse.Data, providerResponse.Data...) totalLatency += result.latency diff --git a/core/providers/huggingface/models.go b/core/providers/huggingface/models.go index 749fb2ec38..bc2314af31 100644 --- a/core/providers/huggingface/models.go +++ b/core/providers/huggingface/models.go @@ -13,7 +13,7 @@ const ( maxModelFetchLimit = 1000 ) -func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -22,7 +22,7 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi Data: make([]schemas.Model, 0, len(response.Models)), } - if !unfiltered && allowedModels.IsEmpty() { + if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -40,6 +40,9 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ModelID) { continue } + if !unfiltered && blacklistedModels.IsBlocked(model.ModelID) { + continue + } newModel := schemas.Model{ ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, model.ModelID), @@ -55,6 +58,9 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi // Backfill allowed models that were not in the response if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { + if blacklistedModels.IsBlocked(allowedModel) { + continue + } if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, allowedModel), diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index a3f1e1e7c4..44b7ba9573 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -116,7 +116,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Create final response - response := mistralResponse.ToBifrostListModelsResponse(key.Models, request.Unfiltered) + response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/mistral/models.go b/core/providers/mistral/models.go index 8030f5e1f8..9b1002e54b 100644 --- a/core/providers/mistral/models.go +++ b/core/providers/mistral/models.go @@ -6,7 +6,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,7 +15,7 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo Data: make([]schemas.Model, 0, len(response.Data)), } - if !unfiltered && allowedModels.IsEmpty() { + if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -24,6 +24,9 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ID) { continue } + if !unfiltered && blacklistedModels.IsBlocked(model.ID) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Mistral) + "/" + model.ID, Name: schemas.Ptr(model.Name), @@ -38,6 +41,9 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo // Backfill allowed models that were not in the response if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { + if blacklistedModels.IsBlocked(allowedModel) { + continue + } if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Mistral) + "/" + allowedModel, diff --git a/core/providers/openai/models.go b/core/providers/openai/models.go index cd79832290..8268608568 100644 --- a/core/providers/openai/models.go +++ b/core/providers/openai/models.go @@ -7,7 +7,7 @@ import ( ) // ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response -func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -16,7 +16,7 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Data)), } - if !unfiltered && allowedModels.IsEmpty() { + if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -25,6 +25,9 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ID) { continue } + if !unfiltered && blacklistedModels.IsBlocked(model.ID) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + model.ID, Created: model.Created, @@ -37,6 +40,9 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe // Backfill allowed models that were not in the response if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { + if blacklistedModels.IsBlocked(allowedModel) { + continue + } if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index fa3dbede3e..03952bd514 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -181,7 +181,7 @@ func ListModelsByKey( return nil, bifrostErr } - response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, unfiltered) + response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, unfiltered) response.ExtraFields.Provider = providerName response.ExtraFields.RequestType = schemas.ListModelsRequest diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index 79703be364..eda7353372 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -186,9 +186,10 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, // Filter by key.Models allowedModels := key.Models + blacklistedModels := key.BlacklistedModels providerPrefix := string(schemas.OpenRouter) + "/" - if !request.Unfiltered && allowedModels.IsEmpty() { + if !request.Unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { openrouterResponse.Data = make([]schemas.Model, 0) } else if !request.Unfiltered && allowedModels.IsRestricted() { filteredData := make([]schemas.Model, 0, len(openrouterResponse.Data)) @@ -198,6 +199,9 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, if !(allowedModels.Contains(rawID) || allowedModels.Contains(providerPrefix+rawID)) { continue } + if blacklistedModels.IsBlocked(rawID) || blacklistedModels.IsBlocked(providerPrefix+rawID) { + continue + } openrouterResponse.Data[i].ID = providerPrefix + rawID filteredData = append(filteredData, openrouterResponse.Data[i]) includedModels[strings.ToLower(rawID)] = true @@ -209,6 +213,9 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, if strings.HasPrefix(strings.ToLower(allowedModel), strings.ToLower(providerPrefix)) { rawID = allowedModel[len(providerPrefix):] } + if blacklistedModels.IsBlocked(rawID) || blacklistedModels.IsBlocked(providerPrefix+rawID) { + continue + } if !includedModels[strings.ToLower(rawID)] { filteredData = append(filteredData, schemas.Model{ ID: providerPrefix + rawID, diff --git a/core/providers/replicate/models.go b/core/providers/replicate/models.go index 12000f4b4b..206d0e0ca6 100644 --- a/core/providers/replicate/models.go +++ b/core/providers/replicate/models.go @@ -11,13 +11,14 @@ func ToBifrostListModelsResponse( deploymentsResponse *ReplicateDeploymentListResponse, providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, + blacklistedModels schemas.BlackList, unfiltered bool, ) *schemas.BifrostListModelsResponse { bifrostResponse := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } - if !unfiltered && allowedModels.IsEmpty() { + if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -33,6 +34,9 @@ func ToBifrostListModelsResponse( if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(deploymentID) { continue } + if !unfiltered && blacklistedModels.IsBlocked(deploymentID) { + continue + } // Extract information from current release if available if deployment.CurrentRelease != nil { @@ -65,6 +69,9 @@ func ToBifrostListModelsResponse( // Backfill allowed models that were not in the response if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { + if blacklistedModels.IsBlocked(allowedModel) { + continue + } if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/replicate/replicate.go b/core/providers/replicate/replicate.go index 81cd67931d..e1ed9131e6 100644 --- a/core/providers/replicate/replicate.go +++ b/core/providers/replicate/replicate.go @@ -361,6 +361,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont deploymentsResponse, providerName, key.Models, + key.BlacklistedModels, request.Unfiltered, ) diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index 924b421d98..54ba41ac28 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -113,7 +113,7 @@ func findDeploymentMatch(deployments map[string]string, customModelID string) (d // - If allowedModels is empty, all models are allowed // - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included // - Deployments map is used to match model IDs to aliases and filter accordingly -func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -122,7 +122,7 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod Data: make([]schemas.Model, 0, len(response.Models)), } - if !unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 { + if !unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -176,6 +176,9 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod if shouldFilter { continue } + if !unfiltered && blacklistedModels.IsBlocked(customModelID) { + continue + } modelID := customModelID @@ -209,6 +212,9 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod if restrictAllowed && !allowedModels.Contains(alias) { continue } + if blacklistedModels.IsBlocked(alias) { + continue + } modelName := formatDeploymentName(alias) modelEntry := schemas.Model{ @@ -230,6 +236,9 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod if addedModelIDs[modelID] { continue } + if blacklistedModels.IsBlocked(allowedModel) { + continue + } modelName := formatDeploymentName(allowedModel) modelEntry := schemas.Model{ @@ -249,7 +258,7 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod // ToBifrostListModelsResponse converts a Vertex AI publisher models response to Bifrost's format. // This is for foundation models from the Model Garden (publishers.models.list endpoint). -func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -258,7 +267,7 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a Data: make([]schemas.Model, 0, len(response.PublisherModels)), } - if !unfiltered && allowedModels.IsEmpty() { + if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { return bifrostResponse } @@ -276,6 +285,9 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(modelID) { continue } + if !unfiltered && blacklistedModels.IsBlocked(modelID) { + continue + } // Skip if already added (shouldn't happen, but safety check) fullModelID := string(schemas.Vertex) + "/" + modelID diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go index be34e2d201..441bd13c50 100644 --- a/core/providers/vertex/utils.go +++ b/core/providers/vertex/utils.go @@ -185,11 +185,15 @@ func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID // buildResponseFromConfig builds a list models response from configured deployments and allowedModels. // This is used when the user has explicitly configured which models they want to use. -func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList) *schemas.BifrostListModelsResponse { +func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList) *schemas.BifrostListModelsResponse { response := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } + if blacklistedModels.IsBlockAll() { + return response + } + addedModelIDs := make(map[string]bool) restrictAllowed := allowedModels.IsRestricted() @@ -199,6 +203,9 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels schema if restrictAllowed && !allowedModels.Contains(alias) { continue } + if blacklistedModels.IsBlocked(alias) { + continue + } modelID := string(schemas.Vertex) + "/" + alias if addedModelIDs[modelID] { continue @@ -224,6 +231,9 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels schema if addedModelIDs[modelID] { continue } + if blacklistedModels.IsBlocked(allowedModel) { + continue + } modelName := formatDeploymentName(allowedModel) modelEntry := schemas.Model{ diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index fee2e7f09d..67fa0a7c79 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -190,14 +190,14 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key deployments := key.VertexKeyConfig.Deployments allowedModels := key.Models - if !request.Unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 { + if !request.Unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || key.BlacklistedModels.IsBlockAll()) { return &schemas.BifrostListModelsResponse{Data: make([]schemas.Model, 0)}, nil } // If deployments or allowedModels are configured, return those directly without API call // Skip this fast path when Unfiltered is set so the full Vertex catalog can be retrieved if !request.Unfiltered && (len(deployments) > 0 || allowedModels.IsRestricted()) { - return buildResponseFromConfig(deployments, allowedModels), nil + return buildResponseFromConfig(deployments, allowedModels, key.BlacklistedModels), nil } // No deployments configured - fetch from Model Garden API @@ -326,7 +326,7 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key PublisherModels: allPublisherModels, } - response := aggregatedResponse.ToBifrostListModelsResponse(nil, request.Unfiltered) + response := aggregatedResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, request.Unfiltered) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequests diff --git a/core/schemas/account.go b/core/schemas/account.go index 2ef61c7988..1ee9fc25c7 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -73,6 +73,50 @@ func (wl WhiteList) Validate() error { return nil } +// BlackList is a list of values that are denied. +// Semantics: +// - "*" (alone) means all values are blocked. +// - Empty list means nothing is blocked. +// - Non-empty list (without "*") means only the listed values are blocked. +type BlackList []string + +func (bl BlackList) Contains(value string) bool { + return slices.ContainsFunc(bl, func(s string) bool { + return strings.EqualFold(s, value) + }) +} + +// IsBlocked reports whether value is blocked. +func (bl BlackList) IsBlocked(value string) bool { + return bl.IsBlockAll() || bl.Contains(value) +} + +// IsEmpty reports whether the blacklist has no entries (nothing is blocked). +func (bl BlackList) IsEmpty() bool { + return len(bl) == 0 +} + +// IsBlockAll reports whether the blacklist contains "*", meaning all values are blocked. +func (bl BlackList) IsBlockAll() bool { + return len(bl) == 1 && bl[0] == "*" +} + +// Validate checks that the blacklist is well-formed. +func (bl BlackList) Validate() error { + if bl.Contains("*") && len(bl) > 1 { + return fmt.Errorf("wildcard '*' cannot be used with other values in the blacklist") + } + seen := make(map[string]struct{}, len(bl)) + for _, v := range bl { + normalized := strings.ToLower(v) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate value '%s' in blacklist", v) + } + seen[normalized] = struct{}{} + } + return nil +} + // Key represents an API key and its associated configuration for a provider. // It contains the key value, supported models, and a weight for load balancing. type Key struct { @@ -80,7 +124,7 @@ type Key struct { Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost) Value EnvVar `json:"value"` // The actual API key value Models WhiteList `json:"models"` // List of models this key can access - BlacklistedModels WhiteList `json:"blacklisted_models"` // List of models this key cannot access + BlacklistedModels BlackList `json:"blacklisted_models"` // List of models this key cannot access Weight float64 `json:"weight"` // Weight for load balancing between multiple keys AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration diff --git a/framework/configstore/tables/key.go b/framework/configstore/tables/key.go index ce9206b420..6637ea1ee0 100644 --- a/framework/configstore/tables/key.go +++ b/framework/configstore/tables/key.go @@ -74,7 +74,7 @@ type TableKey struct { // Virtual fields for runtime use (not stored in DB) Models schemas.WhiteList `gorm:"-" json:"models"` // ["*"] allows all models; empty denies all (deny-by-default) - BlacklistedModels schemas.WhiteList `gorm:"-" json:"blacklisted_models"` + BlacklistedModels schemas.BlackList `gorm:"-" json:"blacklisted_models"` AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"` VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"` BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"` @@ -91,20 +91,22 @@ func (TableKey) TableName() string { return "config_keys" } // batch S3 config) before writing to the database. Encryption runs last to ensure it // operates on the final serialized values. func (k *TableKey) BeforeSave(tx *gorm.DB) error { + if err := k.Models.Validate(); err != nil { + return err + } data, err := json.Marshal(k.Models) if err != nil { return err } k.ModelsJSON = string(data) - if k.BlacklistedModels != nil { - data, err := json.Marshal(k.BlacklistedModels) - if err != nil { - return err - } - k.BlacklistedModelsJSON = string(data) - } else { - k.BlacklistedModelsJSON = "[]" + if err := k.BlacklistedModels.Validate(); err != nil { + return err } + data, err = json.Marshal(k.BlacklistedModels) + if err != nil { + return err + } + k.BlacklistedModelsJSON = string(data) if k.Enabled == nil { enabled := true // DB default k.Enabled = &enabled @@ -496,8 +498,6 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { if err := json.Unmarshal([]byte(k.BlacklistedModelsJSON), &k.BlacklistedModels); err != nil { return err } - } else { - k.BlacklistedModels = []string{} } if k.Enabled == nil { enabled := true // DB default diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 2c0e9c7bfe..f8f890b5fe 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -647,7 +647,7 @@ func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvide } // UpsertModelDataForProvider upserts model data for a given provider -func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model, deniedModels []schemas.Model) { +func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) { if modelData == nil { return } @@ -676,7 +676,7 @@ func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvide } } // If modelData is empty, then we allow all models - if len(modelData.Data) == 0 && len(allowedModels) == 0 && len(deniedModels) == 0 { + if len(modelData.Data) == 0 && len(allowedModels) == 0 { mc.modelPool[provider] = providerModels return } @@ -709,15 +709,7 @@ func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvide } if len(allowedModels) == 0 { - deniedSet := make(map[string]struct{}, len(deniedModels)) - for _, d := range deniedModels { - _, modelName := schemas.ParseModelString(d.ID, "") - deniedSet[modelName] = struct{}{} - } for _, model := range providerModels { - if _, denied := deniedSet[model]; denied { - continue - } if !seenModels[model] { seenModels[model] = true finalModelList = append(finalModelList, model) diff --git a/plugins/governance/http_transport_prehook_test.go b/plugins/governance/http_transport_prehook_test.go index c79ca20b15..f50511d740 100644 --- a/plugins/governance/http_transport_prehook_test.go +++ b/plugins/governance/http_transport_prehook_test.go @@ -23,7 +23,7 @@ func TestHTTPTransportPreHook_VirtualKeyReplicateRefinesNestedModel(t *testing.T Data: []schemas.Model{ {ID: "replicate/openai/gpt-5-nano"}, }, - }, nil, nil) + }, nil) virtualKey := buildVirtualKeyWithProviders( "vk1", diff --git a/transports/bifrost-http/handlers/provider_keys.go b/transports/bifrost-http/handlers/provider_keys.go index 70a86ebfcc..86fb51fd3d 100644 --- a/transports/bifrost-http/handlers/provider_keys.go +++ b/transports/bifrost-http/handlers/provider_keys.go @@ -98,6 +98,11 @@ func (h *ProviderHandler) createProviderKey(ctx *fasthttp.RequestCtx) { return } + if err := key.BlacklistedModels.Validate(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err)) + return + } + if key.ID == "" { key.ID = uuid.NewString() } @@ -189,6 +194,11 @@ func (h *ProviderHandler) updateProviderKey(ctx *fasthttp.RequestCtx) { updateKey.ID = keyID mergedKey := h.mergeUpdatedKey(*oldRawKey, *oldRedactedKey, updateKey) + if err := mergedKey.BlacklistedModels.Validate(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err)) + return + } + if err := h.inMemoryStore.UpdateProviderKey(ctx, provider, keyID, mergedKey); err != nil { logger.Warn("Failed to update key %s for provider %s: %v", keyID, provider, err) if errors.Is(err, lib.ErrNotFound) { diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 45f4ea8e8e..9c83a424b6 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -712,7 +712,11 @@ func keyAllowsModelForList(key schemas.Key, model string) bool { return true } -// filterModelsByKeys filters models based on key-level model restrictions +// filterModelsByKeys filters models based on key-level model restrictions. +// A model is included in the result if at least one of the specified keys grants access to it. +// A key grants access to a model when: the model is not blacklisted by that key AND +// the key's allowlist is unrestricted (wildcard) or explicitly contains the model. +// Empty allowlist (no wildcard) = deny-all for that key. func (h *ProviderHandler) filterModelsByKeys(provider schemas.ModelProvider, models []string, keyIDs []string) []string { // Get provider config to access keys config, err := h.inMemoryStore.GetProviderConfigRaw(provider) @@ -720,50 +724,40 @@ func (h *ProviderHandler) filterModelsByKeys(provider schemas.ModelProvider, mod logger.Warn("Failed to get config for provider %s: %v", provider, err) return models } - // Build a set of allowed models from the specified keys - // Track whether we have any unrestricted keys (which grant access to all models) - // and whether we have any restricted keys (which limit to specific models) - allowedModels := make(map[string]bool) - hasRestrictedKey := false - hasUnrestrictedKey := false - hasDenyAllKey := false + + // Index keys by ID for fast lookup + keyMap := make(map[string]schemas.Key, len(config.Keys)) + for _, key := range config.Keys { + keyMap[key.ID] = key + } + + // Collect only the keys referenced by keyIDs + matchedKeys := make([]schemas.Key, 0, len(keyIDs)) for _, keyID := range keyIDs { - for _, key := range config.Keys { - if key.ID == keyID { - if key.Models.IsUnrestricted() { - // Key allows all models (wildcard) - hasUnrestrictedKey = true - } else if !key.Models.IsEmpty() { - // Key has specific model restrictions - add them to allowedModels - hasRestrictedKey = true - for _, model := range key.Models { - allowedModels[model] = true - } - } else { - // Empty Models = explicit deny-all for this key - hasDenyAllKey = true - } - break - } + if key, ok := keyMap[keyID]; ok { + matchedKeys = append(matchedKeys, key) } } - // If any key is unrestricted, return all models (union of "all" and restricted subsets is "all") - if hasUnrestrictedKey { - return models - } - // If no keys were matched or restricted, but at least one key explicitly denies all, return nothing - if !hasRestrictedKey && hasDenyAllKey { - return []string{} - } - // If no keys have model restrictions (e.g., unknown key IDs), return all models - if !hasRestrictedKey { + + // If none of the requested key IDs exist in config, fall back to returning all models + if len(matchedKeys) == 0 { return models } - // Filter models based on restrictions from restricted keys only - filtered := []string{} + + // For each model, include it if at least one matched key grants access + filtered := make([]string, 0, len(models)) for _, model := range models { - if allowedModels[model] { - filtered = append(filtered, model) + for _, key := range matchedKeys { + // Blacklist wins over allowlist + if key.BlacklistedModels.IsBlocked(model) { + continue + } + // Unrestricted (wildcard) key grants access to all non-blacklisted models; + // restricted key grants access only if the model is explicitly listed + if key.Models.IsUnrestricted() || key.Models.Contains(model) { + filtered = append(filtered, model) + break + } } } return filtered diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 682ca97aad..0e8a31669f 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -574,7 +574,7 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas }) } } - s.Config.ModelCatalog.UpsertModelDataForProvider(provider, allModels, modelsInKeys, nil) + s.Config.ModelCatalog.UpsertModelDataForProvider(provider, allModels, modelsInKeys) if listModelsErr != nil { if hasNoKeys { logger.Warn("unfiltered model discovery skipped for provider %s: no keys configured", provider) @@ -807,7 +807,7 @@ func (s *BifrostHTTPServer) ForceReloadPricing(ctx context.Context) error { }) } } - s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels, nil) + s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels) unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ Provider: provider, Unfiltered: true, @@ -1322,7 +1322,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { }) } } - s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels, nil) + s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels) unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ Provider: provider, Unfiltered: true, diff --git a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx index a6144f0f7c..c0f0697960 100644 --- a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx @@ -245,7 +245,7 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { render={({ field }) => (
- Blacklisted models + Blocked Models @@ -255,15 +255,39 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) {

- Comma-separated list of models this key must never use. If a model appears in both Allowed Models and here, the blacklist wins. Leave - empty if none. + Models this key must never serve. The denylist always wins — if a model appears in both Allowed Models and here, it is blocked. + Select "All Models" to block every model on this key.

- + { + const hadStar = (field.value || []).includes("*"); + const hasStar = models.includes("*"); + if (!hadStar && hasStar) { + field.onChange(["*"]); + } else if (hadStar && hasStar && models.length > 1) { + field.onChange(models.filter((m: string) => m !== "*")); + } else { + field.onChange(models); + } + }} + placeholder={ + (field.value || []).includes("*") + ? "All models blocked" + : (field.value || []).length === 0 + ? "No models blocked" + : "Search models..." + } + unfiltered={true} + />
diff --git a/ui/components/ui/modelMultiselect.tsx b/ui/components/ui/modelMultiselect.tsx index 53b6b64924..3c73c7af72 100644 --- a/ui/components/ui/modelMultiselect.tsx +++ b/ui/components/ui/modelMultiselect.tsx @@ -53,7 +53,7 @@ interface ModelOption { provider?: string; } -const ALL_MODELS_OPTION: ModelOption = { label: "Allow All Models", value: "*" }; +const ALL_MODELS_OPTION: ModelOption = { label: "All Models", value: "*" }; export function ModelMultiselect(props: ModelMultiselectProps) { const { @@ -114,7 +114,7 @@ export function ModelMultiselect(props: ModelMultiselectProps) { const loadOptions = useCallback( (query: string, callback: (options: ModelOption[]) => void) => { // Prepend "Allow All Models" when allowAllOption is enabled and query matches (or is empty) - const prefix: ModelOption[] = allowAllOption && (!query || "allow all models".includes(query.toLowerCase())) + const prefix: ModelOption[] = allowAllOption && (!query || "all models".includes(query.toLowerCase())) ? [ALL_MODELS_OPTION] : [];