diff --git a/core/bifrost.go b/core/bifrost.go index 8b77fda769..f805ed03df 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -6103,9 +6103,9 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p // - If key.Models is empty → exclude key (deny-by-default) // - If key.Models is non-empty → only include if model is in list if model != nil && *model != "" { - if slices.Contains(k.Models, "*") { + if k.Models.IsUnrestricted() { // wildcard: allow all models - } else if len(k.Models) == 0 || !slices.Contains(k.Models, *model) { + } else if !k.Models.IsAllowed(*model) { continue } } @@ -6199,7 +6199,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex } hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) // ["*"] = allow all models; [] = deny all; specific list = allow only listed - modelSupported := hasValue && (slices.Contains(key.Models, "*") || slices.Contains(key.Models, model)) + modelSupported := hasValue && key.Models.IsAllowed(model) // Additional deployment checks for Azure, Bedrock and Vertex deploymentSupported := true if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { diff --git a/core/mcp/agent.go b/core/mcp/agent.go index 7874d03530..88ec11e5fb 100644 --- a/core/mcp/agent.go +++ b/core/mcp/agent.go @@ -4,12 +4,11 @@ import ( "fmt" "strings" "sync" - + "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" ) - type AgentModeExecutor struct { logger schemas.Logger } @@ -39,7 +38,7 @@ func (a *AgentModeExecutor) ExecuteAgentForChatRequest( makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), - clientManager ClientManager, + clientManager ClientManager, ) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Create adapter for Chat API adapter := &chatAPIAdapter{ @@ -142,7 +141,7 @@ func (a *AgentModeExecutor) executeAgent( adapter agentAPIAdapter, fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), - clientManager ClientManager, + clientManager ClientManager, ) (interface{}, *schemas.BifrostError) { // Get initial response from adapter currentResponse := adapter.getInitialResponse() @@ -454,25 +453,23 @@ func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager C // Get auto-executable tools from config toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute - if len(toolsToAutoExecute) == 0 { + if toolsToAutoExecute.IsEmpty() { // No auto-executable tools configured for this client continue } // Parse tool names (as they appear in JavaScript code) autoExecutableTools := []string{} - for _, originalToolName := range toolsToAutoExecute { - // Handle wildcard "*" - means all tools are auto-executable - if originalToolName == "*" { - autoExecutableTools = append(autoExecutableTools, "*") - continue + if toolsToAutoExecute.IsUnrestricted() { + autoExecutableTools = append(autoExecutableTools, "*") + } else { + for _, originalToolName := range toolsToAutoExecute { + // Replace - with _ for code mode compatibility, then parse for JS compatibility + toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_") + parsedToolName := parseToolName(toolNameForCode) + autoExecutableTools = append(autoExecutableTools, parsedToolName) } - // Replace - with _ for code mode compatibility, then parse for JS compatibility - toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_") - parsedToolName := parseToolName(toolNameForCode) - autoExecutableTools = append(autoExecutableTools, parsedToolName) } - // Add to map if there are auto-executable tools if len(autoExecutableTools) > 0 { allowedTools[clientName] = autoExecutableTools diff --git a/core/mcp/utils.go b/core/mcp/utils.go index dcd39acd6c..5123187cac 100644 --- a/core/mcp/utils.go +++ b/core/mcp/utils.go @@ -381,12 +381,12 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b // If ToolsToExecute is specified (not nil), apply filtering if config.ToolsToExecute != nil { // Handle empty array [] - means no tools are allowed - if len(config.ToolsToExecute) == 0 { + if config.ToolsToExecute.IsEmpty() { return true // No tools allowed } // Handle wildcard "*" - if present, all tools are allowed - if slices.Contains(config.ToolsToExecute, "*") { + if config.ToolsToExecute.IsUnrestricted() { return false // All tools allowed } @@ -396,7 +396,7 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the allowed list - return !slices.Contains(config.ToolsToExecute, unprefixedToolName) // Tool not in allowed list + return !config.ToolsToExecute.Contains(unprefixedToolName) // Tool not in allowed list } return true // Tool is skipped (nil is treated as [] - no tools) @@ -413,12 +413,12 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // If ToolsToAutoExecute is specified (not nil), apply filtering if config.ToolsToAutoExecute != nil { // Handle empty array [] - means no tools are auto-executed - if len(config.ToolsToAutoExecute) == 0 { + if config.ToolsToAutoExecute.IsEmpty() { return false // No tools auto-executed } // Handle wildcard "*" - if present, all tools are auto-executed - if slices.Contains(config.ToolsToAutoExecute, "*") { + if config.ToolsToAutoExecute.IsUnrestricted() { return true // All tools auto-executed } @@ -428,7 +428,7 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the auto-execute list - return slices.Contains(config.ToolsToAutoExecute, unprefixedToolName) + return config.ToolsToAutoExecute.Contains(unprefixedToolName) } return false // Tool is not auto-executed (nil is treated as [] - no tools) diff --git a/core/providers/anthropic/models.go b/core/providers/anthropic/models.go index 2aa26fa5c1..baf0ad098e 100644 --- a/core/providers/anthropic/models.go +++ b/core/providers/anthropic/models.go @@ -6,7 +6,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -24,10 +24,14 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide bifrostResponse.NextPageToken = *response.LastID } + if !unfiltered && allowedModels.IsEmpty() { + return bifrostResponse + } + includedModels := make(map[string]bool) for _, model := range response.Data { modelID := model.ID - if !unfiltered && len(allowedModels) > 0 { + if !unfiltered && allowedModels.IsRestricted() { allowed := false for _, allowedModel := range allowedModels { if schemas.SameBaseModel(model.ID, allowedModel) { @@ -49,7 +53,7 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { if !includedModels[allowedModel] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ diff --git a/core/providers/azure/models.go b/core/providers/azure/models.go index c4ac069625..e6d6bda14c 100644 --- a/core/providers/azure/models.go +++ b/core/providers/azure/models.go @@ -1,25 +1,25 @@ package azure import ( - "slices" + "strings" "github.com/maximhq/bifrost/core/schemas" ) -// findMatchingAllowedModel finds a matching item in a slice, considering both +// findMatchingAllowedModel finds a matching item in a whitelist, considering both // exact match and base model matches (ignoring version suffixes). -// Returns the matched item from the slice if found, empty string otherwise. -// If matched via base model, returns the item from slice (not the value parameter). -func findMatchingAllowedModel(slice []string, value string) string { - // First check exact match - if slices.Contains(slice, value) { +// Returns the matched item from the whitelist if found, empty string otherwise. +// If matched via base model, returns the item from whitelist (not the value parameter). +func findMatchingAllowedModel(wl schemas.WhiteList, value string) string { + // First check exact match (case-insensitive) + if wl.Contains(value) { return value } // Additional layer: check base model matches (ignoring version suffixes) // This handles cases where model versions differ but base model is the same - // Return the item from slice (not value) to use the actual name from allowedModels - for _, item := range slice { + // Return the item from whitelist (not value) to use the actual name from allowedModels + for _, item := range wl { if schemas.SameBaseModel(item, value) { return item } @@ -58,7 +58,7 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym return "", "" } -func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -67,6 +67,12 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode Data: make([]schemas.Model, 0, len(response.Data)), } + if !unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 { + return bifrostResponse + } + + restrictAllowed := !unfiltered && allowedModels.IsRestricted() + includedModels := make(map[string]bool) for _, model := range response.Data { modelID := model.ID @@ -78,7 +84,7 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode // Empty lists mean "allow all" for that dimension // Check considering base model matches (ignoring version suffixes) shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deployments) > 0 { + if restrictAllowed && len(deployments) > 0 { // Both lists are present: model must be in allowedModels AND deployments // AND the deployment alias must also be in allowedModels matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) @@ -88,12 +94,12 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode // Check if deployment alias is also in allowedModels (direct string match) deploymentAliasInAllowedModels := false if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) + deploymentAliasInAllowedModels = allowedModels.Contains(deploymentAlias) } // Filter if: model not in deployments OR deployment alias not in allowedModels shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { + } else if restrictAllowed { // Only allowedModels is present: filter if model is not in allowedModels matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) shouldFilter = matchedAllowedModel == "" @@ -102,7 +108,7 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID) shouldFilter = deploymentValue == "" } - // If both are empty, shouldFilter remains false (allow all) + // If both are empty (or allowedModels is unrestricted and no deployments), shouldFilter remains false if shouldFilter { continue @@ -124,9 +130,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode if deploymentValue != "" && deploymentAlias != "" { modelEntry.ID = string(schemas.Azure) + "/" + deploymentAlias modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[deploymentAlias] = true + includedModels[strings.ToLower(deploymentAlias)] = true } else { - includedModels[modelID] = true + includedModels[strings.ToLower(modelID)] = true } bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) @@ -135,11 +141,11 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode // Backfill deployments that were not matched from the API response if !unfiltered && len(deployments) > 0 { for alias, deploymentValue := range deployments { - if includedModels[alias] { + if includedModels[strings.ToLower(alias)] { continue } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { + // If allowedModels is restricted, only include if alias is in the list + if restrictAllowed && !allowedModels.Contains(alias) { continue } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ @@ -147,14 +153,14 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode Name: schemas.Ptr(alias), Deployment: schemas.Ptr(deploymentValue), }) - includedModels[alias] = true + includedModels[strings.ToLower(alias)] = true } } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if restrictAllowed { for _, allowedModel := range allowedModels { - if !includedModels[allowedModel] { + if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Azure) + "/" + allowedModel, Name: schemas.Ptr(allowedModel), diff --git a/core/providers/bedrock/models.go b/core/providers/bedrock/models.go index a3bebb5890..cfffd360f7 100644 --- a/core/providers/bedrock/models.go +++ b/core/providers/bedrock/models.go @@ -1,7 +1,6 @@ package bedrock import ( - "slices" "strings" "github.com/maximhq/bifrost/core/schemas" @@ -117,29 +116,29 @@ func removePrefix(s string) string { return s } -// findMatchingAllowedModel finds a matching item in a slice, considering both +// findMatchingAllowedModel finds a matching item in a whitelist, considering both // exact match and match with/without region prefixes (e.g., "global.", "us.", "eu."), // and also checks base model matches (ignoring version suffixes). -// Returns the matched item from the slice if found, empty string otherwise. -// If matched via base model, returns the item from slice (not the value parameter). -func findMatchingAllowedModel(slice []string, value string) string { - // First check exact matches - if slices.Contains(slice, value) { +// Returns the matched item from the whitelist if found, empty string otherwise. +// If matched via base model, returns the item from whitelist (not the value parameter). +func findMatchingAllowedModel(wl schemas.WhiteList, value string) string { + // First check exact matches (case-insensitive) + if wl.Contains(value) { return value } // Check with region prefix added/removed valuePrefix := extractPrefix(value) if valuePrefix != "" { - // value has a prefix, check if slice contains version without prefix + // value has a prefix, check if whitelist contains version without prefix withoutPrefix := removePrefix(value) - if slices.Contains(slice, withoutPrefix) { + if wl.Contains(withoutPrefix) { return withoutPrefix } } - // Check if any item in slice has a prefix that matches value without prefix - for _, item := range slice { + // Check if any item in whitelist has a prefix that matches value without prefix + for _, item := range wl { itemPrefix := extractPrefix(item) if itemPrefix != "" { // item has prefix, check if value matches without the prefix @@ -155,12 +154,12 @@ func findMatchingAllowedModel(slice []string, value string) string { // Normalize value by removing any region prefix for base model comparison valueNormalized := removePrefix(value) - for _, item := range slice { + for _, item := range wl { // Normalize item by removing any region prefix for base model comparison itemNormalized := removePrefix(item) // Check base model match with normalized values (prefix removed from both) - // Return the item from slice (not value) to use the actual name from allowedModels + // Return the item from whitelist (not value) to use the actual name from allowedModels if schemas.SameBaseModel(itemNormalized, valueNormalized) { return item } @@ -221,7 +220,7 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym return "", "" } -func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -230,11 +229,17 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK Data: make([]schemas.Model, 0, len(response.ModelSummaries)), } + if !unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 { + return bifrostResponse + } + deploymentValues := make([]string, 0, len(deployments)) for _, deployment := range deployments { deploymentValues = append(deploymentValues, deployment) } + restrictAllowed := !unfiltered && allowedModels.IsRestricted() + includedModels := make(map[string]bool) for _, model := range response.ModelSummaries { modelID := model.ModelID @@ -246,7 +251,7 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK // Empty lists mean "allow all" for that dimension // Check considering global prefix variations shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deploymentValues) > 0 { + if restrictAllowed && len(deploymentValues) > 0 { // Both lists are present: model must be in allowedModels AND deployments // AND the deployment alias must also be in allowedModels matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) @@ -256,12 +261,12 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK // Check if deployment alias is also in allowedModels (direct string match) deploymentAliasInAllowedModels := false if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) + deploymentAliasInAllowedModels = allowedModels.Contains(deploymentAlias) } // Filter if: model not in deployments OR deployment alias not in allowedModels shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { + } else if restrictAllowed { // Only allowedModels is present: filter if model is not in allowedModels matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) shouldFilter = matchedAllowedModel == "" @@ -270,7 +275,7 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ModelID) shouldFilter = deploymentValue == "" } - // If both are empty, shouldFilter remains false (allow all) + // If both are empty (or allowedModels is unrestricted and no deployments), shouldFilter remains false if shouldFilter { continue @@ -298,9 +303,9 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK modelEntry.ID = string(providerKey) + "/" + deploymentAlias // Use the actual deployment value (which might have global prefix) modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[deploymentAlias] = true + includedModels[strings.ToLower(deploymentAlias)] = true } else { - includedModels[modelID] = true + includedModels[strings.ToLower(modelID)] = true } bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) } @@ -308,11 +313,11 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK // Backfill deployments that were not matched from the API response if !unfiltered && len(deployments) > 0 { for alias, deploymentValue := range deployments { - if includedModels[alias] { + if includedModels[strings.ToLower(alias)] { continue } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { + // If allowedModels is restricted, only include if alias is in the list + if restrictAllowed && !allowedModels.Contains(alias) { continue } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ @@ -320,14 +325,14 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK Name: schemas.Ptr(alias), Deployment: schemas.Ptr(deploymentValue), }) - includedModels[alias] = true + includedModels[strings.ToLower(alias)] = true } } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if restrictAllowed { for _, allowedModel := range allowedModels { - if !includedModels[allowedModel] { + if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, Name: schemas.Ptr(allowedModel), diff --git a/core/providers/cohere/models.go b/core/providers/cohere/models.go index fb399731ca..7219cedf35 100644 --- a/core/providers/cohere/models.go +++ b/core/providers/cohere/models.go @@ -1,7 +1,7 @@ package cohere import ( - "slices" + "strings" "github.com/maximhq/bifrost/core/schemas" ) @@ -43,7 +43,7 @@ type CohereRerankMeta struct { Tokens *CohereTokenUsage `json:"tokens,omitempty"` } -func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -52,9 +52,13 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } + if !unfiltered && allowedModels.IsEmpty() { + return bifrostResponse + } + includedModels := make(map[string]bool) for _, model := range response.Models { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.Name) { + if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.Name) { continue } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ @@ -63,13 +67,13 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe ContextLength: schemas.Ptr(int(model.ContextLength)), SupportedMethods: model.Endpoints, }) - includedModels[model.Name] = true + includedModels[strings.ToLower(model.Name)] = true } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { - if !includedModels[allowedModel] { + if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, Name: schemas.Ptr(allowedModel), diff --git a/core/providers/elevenlabs/models.go b/core/providers/elevenlabs/models.go index a00b81847e..dd9136b27e 100644 --- a/core/providers/elevenlabs/models.go +++ b/core/providers/elevenlabs/models.go @@ -1,12 +1,12 @@ package elevenlabs import ( - "slices" + "strings" "github.com/maximhq/bifrost/core/schemas" ) -func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,22 +15,26 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid Data: make([]schemas.Model, 0, len(*response)), } + if !unfiltered && allowedModels.IsEmpty() { + return bifrostResponse + } + includedModels := make(map[string]bool) for _, model := range *response { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { + if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ModelID) { continue } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + model.ModelID, Name: schemas.Ptr(model.Name), }) - includedModels[model.ModelID] = true + includedModels[strings.ToLower(model.ModelID)] = true } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { - if !includedModels[allowedModel] { + if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, Name: schemas.Ptr(allowedModel), diff --git a/core/providers/gemini/models.go b/core/providers/gemini/models.go index ec36f9406b..576201bcc4 100644 --- a/core/providers/gemini/models.go +++ b/core/providers/gemini/models.go @@ -1,7 +1,6 @@ package gemini import ( - "slices" "strings" "github.com/maximhq/bifrost/core/schemas" @@ -17,7 +16,7 @@ func toGeminiModelResourceName(modelID string) string { return "models/" + modelID } -func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -26,13 +25,17 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } + if !unfiltered && allowedModels.IsEmpty() { + return bifrostResponse + } + includedModels := make(map[string]bool) for _, model := range response.Models { contextLength := model.InputTokenLimit + model.OutputTokenLimit // Remove prefix models/ from model.Name modelName := strings.TrimPrefix(model.Name, "models/") - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelName) { + if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(modelName) { continue } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ @@ -44,13 +47,13 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit), SupportedMethods: model.SupportedGenerationMethods, }) - includedModels[modelName] = true + includedModels[strings.ToLower(modelName)] = true } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { - if !includedModels[allowedModel] { + if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, Name: schemas.Ptr(allowedModel), diff --git a/core/providers/huggingface/models.go b/core/providers/huggingface/models.go index e306c4dd2f..749fb2ec38 100644 --- a/core/providers/huggingface/models.go +++ b/core/providers/huggingface/models.go @@ -13,7 +13,7 @@ const ( maxModelFetchLimit = 1000 ) -func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -22,6 +22,10 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi Data: make([]schemas.Model, 0, len(response.Models)), } + if !unfiltered && allowedModels.IsEmpty() { + return bifrostResponse + } + includedModels := make(map[string]bool) for _, model := range response.Models { if model.ModelID == "" { @@ -33,7 +37,7 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi continue } - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { + if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ModelID) { continue } @@ -45,13 +49,13 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi } bifrostResponse.Data = append(bifrostResponse.Data, newModel) - includedModels[model.ModelID] = true + includedModels[strings.ToLower(model.ModelID)] = true } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { - if !includedModels[allowedModel] { + if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, allowedModel), Name: schemas.Ptr(allowedModel), diff --git a/core/providers/mistral/models.go b/core/providers/mistral/models.go index 45d638a0b2..8030f5e1f8 100644 --- a/core/providers/mistral/models.go +++ b/core/providers/mistral/models.go @@ -1,12 +1,12 @@ package mistral import ( - "slices" + "strings" "github.com/maximhq/bifrost/core/schemas" ) -func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,9 +15,13 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo Data: make([]schemas.Model, 0, len(response.Data)), } + if !unfiltered && allowedModels.IsEmpty() { + return bifrostResponse + } + includedModels := make(map[string]bool) for _, model := range response.Data { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { + if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ID) { continue } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ @@ -28,18 +32,18 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo ContextLength: schemas.Ptr(int(model.MaxContextLength)), OwnedBy: schemas.Ptr(model.OwnedBy), }) - includedModels[model.ID] = true + includedModels[strings.ToLower(model.ID)] = true } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { - if !includedModels[allowedModel] { + if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(schemas.Mistral) + "/" + allowedModel, Name: schemas.Ptr(allowedModel), }) - includedModels[allowedModel] = true + includedModels[strings.ToLower(allowedModel)] = true } } } diff --git a/core/providers/openai/models.go b/core/providers/openai/models.go index 6cb2b57f4b..cd79832290 100644 --- a/core/providers/openai/models.go +++ b/core/providers/openai/models.go @@ -1,13 +1,13 @@ package openai import ( - "slices" + "strings" "github.com/maximhq/bifrost/core/schemas" ) // ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response -func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -16,9 +16,13 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Data)), } + if !unfiltered && allowedModels.IsEmpty() { + return bifrostResponse + } + includedModels := make(map[string]bool) for _, model := range response.Data { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { + if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ID) { continue } bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ @@ -27,13 +31,13 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe OwnedBy: schemas.Ptr(model.OwnedBy), ContextLength: model.ContextWindow, }) - includedModels[model.ID] = true + includedModels[strings.ToLower(model.ID)] = true } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { - if !includedModels[allowedModel] { + if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, Name: schemas.Ptr(allowedModel), diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index a8a20f28ca..58001be53a 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -4,7 +4,6 @@ package openrouter import ( "fmt" "net/http" - "slices" "strings" "time" @@ -189,27 +188,33 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, allowedModels := key.Models providerPrefix := string(schemas.OpenRouter) + "/" - if !request.Unfiltered && len(allowedModels) > 0 { + if !request.Unfiltered && allowedModels.IsEmpty() { + openrouterResponse.Data = make([]schemas.Model, 0) + } else if !request.Unfiltered && allowedModels.IsRestricted() { filteredData := make([]schemas.Model, 0, len(openrouterResponse.Data)) includedModels := make(map[string]bool) for i := range openrouterResponse.Data { rawID := openrouterResponse.Data[i].ID - if !(slices.Contains(allowedModels, rawID) || slices.Contains(allowedModels, providerPrefix+rawID)) { + if !(allowedModels.Contains(rawID) || allowedModels.Contains(providerPrefix+rawID)) { continue } openrouterResponse.Data[i].ID = providerPrefix + rawID filteredData = append(filteredData, openrouterResponse.Data[i]) - includedModels[rawID] = true + includedModels[strings.ToLower(rawID)] = true } // Backfill allowed models not in the API response for _, allowedModel := range allowedModels { - rawID := strings.TrimPrefix(allowedModel, providerPrefix) - if !includedModels[rawID] { + // Strip provider prefix case-insensitively to handle any casing users may supply + rawID := allowedModel + if strings.HasPrefix(strings.ToLower(allowedModel), strings.ToLower(providerPrefix)) { + rawID = allowedModel[len(providerPrefix):] + } + if !includedModels[strings.ToLower(rawID)] { filteredData = append(filteredData, schemas.Model{ ID: providerPrefix + rawID, Name: schemas.Ptr(rawID), }) - includedModels[rawID] = true // avoid duplicate backfill + includedModels[strings.ToLower(rawID)] = true // avoid duplicate backfill } } openrouterResponse.Data = filteredData diff --git a/core/providers/replicate/models.go b/core/providers/replicate/models.go index 1cb2016f82..12000f4b4b 100644 --- a/core/providers/replicate/models.go +++ b/core/providers/replicate/models.go @@ -1,7 +1,6 @@ package replicate import ( - "slices" "strings" "github.com/maximhq/bifrost/core/schemas" @@ -11,13 +10,17 @@ import ( func ToBifrostListModelsResponse( deploymentsResponse *ReplicateDeploymentListResponse, providerKey schemas.ModelProvider, - allowedModels []string, + allowedModels schemas.WhiteList, unfiltered bool, ) *schemas.BifrostListModelsResponse { bifrostResponse := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } + if !unfiltered && allowedModels.IsEmpty() { + return bifrostResponse + } + includedModels := make(map[string]bool) // Add deployments from /v1/deployments endpoint if deploymentsResponse != nil { @@ -27,7 +30,7 @@ func ToBifrostListModelsResponse( modelName := schemas.Ptr(deployment.Name) var created *int64 - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, deploymentID) { + if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(deploymentID) { continue } @@ -51,7 +54,7 @@ func ToBifrostListModelsResponse( } bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel) - includedModels[deploymentID] = true + includedModels[strings.ToLower(deploymentID)] = true } if deploymentsResponse.Next != nil { @@ -60,9 +63,9 @@ func ToBifrostListModelsResponse( } // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { + if !unfiltered && allowedModels.IsRestricted() { for _, allowedModel := range allowedModels { - if !includedModels[allowedModel] { + if !includedModels[strings.ToLower(allowedModel)] { bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ ID: string(providerKey) + "/" + allowedModel, Name: schemas.Ptr(allowedModel), diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index cdec949e67..924b421d98 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -1,7 +1,6 @@ package vertex import ( - "slices" "strings" "github.com/maximhq/bifrost/core/schemas" @@ -114,7 +113,7 @@ func findDeploymentMatch(deployments map[string]string, customModelID string) (d // - If allowedModels is empty, all models are allowed // - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included // - Deployments map is used to match model IDs to aliases and filter accordingly -func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -123,6 +122,10 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod Data: make([]schemas.Model, 0, len(response.Models)), } + if !unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 { + return bifrostResponse + } + // Track which model IDs have been added to avoid duplicates addedModelIDs := make(map[string]bool) @@ -143,10 +146,10 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod } // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension var deploymentValue, deploymentAlias string + restrictAllowed := !unfiltered && allowedModels.IsRestricted() shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deployments) > 0 { + if restrictAllowed && len(deployments) > 0 { // Both lists are present: model must be in allowedModels AND deployments // AND the deployment alias must also be in allowedModels deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) @@ -155,20 +158,20 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod // Check if deployment alias is also in allowedModels (direct string match) deploymentAliasInAllowedModels := false if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) + deploymentAliasInAllowedModels = allowedModels.Contains(deploymentAlias) } // Filter if: model not in deployments OR deployment alias not in allowedModels shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { + } else if restrictAllowed { // Only allowedModels is present: filter if model is not in allowedModels - shouldFilter = !slices.Contains(allowedModels, customModelID) + shouldFilter = !allowedModels.Contains(customModelID) } else if !unfiltered && len(deployments) > 0 { // Only deployments is present: filter if model is not in deployments deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) shouldFilter = deploymentValue == "" } - // If both are empty, shouldFilter remains false (allow all) + // If both are empty (or allowedModels is unrestricted and no deployments), shouldFilter remains false if shouldFilter { continue @@ -192,6 +195,8 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod } } + restrictAllowed := !unfiltered && allowedModels.IsRestricted() + // Second pass: Backfill deployments that were not matched from the API response if !unfiltered && len(deployments) > 0 { for alias, deploymentValue := range deployments { @@ -200,8 +205,8 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod if addedModelIDs[modelID] { continue } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { + // If allowedModels is restricted, only include if alias is in the list + if restrictAllowed && !allowedModels.Contains(alias) { continue } @@ -218,7 +223,7 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod } // Third pass: Backfill allowed models that were not in the response or deployments - if !unfiltered && len(allowedModels) > 0 { + if restrictAllowed { for _, allowedModel := range allowedModels { // Check if model already exists in the list modelID := string(schemas.Vertex) + "/" + allowedModel @@ -244,7 +249,7 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod // ToBifrostListModelsResponse converts a Vertex AI publisher models response to Bifrost's format. // This is for foundation models from the Model Garden (publishers.models.list endpoint). -func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -253,6 +258,10 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a Data: make([]schemas.Model, 0, len(response.PublisherModels)), } + if !unfiltered && allowedModels.IsEmpty() { + return bifrostResponse + } + // Track which model IDs have been added to avoid duplicates addedModelIDs := make(map[string]bool) @@ -264,7 +273,7 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a } // Filter based on allowedModels if specified - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelID) { + if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(modelID) { continue } diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go index 69bb4d51a6..b8b2dc67ea 100644 --- a/core/providers/vertex/utils.go +++ b/core/providers/vertex/utils.go @@ -155,22 +155,18 @@ func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID // buildResponseFromConfig builds a list models response from configured deployments and allowedModels. // This is used when the user has explicitly configured which models they want to use. -func buildResponseFromConfig(deployments map[string]string, allowedModels []string) *schemas.BifrostListModelsResponse { +func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList) *schemas.BifrostListModelsResponse { response := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } addedModelIDs := make(map[string]bool) - // Build allowlist set for O(1) lookup - allowedSet := make(map[string]bool, len(allowedModels)) - for _, m := range allowedModels { - allowedSet[m] = true - } + restrictAllowed := allowedModels.IsRestricted() // First add models from deployments (filtered by allowedModels when set) for alias, deploymentValue := range deployments { - if len(allowedSet) > 0 && !allowedSet[alias] { + if restrictAllowed && !allowedModels.Contains(alias) { continue } modelID := string(schemas.Vertex) + "/" + alias @@ -189,7 +185,10 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels []stri addedModelIDs[modelID] = true } - // Then add models from allowedModels that aren't already in deployments + // Then add models from allowedModels that aren't already in deployments (only when restricted) + if !restrictAllowed { + return response + } for _, allowedModel := range allowedModels { modelID := string(schemas.Vertex) + "/" + allowedModel if addedModelIDs[modelID] { diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index e803acf673..d80488dc5c 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -190,9 +190,13 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key deployments := key.VertexKeyConfig.Deployments allowedModels := key.Models + if !request.Unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 { + return &schemas.BifrostListModelsResponse{Data: make([]schemas.Model, 0)}, nil + } + // If deployments or allowedModels are configured, return those directly without API call // Skip this fast path when Unfiltered is set so the full Vertex catalog can be retrieved - if !request.Unfiltered && (len(deployments) > 0 || len(allowedModels) > 0) { + if !request.Unfiltered && (len(deployments) > 0 || allowedModels.IsRestricted()) { return buildResponseFromConfig(deployments, allowedModels), nil } diff --git a/core/schemas/account.go b/core/schemas/account.go index e89b55d8a8..488fc5573a 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -1,7 +1,12 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -import "context" +import ( + "context" + "fmt" + "slices" + "strings" +) type KeyStatusType string @@ -10,13 +15,71 @@ const ( KeyStatusListModelsFailed KeyStatusType = "list_models_failed" ) +// WhiteList is a list of values that are allowed to be used. +// Semantics: +// - "*" (alone) means all values are allowed. +// - Empty list means nothing is allowed. +// - Non-empty list (without "*") means only the listed values are allowed. +// +// This type is used generically for any field that needs whitelist behavior +// (e.g., allowed models, allowed tools). +type WhiteList []string + +// Contains reports whether value is in the whitelist. +// Returns true if value is in the list. +func (wl WhiteList) Contains(value string) bool { + return slices.ContainsFunc(wl, func(s string) bool { + return strings.EqualFold(s, value) + }) +} + +// IsAllowed reports whether value is in the whitelist. +// Returns true if value is in the list. +func (wl WhiteList) IsAllowed(value string) bool { + return wl.IsUnrestricted() || wl.Contains(value) +} + +// IsEmpty reports whether the whitelist has no entries. +func (wl WhiteList) IsEmpty() bool { + return len(wl) == 0 +} + +// IsUnrestricted reports whether the whitelist contains only "*", +// meaning all values are allowed. +func (wl WhiteList) IsUnrestricted() bool { + return len(wl) == 1 && wl[0] == "*" +} + +// IsRestricted reports whether the whitelist contains entries other than "*", +// meaning only the listed values are allowed. +func (wl WhiteList) IsRestricted() bool { + return !wl.IsUnrestricted() +} + +// Validate checks that the whitelist is well-formed. +// Returns an error if "*" is present alongside other values, or if there are duplicate entries. +func (wl WhiteList) Validate() error { + if wl.Contains("*") && len(wl) > 1 { + return fmt.Errorf("wildcard '*' cannot be used with other values in the whitelist") + } + seen := make(map[string]struct{}, len(wl)) + for _, v := range wl { + normalized := strings.ToLower(v) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate value '%s' in whitelist", v) + } + seen[normalized] = struct{}{} + } + return nil +} + // Key represents an API key and its associated configuration for a provider. // It contains the key value, supported models, and a weight for load balancing. type Key struct { ID string `json:"id"` // The unique identifier for the key (used by bifrost to identify the key) Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost) Value EnvVar `json:"value"` // The actual API key value - Models []string `json:"models"` // List of models this key can access + Models WhiteList `json:"models"` // List of models this key can access Weight float64 `json:"weight"` // Weight for load balancing between multiple keys AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index a6183491e7..70cf436875 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -49,7 +49,7 @@ type MCPConfig struct { type MCPToolManagerConfig struct { ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` MaxAgentDepth int `json:"max_agent_depth"` - CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" + CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" DisableAutoToolInject bool `json:"disable_auto_tool_inject,omitempty"` // When true, MCP tools are not injected into requests by default } @@ -77,7 +77,7 @@ const ( // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { - ID string `json:"client_id"` // Client ID + ID string `json:"client_id"` // Client ID Name string `json:"name"` // Client name IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) @@ -88,13 +88,13 @@ type MCPClientConfig struct { State string `json:"state,omitempty"` // Connection state (connected, disconnected, error) Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type) InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) - ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Include-only list. + ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list. // ToolsToExecute semantics: // - ["*"] => all tools are included // - [] => no tools are included (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => include only the specified tools - ToolsToAutoExecute []string `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. + ToolsToAutoExecute WhiteList `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. // ToolsToAutoExecute semantics: // - ["*"] => all tools are auto-executed // - [] => no tools are auto-executed (deny-by-default) diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index 38a5a6b942..fedaf0f65a 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -43,8 +43,8 @@ type TableMCPClient struct { // Virtual fields for runtime use (not stored in DB) StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` - ToolsToExecute []string `gorm:"-" json:"tools_to_execute"` - ToolsToAutoExecute []string `gorm:"-" json:"tools_to_auto_execute"` + ToolsToExecute schemas.WhiteList `gorm:"-" json:"tools_to_execute"` + ToolsToAutoExecute schemas.WhiteList `gorm:"-" json:"tools_to_auto_execute"` Headers map[string]schemas.EnvVar `gorm:"-" json:"headers"` ToolPricing map[string]float64 `gorm:"-" json:"tool_pricing"` } @@ -68,6 +68,9 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { } if c.ToolsToExecute != nil { + if err := c.ToolsToExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_execute: %w", err) + } data, err := json.Marshal(c.ToolsToExecute) if err != nil { return err @@ -78,6 +81,9 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { } if c.ToolsToAutoExecute != nil { + if err := c.ToolsToAutoExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_auto_execute: %w", err) + } data, err := json.Marshal(c.ToolsToAutoExecute) if err != nil { return err diff --git a/framework/configstore/tables/virtualkey.go b/framework/configstore/tables/virtualkey.go index a92fd0f980..af6b9f2b7e 100644 --- a/framework/configstore/tables/virtualkey.go +++ b/framework/configstore/tables/virtualkey.go @@ -24,14 +24,14 @@ func (TableVirtualKeyProviderConfigKey) TableName() string { // TableVirtualKeyProviderConfig represents a provider configuration for a virtual key type TableVirtualKeyProviderConfig struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"` - Provider string `gorm:"type:varchar(50);not null" json:"provider"` - Weight *float64 `json:"weight"` - AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // ["*"] allows all models; empty denies all (deny-by-default) - AllowAllKeys bool `gorm:"default:false" json:"allow_all_keys"` // True means all keys allowed; false with empty Keys means no keys allowed (deny-by-default) - BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"` - RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"` + Provider string `gorm:"type:varchar(50);not null" json:"provider"` + Weight *float64 `json:"weight"` + AllowedModels schemas.WhiteList `gorm:"type:text;serializer:json" json:"allowed_models"` // ["*"] allows all models; empty denies all (deny-by-default) + AllowAllKeys bool `gorm:"default:false" json:"allow_all_keys"` // True means all keys allowed; false with empty Keys means no keys allowed (deny-by-default) + BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"` + RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"` // Relationships Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"` @@ -78,6 +78,14 @@ func (pc *TableVirtualKeyProviderConfig) UnmarshalJSON(data []byte) error { return nil } +// BeforeSave validates WhiteList fields before GORM persists the record. +func (pc *TableVirtualKeyProviderConfig) BeforeSave(tx *gorm.DB) error { + if err := pc.AllowedModels.Validate(); err != nil { + return fmt.Errorf("invalid allowed_models: %w", err) + } + return nil +} + // MarshalJSON custom marshaller to ensure AllowedModels is always an array (never null) func (pc TableVirtualKeyProviderConfig) MarshalJSON() ([]byte, error) { type Alias TableVirtualKeyProviderConfig @@ -147,11 +155,11 @@ func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error { } type TableVirtualKeyMCPConfig struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - VirtualKeyID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_vk_mcpclient" json:"virtual_key_id"` - MCPClientID uint `gorm:"not null;uniqueIndex:idx_vk_mcpclient" json:"mcp_client_id"` - MCPClient TableMCPClient `gorm:"foreignKey:MCPClientID" json:"mcp_client"` - ToolsToExecute []string `gorm:"type:text;serializer:json" json:"tools_to_execute"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + VirtualKeyID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_vk_mcpclient" json:"virtual_key_id"` + MCPClientID uint `gorm:"not null;uniqueIndex:idx_vk_mcpclient" json:"mcp_client_id"` + MCPClient TableMCPClient `gorm:"foreignKey:MCPClientID" json:"mcp_client"` + ToolsToExecute schemas.WhiteList `gorm:"type:text;serializer:json" json:"tools_to_execute"` // MCPClientName is used during config file parsing to resolve the MCP client by name. // This field is not persisted to the database - it's only used to capture @@ -164,6 +172,14 @@ func (TableVirtualKeyMCPConfig) TableName() string { return "governance_virtual_key_mcp_configs" } +// BeforeSave validates WhiteList fields before GORM persists the record. +func (mc *TableVirtualKeyMCPConfig) BeforeSave(tx *gorm.DB) error { + if err := mc.ToolsToExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_execute: %w", err) + } + return nil +} + // UnmarshalJSON custom unmarshaller to handle both "mcp_client_id" (database format) // and "mcp_client_name" (config file format) for MCP client references. func (mc *TableVirtualKeyMCPConfig) UnmarshalJSON(data []byte) error { diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index a61ec5c95f..be25cb69c8 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -530,14 +530,14 @@ func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvid // // Explicit allowedModels without prefix // mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"}) // // Returns: true (direct match) -func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, allowedModels []string) bool { +func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, allowedModels schemas.WhiteList) bool { // Case 1: ["*"] = allow all models; use catalog to determine support // Empty allowedModels = deny all (fail-safe deny-by-default) - if slices.Contains(allowedModels, "*") { + if allowedModels.IsUnrestricted() { supportedProviders := mc.GetProvidersForModel(model) return slices.Contains(supportedProviders, provider) } - if len(allowedModels) == 0 { + if allowedModels.IsEmpty() { return false } diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 1eb211c011..200afe1237 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -7,7 +7,6 @@ import ( "fmt" "math/rand/v2" "net/url" - "slices" "sort" "strings" "sync" @@ -614,11 +613,7 @@ func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req } else { // Fallback when model catalog is not available: simple string matching // ["*"] = allow all models; [] = deny all models - if slices.Contains(config.AllowedModels, "*") { - isProviderAllowed = true - } else { - isProviderAllowed = slices.Contains(config.AllowedModels, modelStr) - } + isProviderAllowed = config.AllowedModels.IsAllowed(modelStr) } if isProviderAllowed { @@ -908,12 +903,12 @@ func (p *GovernancePlugin) addMCPIncludeTools(headers map[string]string, virtual executeOnlyTools := make([]string, 0) for _, vkMcpConfig := range virtualKey.MCPConfigs { - if len(vkMcpConfig.ToolsToExecute) == 0 { + if vkMcpConfig.ToolsToExecute.IsEmpty() { // No tools specified in virtual key config - skip this client entirely continue } // Handle wildcard in virtual key config - allow all tools from this client - if slices.Contains(vkMcpConfig.ToolsToExecute, "*") { + if vkMcpConfig.ToolsToExecute.IsUnrestricted() { // Virtual key uses wildcard - use client-specific wildcard executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) continue @@ -1121,18 +1116,18 @@ func isMCPToolAllowedByVK(vk *configstoreTables.TableVirtualKey, toolPattern str clientName := mcpConfig.MCPClient.Name // Wildcard pattern "clientName-*": VK just needs to have this client configured at all. if toolPattern == clientName+"-*" { - if len(mcpConfig.ToolsToExecute) > 0 { + if !mcpConfig.ToolsToExecute.IsEmpty() { return true } continue } // Specific tool "clientName-toolName" if strings.HasPrefix(toolPattern, clientName+"-") { - if slices.Contains(mcpConfig.ToolsToExecute, "*") { + if mcpConfig.ToolsToExecute.IsUnrestricted() { return true } toolSuffix := strings.TrimPrefix(toolPattern, clientName+"-") - if slices.Contains(mcpConfig.ToolsToExecute, toolSuffix) { + if mcpConfig.ToolsToExecute.Contains(toolSuffix) { return true } } diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index b1702a0dde..65607ed86e 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -4,7 +4,6 @@ package governance import ( "context" "fmt" - "slices" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" @@ -277,10 +276,7 @@ func (r *BudgetResolver) isModelAllowed(vk *configstoreTables.TableVirtualKey, p } // Fallback when model catalog is not available: simple string matching // ["*"] = allow all models; [] = deny all models - if slices.Contains(pc.AllowedModels, "*") { - return true - } - return slices.Contains(pc.AllowedModels, model) + return pc.AllowedModels.IsAllowed(model) } } diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index 5d0a56ac76..731fd3b4b9 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -69,14 +69,14 @@ type CreateVirtualKeyRequest struct { ProviderConfigs []struct { Provider string `json:"provider" validate:"required"` Weight *float64 `json:"weight,omitempty"` - AllowedModels []string `json:"allowed_models,omitempty"` // ["*"] allows all models; empty denies all + AllowedModels schemas.WhiteList `json:"allowed_models,omitempty"` // ["*"] allows all models; empty denies all Budget *CreateBudgetRequest `json:"budget,omitempty"` // Provider-level budget RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` // Provider-level rate limit - KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this provider config + KeyIDs schemas.WhiteList `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this provider config } `json:"provider_configs,omitempty"` // Empty means no providers allowed (deny-by-default) MCPConfigs []struct { - MCPClientName string `json:"mcp_client_name" validate:"required"` - ToolsToExecute []string `json:"tools_to_execute,omitempty"` + MCPClientName string `json:"mcp_client_name" validate:"required"` + ToolsToExecute schemas.WhiteList `json:"tools_to_execute,omitempty"` } `json:"mcp_configs,omitempty"` // Empty means no MCP clients allowed (deny-by-default) TeamID *string `json:"team_id,omitempty"` // Mutually exclusive with CustomerID CustomerID *string `json:"customer_id,omitempty"` // Mutually exclusive with TeamID @@ -93,15 +93,15 @@ type UpdateVirtualKeyRequest struct { ID *uint `json:"id,omitempty"` // null for new entries Provider string `json:"provider" validate:"required"` Weight *float64 `json:"weight,omitempty"` - AllowedModels []string `json:"allowed_models,omitempty"` // ["*"] allows all models; empty denies all + AllowedModels schemas.WhiteList `json:"allowed_models,omitempty"` // ["*"] allows all models; empty denies all Budget *UpdateBudgetRequest `json:"budget,omitempty"` // Provider-level budget RateLimit *UpdateRateLimitRequest `json:"rate_limit,omitempty"` // Provider-level rate limit - KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this provider config + KeyIDs schemas.WhiteList `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this provider config } `json:"provider_configs,omitempty"` MCPConfigs []struct { - ID *uint `json:"id,omitempty"` // null for new entries - MCPClientName string `json:"mcp_client_name" validate:"required"` - ToolsToExecute []string `json:"tools_to_execute,omitempty"` + ID *uint `json:"id,omitempty"` // null for new entries + MCPClientName string `json:"mcp_client_name" validate:"required"` + ToolsToExecute schemas.WhiteList `json:"tools_to_execute,omitempty"` } `json:"mcp_configs,omitempty"` TeamID *string `json:"team_id,omitempty"` CustomerID *string `json:"customer_id,omitempty"` @@ -494,12 +494,19 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { } } + if err := pc.AllowedModels.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid allowed_models for provider %s: %w", pc.Provider, err)} + } + if err := pc.KeyIDs.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid key_ids for provider %s: %w", pc.Provider, err)} + } + // Get keys for this provider config if specified var keys []configstoreTables.TableKey allowAllKeys := false - if len(pc.KeyIDs) == 1 && pc.KeyIDs[0] == "*" { + if pc.KeyIDs.IsUnrestricted() { allowAllKeys = true - } else if len(pc.KeyIDs) > 0 { + } else if !pc.KeyIDs.IsEmpty() { var err error keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) if err != nil { @@ -566,12 +573,15 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { seenMCPClientNames := make(map[string]bool) for _, mc := range req.MCPConfigs { if seenMCPClientNames[mc.MCPClientName] { - return fmt.Errorf("duplicate mcp_client_name: %s", mc.MCPClientName) + return &badRequestError{err: fmt.Errorf("duplicate mcp_client_name: %s", mc.MCPClientName)} } seenMCPClientNames[mc.MCPClientName] = true } for _, mc := range req.MCPConfigs { + if err := mc.ToolsToExecute.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid tools_to_execute for mcp client %s: %w", mc.MCPClientName, err)} + } mcpClient, err := h.configStore.GetMCPClientByName(ctx, mc.MCPClientName) if err != nil { return fmt.Errorf("failed to get MCP client: %w", err) @@ -587,8 +597,8 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { } return nil }); err != nil { - // Check if this is a duplicate MCPClientName error and return 400 instead of 500 - if strings.Contains(err.Error(), "duplicate mcp_client_name:") { + var badReqErr *badRequestError + if errors.As(err, &badReqErr) { SendError(ctx, 400, err.Error()) return } @@ -835,12 +845,19 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return fmt.Errorf("both max_limit and reset_duration are required when creating a new provider budget") } } + if err := pc.AllowedModels.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid allowed_models for provider %s: %w", pc.Provider, err)} + } + if err := pc.KeyIDs.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid key_ids for provider %s: %w", pc.Provider, err)} + } + // Get keys for this provider config if specified var keys []configstoreTables.TableKey allowAllKeys := false - if len(pc.KeyIDs) == 1 && pc.KeyIDs[0] == "*" { + if pc.KeyIDs.IsUnrestricted() { allowAllKeys = true - } else if len(pc.KeyIDs) > 0 { + } else if !pc.KeyIDs.IsEmpty() { var err error keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) if err != nil { @@ -906,6 +923,12 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return fmt.Errorf("provider config %d does not belong to this virtual key", *pc.ID) } requestConfigsMap[*pc.ID] = true + if err := pc.AllowedModels.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid allowed_models for provider %s: %w", pc.Provider, err)} + } + if err := pc.KeyIDs.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid key_ids for provider %s: %w", pc.Provider, err)} + } existing.Provider = pc.Provider existing.Weight = pc.Weight existing.AllowedModels = pc.AllowedModels @@ -913,9 +936,9 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { // Get keys for this provider config if specified var keys []configstoreTables.TableKey allowAllKeys := false - if len(pc.KeyIDs) == 1 && pc.KeyIDs[0] == "*" { + if pc.KeyIDs.IsUnrestricted() { allowAllKeys = true - } else if len(pc.KeyIDs) > 0 { + } else if !pc.KeyIDs.IsEmpty() { var err error keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) if err != nil { @@ -1054,7 +1077,7 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { seenMCPClientNames := make(map[string]bool) for _, mc := range req.MCPConfigs { if seenMCPClientNames[mc.MCPClientName] { - return fmt.Errorf("duplicate mcp_client_name: %s", mc.MCPClientName) + return &badRequestError{err: fmt.Errorf("duplicate mcp_client_name: %s", mc.MCPClientName)} } seenMCPClientNames[mc.MCPClientName] = true } @@ -1071,6 +1094,9 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { requestMCPConfigsMap := make(map[uint]bool) // Process new configs: create new ones and update existing ones for _, mc := range req.MCPConfigs { + if err := mc.ToolsToExecute.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid tools_to_execute for mcp client %s: %w", mc.MCPClientName, err)} + } if mc.ID == nil { mcpClient, err := h.configStore.GetMCPClientByName(ctx, mc.MCPClientName) if err != nil { @@ -1130,11 +1156,10 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return nil }); err != nil { - errMsg := err.Error() - // Check if this is a duplicate MCPClientName error and return 400 instead of 500 - if strings.Contains(errMsg, "duplicate mcp_client_name:") || - strings.Contains(errMsg, "already exists'") || - strings.Contains(errMsg, "duplicate key") { + var badReqErr *badRequestError + if errors.As(err, &badReqErr) || + strings.Contains(err.Error(), "already exists") || + strings.Contains(err.Error(), "duplicate key") { SendError(ctx, 400, fmt.Sprintf("Failed to update virtual key: %v", err)) return } diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 7c6e948147..6332536fac 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "fmt" - "slices" "sort" "strconv" "strings" @@ -304,8 +303,8 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { } // Auto-clear tools_to_auto_execute if tools_to_execute is empty // If no tools are allowed to execute, no tools can be auto-executed - if len(req.ToolsToExecute) == 0 { - req.ToolsToAutoExecute = []string{} + if req.ToolsToExecute.IsEmpty() { + req.ToolsToAutoExecute = schemas.WhiteList{} } if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_auto_execute: %v", err)) @@ -506,8 +505,8 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { } // Auto-clear tools_to_auto_execute if tools_to_execute is empty // If no tools are allowed to execute, no tools can be auto-executed - if len(req.ToolsToExecute) == 0 { - req.ToolsToAutoExecute = []string{} + if req.ToolsToExecute.IsEmpty() { + req.ToolsToAutoExecute = schemas.WhiteList{} } // Validate tools_to_auto_execute if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { @@ -648,64 +647,30 @@ func getIDFromCtx(ctx *fasthttp.RequestCtx) (string, error) { return idStr, nil } -func validateToolsToExecute(toolsToExecute []string) error { - if len(toolsToExecute) > 0 { - // Check if wildcard "*" is combined with other tool names - hasWildcard := slices.Contains(toolsToExecute, "*") - if hasWildcard && len(toolsToExecute) > 1 { - return fmt.Errorf("invalid tools_to_execute: wildcard '*' cannot be combined with other tool names") - } - - // Check for duplicate entries - seen := make(map[string]bool) - for _, tool := range toolsToExecute { - if seen[tool] { - return fmt.Errorf("invalid tools_to_execute: duplicate tool name '%s'", tool) - } - seen[tool] = true - } +func validateToolsToExecute(toolsToExecute schemas.WhiteList) error { + if err := toolsToExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_execute: %w", err) } - return nil } -func validateToolsToAutoExecute(toolsToAutoExecute []string, toolsToExecute []string) error { - if len(toolsToAutoExecute) > 0 { - // Check if wildcard "*" is combined with other tool names - hasWildcard := slices.Contains(toolsToAutoExecute, "*") - if hasWildcard && len(toolsToAutoExecute) > 1 { - return fmt.Errorf("wildcard '*' cannot be combined with other tool names") - } - - // Check for duplicate entries - seen := make(map[string]bool) - for _, tool := range toolsToAutoExecute { - if seen[tool] { - return fmt.Errorf("duplicate tool name '%s'", tool) - } - seen[tool] = true - } +func validateToolsToAutoExecute(toolsToAutoExecute schemas.WhiteList, toolsToExecute schemas.WhiteList) error { + if err := toolsToAutoExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_auto_execute: %w", err) + } - // Check that all tools in ToolsToAutoExecute are also in ToolsToExecute - // Create a set of allowed tools from ToolsToExecute - allowedTools := make(map[string]bool) - hasWildcardInExecute := slices.Contains(toolsToExecute, "*") - if hasWildcardInExecute { - // If "*" is in ToolsToExecute, all tools are allowed + if !toolsToAutoExecute.IsEmpty() { + // If ToolsToExecute allows all, no further cross-validation needed + if toolsToExecute.IsUnrestricted() { return nil } - for _, tool := range toolsToExecute { - allowedTools[tool] = true - } - // Validate each tool in ToolsToAutoExecute + // Check that all tools in ToolsToAutoExecute are also in ToolsToExecute for _, tool := range toolsToAutoExecute { if tool == "*" { - // Wildcard is allowed if "*" is in ToolsToExecute - if !hasWildcardInExecute { - return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) - } - } else if !allowedTools[tool] { + return fmt.Errorf("tool '*' in tools_to_auto_execute requires '*' in tools_to_execute") + } + if !toolsToExecute.Contains(tool) { return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) } } diff --git a/transports/bifrost-http/handlers/mcpserver.go b/transports/bifrost-http/handlers/mcpserver.go index 8cd077b4c9..2ab0dddc6f 100644 --- a/transports/bifrost-http/handlers/mcpserver.go +++ b/transports/bifrost-http/handlers/mcpserver.go @@ -6,7 +6,6 @@ import ( "bufio" "context" "fmt" - "slices" "strings" "sync" @@ -317,13 +316,13 @@ func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schema if len(vk.MCPConfigs) > 0 { for _, vkMcpConfig := range vk.MCPConfigs { - if len(vkMcpConfig.ToolsToExecute) == 0 { + if vkMcpConfig.ToolsToExecute.IsEmpty() { // No tools specified in virtual key config - skip this client entirely continue } // Handle wildcard in virtual key config - allow all tools from this client - if slices.Contains(vkMcpConfig.ToolsToExecute, "*") { + if vkMcpConfig.ToolsToExecute.IsUnrestricted() { // Virtual key uses wildcard - use client-specific wildcard executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) continue diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 32a00b50b9..a00ee3789a 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -752,10 +752,10 @@ func (h *ProviderHandler) filterModelsByKeys(provider schemas.ModelProvider, mod for _, keyID := range keyIDs { for _, key := range config.Keys { if key.ID == keyID { - if slices.Contains(key.Models, "*") { + if key.Models.IsUnrestricted() { // Key allows all models (wildcard) hasUnrestrictedKey = true - } else if len(key.Models) > 0 { + } else if !key.Models.IsEmpty() { // Key has specific model restrictions - add them to allowedModels hasRestrictedKey = true for _, model := range key.Models { diff --git a/transports/bifrost-http/handlers/utils.go b/transports/bifrost-http/handlers/utils.go index 554cf3aad3..bcc35d62fa 100644 --- a/transports/bifrost-http/handlers/utils.go +++ b/transports/bifrost-http/handlers/utils.go @@ -20,6 +20,13 @@ type pluginDisabledKey struct{} // PluginDisabledKey is the context key used to indicate a plugin is being disabled. var PluginDisabledKey pluginDisabledKey +// badRequestError wraps a client input validation error so that outer handlers +// can distinguish it from internal server errors and return HTTP 400. +type badRequestError struct{ err error } + +func (e *badRequestError) Error() string { return e.err.Error() } +func (e *badRequestError) Unwrap() error { return e.err } + // SendJSON sends a JSON response with 200 OK status func SendJSON(ctx *fasthttp.RequestCtx, data interface{}) { ctx.SetContentType("application/json") @@ -115,7 +122,7 @@ func IsOriginAllowed(origin string, allowedOrigins []string) bool { return true } - if allowedOrigin == "*" { + if allowedOrigin == "*" { return true } diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index 460db23796..86dbfb6871 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -149,7 +149,6 @@ func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.Requ return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { hydrateOpenAIRequestFromLargePayloadMetadata(bifrostCtx, req) - azureKey := ctx.Request.Header.Peek("authorization") deploymentEndpoint := ctx.Request.Header.Peek("x-bf-azure-endpoint") apiVersion := string(ctx.QueryArgs().Peek("api-version")) @@ -280,7 +279,7 @@ func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.Requ key := schemas.Key{ ID: uuid.New().String(), - Models: []string{}, + Models: schemas.WhiteList{"*"}, AzureKeyConfig: &schemas.AzureKeyConfig{}, } @@ -539,7 +538,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType { return schemas.ChatCompletionRequest }, - GetRequestTypeInstance: func(ctx context.Context) interface{} { + GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIChatRequest{} }, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { @@ -3294,4 +3293,3 @@ func parseContainerFileCreateMultipartRequest(ctx *fasthttp.RequestCtx, req inte return nil } - diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index ea0c8b0a72..7131faab23 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -152,8 +152,8 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat "transfer-encoding": true, // prevent auth/key overrides via x-bf-eh-* - "x-api-key": true, - "x-goog-api-key": true, + "x-api-key": true, + "x-goog-api-key": true, "x-bf-api-key": true, "x-bf-api-key-id": true, "x-bf-vk": true, @@ -458,8 +458,8 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat key := schemas.Key{ ID: "header-provided", // Identifier for header-provided keys Value: *schemas.NewEnvVar(apiKey), - Models: []string{}, // Empty models list - will be validated by provider - Weight: 1.0, // Default weight + Models: schemas.WhiteList{"*"}, // Allow all models + Weight: 1.0, // Default weight } bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) }