Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -6103,9 +6103,9 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p
// - If key.Models is empty → exclude key (deny-by-default)
// - If key.Models is non-empty → only include if model is in list
if model != nil && *model != "" {
if slices.Contains(k.Models, "*") {
if k.Models.IsUnrestricted() {
// wildcard: allow all models
} else if len(k.Models) == 0 || !slices.Contains(k.Models, *model) {
} else if !k.Models.IsAllowed(*model) {
continue
}
}
Expand Down Expand Up @@ -6199,7 +6199,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex
}
hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType)
// ["*"] = allow all models; [] = deny all; specific list = allow only listed
modelSupported := hasValue && (slices.Contains(key.Models, "*") || slices.Contains(key.Models, model))
modelSupported := hasValue && key.Models.IsAllowed(model)
// Additional deployment checks for Azure, Bedrock and Vertex
deploymentSupported := true
if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil {
Expand Down
27 changes: 12 additions & 15 deletions core/mcp/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ import (
"fmt"
"strings"
"sync"

"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)


type AgentModeExecutor struct {
logger schemas.Logger
}
Expand Down Expand Up @@ -39,7 +38,7 @@ func (a *AgentModeExecutor) ExecuteAgentForChatRequest(
makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError),
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
clientManager ClientManager,
clientManager ClientManager,
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
// Create adapter for Chat API
adapter := &chatAPIAdapter{
Expand Down Expand Up @@ -142,7 +141,7 @@ func (a *AgentModeExecutor) executeAgent(
adapter agentAPIAdapter,
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string,
executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error),
clientManager ClientManager,
clientManager ClientManager,
) (interface{}, *schemas.BifrostError) {
// Get initial response from adapter
currentResponse := adapter.getInitialResponse()
Expand Down Expand Up @@ -454,25 +453,23 @@ func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager C

// Get auto-executable tools from config
toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute
if len(toolsToAutoExecute) == 0 {
if toolsToAutoExecute.IsEmpty() {
// No auto-executable tools configured for this client
continue
}

// Parse tool names (as they appear in JavaScript code)
autoExecutableTools := []string{}
for _, originalToolName := range toolsToAutoExecute {
// Handle wildcard "*" - means all tools are auto-executable
if originalToolName == "*" {
autoExecutableTools = append(autoExecutableTools, "*")
continue
if toolsToAutoExecute.IsUnrestricted() {
autoExecutableTools = append(autoExecutableTools, "*")
} else {
for _, originalToolName := range toolsToAutoExecute {
// Replace - with _ for code mode compatibility, then parse for JS compatibility
toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_")
parsedToolName := parseToolName(toolNameForCode)
autoExecutableTools = append(autoExecutableTools, parsedToolName)
}
// Replace - with _ for code mode compatibility, then parse for JS compatibility
toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_")
parsedToolName := parseToolName(toolNameForCode)
autoExecutableTools = append(autoExecutableTools, parsedToolName)
}

// Add to map if there are auto-executable tools
if len(autoExecutableTools) > 0 {
allowedTools[clientName] = autoExecutableTools
Expand Down
12 changes: 6 additions & 6 deletions core/mcp/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,12 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b
// If ToolsToExecute is specified (not nil), apply filtering
if config.ToolsToExecute != nil {
// Handle empty array [] - means no tools are allowed
if len(config.ToolsToExecute) == 0 {
if config.ToolsToExecute.IsEmpty() {
return true // No tools allowed
}

// Handle wildcard "*" - if present, all tools are allowed
if slices.Contains(config.ToolsToExecute, "*") {
if config.ToolsToExecute.IsUnrestricted() {
return false // All tools allowed
}

Expand All @@ -396,7 +396,7 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b
unprefixedToolName := stripClientPrefix(toolName, config.Name)

// Check if specific tool is in the allowed list
return !slices.Contains(config.ToolsToExecute, unprefixedToolName) // Tool not in allowed list
return !config.ToolsToExecute.Contains(unprefixedToolName) // Tool not in allowed list
}

return true // Tool is skipped (nil is treated as [] - no tools)
Expand All @@ -413,12 +413,12 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool {
// If ToolsToAutoExecute is specified (not nil), apply filtering
if config.ToolsToAutoExecute != nil {
// Handle empty array [] - means no tools are auto-executed
if len(config.ToolsToAutoExecute) == 0 {
if config.ToolsToAutoExecute.IsEmpty() {
return false // No tools auto-executed
}

// Handle wildcard "*" - if present, all tools are auto-executed
if slices.Contains(config.ToolsToAutoExecute, "*") {
if config.ToolsToAutoExecute.IsUnrestricted() {
return true // All tools auto-executed
}

Expand All @@ -428,7 +428,7 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool {
unprefixedToolName := stripClientPrefix(toolName, config.Name)

// Check if specific tool is in the auto-execute list
return slices.Contains(config.ToolsToAutoExecute, unprefixedToolName)
return config.ToolsToAutoExecute.Contains(unprefixedToolName)
}

return false // Tool is not auto-executed (nil is treated as [] - no tools)
Expand Down
10 changes: 7 additions & 3 deletions core/providers/anthropic/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/maximhq/bifrost/core/schemas"
)

func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -24,10 +24,14 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide
bifrostResponse.NextPageToken = *response.LastID
}

if !unfiltered && allowedModels.IsEmpty() {
return bifrostResponse
}

includedModels := make(map[string]bool)
for _, model := range response.Data {
modelID := model.ID
if !unfiltered && len(allowedModels) > 0 {
if !unfiltered && allowedModels.IsRestricted() {
allowed := false
for _, allowedModel := range allowedModels {
if schemas.SameBaseModel(model.ID, allowedModel) {
Expand All @@ -49,7 +53,7 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide
}

// Backfill allowed models that were not in the response
if !unfiltered && len(allowedModels) > 0 {
if !unfiltered && allowedModels.IsRestricted() {
for _, allowedModel := range allowedModels {
if !includedModels[allowedModel] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
Expand Down
50 changes: 28 additions & 22 deletions core/providers/azure/models.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
package azure

import (
"slices"
"strings"

"github.com/maximhq/bifrost/core/schemas"
)

// findMatchingAllowedModel finds a matching item in a slice, considering both
// findMatchingAllowedModel finds a matching item in a whitelist, considering both
// exact match and base model matches (ignoring version suffixes).
// Returns the matched item from the slice if found, empty string otherwise.
// If matched via base model, returns the item from slice (not the value parameter).
func findMatchingAllowedModel(slice []string, value string) string {
// First check exact match
if slices.Contains(slice, value) {
// Returns the matched item from the whitelist if found, empty string otherwise.
// If matched via base model, returns the item from whitelist (not the value parameter).
func findMatchingAllowedModel(wl schemas.WhiteList, value string) string {
// First check exact match (case-insensitive)
if wl.Contains(value) {
return value
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
}

// Additional layer: check base model matches (ignoring version suffixes)
// This handles cases where model versions differ but base model is the same
// Return the item from slice (not value) to use the actual name from allowedModels
for _, item := range slice {
// Return the item from whitelist (not value) to use the actual name from allowedModels
for _, item := range wl {
if schemas.SameBaseModel(item, value) {
return item
}
Expand Down Expand Up @@ -58,7 +58,7 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym
return "", ""
}

func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -67,6 +67,12 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
Data: make([]schemas.Model, 0, len(response.Data)),
}

if !unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 {
return bifrostResponse
}

restrictAllowed := !unfiltered && allowedModels.IsRestricted()

includedModels := make(map[string]bool)
for _, model := range response.Data {
modelID := model.ID
Expand All @@ -78,7 +84,7 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
// Empty lists mean "allow all" for that dimension
// Check considering base model matches (ignoring version suffixes)
shouldFilter := false
if !unfiltered && len(allowedModels) > 0 && len(deployments) > 0 {
if restrictAllowed && len(deployments) > 0 {
// Both lists are present: model must be in allowedModels AND deployments
// AND the deployment alias must also be in allowedModels
matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID)
Expand All @@ -88,12 +94,12 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
// Check if deployment alias is also in allowedModels (direct string match)
deploymentAliasInAllowedModels := false
if deploymentAlias != "" {
deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias)
deploymentAliasInAllowedModels = allowedModels.Contains(deploymentAlias)
}

// Filter if: model not in deployments OR deployment alias not in allowedModels
shouldFilter = !inDeployments || !deploymentAliasInAllowedModels
} else if !unfiltered && len(allowedModels) > 0 {
} else if restrictAllowed {
// Only allowedModels is present: filter if model is not in allowedModels
matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID)
shouldFilter = matchedAllowedModel == ""
Expand All @@ -102,7 +108,7 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID)
shouldFilter = deploymentValue == ""
}
// If both are empty, shouldFilter remains false (allow all)
// If both are empty (or allowedModels is unrestricted and no deployments), shouldFilter remains false

if shouldFilter {
continue
Expand All @@ -124,9 +130,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
if deploymentValue != "" && deploymentAlias != "" {
modelEntry.ID = string(schemas.Azure) + "/" + deploymentAlias
modelEntry.Deployment = schemas.Ptr(deploymentValue)
includedModels[deploymentAlias] = true
includedModels[strings.ToLower(deploymentAlias)] = true
} else {
includedModels[modelID] = true
includedModels[strings.ToLower(modelID)] = true
}

bifrostResponse.Data = append(bifrostResponse.Data, modelEntry)
Expand All @@ -135,26 +141,26 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
// Backfill deployments that were not matched from the API response
if !unfiltered && len(deployments) > 0 {
for alias, deploymentValue := range deployments {
if includedModels[alias] {
if includedModels[strings.ToLower(alias)] {
continue
}
// If allowedModels is non-empty, only include if alias is in the list
if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) {
// If allowedModels is restricted, only include if alias is in the list
if restrictAllowed && !allowedModels.Contains(alias) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(schemas.Azure) + "/" + alias,
Name: schemas.Ptr(alias),
Deployment: schemas.Ptr(deploymentValue),
})
includedModels[alias] = true
includedModels[strings.ToLower(alias)] = true
}
}

// Backfill allowed models that were not in the response
if !unfiltered && len(allowedModels) > 0 {
if restrictAllowed {
for _, allowedModel := range allowedModels {
if !includedModels[allowedModel] {
if !includedModels[strings.ToLower(allowedModel)] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(schemas.Azure) + "/" + allowedModel,
Name: schemas.Ptr(allowedModel),
Expand Down
Loading