diff --git a/core/go.mod b/core/go.mod index f0b7d8c55b..5a5d0150b7 100644 --- a/core/go.mod +++ b/core/go.mod @@ -17,6 +17,7 @@ require ( github.com/bytedance/sonic v1.15.0 github.com/fasthttp/websocket v1.5.12 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 github.com/hajimehoshi/go-mp3 v0.3.4 github.com/klauspost/compress v1.18.2 github.com/mark3labs/mcp-go v0.43.2 diff --git a/core/internal/mcptests/agent_test_helpers.go b/core/internal/mcptests/agent_test_helpers.go index d19d953ca0..85512dcce6 100644 --- a/core/internal/mcptests/agent_test_helpers.go +++ b/core/internal/mcptests/agent_test_helpers.go @@ -131,11 +131,11 @@ func SetupAgentTest(t *testing.T, config AgentTestConfig) (*mcp.MCPManager, *Dyn // Create context with filtering baseCtx := context.Background() - if len(config.ClientFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, config.ClientFiltering) + if config.ClientFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering) } - if len(config.ToolFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, config.ToolFiltering) + if config.ToolFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) @@ -192,11 +192,11 @@ func SetupAgentTestWithClients(t *testing.T, config AgentTestConfig, customClien // Create context with filtering baseCtx := context.Background() - if len(config.ClientFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, config.ClientFiltering) + if config.ClientFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering) } - if len(config.ToolFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, config.ToolFiltering) + if config.ToolFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) diff --git a/core/internal/mcptests/codemode_stdio_test.go b/core/internal/mcptests/codemode_stdio_test.go index 8fe5841a82..aab3a15172 100644 --- a/core/internal/mcptests/codemode_stdio_test.go +++ b/core/internal/mcptests/codemode_stdio_test.go @@ -56,27 +56,27 @@ func setupCodeModeWithSTDIOServers(t *testing.T, serverNames ...string) (*mcp.MC config = GetTemperatureMCPClientConfig(bifrostRoot) config.IsCodeModeClient = true config.ID = "temperature-client" // Match test expectations - config.Name = "temperature" // Use lowercase to match test code + config.Name = "temperature" // Use lowercase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "go-test-server": config = GetGoTestServerConfig(bifrostRoot) config.ID = "goTestServer-client" // Match test expectations - config.Name = "goTestServer" // Use camelCase to match test code + config.Name = "goTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "edge-case-server": config = GetEdgeCaseServerConfig(bifrostRoot) config.ID = "edgeCaseServer-client" // Match test expectations - config.Name = "edgeCaseServer" // Use camelCase to match test code + config.Name = "edgeCaseServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "error-test-server": config = GetErrorTestServerConfig(bifrostRoot) config.ID = "errorTestServer-client" // Match test expectations - config.Name = "errorTestServer" // Use camelCase to match test code + config.Name = "errorTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "parallel-test-server": config = GetParallelTestServerConfig(bifrostRoot) config.ID = "parallelTestServer-client" // Match test expectations - config.Name = "parallelTestServer" // Use camelCase to match test code + config.Name = "parallelTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "test-tools-server": // test-tools-server doesn't have a fixture, set up manually @@ -367,9 +367,9 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { expectedError string }{ { - name: "allow_only_test_tools_server", - includeClients: []string{"testToolsServer"}, - code: `result = testToolsServer.echo(message="allowed")`, + name: "allow_only_test_tools_server", + includeClients: []string{"testToolsServer"}, + code: `result = testToolsServer.echo(message="allowed")`, shouldSucceed: true, expectedInResult: "allowed", }, @@ -377,13 +377,13 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { name: "block_test_tools_server", includeClients: []string{"temperature"}, code: `result = testToolsServer.echo(message="blocked")`, - shouldSucceed: false, - expectedError: "undefined: testToolsServer", + shouldSucceed: false, + expectedError: "undefined: testToolsServer", }, { - name: "allow_only_temperature_server", - includeClients: []string{"temperature"}, - code: `result = temperature.get_temperature(location="Paris")`, + name: "allow_only_temperature_server", + includeClients: []string{"temperature"}, + code: `result = temperature.get_temperature(location="Paris")`, shouldSucceed: true, expectedInResult: "Paris", }, @@ -391,8 +391,8 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { name: "block_temperature_server", includeClients: []string{"testToolsServer"}, code: `result = temperature.get_temperature(location="blocked")`, - shouldSucceed: false, - expectedError: "undefined: temperature", + shouldSucceed: false, + expectedError: "undefined: temperature", }, { name: "allow_both_servers", @@ -409,7 +409,7 @@ result = {"echo": echo, "temp": temp}`, t.Run(tc.name, func(t *testing.T) { // Create context with client filtering baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) // Verify filtering is applied at tool listing level @@ -524,7 +524,7 @@ result = {"echo": echo, "calc": calc}`, t.Run(tc.name, func(t *testing.T) { // Create context with tool filtering baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, tc.includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, tc.includeTools) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) // Verify filtering is applied @@ -622,10 +622,10 @@ result = {"echo": echo, "temp": temp}`, // Create context with both client and tool filtering baseCtx := context.Background() if tc.includeClients != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) } if tc.includeTools != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, tc.includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, tc.includeTools) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) @@ -1692,7 +1692,7 @@ result = {"count": 3}`, for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) toolCall := CreateExecuteToolCodeCall(fmt.Sprintf("call-%s", tc.name), tc.code) diff --git a/core/internal/mcptests/concurrency_advanced_test.go b/core/internal/mcptests/concurrency_advanced_test.go index a1c3823831..e3c5793df4 100644 --- a/core/internal/mcptests/concurrency_advanced_test.go +++ b/core/internal/mcptests/concurrency_advanced_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -533,14 +532,14 @@ func TestConcurrent_FilteringChanges(t *testing.T) { if id%2 == 0 { // Even: allow all tools baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, []string{"*"}) - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, []string{"bifrostInternal-*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, []string{"*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, []string{"bifrostInternal-*"}) ctx = schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } else { // Odd: allow only echo baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, []string{"*"}) - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, []string{"bifrostInternal-echo"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, []string{"*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, []string{"bifrostInternal-echo"}) ctx = schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } diff --git a/core/internal/mcptests/fixtures.go b/core/internal/mcptests/fixtures.go index 88b00a9f70..67ae429303 100644 --- a/core/internal/mcptests/fixtures.go +++ b/core/internal/mcptests/fixtures.go @@ -1984,10 +1984,10 @@ func AssertExecutionTimeUnder(t *testing.T, fn func(), maxDuration time.Duration func CreateTestContextWithMCPFilter(includeClients []string, includeTools []string) *schemas.BifrostContext { baseCtx := context.Background() if includeClients != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, includeClients) } if includeTools != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, includeTools) } return schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go index b86409ef1a..a38dd99307 100644 --- a/core/mcp/mcp.go +++ b/core/mcp/mcp.go @@ -21,13 +21,6 @@ const ( BifrostMCPClientKey = "bifrostInternal" // Key for internal Bifrost client in clientMap MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment - - // Context keys for client filtering in requests - // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). - // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. - // Request context filtering takes priority over client config - context can override client exclusions. - MCPContextKeyIncludeClients schemas.BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering - MCPContextKeyIncludeTools schemas.BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName-toolName" format for individual tools, or "clientName-*" for wildcard) ) // ============================================================================ diff --git a/core/mcp/utils.go b/core/mcp/utils.go index fad5d2a7bf..dcd39acd6c 100644 --- a/core/mcp/utils.go +++ b/core/mcp/utils.go @@ -65,7 +65,7 @@ func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas. var includeClients []string // Extract client filtering from request context - if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { + if existingIncludeClients, ok := ctx.Value(schemas.MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { includeClients = existingIncludeClients } @@ -439,7 +439,7 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // Context filtering can only NARROW the tools available, NOT expand beyond client configuration. // This is checked AFTER client-level filtering (shouldSkipToolForConfig). func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool { - includeTools := ctx.Value(MCPContextKeyIncludeTools) + includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools) if includeTools != nil { // Try []string first (preferred type) @@ -752,6 +752,7 @@ func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool { if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { return true } + // Check if message has tool calls regardless of finish_reason. // Some providers (e.g. Gemini) return finish_reason "stop" even when tool calls are present, // so we cannot rely solely on finish_reason to detect tool calls. diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 60ebeb9702..07f8f33f46 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -158,13 +158,20 @@ type BifrostContextKey string // BifrostContextKeyRequestType is a context key for the request type. const ( - BifrostContextKeySessionToken BifrostContextKey = "bifrost-session-token" // string (session token for authentication - set by auth middleware) - BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string - BifrostContextKeyAPIKeyName BifrostContextKey = "x-bf-api-key" // string (explicit key name selection) - BifrostContextKeyAPIKeyID BifrostContextKey = "x-bf-api-key-id" // string (explicit key ID selection, takes priority over name) - BifrostContextKeyRequestID BifrostContextKey = "request-id" // string - BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string - BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct + BifrostContextKeySessionToken BifrostContextKey = "bifrost-session-token" // string (session token for authentication - set by auth middleware) + BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string + BifrostContextKeyAPIKeyName BifrostContextKey = "x-bf-api-key" // string (explicit key name selection) + BifrostContextKeyAPIKeyID BifrostContextKey = "x-bf-api-key-id" // string (explicit key ID selection, takes priority over name) + BifrostContextKeyRequestID BifrostContextKey = "request-id" // string + BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string + BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct + + // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). + // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. + // Request context filtering takes priority over client config - context can override client exclusions. + MCPContextKeyIncludeClients BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering + MCPContextKeyIncludeTools BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName-toolName" format for individual tools, or "clientName-*" for wildcard) + BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernanceVirtualKeyID BifrostContextKey = "bifrost-governance-virtual-key-id" // string (to store the virtual key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) diff --git a/docs/features/governance/mcp-tools.mdx b/docs/features/governance/mcp-tools.mdx index 2f709ac6b7..a669f195f2 100644 --- a/docs/features/governance/mcp-tools.mdx +++ b/docs/features/governance/mcp-tools.mdx @@ -15,8 +15,8 @@ Make sure you have at least one MCP client set up. Read more about it [here](../ The filtering logic is determined by the Virtual Key's configuration: 1. **No MCP Configuration on Virtual Key (Default)** - - If a Virtual Key has no specific MCP configurations, all tools from all enabled MCP clients are available by default. - - In this state, a user can still manually filter tools for a single request by passing the `x-bf-mcp-include-tools` header. + - If a Virtual Key has no specific MCP configurations, **no MCP tools are available** (deny-by-default). + - You must explicitly add MCP client configurations to allow tools. 2. **With MCP Configuration on Virtual Key** - When you configure MCP clients on a Virtual Key, its settings take full precedence. diff --git a/docs/features/governance/routing.mdx b/docs/features/governance/routing.mdx index 1790b15160..937d48a6f7 100644 --- a/docs/features/governance/routing.mdx +++ b/docs/features/governance/routing.mdx @@ -28,8 +28,8 @@ This powerful feature enables key use cases like: Virtual Keys can be restricted to use only specific provider/models. When provider/model restrictions are configured, the VK can only access those designated provider/models, providing fine-grained control over which provider/models different users or applications can utilize. **How It Works:** -- **No Restrictions** (default): VK can use any available provider/models based on global configuration -- **With Restrictions**: VK limited to only the specified provider/models with weighted load balancing +- **No Provider Configs** (default): VK **blocks all providers** (deny-by-default). You must add provider configurations to allow traffic. +- **With Provider Configs**: VK limited to only the specified provider/models. Configured providers participate in weighted load balancing only if their `weight` is set to a numeric value, while providers with `weight: null` remain configured but are opted out of weighted selection. **Model Validation:** When you configure provider restrictions on a Virtual Key, Bifrost validates that the requested model is allowed for the selected provider: diff --git a/docs/mcp/filtering.mdx b/docs/mcp/filtering.mdx index b340a8e92a..c2e8a80397 100644 --- a/docs/mcp/filtering.mdx +++ b/docs/mcp/filtering.mdx @@ -227,7 +227,7 @@ This consistent naming convention ensures clear separation between tools from di Virtual Keys can have their own MCP tool access configuration, which **takes precedence** over request-level headers. -When a Virtual Key has MCP configurations, it generates the `x-bf-mcp-include-tools` header automatically, overriding any manually sent header. +When a Virtual Key has no MCP configurations, **no MCP tools are available** (deny-by-default). You must explicitly add MCP client configurations to allow tools. When a Virtual Key has MCP configurations, it generates the `x-bf-mcp-include-tools` header automatically, overriding any manually sent header. ### Configuration diff --git a/docs/openapi/schemas/management/governance.yaml b/docs/openapi/schemas/management/governance.yaml index 21be22ff47..7053f1a8dc 100644 --- a/docs/openapi/schemas/management/governance.yaml +++ b/docs/openapi/schemas/management/governance.yaml @@ -130,6 +130,8 @@ VirtualKeyProviderConfig: type: string weight: type: number + nullable: true + description: Weight for provider load balancing. Null means excluded from weighted routing. allowed_models: type: array items: @@ -195,6 +197,7 @@ CreateVirtualKeyRequest: type: string provider_configs: type: array + description: Provider configurations (empty means no providers allowed, deny-by-default) items: type: object properties: @@ -202,6 +205,8 @@ CreateVirtualKeyRequest: type: string weight: type: number + nullable: true + description: Weight for load balancing. Null means excluded from weighted routing. allowed_models: type: array items: @@ -216,6 +221,7 @@ CreateVirtualKeyRequest: type: string mcp_configs: type: array + description: MCP configurations (empty means no MCP tools allowed, deny-by-default) items: type: object properties: @@ -255,6 +261,8 @@ UpdateVirtualKeyRequest: type: string weight: type: number + nullable: true + description: Weight for load balancing. Null means excluded from weighted routing. allowed_models: type: array items: diff --git a/docs/providers/provider-routing.mdx b/docs/providers/provider-routing.mdx index b7a562196f..3f98c77e66 100644 --- a/docs/providers/provider-routing.mdx +++ b/docs/providers/provider-routing.mdx @@ -911,6 +911,8 @@ When a Virtual Key has `provider_configs` defined: **Empty `allowed_models`**: When left empty, Bifrost uses the Model Catalog (populated from pricing data and the provider's list models API) to determine which models are supported. See the [Model Catalog section](#the-model-catalog) above for how syncing works. For configuration instructions, see [Governance Routing](/features/governance/routing). + +**Empty `provider_configs`**: When `provider_configs` is empty (no providers configured), **all providers are blocked** (deny-by-default). You must explicitly add provider configurations to allow traffic through a Virtual Key. --- @@ -1158,7 +1160,8 @@ curl -X POST http://localhost:8080/v1/chat/completions \ **Setup:** -- No Virtual Key or Virtual Key without `provider_configs` +- **No Virtual Key** (do not send `x-bf-vk`) → this is the **Load Balancing–only** setup +- **Virtual Key with empty / missing `provider_configs`** → **blocks all providers** (deny-by-default) and therefore is **NOT** an LB-only setup - Adaptive load balancing enabled **Request:** @@ -1232,7 +1235,7 @@ curl -X POST http://localhost:8080/v1/chat/completions \ | Scenario | Provider Selection | Key Selection | |----------|-------------------|---------------| | VK with provider_configs | **Governance** (weighted random) | **Standard** or **Adaptive** (if enabled) | -| VK without provider_configs + LB | **Load Balancing Level 1** (performance) | **Load Balancing Level 2** (performance) | +| VK without provider_configs + LB | **Blocked** (empty = no providers allowed) | N/A | | No VK + LB | **Load Balancing Level 1** (performance) | **Load Balancing Level 2** (performance) | | Model with provider prefix + LB | **Skip** (already specified) | **Load Balancing Level 2** (performance) ✅ | | No Load Balancing enabled | **Governance** or **User** or **Model Catalog** | **Standard** (static weights) | diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index ca312225a5..6382689007 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -576,7 +576,7 @@ type VirtualKeyHashInput struct { // VirtualKeyProviderConfigHashInput represents provider config fields for hashing type VirtualKeyProviderConfigHashInput struct { Provider string - Weight float64 + Weight *float64 AllowedModels []string BudgetID *string RateLimitID *string @@ -651,7 +651,14 @@ func GenerateVirtualKeyHash(vk tables.TableVirtualKey) (string, error) { if ri != rj { return ri < rj } - return getWeight(sortedProviderConfigs[i].Weight) < getWeight(sortedProviderConfigs[j].Weight) + wi, wj := sortedProviderConfigs[i].Weight, sortedProviderConfigs[j].Weight + if (wi == nil) != (wj == nil) { + return wi == nil + } + if wi != nil && wj != nil && *wi != *wj { + return *wi < *wj + } + return false }) // Filter out provider configs that are not available providerConfigsForHash := make([]VirtualKeyProviderConfigHashInput, len(sortedProviderConfigs)) @@ -669,7 +676,7 @@ func GenerateVirtualKeyHash(vk tables.TableVirtualKey) (string, error) { sort.Strings(sortedAllowedModels) providerConfigsForHash[i] = VirtualKeyProviderConfigHashInput{ Provider: pc.Provider, - Weight: getWeight(pc.Weight), + Weight: pc.Weight, AllowedModels: sortedAllowedModels, BudgetID: pc.BudgetID, RateLimitID: pc.RateLimitID, diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 0b57e366c9..b7b58daedd 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -314,6 +314,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddPluginOrderColumns(ctx, db); err != nil { return err } + if err := migrationBackfillEmptyVirtualKeyConfigs(ctx, db); err != nil { + return err + } return nil } @@ -3690,6 +3693,125 @@ func migrationAddRateLimitToTeamsAndCustomers(ctx context.Context, db *gorm.DB) return nil } +// migrationBackfillEmptyVirtualKeyConfigs backfills existing virtual keys that have +// empty ProviderConfigs or MCPConfigs with all available providers/MCP clients. +// This preserves the previous "empty means all" behavior for existing VKs after +// the semantic change to "empty means none" (deny-by-default). +func migrationBackfillEmptyVirtualKeyConfigs(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "backfill_empty_virtual_key_configs", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + + // Step 1: Backfill ProviderConfigs for VKs that have none + // Find all virtual keys + var allVKs []tables.TableVirtualKey + if err := tx.Find(&allVKs).Error; err != nil { + return fmt.Errorf("failed to query virtual keys: %w", err) + } + + // Get all available providers + var allProviders []tables.TableProvider + if err := tx.Find(&allProviders).Error; err != nil { + return fmt.Errorf("failed to query providers: %w", err) + } + + // Track which VK IDs were modified so we can recompute their config_hash + modifiedVKIDs := make(map[string]struct{}) + + for _, vk := range allVKs { + // Check if this VK has any provider configs + var providerConfigCount int64 + if err := tx.Model(&tables.TableVirtualKeyProviderConfig{}).Where("virtual_key_id = ?", vk.ID).Count(&providerConfigCount).Error; err != nil { + return fmt.Errorf("failed to count provider configs for VK %s: %w", vk.ID, err) + } + + if providerConfigCount == 0 && len(allProviders) > 0 { + // VK has no provider configs - backfill with all available providers + for _, provider := range allProviders { + providerConfig := tables.TableVirtualKeyProviderConfig{ + VirtualKeyID: vk.ID, + Provider: provider.Name, + Weight: bifrost.Ptr(1.0), + AllowedModels: []string{}, + } + if err := tx.Create(&providerConfig).Error; err != nil { + return fmt.Errorf("failed to create provider config for VK %s, provider %s: %w", vk.ID, provider.Name, err) + } + } + modifiedVKIDs[vk.ID] = struct{}{} + log.Printf("[Migration] Backfilled VK '%s' with %d provider configs", vk.Name, len(allProviders)) + } + } + + // Step 2: Backfill MCPConfigs for VKs that have none + // Get all available MCP clients + var allMCPClients []tables.TableMCPClient + if err := tx.Find(&allMCPClients).Error; err != nil { + return fmt.Errorf("failed to query MCP clients: %w", err) + } + + for _, vk := range allVKs { + // Check if this VK has any MCP configs + var mcpConfigCount int64 + if err := tx.Model(&tables.TableVirtualKeyMCPConfig{}).Where("virtual_key_id = ?", vk.ID).Count(&mcpConfigCount).Error; err != nil { + return fmt.Errorf("failed to count MCP configs for VK %s: %w", vk.ID, err) + } + + if mcpConfigCount == 0 && len(allMCPClients) > 0 { + // VK has no MCP configs - backfill with all available MCP clients with wildcard + for _, mcpClient := range allMCPClients { + mcpConfig := tables.TableVirtualKeyMCPConfig{ + VirtualKeyID: vk.ID, + MCPClientID: mcpClient.ID, + ToolsToExecute: []string{"*"}, + } + if err := tx.Create(&mcpConfig).Error; err != nil { + return fmt.Errorf("failed to create MCP config for VK %s, client %d: %w", vk.ID, mcpClient.ID, err) + } + } + modifiedVKIDs[vk.ID] = struct{}{} + log.Printf("[Migration] Backfilled VK '%s' with %d MCP client configs", vk.Name, len(allMCPClients)) + } + } + + // Step 3: Recompute and persist config_hash for every VK that was modified. + // Without this, subsequent config-sync diff logic would see a stale hash and + // attempt to re-reconcile the VK (potentially undoing the backfill). + for vkID := range modifiedVKIDs { + var vk tables.TableVirtualKey + if err := tx. + Preload("ProviderConfigs"). + Preload("ProviderConfigs.Keys"). + Preload("MCPConfigs"). + First(&vk, "id = ?", vkID).Error; err != nil { + return fmt.Errorf("failed to reload VK %s for hash recomputation: %w", vkID, err) + } + newHash, err := GenerateVirtualKeyHash(vk) + if err != nil { + return fmt.Errorf("failed to generate hash for VK %s: %w", vkID, err) + } + if err := tx.Model(&tables.TableVirtualKey{}). + Where("id = ?", vkID). + Update("config_hash", newHash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for VK %s: %w", vkID, err) + } + log.Printf("[Migration] Recomputed config_hash for VK '%s'", vk.Name) + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + // No rollback needed - the backfilled configs are valid data + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running backfill empty virtual key configs migration: %s", err.Error()) + } + return nil +} + // migrationAddRequiredHeadersJSONColumn adds the required_headers_json column to the config_client table func migrationAddRequiredHeadersJSONColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ diff --git a/framework/configstore/tables/virtualkey.go b/framework/configstore/tables/virtualkey.go index 79594a0adc..4d2f039d35 100644 --- a/framework/configstore/tables/virtualkey.go +++ b/framework/configstore/tables/virtualkey.go @@ -191,7 +191,7 @@ type TableVirtualKey struct { Description string `gorm:"type:text" json:"description,omitempty"` Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:text;not null" json:"value"` // The virtual key value IsActive bool `gorm:"default:true" json:"is_active"` - ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means all providers allowed + ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means no providers allowed (deny-by-default) MCPConfigs []TableVirtualKeyMCPConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"mcp_configs"` // Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both) diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 0098de2fe5..65ae1b3820 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -617,26 +617,40 @@ func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req // No allowed provider configs, continue without modification return body, nil } - // Weighted random selection from allowed providers for the main model - totalWeight := 0.0 + // Separate providers with weight set (participate in routing) from those without (nil weight = excluded from routing) + weightedConfigs := make([]configstoreTables.TableVirtualKeyProviderConfig, 0, len(allowedProviderConfigs)) for _, config := range allowedProviderConfigs { - totalWeight += getWeight(config.Weight) + if config.Weight != nil { + weightedConfigs = append(weightedConfigs, config) + } } - // Generate random number between 0 and totalWeight - randomValue := rand.Float64() * totalWeight - // Select provider based on weighted random selection + var selectedProvider schemas.ModelProvider - currentWeight := 0.0 - for _, config := range allowedProviderConfigs { - currentWeight += getWeight(config.Weight) - if randomValue <= currentWeight { - selectedProvider = schemas.ModelProvider(config.Provider) - break + + if len(weightedConfigs) > 0 { + // Weighted random selection from providers that have weight set + totalWeight := 0.0 + for _, config := range weightedConfigs { + totalWeight += getWeight(config.Weight) } - } - // Fallback: if no provider was selected (shouldn't happen but guard against FP issues) - if selectedProvider == "" && len(allowedProviderConfigs) > 0 { - selectedProvider = schemas.ModelProvider(allowedProviderConfigs[0].Provider) + // Generate random number between 0 and totalWeight + randomValue := rand.Float64() * totalWeight + // Select provider based on weighted random selection + currentWeight := 0.0 + for _, config := range weightedConfigs { + currentWeight += getWeight(config.Weight) + if randomValue <= currentWeight { + selectedProvider = schemas.ModelProvider(config.Provider) + break + } + } + // Fallback: if no provider was selected (shouldn't happen but guard against FP issues) + if selectedProvider == "" { + selectedProvider = schemas.ModelProvider(weightedConfigs[0].Provider) + } + } else { + // No providers have weight set + return body, nil } p.logger.Debug("[Governance] Selected provider: %s", selectedProvider) @@ -667,15 +681,17 @@ func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req // Check if fallbacks field is already present _, hasFallbacks := body["fallbacks"] - if !hasFallbacks && len(allowedProviderConfigs) > 1 { - // Sort allowed provider configs by weight (descending) - sort.Slice(allowedProviderConfigs, func(i, j int) bool { - return getWeight(allowedProviderConfigs[i].Weight) > getWeight(allowedProviderConfigs[j].Weight) + // Use the same candidate set that was used for primary selection + fallbackConfigs := weightedConfigs + if !hasFallbacks && len(fallbackConfigs) > 1 { + // Sort fallback configs by weight (descending) + sort.Slice(fallbackConfigs, func(i, j int) bool { + return getWeight(fallbackConfigs[i].Weight) > getWeight(fallbackConfigs[j].Weight) }) // Filter out the selected provider and create fallbacks array - fallbacks := make([]string, 0, len(allowedProviderConfigs)-1) - for _, config := range allowedProviderConfigs { + fallbacks := make([]string, 0, len(fallbackConfigs)-1) + for _, config := range fallbackConfigs { if config.Provider != string(selectedProvider) { var err error refinedModel := modelStr @@ -847,35 +863,40 @@ func (p *GovernancePlugin) applyRoutingRules(ctx *schemas.BifrostContext, req *s // - map[string]string: The updated request headers // - error: Any error that occurred during processing func (p *GovernancePlugin) addMCPIncludeTools(headers map[string]string, virtualKey *configstoreTables.TableVirtualKey) (map[string]string, error) { - if len(virtualKey.MCPConfigs) > 0 { - if headers == nil { - headers = make(map[string]string) + if headers == nil { + headers = make(map[string]string) + } + + // Empty MCPConfigs means no MCP tools are allowed (deny-by-default) + if len(virtualKey.MCPConfigs) == 0 { + headers["x-bf-mcp-include-tools"] = "" + return headers, nil + } + + executeOnlyTools := make([]string, 0) + for _, vkMcpConfig := range virtualKey.MCPConfigs { + if len(vkMcpConfig.ToolsToExecute) == 0 { + // No tools specified in virtual key config - skip this client entirely + continue + } + // Handle wildcard in virtual key config - allow all tools from this client + if slices.Contains(vkMcpConfig.ToolsToExecute, "*") { + // Virtual key uses wildcard - use client-specific wildcard + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) + continue } - executeOnlyTools := make([]string, 0) - for _, vkMcpConfig := range virtualKey.MCPConfigs { - if len(vkMcpConfig.ToolsToExecute) == 0 { - // No tools specified in virtual key config - skip this client entirely - continue - } - // Handle wildcard in virtual key config - allow all tools from this client - if slices.Contains(vkMcpConfig.ToolsToExecute, "*") { - // Virtual key uses wildcard - use client-specific wildcard - executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) - continue - } - for _, tool := range vkMcpConfig.ToolsToExecute { - if tool != "" { - // Add the tool - client config filtering will be handled by mcp.go - executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool)) - } + for _, tool := range vkMcpConfig.ToolsToExecute { + if tool != "" { + // Add the tool - client config filtering will be handled by mcp.go + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool)) } } - - // Set even when empty to exclude tools when no tools are present in the virtual key config - headers["x-bf-mcp-include-tools"] = strings.Join(executeOnlyTools, ",") } + // Set even when empty to exclude tools when no tools are present in the virtual key config + headers["x-bf-mcp-include-tools"] = strings.Join(executeOnlyTools, ",") + return headers, nil } diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index 8d3da777af..fe56192c5b 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -324,9 +324,9 @@ func (r *BudgetResolver) EvaluateVirtualKeyFiltering(ctx *schemas.BifrostContext // isModelAllowed checks if the requested model is allowed for this VK func (r *BudgetResolver) isModelAllowed(vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, model string) bool { - // Empty ProviderConfigs means all models are allowed + // Empty ProviderConfigs means no models are allowed (deny-by-default) if len(vk.ProviderConfigs) == 0 { - return true + return false } for _, pc := range vk.ProviderConfigs { @@ -350,9 +350,9 @@ func (r *BudgetResolver) isModelAllowed(vk *configstoreTables.TableVirtualKey, p // isProviderAllowed checks if the requested provider is allowed for this VK func (r *BudgetResolver) isProviderAllowed(vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) bool { - // Empty AllowedProviders means all providers are allowed + // Empty ProviderConfigs means no providers are allowed (deny-by-default) if len(vk.ProviderConfigs) == 0 { - return true + return false } for _, pc := range vk.ProviderConfigs { diff --git a/plugins/governance/resolver_test.go b/plugins/governance/resolver_test.go index 7e55c57328..f83c8a53c1 100644 --- a/plugins/governance/resolver_test.go +++ b/plugins/governance/resolver_test.go @@ -359,10 +359,10 @@ func TestBudgetResolver_IsProviderAllowed(t *testing.T) { shouldBeAllowed bool }{ { - name: "No provider configs (all allowed)", + name: "No provider configs (none allowed - deny-by-default)", vk: buildVirtualKey("vk1", "sk-bf-test", "Test", true), provider: schemas.OpenAI, - shouldBeAllowed: true, + shouldBeAllowed: false, }, { name: "Provider in allowlist", @@ -408,11 +408,11 @@ func TestBudgetResolver_IsModelAllowed(t *testing.T) { shouldBeAllowed bool }{ { - name: "No provider configs (all models allowed)", + name: "No provider configs (no models allowed - deny-by-default)", vk: buildVirtualKey("vk1", "sk-bf-test", "Test", true), provider: schemas.OpenAI, model: "gpt-4", - shouldBeAllowed: true, + shouldBeAllowed: false, }, { name: "Empty allowed models (all models allowed)", diff --git a/plugins/governance/utils.go b/plugins/governance/utils.go index f33abb3aa1..b15b6fb02e 100644 --- a/plugins/governance/utils.go +++ b/plugins/governance/utils.go @@ -34,7 +34,7 @@ func ParseVirtualKeyFromFastHTTPRequest(req *fasthttp.RequestCtx) *string { return bifrost.Ptr(xAPIKey) } xGoogleAPIKey := string(req.Request.Header.Peek("x-goog-api-key")) - if xGoogleAPIKey != "" && strings.HasPrefix(strings.ToLower(xGoogleAPIKey), VirtualKeyPrefix) { + if xGoogleAPIKey != "" && strings.HasPrefix(strings.ToLower(xGoogleAPIKey), VirtualKeyPrefix) { return bifrost.Ptr(xGoogleAPIKey) } return nil @@ -99,9 +99,9 @@ func (p *GovernancePlugin) filterModelsForVirtualKey( return []schemas.Model{} // VK not found, return empty list } - // Empty ProviderConfigs means all models are allowed + // Empty ProviderConfigs means no models are allowed (deny-by-default) if len(vk.ProviderConfigs) == 0 { - return models + return []schemas.Model{} } // Filter models based on ProviderConfigs diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index 5fa70f2b12..560d6ec993 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -68,16 +68,16 @@ type CreateVirtualKeyRequest struct { Description string `json:"description,omitempty"` ProviderConfigs []struct { Provider string `json:"provider" validate:"required"` - Weight float64 `json:"weight,omitempty"` + Weight *float64 `json:"weight,omitempty"` AllowedModels []string `json:"allowed_models,omitempty"` // Empty means all models allowed Budget *CreateBudgetRequest `json:"budget,omitempty"` // Provider-level budget RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` // Provider-level rate limit KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this provider config - } `json:"provider_configs,omitempty"` // Empty means all providers allowed + } `json:"provider_configs,omitempty"` // Empty means no providers allowed (deny-by-default) MCPConfigs []struct { MCPClientName string `json:"mcp_client_name" validate:"required"` ToolsToExecute []string `json:"tools_to_execute,omitempty"` - } `json:"mcp_configs,omitempty"` // Empty means all MCP clients allowed + } `json:"mcp_configs,omitempty"` // Empty means no MCP clients allowed (deny-by-default) TeamID *string `json:"team_id,omitempty"` // Mutually exclusive with CustomerID CustomerID *string `json:"customer_id,omitempty"` // Mutually exclusive with TeamID Budget *CreateBudgetRequest `json:"budget,omitempty"` @@ -92,7 +92,7 @@ type UpdateVirtualKeyRequest struct { ProviderConfigs []struct { ID *uint `json:"id,omitempty"` // null for new entries Provider string `json:"provider" validate:"required"` - Weight float64 `json:"weight,omitempty"` + Weight *float64 `json:"weight,omitempty"` AllowedModels []string `json:"allowed_models,omitempty"` // Empty means all models allowed Budget *UpdateBudgetRequest `json:"budget,omitempty"` // Provider-level budget RateLimit *UpdateRateLimitRequest `json:"rate_limit,omitempty"` // Provider-level rate limit @@ -510,7 +510,7 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ VirtualKeyID: vk.ID, Provider: pc.Provider, - Weight: &pc.Weight, + Weight: pc.Weight, AllowedModels: pc.AllowedModels, Keys: keys, } @@ -848,7 +848,7 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ VirtualKeyID: vk.ID, Provider: pc.Provider, - Weight: &pc.Weight, + Weight: pc.Weight, AllowedModels: pc.AllowedModels, Keys: keys, } @@ -899,7 +899,7 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { } requestConfigsMap[*pc.ID] = true existing.Provider = pc.Provider - existing.Weight = &pc.Weight + existing.Weight = pc.Weight existing.AllowedModels = pc.AllowedModels // Get keys for this provider config if specified @@ -1135,7 +1135,12 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { logger.Error("failed to load relationships for updated VK: %v", err) preloadedVk = vk } - h.governanceManager.ReloadVirtualKey(ctx, vk.ID) + if _, err := h.governanceManager.ReloadVirtualKey(ctx, vk.ID); err != nil { + // Should never happen but just in case + logger.Error("failed to reload virtual key after update: %v", err) + SendError(ctx, 500, "Virtual key updated in database but failed to reload in-memory state") + return + } SendJSON(ctx, map[string]interface{}{ "message": "Virtual key updated successfully", "virtual_key": preloadedVk, diff --git a/transports/bifrost-http/handlers/mcpserver.go b/transports/bifrost-http/handlers/mcpserver.go index cc60b40b77..8cd077b4c9 100644 --- a/transports/bifrost-http/handlers/mcpserver.go +++ b/transports/bifrost-http/handlers/mcpserver.go @@ -224,7 +224,7 @@ func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools [ handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Inject tool filter into execution context if present if toolFilter != nil { - ctx = context.WithValue(ctx, schemas.BifrostContextKey("mcp-include-tools"), toolFilter) + ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, toolFilter) } // Convert to Bifrost tool call format toolCallType := "function" @@ -312,8 +312,10 @@ func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schema ctx := context.Background() var toolFilter []string + // Empty MCPConfigs means no MCP tools are allowed (deny-by-default) + executeOnlyTools := make([]string, 0) + if len(vk.MCPConfigs) > 0 { - executeOnlyTools := make([]string, 0) for _, vkMcpConfig := range vk.MCPConfigs { if len(vkMcpConfig.ToolsToExecute) == 0 { // No tools specified in virtual key config - skip this client entirely @@ -335,12 +337,12 @@ func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schema } } } - - // Set even when empty to exclude tools when no tools are present in the virtual key config - ctx = context.WithValue(ctx, schemas.BifrostContextKey("mcp-include-tools"), executeOnlyTools) - toolFilter = executeOnlyTools } + // Always set the include-tools filter (empty = deny-all when no MCPConfigs) + ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, executeOnlyTools) + toolFilter = executeOnlyTools + return h.toolManager.GetAvailableMCPTools(ctx), toolFilter } diff --git a/transports/config.schema.json b/transports/config.schema.json index 605b5a0e41..04edba6dfc 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -465,14 +465,14 @@ }, "provider_configs": { "type": "array", - "description": "Provider configurations for this virtual key (empty means all providers allowed)", + "description": "Provider configurations for this virtual key (empty means no providers allowed, deny-by-default)", "items": { "$ref": "#/$defs/virtual_key_provider_config" } }, "mcp_configs": { "type": "array", - "description": "MCP configurations for this virtual key", + "description": "MCP configurations for this virtual key (empty array means no MCP tools allowed, deny-by-default)", "items": { "$ref": "#/$defs/virtual_key_mcp_config" } @@ -1505,9 +1505,9 @@ "description": "Provider name" }, "weight": { - "type": "number", - "description": "Weight for load balancing", - "default": 1.0 + "type": ["number", "null"], + "description": "Weight for load balancing (null opts out of weighted routing)", + "default": null }, "allowed_models": { "type": "array", diff --git a/ui/app/workspace/logs/sheets/logDetailsSheet.tsx b/ui/app/workspace/logs/sheets/logDetailsSheet.tsx index 689b9c1ce2..527010e3e5 100644 --- a/ui/app/workspace/logs/sheets/logDetailsSheet.tsx +++ b/ui/app/workspace/logs/sheets/logDetailsSheet.tsx @@ -15,7 +15,7 @@ import { } from "@/components/ui/alertDialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; -import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger } from "@/components/ui/dropdownMenu"; +import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuSeparator, DropdownMenuTrigger } from "@/components/ui/dropdownMenu"; import { DottedSeparator } from "@/components/ui/separator"; import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; import { ProviderIconType, RenderProviderIcon, RoutingEngineUsedIcons } from "@/lib/constants/icons"; @@ -28,7 +28,7 @@ import { StatusColors, } from "@/lib/constants/logs"; import { LogEntry } from "@/lib/types/logs"; -import { Clipboard, MoreVertical, Trash2 } from "lucide-react"; +import { Clipboard, DollarSign, FileText, MoreVertical, Timer, Trash2 } from "lucide-react"; import moment from "moment"; import { toast } from "sonner"; import BlockHeader from "../views/blockHeader"; @@ -168,10 +168,11 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet - copyRequestBody(log)} data-testid="logdetails-copy-request-body-button"> + copyRequestBody(log)}> Copy request body + diff --git a/ui/app/workspace/virtual-keys/views/virtualKeyDetailsSheet.tsx b/ui/app/workspace/virtual-keys/views/virtualKeyDetailsSheet.tsx index dd8d831d91..10276938a8 100644 --- a/ui/app/workspace/virtual-keys/views/virtualKeyDetailsSheet.tsx +++ b/ui/app/workspace/virtual-keys/views/virtualKeyDetailsSheet.tsx @@ -85,7 +85,7 @@ export default function VirtualKeyDetailSheet({ virtualKey, onClose }: VirtualKe - + {/* Provider Configurations */}

Provider Configurations

diff --git a/ui/app/workspace/virtual-keys/views/virtualKeySheet.tsx b/ui/app/workspace/virtual-keys/views/virtualKeySheet.tsx index 321b37fc9a..9ad01fb082 100644 --- a/ui/app/workspace/virtual-keys/views/virtualKeySheet.tsx +++ b/ui/app/workspace/virtual-keys/views/virtualKeySheet.tsx @@ -53,7 +53,19 @@ interface VirtualKeySheetProps { const providerConfigSchema = z.object({ id: z.number().optional(), provider: z.string().min(1, "Provider is required"), - weight: z.union([z.number().min(0, "Weight must be at least 0").max(1, "Weight must be at most 1"), z.string()]), + weight: z + .union([ + z.literal("").transform(() => undefined as undefined), + z + .string() + .transform((v) => { + const n = Number.parseFloat(v); + return Number.isNaN(n) ? undefined : n; + }) + .pipe(z.number().min(0, "Weight must be at least 0").max(1, "Weight must be at most 1").optional()), + z.number().min(0, "Weight must be at least 0").max(1, "Weight must be at most 1"), + ]) + .optional(), allowed_models: z.array(z.string()).optional(), key_ids: z.array(z.string()).optional(), // Keys associated with this provider config // Provider-level budget @@ -156,7 +168,7 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave, const availableProviders = providersData || []; // Form setup - const form = useForm({ + const form = useForm, unknown, FormData>({ resolver: zodResolver(formSchema), defaultValues: { name: virtualKey?.name || "", @@ -164,20 +176,21 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave, providerConfigs: virtualKey?.provider_configs?.map((config) => ({ ...config, + weight: config.weight ?? "", key_ids: config.keys?.map((key) => key.key_id) || [], budget: config.budget ? { - max_limit: String(config.budget.max_limit), - reset_duration: config.budget.reset_duration, - } + max_limit: String(config.budget.max_limit), + reset_duration: config.budget.reset_duration, + } : undefined, rate_limit: config.rate_limit ? { - token_max_limit: config.rate_limit.token_max_limit ? String(config.rate_limit.token_max_limit) : undefined, - token_reset_duration: config.rate_limit.token_reset_duration, - request_max_limit: config.rate_limit.request_max_limit ? String(config.rate_limit.request_max_limit) : undefined, - request_reset_duration: config.rate_limit.request_reset_duration, - } + token_max_limit: config.rate_limit.token_max_limit ? String(config.rate_limit.token_max_limit) : undefined, + token_reset_duration: config.rate_limit.token_reset_duration, + request_max_limit: config.rate_limit.request_max_limit ? String(config.rate_limit.request_max_limit) : undefined, + request_reset_duration: config.rate_limit.request_reset_duration, + } : undefined, })) || [], mcpConfigs: @@ -260,7 +273,7 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave, const newConfig = { provider: provider, - weight: 0.5, // Default weight, user can adjust + weight: "" as string | number, // Default empty string = excluded from weighted routing until user sets a weight allowed_models: [], key_ids: [], }; @@ -335,23 +348,24 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave, ): any[] => { return configs.map((config) => ({ ...config, - weight: typeof config.weight === "string" ? parseFloat(config.weight) || 0 : config.weight, - budget: (() => { - const budgetMaxLimit = normalizeNumericField(config.budget?.max_limit); - if (budgetMaxLimit !== undefined) { - return { - max_limit: budgetMaxLimit, - reset_duration: config.budget?.reset_duration || "1M", - }; - } - - const existingConfig = existingConfigs?.find((item) => (config.id ? item.id === config.id : item.provider === config.provider)); - if (existingConfig?.budget) { - return {}; - } - - return undefined; - })(), + weight: config.weight === "" || config.weight === undefined || config.weight === null + ? null + : typeof config.weight === "string" ? (Number.isNaN(parseFloat(config.weight)) ? null : parseFloat(config.weight)) : config.weight, budget: (() => { + const budgetMaxLimit = normalizeNumericField(config.budget?.max_limit); + if (budgetMaxLimit !== undefined) { + return { + max_limit: budgetMaxLimit, + reset_duration: config.budget?.reset_duration || "1M", + }; + } + + const existingConfig = existingConfigs?.find((item) => (config.id ? item.id === config.id : item.provider === config.provider)); + if (existingConfig?.budget) { + return {}; + } + + return undefined; + })(), rate_limit: (() => { const tokenMaxLimit = normalizeIntegerField(config.rate_limit?.token_max_limit); const requestMaxLimit = normalizeIntegerField(config.rate_limit?.request_max_limit); @@ -553,8 +567,8 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave,

- Configure which providers this virtual key can use and their specific settings. Leave empty to allow all - providers. + Configure which providers this virtual key can use and their specific settings. Leave empty to block all + providers. Add providers to allow them.

@@ -643,7 +657,7 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave, : ProviderLabels[config.provider as ProviderName]}
- handleRemoveProvider(index)} className="h-4 w-4 opacity-75" /> + handleRemoveProvider(index)} className="h-4 w-4 opacity-75" data-testid={`vk-delete-provider-${index}`} />
@@ -652,9 +666,10 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave,
{ const inputValue = e.target.value; // Allow empty string, numbers, and partial decimal inputs like "0." @@ -683,6 +698,7 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave, Allowed Models type to search { const providerKeys = availableKeys.filter((key) => key.provider === config.provider); @@ -894,8 +910,9 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave,

- Configure which MCP clients this virtual key can use and their allowed tools. Leave empty to allow all MCP - clients and tools. + Configure which MCP clients this virtual key can use and their allowed tools. Leaving this section empty + blocks all MCP tools. After adding an MCP client, you must select specific tools or choose{" "} + Allow All Tools to grant tool access.

@@ -986,20 +1003,39 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave, {config.mcp_client_name} arr.findIndex((t) => t.name === tool.name) === index) - .map((tool) => ({ - label: tool.name, - value: tool.name, - description: tool.description, - }))} + options={[ + { + label: "Allow All Tools", + value: "*", + description: "Allow all current and future tools (including dynamically fetched ones)", + }, + ...[...availableTools, ...enabledToolsByConfig] + .filter((tool, index, arr) => arr.findIndex((t) => t.name === tool.name) === index) + .map((tool) => ({ + label: tool.name, + value: tool.name, + description: tool.description, + })), + ]} defaultValue={selectedTools} - onValueChange={(tools: string[]) => handleUpdateMCPConfig(index, "tools_to_execute", tools)} + onValueChange={(tools: string[]) => { + const hadStar = selectedTools.includes("*"); + const hasStar = tools.includes("*"); + if (!hadStar && hasStar) { + // Just selected "Allow All Tools" — set to ["*"] only + handleUpdateMCPConfig(index, "tools_to_execute", ["*"]); + } else if (hadStar && hasStar && tools.length > 1) { + // Had "*", still has "*", but user also selected a specific tool — drop "*" + handleUpdateMCPConfig(index, "tools_to_execute", tools.filter((t) => t !== "*")); + } else { + handleUpdateMCPConfig(index, "tools_to_execute", tools); + } + }} placeholder={ selectedTools.length === 0 ? "No tools selected" : selectedTools.includes("*") - ? "All tools selected" + ? "All tools allowed" : "Select tools..." } variant="inverted" @@ -1010,7 +1046,7 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave, /> - @@ -1172,7 +1208,7 @@ export default function VirtualKeySheet({ virtualKey, teams, customers, onSave, - + No Assignment {teams?.length > 0 && Assign to Team} {customers?.length > 0 && Assign to Customer} diff --git a/ui/components/ui/numberAndSelect.tsx b/ui/components/ui/numberAndSelect.tsx index 5eeae048f1..a443be2331 100644 --- a/ui/components/ui/numberAndSelect.tsx +++ b/ui/components/ui/numberAndSelect.tsx @@ -12,6 +12,7 @@ const NumberAndSelect = ({ onChangeSelect, options, labelClassName, + placeholder = "100", dataTestId, }: { id: string; @@ -22,6 +23,7 @@ const NumberAndSelect = ({ onChangeSelect: (value: string) => void; options: { label: string; value: string }[]; labelClassName?: string; + placeholder?: string; dataTestId?: string; }) => { return ( @@ -33,7 +35,7 @@ const NumberAndSelect = ({ { const inputValue = e.target.value; diff --git a/ui/lib/types/governance.ts b/ui/lib/types/governance.ts index d9293c341f..cb35cfef3f 100644 --- a/ui/lib/types/governance.ts +++ b/ui/lib/types/governance.ts @@ -87,7 +87,7 @@ export interface VirtualKey { export interface VirtualKeyProviderConfig { id?: number; provider: string; - weight: number; + weight: number | null; allowed_models: string[]; budget?: Budget; rate_limit?: RateLimit; @@ -130,7 +130,7 @@ export interface UsageStats { // Request interfaces for provider config operations export interface VirtualKeyProviderConfigRequest { provider: string; - weight?: number; + weight?: number | null; allowed_models?: string[]; budget?: CreateBudgetRequest; rate_limit?: CreateRateLimitRequest; @@ -140,7 +140,7 @@ export interface VirtualKeyProviderConfigRequest { export interface VirtualKeyProviderConfigUpdateRequest { id?: number; provider: string; - weight?: number; + weight?: number | null; allowed_models?: string[]; budget?: UpdateBudgetRequest; rate_limit?: UpdateRateLimitRequest;