diff --git a/core/bifrost.go b/core/bifrost.go index 09886fb78e..fd0e3e26c8 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -6096,10 +6096,14 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p // Model filtering logic: // - If model is nil or empty → include all keys (no model filter) // - If model is specified: - // - If key.Models is empty → include key (supports all models) + // - If model is in key.BlacklistedModels → exclude (wins over Models allow list) + // - If key.Models is empty → include key (supports all non-blacklisted models) // - If key.Models is non-empty → only include if model is in list - if model != nil && *model != "" && len(k.Models) > 0 { - if !slices.Contains(k.Models, *model) { + if model != nil && *model != "" { + if len(k.BlacklistedModels) > 0 && slices.Contains(k.BlacklistedModels, *model) { + continue + } + if len(k.Models) > 0 && !slices.Contains(k.Models, *model) { continue } } @@ -6167,7 +6171,8 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex keys = batchEnabledKeys } - // filter out keys which don't support the model, if the key has no models, it is supported for all models + // Filter out keys that don't support the model: blacklisted_models wins over models allow list; + // if the key has no models list, it supports all models except those blacklisted. var supportedKeys []schemas.Key // Skip model check conditions @@ -6192,7 +6197,12 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex continue } hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) - modelSupported := (len(key.Models) == 0 && hasValue) || (slices.Contains(key.Models, model) && hasValue) + var modelSupported bool + if len(key.BlacklistedModels) > 0 && slices.Contains(key.BlacklistedModels, model) { + modelSupported = false + } else { + modelSupported = (len(key.Models) == 0 && hasValue) || (slices.Contains(key.Models, model) && hasValue) + } // Additional deployment checks for Azure, Bedrock and Vertex deploymentSupported := true if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { diff --git a/core/bifrost_test.go b/core/bifrost_test.go index 0db9356350..806fd4d844 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -847,6 +847,62 @@ func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T } } +func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { + account := NewMockAccount() + account.AddProvider(schemas.OpenAI, 5, 1000) + + ctx := context.Background() + bifrost, err := Init(ctx, schemas.BifrostConfig{ + Account: account, + Logger: NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Init failed: %v", err) + } + bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + + t.Run("all keys blacklist model", func(t *testing.T) { + account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ + {ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}}, + }) + _, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + if err == nil { + t.Fatal("expected error when model is only blacklisted") + } + if !strings.Contains(err.Error(), "no keys found that support model") { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("blacklist wins over models allow list", func(t *testing.T) { + account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ + { + ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, + Models: []string{"gpt-4"}, + BlacklistedModels: []string{"gpt-4"}, + }, + }) + _, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + if err == nil { + t.Fatal("expected error when model is both allowed and blacklisted") + } + }) + + t.Run("second key used when first blacklists", func(t *testing.T) { + account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ + {ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}}, + {ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1}, + }) + key, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if key.ID != "k2" { + t.Fatalf("expected k2, got %s", key.ID) + } + }) +} + // Test UpdateProvider functionality func TestUpdateProvider(t *testing.T) { t.Run("SuccessfulUpdate", func(t *testing.T) { diff --git a/core/changelog.md b/core/changelog.md index e69de29bb2..9522a8dea0 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -0,0 +1 @@ +- feat: added `blacklisted_models` on provider keys to exclude models from routing and filtered list-models diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index 4afba1c96d..0857214d7c 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -298,7 +298,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, } // Create final response - response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, request.Unfiltered) + response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled diff --git a/core/providers/anthropic/models.go b/core/providers/anthropic/models.go index 2aa26fa5c1..d56f1a80ac 100644 --- a/core/providers/anthropic/models.go +++ b/core/providers/anthropic/models.go @@ -3,10 +3,11 @@ package anthropic import ( "time" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -40,6 +41,9 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide continue } } + if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, modelID) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + modelID, Name: schemas.Ptr(model.DisplayName), @@ -48,9 +52,12 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide includedModels[modelID] = true } - // Backfill allowed models that were not in the response + // Backfill allowed models that were not in the response (skip blacklisted; blacklist wins over allow list) if !unfiltered && len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index f25fee81f4..9f1928e006 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -366,7 +366,7 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert to Bifrost response - response := azureResponse.ToBifrostListModelsResponse(key.Models, key.AzureKeyConfig.Deployments, request.Unfiltered) + response := azureResponse.ToBifrostListModelsResponse(key.Models, key.AzureKeyConfig.Deployments, key.BlacklistedModels, request.Unfiltered) if response == nil { return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure) } diff --git a/core/providers/azure/models.go b/core/providers/azure/models.go index c4ac069625..d5ff81229a 100644 --- a/core/providers/azure/models.go +++ b/core/providers/azure/models.go @@ -3,6 +3,7 @@ package azure import ( "slices" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -58,7 +59,7 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym return "", "" } -func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -116,6 +117,10 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode modelID = matchedAllowedModel } + if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ID, modelID, deploymentAlias, matchedAllowedModel) { + continue + } + modelEntry := schemas.Model{ ID: string(schemas.Azure) + "/" + modelID, Created: schemas.Ptr(model.CreatedAt), @@ -142,6 +147,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { continue } + if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Azure) + "/" + alias, Name: schemas.Ptr(alias), @@ -154,6 +162,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode // Backfill allowed models that were not in the response if !unfiltered && len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Azure) + "/" + allowedModel, diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 8d2051ba9d..bae74820ea 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -81,7 +81,7 @@ func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) ( MaxIdleConns: schemas.DefaultMaxIdleConnsPerHost, MaxIdleConnsPerHost: schemas.DefaultMaxIdleConnsPerHost, IdleConnTimeout: 30 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, ResponseHeaderTimeout: requestTimeout, ExpectContinueTimeout: 1 * time.Second, ForceAttemptHTTP2: config.NetworkConfig.EnforceHTTP2, @@ -732,7 +732,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Convert to Bifrost response - response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, config.Deployments, request.Unfiltered) + response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, config.Deployments, key.BlacklistedModels, request.Unfiltered) if response == nil { return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil, providerName) } diff --git a/core/providers/bedrock/models.go b/core/providers/bedrock/models.go index a3bebb5890..e4e96a8017 100644 --- a/core/providers/bedrock/models.go +++ b/core/providers/bedrock/models.go @@ -4,6 +4,7 @@ import ( "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -221,7 +222,7 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym return "", "" } -func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -284,6 +285,10 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK modelID = matchedAllowedModel } + if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ModelID, modelID, deploymentAlias, matchedAllowedModel) { + continue + } + modelEntry := schemas.Model{ ID: string(providerKey) + "/" + modelID, Name: schemas.Ptr(model.ModelName), @@ -315,6 +320,9 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { continue } + if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + alias, Name: schemas.Ptr(alias), @@ -327,6 +335,9 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK // Backfill allowed models that were not in the response if !unfiltered && len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index d9636eee40..8d562364ba 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -288,7 +288,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert Cohere v2 response to Bifrost response - response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered) + response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/cohere/models.go b/core/providers/cohere/models.go index 83ffd8752f..3df2aab89a 100644 --- a/core/providers/cohere/models.go +++ b/core/providers/cohere/models.go @@ -25,8 +25,8 @@ func (r *CohereRerankRequest) GetExtraParams() map[string]interface{} { // CohereRerankResult represents a single result from Cohere rerank. type CohereRerankResult struct { - Index int `json:"index"` - RelevanceScore float64 `json:"relevance_score"` + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` Document json.RawMessage `json:"document,omitempty"` } @@ -44,7 +44,7 @@ type CohereRerankMeta struct { Tokens *CohereTokenUsage `json:"tokens,omitempty"` } -func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -58,6 +58,9 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.Name) { continue } + if !unfiltered && slices.Contains(blacklistedModels, model.Name) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + model.Name, Name: schemas.Ptr(model.Name), @@ -70,6 +73,9 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe // Backfill allowed models that were not in the response if !unfiltered && len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if slices.Contains(blacklistedModels, allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index bddf520be3..6ca0172d6c 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -115,7 +115,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, return nil, bifrostErr } - response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered) + response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) diff --git a/core/providers/elevenlabs/models.go b/core/providers/elevenlabs/models.go index a00b81847e..c211e85196 100644 --- a/core/providers/elevenlabs/models.go +++ b/core/providers/elevenlabs/models.go @@ -6,7 +6,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -20,6 +20,9 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { continue } + if !unfiltered && slices.Contains(blacklistedModels, model.ModelID) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + model.ModelID, Name: schemas.Ptr(model.Name), @@ -30,6 +33,9 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid // Backfill allowed models that were not in the response if !unfiltered && len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if slices.Contains(blacklistedModels, allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index 57146bd737..7a4d299bd8 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -227,7 +227,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key } } - response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered) + response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/gemini/models.go b/core/providers/gemini/models.go index ec36f9406b..4c8f83c364 100644 --- a/core/providers/gemini/models.go +++ b/core/providers/gemini/models.go @@ -17,7 +17,7 @@ func toGeminiModelResourceName(modelID string) string { return "models/" + modelID } -func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -35,6 +35,9 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelName) { continue } + if !unfiltered && slices.Contains(blacklistedModels, modelName) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + modelName, Name: schemas.Ptr(model.DisplayName), @@ -50,6 +53,9 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe // Backfill allowed models that were not in the response if !unfiltered && len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if slices.Contains(blacklistedModels, allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/huggingface/huggingface.go b/core/providers/huggingface/huggingface.go index d1b7f145f3..62f17c8931 100644 --- a/core/providers/huggingface/huggingface.go +++ b/core/providers/huggingface/huggingface.go @@ -384,7 +384,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext } if result.response != nil { - providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, request.Unfiltered) + providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, request.Unfiltered) if providerResponse != nil { aggregatedResponse.Data = append(aggregatedResponse.Data, providerResponse.Data...) totalLatency += result.latency diff --git a/core/providers/huggingface/models.go b/core/providers/huggingface/models.go index e306c4dd2f..c637c3b6f6 100644 --- a/core/providers/huggingface/models.go +++ b/core/providers/huggingface/models.go @@ -13,7 +13,7 @@ const ( maxModelFetchLimit = 1000 ) -func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -22,6 +22,14 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi Data: make([]schemas.Model, 0, len(response.Models)), } + var blacklisted map[string]struct{} + if !unfiltered && len(blacklistedModels) > 0 { + blacklisted = make(map[string]struct{}, len(blacklistedModels)) + for _, m := range blacklistedModels { + blacklisted[m] = struct{}{} + } + } + includedModels := make(map[string]bool) for _, model := range response.Models { if model.ModelID == "" { @@ -36,6 +44,9 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { continue } + if _, ok := blacklisted[model.ModelID]; ok { + continue + } newModel := schemas.Model{ ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, model.ModelID), @@ -51,6 +62,9 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi // Backfill allowed models that were not in the response if !unfiltered && len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if _, ok := blacklisted[allowedModel]; ok { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, allowedModel), diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index dfc5d95e11..8ae93dd6a0 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -116,7 +116,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Create final response - response := mistralResponse.ToBifrostListModelsResponse(key.Models) + response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels) response.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/mistral/models.go b/core/providers/mistral/models.go index 87bb59625d..ef3e5934c1 100644 --- a/core/providers/mistral/models.go +++ b/core/providers/mistral/models.go @@ -6,7 +6,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels []string) *schemas.BifrostListModelsResponse { +func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, blacklistedModels []string) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -20,6 +20,9 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo if len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { continue } + if slices.Contains(blacklistedModels, model.ID) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Mistral) + "/" + model.ID, Name: schemas.Ptr(model.Name), @@ -34,6 +37,9 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo // Backfill allowed models that were not in the response if len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if slices.Contains(blacklistedModels, allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Mistral) + "/" + allowedModel, diff --git a/core/providers/openai/models.go b/core/providers/openai/models.go index 6cb2b57f4b..d00a8af112 100644 --- a/core/providers/openai/models.go +++ b/core/providers/openai/models.go @@ -7,7 +7,7 @@ import ( ) // ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response -func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -21,6 +21,9 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { continue } + if !unfiltered && slices.Contains(blacklistedModels, model.ID) { + continue + } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + model.ID, Created: model.Created, @@ -33,6 +36,9 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe // Backfill allowed models that were not in the response if !unfiltered && len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if slices.Contains(blacklistedModels, allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index efa2ecd1e5..6d309f3034 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -181,7 +181,7 @@ func ListModelsByKey( return nil, bifrostErr } - response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, unfiltered) + response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, unfiltered) response.ExtraFields.Provider = providerName response.ExtraFields.RequestType = schemas.ListModelsRequest diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index bc63fa2fc5..55ec63daad 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -187,6 +187,7 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, // Filter by key.Models allowedModels := key.Models + blacklistedModels := key.BlacklistedModels providerPrefix := string(schemas.OpenRouter) + "/" if !request.Unfiltered && len(allowedModels) > 0 { @@ -197,6 +198,9 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, if !(slices.Contains(allowedModels, rawID) || slices.Contains(allowedModels, providerPrefix+rawID)) { continue } + if slices.Contains(blacklistedModels, rawID) || slices.Contains(blacklistedModels, providerPrefix+rawID) { + continue + } openrouterResponse.Data[i].ID = providerPrefix + rawID filteredData = append(filteredData, openrouterResponse.Data[i]) includedModels[rawID] = true @@ -204,6 +208,9 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, // Backfill allowed models not in the API response for _, allowedModel := range allowedModels { rawID := strings.TrimPrefix(allowedModel, providerPrefix) + if slices.Contains(blacklistedModels, rawID) || slices.Contains(blacklistedModels, providerPrefix+rawID) { + continue + } if !includedModels[rawID] { filteredData = append(filteredData, schemas.Model{ ID: providerPrefix + rawID, diff --git a/core/providers/replicate/models.go b/core/providers/replicate/models.go index 1cb2016f82..3989628db1 100644 --- a/core/providers/replicate/models.go +++ b/core/providers/replicate/models.go @@ -12,6 +12,7 @@ func ToBifrostListModelsResponse( deploymentsResponse *ReplicateDeploymentListResponse, providerKey schemas.ModelProvider, allowedModels []string, + blacklistedModels []string, unfiltered bool, ) *schemas.BifrostListModelsResponse { bifrostResponse := &schemas.BifrostListModelsResponse{ @@ -30,6 +31,9 @@ func ToBifrostListModelsResponse( if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, deploymentID) { continue } + if !unfiltered && slices.Contains(blacklistedModels, deploymentID) { + continue + } // Extract information from current release if available if deployment.CurrentRelease != nil { @@ -62,6 +66,9 @@ func ToBifrostListModelsResponse( // Backfill allowed models that were not in the response if !unfiltered && len(allowedModels) > 0 { for _, allowedModel := range allowedModels { + if slices.Contains(blacklistedModels, allowedModel) { + continue + } if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, diff --git a/core/providers/replicate/replicate.go b/core/providers/replicate/replicate.go index 81cd67931d..e1ed9131e6 100644 --- a/core/providers/replicate/replicate.go +++ b/core/providers/replicate/replicate.go @@ -361,6 +361,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont deploymentsResponse, providerName, key.Models, + key.BlacklistedModels, request.Unfiltered, ) diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index e48693e97d..a7556b08dd 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -2672,3 +2672,26 @@ func CheckAndSetDefaultProvider(ctx *schemas.BifrostContext, defaultProvider sch } return defaultProvider } + +// ModelMatchesDenylist reports whether any of the candidate model IDs matches +// an entry in denylist, using both exact and base-model (SameBaseModel) matching. +// Empty candidates are skipped. Returns false immediately if denylist is empty. +func ModelMatchesDenylist(denylist []string, candidates ...string) bool { + if len(denylist) == 0 { + return false + } + for _, c := range candidates { + if c == "" { + continue + } + if slices.Contains(denylist, c) { + return true + } + for _, d := range denylist { + if schemas.SameBaseModel(d, c) { + return true + } + } + } + return false +} diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index cdec949e67..28b5598022 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -114,7 +114,7 @@ func findDeploymentMatch(deployments map[string]string, customModelID string) (d // - If allowedModels is empty, all models are allowed // - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included // - Deployments map is used to match model IDs to aliases and filter accordingly -func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -176,6 +176,10 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod modelID := customModelID + if !unfiltered && (slices.Contains(blacklistedModels, customModelID) || slices.Contains(blacklistedModels, deploymentAlias)) { + continue + } + modelEntry := schemas.Model{ ID: string(schemas.Vertex) + "/" + modelID, Name: schemas.Ptr(model.DisplayName), @@ -204,6 +208,9 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { continue } + if slices.Contains(blacklistedModels, alias) { + continue + } modelName := formatDeploymentName(alias) modelEntry := schemas.Model{ @@ -225,6 +232,9 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod if addedModelIDs[modelID] { continue } + if slices.Contains(blacklistedModels, allowedModel) { + continue + } modelName := formatDeploymentName(allowedModel) modelEntry := schemas.Model{ @@ -244,7 +254,7 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod // ToBifrostListModelsResponse converts a Vertex AI publisher models response to Bifrost's format. // This is for foundation models from the Model Garden (publishers.models.list endpoint). -func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -267,6 +277,9 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelID) { continue } + if !unfiltered && slices.Contains(blacklistedModels, modelID) { + continue + } // Skip if already added (shouldn't happen, but safety check) fullModelID := string(schemas.Vertex) + "/" + modelID diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go index a08d9fbed8..b80743f07c 100644 --- a/core/providers/vertex/utils.go +++ b/core/providers/vertex/utils.go @@ -185,7 +185,7 @@ func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID // buildResponseFromConfig builds a list models response from configured deployments and allowedModels. // This is used when the user has explicitly configured which models they want to use. -func buildResponseFromConfig(deployments map[string]string, allowedModels []string) *schemas.BifrostListModelsResponse { +func buildResponseFromConfig(deployments map[string]string, allowedModels []string, blacklistedModels []string) *schemas.BifrostListModelsResponse { response := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } @@ -197,12 +197,19 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels []stri for _, m := range allowedModels { allowedSet[m] = true } + blacklistedSet := make(map[string]bool, len(blacklistedModels)) + for _, m := range blacklistedModels { + blacklistedSet[m] = true + } // First add models from deployments (filtered by allowedModels when set) for alias, deploymentValue := range deployments { if len(allowedSet) > 0 && !allowedSet[alias] { continue } + if len(blacklistedSet) > 0 && blacklistedSet[alias] { + continue + } modelID := string(schemas.Vertex) + "/" + alias if addedModelIDs[modelID] { continue @@ -221,6 +228,9 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels []stri // Then add models from allowedModels that aren't already in deployments for _, allowedModel := range allowedModels { + if len(blacklistedSet) > 0 && blacklistedSet[allowedModel] { + continue + } modelID := string(schemas.Vertex) + "/" + allowedModel if addedModelIDs[modelID] { continue diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 0a3bb38154..5787cadadc 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -193,7 +193,7 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key // If deployments or allowedModels are configured, return those directly without API call // Skip this fast path when Unfiltered is set so the full Vertex catalog can be retrieved if !request.Unfiltered && (len(deployments) > 0 || len(allowedModels) > 0) { - return buildResponseFromConfig(deployments, allowedModels), nil + return buildResponseFromConfig(deployments, allowedModels, key.BlacklistedModels), nil } // No deployments configured - fetch from Model Garden API @@ -322,7 +322,7 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key PublisherModels: allPublisherModels, } - response := aggregatedResponse.ToBifrostListModelsResponse(nil, request.Unfiltered) + response := aggregatedResponse.ToBifrostListModelsResponse(nil, key.BlacklistedModels, request.Unfiltered) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequests diff --git a/core/schemas/account.go b/core/schemas/account.go index e89b55d8a8..ceaeb2de8a 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -17,6 +17,7 @@ type Key struct { Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost) Value EnvVar `json:"value"` // The actual API key value Models []string `json:"models"` // List of models this key can access + BlacklistedModels []string `json:"blacklisted_models"` // List of models this key cannot access Weight float64 `json:"weight"` // Weight for load balancing between multiple keys AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration diff --git a/docs/features/keys-management.mdx b/docs/features/keys-management.mdx index 73afaf7f55..f72f843aea 100644 --- a/docs/features/keys-management.mdx +++ b/docs/features/keys-management.mdx @@ -16,7 +16,7 @@ Bifrost follows a precise selection process for every request: 1. **Context Override Check**: First checks if a key is explicitly provided in context (bypassing management) 2. **Provider Key Lookup**: Retrieves all configured keys for the requested provider -3. **Model Filtering**: Filters keys that support the requested model +3. **Model Filtering**: Filters keys that support the requested model (respecting `models` allowlists and `blacklisted_models` denylists) 4. **Deployment Validation**: For Azure/Bedrock, validates deployment mappings 5. **Weighted Selection**: Uses weighted random selection among eligible keys @@ -161,8 +161,9 @@ Random selection ensures statistical distribution over time Keys can be restricted to specific models for access control and cost management: **Model Filtering Logic:** -- **Empty `models` array**: Key supports ALL models for that provider +- **Empty `models` array**: Key supports ALL models for that provider (except any listed under `blacklisted_models`) - **Populated `models` array**: Key only supports listed models +- **`blacklisted_models`**: Optional per-key denylist. If non-empty and the requested model appears in it, the key is excluded—even if that model is also in `models` (denylist wins over the allowlist) - **Model mismatch**: Key is excluded from selection for that request **Use Cases:** @@ -170,6 +171,7 @@ Keys can be restricted to specific models for access control and cost management - **Team Separation**: Different keys for different teams or projects - **Cost Control**: Restrict access to specific model tiers - **Compliance**: Separate keys for different security requirements +- **Denylist**: Block specific models on a key **Example Model Restrictions:** ```json @@ -186,6 +188,13 @@ Keys can be restricted to specific models for access control and cost management "value": "standard-key", "models": ["gpt-4o-mini", "gpt-3.5-turbo"], // Only standard models "weight": 1.0 + }, + { + "name": "openai-shared-key", + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "blacklisted_models": ["gpt-5"], + "weight": 1.0 } ] } @@ -237,7 +246,7 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeyAPIKeyName, "openai-key-1") ``` -Note: Both mechanisms reference a stored key (not the raw secret). The gateway resolves the key against configured provider keys and applies model filtering and deployment mapping. When an explicit key ID or name is supplied, weighted selection is bypassed and the referenced key is used directly. +Note: Both mechanisms reference a stored key (not the raw secret). The gateway resolves the key against configured provider keys and applies model allowlists, denylists, and deployment mapping. When an explicit key ID or name is supplied, weighted selection is bypassed and the referenced key is used directly. ```bash # Example: request referencing a stored key name that doesn't exist diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index ec436ea4ff..3b5be7672b 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -296,12 +296,17 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { if models == nil { models = []string{} // Ensure models is never nil in JSON response } + blacklistedModels := key.BlacklistedModels + if blacklistedModels == nil { + blacklistedModels = []string{} // Match models: empty JSON array, not null + } redactedConfig.Keys[i] = schemas.Key{ - ID: key.ID, - Name: key.Name, - Models: models, - Weight: key.Weight, - ConfigHash: key.ConfigHash, + ID: key.ID, + Name: key.Name, + Models: models, + BlacklistedModels: blacklistedModels, + Weight: key.Weight, + ConfigHash: key.ConfigHash, } if key.Enabled != nil { enabled := *key.Enabled @@ -508,6 +513,18 @@ func GenerateKeyHash(key schemas.Key) (string, error) { } hash.Write(data) } + // Hash BlacklistedModels (key-level deny list) + if len(key.BlacklistedModels) > 0 { + sortedBlacklistedModels := make([]string, len(key.BlacklistedModels)) + copy(sortedBlacklistedModels, key.BlacklistedModels) + sort.Strings(sortedBlacklistedModels) + data, err := sonic.Marshal(sortedBlacklistedModels) + if err != nil { + return "", err + } + hash.Write([]byte("blacklistedModels:")) + hash.Write(data) + } // Hash Weight data, err := sonic.Marshal(key.Weight) if err != nil { diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 86baa2018a..ed343c5d17 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -320,6 +320,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddOpenAIConfigJSONColumn(ctx, db); err != nil { return err } + if err := migrationAddKeyBlacklistedModelsJSONColumn(ctx, db); err != nil { + return err + } return nil } @@ -4801,3 +4804,38 @@ func migrationAddOpenAIConfigJSONColumn(ctx context.Context, db *gorm.DB) error } return nil } + +// migrationAddKeyBlacklistedModelsJSONColumn adds blacklisted_models_json to config_keys +// for per-key model deny lists (JSON array of model ids, default []). +func migrationAddKeyBlacklistedModelsJSONColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_key_blacklisted_models_json_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if !mg.HasColumn(&tables.TableKey{}, "blacklisted_models_json") { + if err := mg.AddColumn(&tables.TableKey{}, "blacklisted_models_json"); err != nil { + return fmt.Errorf("failed to add blacklisted_models_json column: %w", err) + } + } + if err := tx.Exec("UPDATE config_keys SET blacklisted_models_json = '[]' WHERE blacklisted_models_json IS NULL OR blacklisted_models_json = ''").Error; err != nil { + return fmt.Errorf("failed to backfill blacklisted_models_json: %w", err) + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if mg.HasColumn(&tables.TableKey{}, "blacklisted_models_json") { + if err := mg.DropColumn(&tables.TableKey{}, "blacklisted_models_json"); err != nil { + return fmt.Errorf("failed to drop blacklisted_models_json column: %w", err) + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running add_key_blacklisted_models_json_column migration: %s", err.Error()) + } + return nil +} diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 2ec66ecb43..7a59ee18da 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -288,6 +288,7 @@ func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers ma Name: key.Name, Value: key.Value, Models: key.Models, + BlacklistedModels: key.BlacklistedModels, Weight: &key.Weight, Enabled: key.Enabled, UseForBatchAPI: key.UseForBatchAPI, @@ -457,6 +458,7 @@ func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.Mo Name: key.Name, Value: key.Value, Models: key.Models, + BlacklistedModels: key.BlacklistedModels, Weight: &key.Weight, Enabled: key.Enabled, UseForBatchAPI: key.UseForBatchAPI, @@ -579,6 +581,7 @@ func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.Model Name: key.Name, Value: key.Value, Models: key.Models, + BlacklistedModels: key.BlacklistedModels, Weight: &key.Weight, Enabled: key.Enabled, UseForBatchAPI: key.UseForBatchAPI, @@ -700,6 +703,7 @@ func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.Mo Name: dbKey.Name, Value: dbKey.Value, Models: dbKey.Models, + BlacklistedModels: dbKey.BlacklistedModels, Weight: getWeight(dbKey.Weight), Enabled: dbKey.Enabled, UseForBatchAPI: dbKey.UseForBatchAPI, @@ -750,6 +754,7 @@ func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas Name: dbKey.Name, Value: dbKey.Value, Models: dbKey.Models, + BlacklistedModels: dbKey.BlacklistedModels, Weight: getWeight(dbKey.Weight), Enabled: dbKey.Enabled, UseForBatchAPI: dbKey.UseForBatchAPI, @@ -1721,12 +1726,12 @@ func (s *RDBConfigStore) GetKeysByProvider(ctx context.Context, provider string) func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ([]schemas.Key, error) { var keys []tables.TableKey if len(ids) > 0 { - err := s.db.WithContext(ctx).Select("id, key_id, name, models_json, weight").Where("key_id IN ?", ids).Find(&keys).Error + err := s.db.WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Where("key_id IN ?", ids).Find(&keys).Error if err != nil { return nil, err } } else { - err := s.db.WithContext(ctx).Select("id, key_id, name, models_json, weight").Find(&keys).Error + err := s.db.WithContext(ctx).Select("id, key_id, name, models_json, blacklisted_models_json, weight").Find(&keys).Error if err != nil { return nil, err } @@ -1737,11 +1742,16 @@ func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ( if models == nil { models = []string{} // Ensure models is never nil in JSON response } + blacklisted := key.BlacklistedModels + if blacklisted == nil { + blacklisted = []string{} + } redactedKeys[i] = schemas.Key{ - ID: key.KeyID, - Name: key.Name, - Models: models, - Weight: getWeight(key.Weight), + ID: key.KeyID, + Name: key.Name, + Models: models, + BlacklistedModels: blacklisted, + Weight: getWeight(key.Weight), } } return redactedKeys, nil diff --git a/framework/configstore/tables/key.go b/framework/configstore/tables/key.go index 4fb6659dc8..c0763e9045 100644 --- a/framework/configstore/tables/key.go +++ b/framework/configstore/tables/key.go @@ -13,17 +13,18 @@ import ( // TableKey represents an API key configuration in the database type TableKey struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - Name string `gorm:"type:varchar(255);uniqueIndex:idx_key_name;not null" json:"name"` - ProviderID uint `gorm:"index;not null" json:"provider_id"` - Provider string `gorm:"index;type:varchar(50)" json:"provider"` // ModelProvider as string - KeyID string `gorm:"type:varchar(255);uniqueIndex:idx_key_id;not null" json:"key_id"` // UUID from schemas.Key - Value schemas.EnvVar `gorm:"type:text;not null" json:"value"` - ModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - Weight *float64 `json:"weight"` - Enabled *bool `gorm:"default:true" json:"enabled,omitempty"` - CreatedAt time.Time `gorm:"index;not null" json:"created_at"` - UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + Name string `gorm:"type:varchar(255);uniqueIndex:idx_key_name;not null" json:"name"` + ProviderID uint `gorm:"index;not null" json:"provider_id"` + Provider string `gorm:"index;type:varchar(50)" json:"provider"` // ModelProvider as string + KeyID string `gorm:"type:varchar(255);uniqueIndex:idx_key_id;not null" json:"key_id"` // UUID from schemas.Key + Value schemas.EnvVar `gorm:"type:text;not null" json:"value"` + ModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + BlacklistedModelsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + Weight *float64 `json:"weight"` + Enabled *bool `gorm:"default:true" json:"enabled,omitempty"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` // Config hash is used to detect changes synced from config.json file ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"` @@ -73,6 +74,7 @@ type TableKey struct { // Virtual fields for runtime use (not stored in DB) Models []string `gorm:"-" json:"models"` + BlacklistedModels []string `gorm:"-" json:"blacklisted_models"` AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"` VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"` BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"` @@ -98,6 +100,15 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { } else { k.ModelsJSON = "[]" } + if k.BlacklistedModels != nil { + data, err := json.Marshal(k.BlacklistedModels) + if err != nil { + return err + } + k.BlacklistedModelsJSON = string(data) + } else { + k.BlacklistedModelsJSON = "[]" + } if k.Enabled == nil { enabled := true // DB default k.Enabled = &enabled @@ -487,6 +498,13 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { } else { k.Models = []string{} } + if k.BlacklistedModelsJSON != "" { + if err := json.Unmarshal([]byte(k.BlacklistedModelsJSON), &k.BlacklistedModels); err != nil { + return err + } + } else { + k.BlacklistedModels = []string{} + } if k.Enabled == nil { enabled := true // DB default k.Enabled = &enabled diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 9b65d1d11e..3531fe94a9 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -626,7 +626,7 @@ func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvide } // UpsertModelDataForProvider upserts model data for a given provider -func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) { +func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model, deniedModels []schemas.Model) { if modelData == nil { return } @@ -655,7 +655,7 @@ func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvide } } // If modelData is empty, then we allow all models - if len(modelData.Data) == 0 && len(allowedModels) == 0 { + if len(modelData.Data) == 0 && len(allowedModels) == 0 && len(deniedModels) == 0 { mc.modelPool[provider] = providerModels return } @@ -686,9 +686,17 @@ func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvide finalModelList = append(finalModelList, parsedModel) } } - // If there are no allowed models, we add all models from the provider models + if len(allowedModels) == 0 { + deniedSet := make(map[string]struct{}, len(deniedModels)) + for _, d := range deniedModels { + _, modelName := schemas.ParseModelString(d.ID, "") + deniedSet[modelName] = struct{}{} + } for _, model := range providerModels { + if _, denied := deniedSet[model]; denied { + continue + } if !seenModels[model] { seenModels[model] = true finalModelList = append(finalModelList, model) diff --git a/plugins/governance/http_transport_prehook_test.go b/plugins/governance/http_transport_prehook_test.go index f50511d740..c79ca20b15 100644 --- a/plugins/governance/http_transport_prehook_test.go +++ b/plugins/governance/http_transport_prehook_test.go @@ -23,7 +23,7 @@ func TestHTTPTransportPreHook_VirtualKeyReplicateRefinesNestedModel(t *testing.T Data: []schemas.Model{ {ID: "replicate/openai/gpt-5-nano"}, }, - }, nil) + }, nil, nil) virtualKey := buildVirtualKeyWithProviders( "vk1", diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 677da0ea11..fd37ef8000 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -489,7 +489,7 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { // Merge proxy config - preserve secrets if redacted values were sent back if payload.ProxyConfig != nil && oldConfigRaw.ProxyConfig != nil { if payload.ProxyConfig.IsRedactedValue(payload.ProxyConfig.Password) { - payload.ProxyConfig.Password = oldConfigRaw.ProxyConfig.Password + payload.ProxyConfig.Password = oldConfigRaw.ProxyConfig.Password } if payload.ProxyConfig.IsRedactedValue(payload.ProxyConfig.CACertPEM) { payload.ProxyConfig.CACertPEM = oldConfigRaw.ProxyConfig.CACertPEM @@ -770,6 +770,17 @@ func (h *ProviderHandler) getModelParameters(ctx *fasthttp.RequestCtx) { ctx.SetBodyString(params.Data) } +// keyAllowsModelForList reports whether a provider key permits model for catalog listing. +func keyAllowsModelForList(key schemas.Key, model string) bool { + if len(key.BlacklistedModels) > 0 && slices.Contains(key.BlacklistedModels, model) { + return false + } + if len(key.Models) > 0 { + return slices.Contains(key.Models, model) + } + return true +} + // filterModelsByKeys filters models based on key-level model restrictions func (h *ProviderHandler) filterModelsByKeys(provider schemas.ModelProvider, models []string, keyIDs []string) []string { // Get provider config to access keys @@ -778,42 +789,28 @@ func (h *ProviderHandler) filterModelsByKeys(provider schemas.ModelProvider, mod logger.Warn("Failed to get config for provider %s: %v", provider, err) return models } - // Build a set of allowed models from the specified keys - // Track whether we have any unrestricted keys (which grant access to all models) - // and whether we have any restricted keys (which limit to specific models) - allowedModels := make(map[string]bool) - hasRestrictedKey := false - hasUnrestrictedKey := false + keysByID := make(map[string]schemas.Key, len(config.Keys)) + for i := range config.Keys { + k := config.Keys[i] + keysByID[k.ID] = k + } + matchedKeys := make([]schemas.Key, 0, len(keyIDs)) for _, keyID := range keyIDs { - for _, key := range config.Keys { - if key.ID == keyID { - if len(key.Models) > 0 { - // Key has model restrictions - add them to allowedModels - hasRestrictedKey = true - for _, model := range key.Models { - allowedModels[model] = true - } - } else { - // Key has no model restrictions - grants access to all models - hasUnrestrictedKey = true - } - break - } + if key, ok := keysByID[keyID]; ok { + matchedKeys = append(matchedKeys, key) } } - // If any key is unrestricted, return all models (union of "all" and restricted subsets is "all") - if hasUnrestrictedKey { - return models - } - // If no keys have model restrictions (e.g., unknown key IDs), return all models - if !hasRestrictedKey { + // Unknown key IDs (or empty keyIDs): do not filter + if len(matchedKeys) == 0 { return models } - // Filter models based on restrictions from restricted keys only - filtered := []string{} + filtered := make([]string, 0, len(models)) for _, model := range models { - if allowedModels[model] { - filtered = append(filtered, model) + for _, key := range matchedKeys { + if keyAllowsModelForList(key, model) { + filtered = append(filtered, model) + break + } } } return filtered diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 500e76dbcd..21be1180ab 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -187,17 +187,20 @@ func (cd *ConfigData) UnmarshalJSON(data []byte) error { if tableKey.Value.GetValue() != "" { // Full key definition - add to provider keysToAddToProvider = append(keysToAddToProvider, schemas.Key{ - ID: tableKey.KeyID, - Name: tableKey.Name, - Value: tableKey.Value, - Models: tableKey.Models, - Weight: getWeight(tableKey.Weight), - Enabled: tableKey.Enabled, - UseForBatchAPI: tableKey.UseForBatchAPI, - AzureKeyConfig: tableKey.AzureKeyConfig, - VertexKeyConfig: tableKey.VertexKeyConfig, - BedrockKeyConfig: tableKey.BedrockKeyConfig, - ConfigHash: tableKey.ConfigHash, + ID: tableKey.KeyID, + Name: tableKey.Name, + Value: tableKey.Value, + Models: tableKey.Models, + BlacklistedModels: tableKey.BlacklistedModels, + Weight: getWeight(tableKey.Weight), + Enabled: tableKey.Enabled, + UseForBatchAPI: tableKey.UseForBatchAPI, + AzureKeyConfig: tableKey.AzureKeyConfig, + VertexKeyConfig: tableKey.VertexKeyConfig, + BedrockKeyConfig: tableKey.BedrockKeyConfig, + ReplicateKeyConfig: tableKey.ReplicateKeyConfig, + VLLMKeyConfig: tableKey.VLLMKeyConfig, + ConfigHash: tableKey.ConfigHash, }) } // Reference lookups (no Value) are NOT added to provider - they already exist there @@ -825,13 +828,18 @@ func mergeProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []schema } else { // No stored hash (legacy) - fall back to generating fresh hash dbKeyHash, err := configstore.GenerateKeyHash(schemas.Key{ - Name: dbKey.Name, - Value: dbKey.Value, - Models: dbKey.Models, - Weight: dbKey.Weight, - AzureKeyConfig: dbKey.AzureKeyConfig, - VertexKeyConfig: dbKey.VertexKeyConfig, - BedrockKeyConfig: dbKey.BedrockKeyConfig, + Name: dbKey.Name, + Value: dbKey.Value, + Models: dbKey.Models, + BlacklistedModels: dbKey.BlacklistedModels, + Weight: dbKey.Weight, + AzureKeyConfig: dbKey.AzureKeyConfig, + VertexKeyConfig: dbKey.VertexKeyConfig, + BedrockKeyConfig: dbKey.BedrockKeyConfig, + ReplicateKeyConfig: dbKey.ReplicateKeyConfig, + VLLMKeyConfig: dbKey.VLLMKeyConfig, + Enabled: dbKey.Enabled, + UseForBatchAPI: dbKey.UseForBatchAPI, }) if err != nil { logger.Warn("failed to generate key hash for db key %s (%s): %v, falling back to name comparison", dbKey.Name, provider, err) @@ -898,13 +906,18 @@ func reconcileProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []sc } else { // No stored hash (legacy) - fall back to generating fresh hash for comparison dbKeyHash, err := configstore.GenerateKeyHash(schemas.Key{ - Name: dbKey.Name, - Value: dbKey.Value, - Models: dbKey.Models, - Weight: dbKey.Weight, - AzureKeyConfig: dbKey.AzureKeyConfig, - VertexKeyConfig: dbKey.VertexKeyConfig, - BedrockKeyConfig: dbKey.BedrockKeyConfig, + Name: dbKey.Name, + Value: dbKey.Value, + Models: dbKey.Models, + BlacklistedModels: dbKey.BlacklistedModels, + Weight: dbKey.Weight, + AzureKeyConfig: dbKey.AzureKeyConfig, + VertexKeyConfig: dbKey.VertexKeyConfig, + BedrockKeyConfig: dbKey.BedrockKeyConfig, + ReplicateKeyConfig: dbKey.ReplicateKeyConfig, + VLLMKeyConfig: dbKey.VLLMKeyConfig, + Enabled: dbKey.Enabled, + UseForBatchAPI: dbKey.UseForBatchAPI, }) if err != nil { logger.Warn("failed to generate key hash for db key %s (%s): %v", dbKey.Name, provider, err) @@ -3009,14 +3022,19 @@ func (c *Config) GetAllKeys() ([]configstoreTables.TableKey, error) { if models == nil { models = []string{} } + blacklisted := key.BlacklistedModels + if blacklisted == nil { + blacklisted = []string{} + } keys = append(keys, configstoreTables.TableKey{ - KeyID: key.ID, - Name: key.Name, - Value: *schemas.NewEnvVar(""), - Models: models, - Weight: bifrost.Ptr(key.Weight), - Provider: string(providerKey), - ConfigHash: key.ConfigHash, + KeyID: key.ID, + Name: key.Name, + Value: *schemas.NewEnvVar(""), + Models: models, + BlacklistedModels: blacklisted, + Weight: bifrost.Ptr(key.Weight), + Provider: string(providerKey), + ConfigHash: key.ConfigHash, }) } } diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 0c48fb9984..24e709a73b 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -9,6 +9,7 @@ import ( "net" "os" "os/signal" + "slices" "strings" "syscall" "time" @@ -529,15 +530,23 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas if err != nil { return nil, fmt.Errorf("failed to update provider model catalog: failed to get keys by provider: %s", err) } - modelsInKeys := make([]schemas.Model, 0) + allowedInKeys := make([]schemas.Model, 0) + deniedInKeys := make([]schemas.Model, 0) for _, key := range providerKeys { for _, model := range key.Models { - modelsInKeys = append(modelsInKeys, schemas.Model{ + if !slices.Contains(key.BlacklistedModels, model) { + allowedInKeys = append(allowedInKeys, schemas.Model{ + ID: string(provider) + "/" + model, + }) + } + } + for _, model := range key.BlacklistedModels { + deniedInKeys = append(deniedInKeys, schemas.Model{ ID: string(provider) + "/" + model, }) } } - s.Config.ModelCatalog.UpsertModelDataForProvider(provider, allModels, modelsInKeys) + s.Config.ModelCatalog.UpsertModelDataForProvider(provider, allModels, allowedInKeys, deniedInKeys) unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ Provider: provider, Unfiltered: true, @@ -760,14 +769,22 @@ func (s *BifrostHTTPServer) ForceReloadPricing(ctx context.Context) error { logger.Error("failed to list models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) } allowedModels := make([]schemas.Model, 0) + deniedModels := make([]schemas.Model, 0) for _, key := range providerConfig.Keys { for _, model := range key.Models { - allowedModels = append(allowedModels, schemas.Model{ + if !slices.Contains(key.BlacklistedModels, model) { + allowedModels = append(allowedModels, schemas.Model{ + ID: string(provider) + "/" + model, + }) + } + } + for _, model := range key.BlacklistedModels { + deniedModels = append(deniedModels, schemas.Model{ ID: string(provider) + "/" + model, }) } } - s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels) + s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels, deniedModels) unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ Provider: provider, Unfiltered: true, @@ -1249,14 +1266,22 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { logger.Error("failed to list models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) } allowedModels := make([]schemas.Model, 0) + deniedModels := make([]schemas.Model, 0) for _, key := range providerConfig.Keys { for _, model := range key.Models { - allowedModels = append(allowedModels, schemas.Model{ + if !slices.Contains(key.BlacklistedModels, model) { + allowedModels = append(allowedModels, schemas.Model{ + ID: string(provider) + "/" + model, + }) + } + } + for _, model := range key.BlacklistedModels { + deniedModels = append(deniedModels, schemas.Model{ ID: string(provider) + "/" + model, }) } } - s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels) + s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels, deniedModels) unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ Provider: provider, Unfiltered: true, diff --git a/transports/changelog.md b/transports/changelog.md index 9a9f3aea83..37379b5328 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -1 +1,2 @@ -- fix: add support for `x-bf-mcp-include-clients` and `x-bf-mcp-include-tools` request headers to filter MCP tools/list response when using bifrost as an MCP gateway. \ No newline at end of file +- fix: add support for `x-bf-mcp-include-clients` and `x-bf-mcp-include-tools` request headers to filter MCP tools/list response when using bifrost as an +- feat: Provider keys support `blacklisted_models` in config and HTTP provider APIs; excluded models are omitted from filtered list-models and are not eligible for key selection (denylist wins over the `models` allow list). diff --git a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx index 3e94b129c5..7fac3eae34 100644 --- a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx @@ -187,13 +187,14 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { /> )} {!isVLLM && ( + <> (
- Models + Allowed Models @@ -214,6 +215,37 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { )} /> + ( + +
+ Blacklisted models + + + + + + + + +

+ Comma-separated list of models this key must never use. If a model appears in both Allowed Models and here, the blacklist wins. Leave + empty if none. +

+
+
+
+
+ + + + +
+ )} + /> + )} {supportsBatchAPI && !isBedrock && !isAzure && } {isAzure && ( diff --git a/ui/app/workspace/providers/views/providerKeyForm.tsx b/ui/app/workspace/providers/views/providerKeyForm.tsx index f44603b4dd..f9a550fc83 100644 --- a/ui/app/workspace/providers/views/providerKeyForm.tsx +++ b/ui/app/workspace/providers/views/providerKeyForm.tsx @@ -41,6 +41,7 @@ export default function ProviderKeyForm({ provider, keyIndex, onCancel, onSave } id: uuid(), name: "", models: [], + blacklisted_models: [], weight: 1.0, enabled: true, }, diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index fe20a244bc..02b097ffc9 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -118,6 +118,7 @@ export interface ModelProviderKey { name: string; value?: EnvVar; models?: string[]; + blacklisted_models?: string[]; weight: number; enabled?: boolean; use_for_batch_api?: boolean; @@ -141,6 +142,7 @@ export const DefaultModelProviderKey: ModelProviderKey = { from_env: false, }, models: [], + blacklisted_models: [], weight: 1.0, enabled: true, }; diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 3208891e08..4874c92be6 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -171,6 +171,7 @@ export const modelProviderKeySchema = z name: z.string().min(1, "Name is required"), value: envVarSchema.optional(), models: z.array(z.string()).default([]).optional(), + blacklisted_models: z.array(z.string()).default([]).optional(), weight: z.union([ z.number().min(0, "Weight must be equal to or greater than 0").max(1, "Weight must be equal to or less than 1"), z