diff --git a/core/changelog.md b/core/changelog.md index 35606f57f4..375c51adbc 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -1,3 +1,4 @@ - feat: add DisableAutoToolInject to MCPToolManagerConfig to suppress automatic MCP tool injection per request - feat: add BifrostContextKeyMCPAddedTools to context to track MCP tools added to the request -- refactor: standardize empty array conventions in bifrost. Empty array means deny all, ["*"] means allow all for models/tools/keys. \ No newline at end of file +- refactor: standardize empty array conventions in bifrost. Empty array means deny all, ["*"] means allow all for models/tools/keys. +- feat: add support for request-level extra headers in MCP tool execution using BifrostContextKeyMCPExtraHeaders key in context. diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go index d5e6ab9949..f2dd95f945 100644 --- a/core/mcp/clientmanager.go +++ b/core/mcp/clientmanager.go @@ -243,13 +243,14 @@ func (m *MCPManager) UpdateClient(id string, updatedConfig *schemas.MCPClientCon ConfigHash: client.ExecutionConfig.ConfigHash, ToolPricing: maps.Clone(client.ExecutionConfig.ToolPricing), // Updatable fields - copy from updated config with proper cloning - Name: updatedConfig.Name, - IsCodeModeClient: updatedConfig.IsCodeModeClient, - Headers: maps.Clone(updatedConfig.Headers), - ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute), - ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute), - IsPingAvailable: updatedConfig.IsPingAvailable, - ToolSyncInterval: updatedConfig.ToolSyncInterval, + Name: updatedConfig.Name, + IsCodeModeClient: updatedConfig.IsCodeModeClient, + Headers: maps.Clone(updatedConfig.Headers), + ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute), + ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute), + AllowedExtraHeaders: slices.Clone(updatedConfig.AllowedExtraHeaders), + IsPingAvailable: updatedConfig.IsPingAvailable, + ToolSyncInterval: updatedConfig.ToolSyncInterval, } // Atomically replace the config pointer @@ -663,7 +664,11 @@ func (m *MCPManager) connectToMCPClient(config *schemas.MCPClientConfig) error { } // Start health monitoring for the client - monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, config.IsPingAvailable, m.logger) + isPingAvailable := true + if config.IsPingAvailable != nil { + isPingAvailable = *config.IsPingAvailable + } + monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, isPingAvailable, m.logger) m.healthMonitorManager.StartMonitoring(monitor) // Start tool syncing for the client (skip for internal bifrost client) diff --git a/core/mcp/codemode.go b/core/mcp/codemode.go index e81c984195..fa11e52d0b 100644 --- a/core/mcp/codemode.go +++ b/core/mcp/codemode.go @@ -3,7 +3,6 @@ package mcp import ( - "context" "sync" "time" @@ -31,7 +30,7 @@ type CodeMode interface { // ExecuteTool handles a code mode tool call by name. // Returns the response message and any error that occurred. - ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) + ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) // IsCodeModeTool returns true if the given tool name is a code mode tool. IsCodeModeTool(toolName string) bool diff --git a/core/mcp/codemode/starlark/executecode.go b/core/mcp/codemode/starlark/executecode.go index 6fa2a28cf8..17f6ab8685 100644 --- a/core/mcp/codemode/starlark/executecode.go +++ b/core/mcp/codemode/starlark/executecode.go @@ -5,7 +5,6 @@ package starlark import ( "context" "fmt" - "net/http" "strings" "time" @@ -13,6 +12,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" codemcp "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/mcp/utils" "github.com/maximhq/bifrost/core/schemas" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" @@ -103,7 +103,7 @@ func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { } // handleExecuteToolCode handles the executeToolCode tool call. -func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { +func (s *StarlarkCodeMode) handleExecuteToolCode(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { toolName := "unknown" if toolCall.Function.Name != nil { toolName = *toolCall.Function.Name @@ -197,7 +197,7 @@ func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall s } // executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings. -func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) ExecutionResult { +func (s *StarlarkCodeMode) executeCode(ctx *schemas.BifrostContext, code string) ExecutionResult { logs := []string{} s.logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix) @@ -372,7 +372,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } // callMCPTool calls an MCP tool and returns the result. -func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { +func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { // Get available tools per client availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) @@ -400,29 +400,25 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName // Strip the client name prefix from tool name before calling MCP server originalToolName := stripClientPrefix(toolName, clientName) - // Get BifrostContext for plugin pipeline - var bifrostCtx *schemas.BifrostContext - var ok bool - if bifrostCtx, ok = ctx.(*schemas.BifrostContext); !ok { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string) + if !ok { + originalRequestID = "" } - originalRequestID, _ := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string) - // Generate new request ID for this nested tool call var newRequestID string if s.fetchNewRequestIDFunc != nil { - newRequestID = s.fetchNewRequestIDFunc(bifrostCtx) + newRequestID = s.fetchNewRequestIDFunc(ctx) } else { newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName) } // Create new child context - deadline, hasDeadline := bifrostCtx.Deadline() + deadline, hasDeadline := ctx.Deadline() if !hasDeadline { deadline = schemas.NoDeadline } - nestedCtx := schemas.NewBifrostContext(bifrostCtx, deadline) + nestedCtx := schemas.NewBifrostContext(ctx, deadline) nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID) if originalRequestID != "" { nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID) @@ -451,13 +447,17 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName // Check if plugin pipeline is available if s.pluginPipelineProvider == nil { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + // Should never happen, but just in case + s.logger.Warn("%s Plugin pipeline provider is nil", codemcp.CodeModeLogPrefix) + return nil, fmt.Errorf("plugin pipeline provider is nil") } // Get plugin pipeline and run hooks pipeline := s.pluginPipelineProvider() if pipeline == nil { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + // Should never happen, but just in case + s.logger.Warn("%s Plugin pipeline is nil", codemcp.CodeModeLogPrefix) + return nil, fmt.Errorf("plugin pipeline is nil") } defer s.releasePluginPipeline(pipeline) @@ -515,14 +515,7 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName Name: toolNameToCall, Arguments: args, }, - } - - if client.ExecutionConfig.Headers != nil { - headers := make(http.Header) - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - callRequest.Header = headers + Header: utils.GetHeadersForToolExecution(nestedCtx, client), } toolExecutionTimeout := s.getToolExecutionTimeout() @@ -604,57 +597,3 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName return nil, fmt.Errorf("plugin post-hooks returned invalid response") } - -// callMCPToolDirect executes an MCP tool call directly without plugin hooks. -func (s *StarlarkCodeMode) callMCPToolDirect(ctx context.Context, client *schemas.MCPClientState, originalToolName, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { - callRequest := mcp.CallToolRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodToolsCall), - }, - Params: mcp.CallToolParams{ - Name: originalToolName, - Arguments: args, - }, - } - - if client.ExecutionConfig.Headers != nil { - headers := make(http.Header) - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - callRequest.Header = headers - } - - toolExecutionTimeout := s.getToolExecutionTimeout() - toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) - defer cancel() - - logToolName := stripClientPrefix(toolName, clientName) - logToolName = strings.ReplaceAll(logToolName, "-", "_") - - toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) - if callErr != nil { - s.logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, logToolName, callErr) - appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, logToolName, callErr)) - return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, logToolName, callErr) - } - - rawResult := extractTextFromMCPResponse(toolResponse, toolName) - - if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { - errorMsg := after - s.logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, logToolName, errorMsg) - appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, logToolName, errorMsg)) - return nil, fmt.Errorf("%s", errorMsg) - } - - var finalResult interface{} - if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { - finalResult = rawResult - } - - resultStr := formatResultForLog(finalResult) - appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr)) - - return finalResult, nil -} diff --git a/core/mcp/codemode/starlark/starlark.go b/core/mcp/codemode/starlark/starlark.go index 0da1d2ccd9..348655b983 100644 --- a/core/mcp/codemode/starlark/starlark.go +++ b/core/mcp/codemode/starlark/starlark.go @@ -6,7 +6,6 @@ package starlark import ( - "context" "fmt" "sync" "sync/atomic" @@ -111,7 +110,7 @@ func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool { // Returns: // - *schemas.ChatMessage: The tool response message // - error: Any error that occurred during execution -func (s *StarlarkCodeMode) ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { +func (s *StarlarkCodeMode) ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { if toolCall.Function.Name == nil { return nil, fmt.Errorf("tool call missing function name") } diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go index a94afa28bb..d5e2248608 100644 --- a/core/mcp/mcp.go +++ b/core/mcp/mcp.go @@ -142,7 +142,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider manager.clientMap[clientConfig.ID].State = schemas.MCPConnectionStateDisconnected } manager.mu.Unlock() - monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, clientConfig.IsPingAvailable, manager.logger) + isPingAvailable := true + if clientConfig.IsPingAvailable != nil { + isPingAvailable = *clientConfig.IsPingAvailable + } + monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, isPingAvailable, manager.logger) manager.healthMonitorManager.StartMonitoring(monitor) } }(clientConfig) diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go index a37d77d30d..00c1f30b99 100644 --- a/core/mcp/toolmanager.go +++ b/core/mcp/toolmanager.go @@ -6,12 +6,12 @@ import ( "context" "encoding/json" "fmt" - "net/http" "strings" "sync/atomic" "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/mcp/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -553,14 +553,7 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall Name: originalMCPToolName, Arguments: arguments, }, - } - - if client.ExecutionConfig.Headers != nil { - headers := make(http.Header) - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - callRequest.Header = headers + Header: utils.GetHeadersForToolExecution(ctx, client), } // Create timeout context for tool execution diff --git a/core/mcp/utils/utils.go b/core/mcp/utils/utils.go new file mode 100644 index 0000000000..500792a09f --- /dev/null +++ b/core/mcp/utils/utils.go @@ -0,0 +1,49 @@ +package utils + +import ( + "net/http" + + "github.com/maximhq/bifrost/core/schemas" +) + +// GetHeadersForToolExecution sets additional headers for tool execution. +// It returns the headers for the tool execution. +func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header { + if ctx == nil || client == nil || client.ExecutionConfig == nil { + return make(http.Header) + } + headers := make(http.Header) + if client.ExecutionConfig.Headers != nil { + for key, value := range client.ExecutionConfig.Headers { + headers.Add(key, value.GetValue()) + } + } + // Give priority to extra headers in the context + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyMCPExtraHeaders).(map[string][]string); ok { + filteredHeaders := make(http.Header) + for key, values := range extraHeaders { + if client.ExecutionConfig.AllowedExtraHeaders.IsAllowed(key) { + for i, value := range values { + if i == 0 { + filteredHeaders.Set(key, value) + } else { + filteredHeaders.Add(key, value) + } + } + } + } + // Add the filtered headers to the headers + if len(filteredHeaders) > 0 { + for k, values := range filteredHeaders { + for i, v := range values { + if i == 0 { + headers.Set(k, v) + } else { + headers.Add(k, v) + } + } + } + } + } + return headers +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 7043484ffb..9279a21f81 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -252,6 +252,7 @@ const ( BifrostContextKeySSEReaderFactory BifrostContextKey = "bifrost-sse-reader-factory" // *providerUtils.SSEReaderFactory (set by enterprise — replaces default bufio.Scanner SSE readers with streaming readers) BifrostContextKeySessionID BifrostContextKey = "bifrost-session-id" // string session ID for the request (session stickiness) BifrostContextKeySessionTTL BifrostContextKey = "bifrost-session-ttl" // time.Duration session TTL for the request (session stickiness) + BifrostContextKeyMCPExtraHeaders BifrostContextKey = "bifrost-mcp-extra-headers" // map[string][]string (these headers are forwarded only to the MCP while tool execution if they are in the allowlist of the MCP client) ) const ( diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index 70cf436875..898353499f 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -77,18 +77,19 @@ const ( // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { - ID string `json:"client_id"` // Client ID - Name string `json:"name"` // Client name - IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client - ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) - ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) - StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) - AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth) - OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table) - State string `json:"state,omitempty"` // Connection state (connected, disconnected, error) - Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type) - InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) - ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list. + ID string `json:"client_id"` // Client ID + Name string `json:"name"` // Client name + IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client + ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) + ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) + StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) + AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth) + OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table) + State string `json:"state,omitempty"` // Connection state (connected, disconnected, error) + Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type) + AllowedExtraHeaders WhiteList `json:"allowed_extra_headers,omitempty"` // Allowlist of request-level headers that callers may forward to this MCP server at execution time + InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) + ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list. // ToolsToExecute semantics: // - ["*"] => all tools are included // - [] => no tools are included (deny-by-default) @@ -101,7 +102,7 @@ type MCPClientConfig struct { // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => auto-execute only the specified tools // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. - IsPingAvailable bool `json:"is_ping_available"` // Whether the MCP server supports ping for health checks (default: true). If false, uses listTools for health checks. + IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true. ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) diff --git a/examples/mcps/auth-demo-server/main.go b/examples/mcps/auth-demo-server/main.go index ca6bdd6581..4f47d8124b 100644 --- a/examples/mcps/auth-demo-server/main.go +++ b/examples/mcps/auth-demo-server/main.go @@ -7,8 +7,11 @@ package main // tools/call). A missing or wrong key is rejected before the MCP server // sees the message at all. // -// 2. TOOL-LEVEL AUTH (X-Role header) -// Enforced inside individual sensitive tool handlers. Public tools ignore it. +// 2. TOOL-EXECUTION AUTH (X-Tool-Token header) +// A separate secret token checked exclusively inside sensitive tool handlers +// at call time. Public tools ignore it; the connection middleware does not +// inspect it at all. This lets you scope a second credential to tool +// execution only — distinct from the connection credential. // // HOW BIFROST SENDS HEADERS // @@ -22,7 +25,8 @@ package main // This means all configured headers are present on EVERY request — there is no // separate "connection-only" vs "tool-only" header mechanism in Bifrost. To // distinguish the two auth levels you simply use different header names, both -// configured in the same `headers` map. +// configured in the same `headers` map. The server then enforces each header +// at the appropriate layer (middleware vs. handler). // // Bifrost config example: // @@ -32,8 +36,8 @@ package main // "connection_string": "http://localhost:3002/", // "auth_type": "headers", // "headers": { -// "X-API-Key": "super-secret-key", -// "X-Role": "admin" +// "X-API-Key": "super-secret-key", +// "X-Tool-Token": "tool-exec-secret" // }, // "tools_to_execute": ["*"] // } @@ -50,14 +54,16 @@ import ( ) const ( - // connectionAPIKey is checked in HTTP middleware on every request. + // connectionAPIKey is checked in HTTP middleware on every request + // (initialize, tools/list, tools/call). // In production, load this from an environment variable or secrets manager. connectionAPIKey = "super-secret-key" - // requiredRole is checked inside the sensitive tool handler only. - // Both X-API-Key and X-Role are configured together in Bifrost's `headers` - // map and are forwarded on every HTTP request (connection and tool calls). - requiredRole = "admin" + // toolExecToken is checked exclusively inside sensitive tool handlers — + // never in the connection middleware. It acts as a second independent + // credential that gates tool execution only. + // In production, load this from an environment variable or secrets manager. + toolExecToken = "tool-exec-secret" ) // contextKey is a private type so we don't collide with other packages' context keys. @@ -69,7 +75,7 @@ func main() { s := server.NewMCPServer("auth-demo-server", "1.0.0") // public_info only requires connection-level auth (X-API-Key). - // Any authenticated client can call it regardless of role. + // Any authenticated client can call it without a tool execution token. publicTool := mcp.NewTool( "public_info", mcp.WithDescription("Returns non-sensitive public information. Requires connection auth (X-API-Key) only."), @@ -77,13 +83,14 @@ func main() { ) s.AddTool(publicTool, publicInfoHandler) - // secret_data requires BOTH connection-level auth (X-API-Key) AND - // a role check (X-Role: admin) inside the handler. + // secret_data requires BOTH connection-level auth (X-API-Key) AND a + // dedicated tool-execution token (X-Tool-Token) checked inside the handler. // In Bifrost both headers live in the same `headers` map and arrive on - // every request, so the handler just reads X-Role from the context. + // every request, so the handler reads X-Tool-Token from context and + // validates it independently of the connection credential. secretTool := mcp.NewTool( "secret_data", - mcp.WithDescription("Returns sensitive data. Requires connection auth (X-API-Key) AND role check (X-Role: admin)."), + mcp.WithDescription("Returns sensitive data. Requires connection auth (X-API-Key) AND tool-execution auth (X-Tool-Token)."), mcp.WithString("resource", mcp.Required(), mcp.Description("Resource name to fetch")), ) s.AddTool(secretTool, secretDataHandler) @@ -100,10 +107,11 @@ func main() { addr := "localhost:3002" log.Printf("auth-demo-server listening on http://%s/", addr) log.Printf("\nAuth layers:") - log.Printf(" Connection-level: X-API-Key: %s (middleware rejects all requests without it)", connectionAPIKey) - log.Printf(" Tool-level: X-Role: %s (only secret_data checks this, read from context)", requiredRole) + log.Printf(" Connection-level: X-API-Key: %s (middleware rejects all requests without it)", connectionAPIKey) + log.Printf(" Tool-execution: X-Tool-Token: %s (only secret_data checks this, validated inside the handler)", toolExecToken) log.Printf("\nNote: Bifrost sends all `headers` on both connection setup AND every tool call.") - log.Printf("Both X-API-Key and X-Role go in the same `headers` map.\n") + log.Printf("Both X-API-Key and X-Tool-Token go in the same `headers` map.") + log.Printf("The server enforces each at the right layer: middleware vs. handler.\n") log.Printf("Bifrost config:") log.Printf(` { @@ -112,12 +120,12 @@ func main() { "connection_string": "http://%s/", "auth_type": "headers", "headers": { - "X-API-Key": "%s", - "X-Role": "%s" + "X-API-Key": "%s", + "X-Tool-Token": "%s" }, "tools_to_execute": ["*"] } -`, addr, connectionAPIKey, requiredRole) +`, addr, connectionAPIKey, toolExecToken) if err := http.ListenAndServe(addr, handler); err != nil { log.Fatalf("Server error: %v", err) @@ -174,21 +182,22 @@ func publicInfoHandler(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT } // secretDataHandler handles "secret_data". Connection-level auth (X-API-Key) -// has already been verified by middleware. Here we additionally check X-Role, -// which Bifrost sends as part of the same `headers` map — so it is present on -// every request, including this tool call. +// has already been verified by middleware. Here we additionally check +// X-Tool-Token — a separate secret dedicated to authorizing tool execution. +// Bifrost sends it as part of the same `headers` map, so it arrives on every +// request including this tool call; the middleware intentionally ignores it. func secretDataHandler(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // ── Tool-level role check ──────────────────────────────────────────────── + // ── Tool-execution token check ─────────────────────────────────────────── headers, ok := ctx.Value(requestHeadersKey).(http.Header) if !ok { return mcp.NewToolResultError("tool auth error: request headers unavailable in context"), nil } - role := headers.Get("X-Role") - if role == "" { - return mcp.NewToolResultError("tool auth required: missing X-Role header"), nil + token := headers.Get("X-Tool-Token") + if token == "" { + return mcp.NewToolResultError("tool auth required: missing X-Tool-Token header"), nil } - if role != requiredRole { - return mcp.NewToolResultError(fmt.Sprintf("tool auth failed: role %q is not authorized for this tool", role)), nil + if token != toolExecToken { + return mcp.NewToolResultError("tool auth failed: invalid X-Tool-Token"), nil } // ── Auth passed, proceed ───────────────────────────────────────────────── @@ -200,7 +209,7 @@ func secretDataHandler(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT } return mcp.NewToolResultText(fmt.Sprintf( - "Secret data for resource %q: [classified content — X-API-Key + X-Role:%s verified]", args.Resource, role, + "Secret data for resource %q: [classified content — X-API-Key + X-Tool-Token verified]", args.Resource, )), nil } diff --git a/framework/changelog.md b/framework/changelog.md index 6e0b372a06..483066540b 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -1,3 +1,4 @@ - feat: migrate VK provider config allowed keys to explicit allow-list semantics — add AllowAllKeys bool to TableVirtualKeyProviderConfig; backfill existing configs with allow_all_keys=true; empty keys now denies all, ["*"] allows all - feat: add MCPDisableAutoToolInject column to TableClientConfig -- refactor: standardize empty array conventions in modelcatalog and tables. \ No newline at end of file +- refactor: standardize empty array conventions in modelcatalog and tables. +- feat: add AllowedExtraHeadersJSON column to TableMCPClient \ No newline at end of file diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 75af4b6b46..5a1bea3b76 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -326,6 +326,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationBackfillAllowedModelsWildcard(ctx, db); err != nil { return err } + if err := migrationAddMCPClientAllowedExtraHeadersJSONColumn(ctx, db); err != nil { + return err + } return nil } @@ -5030,6 +5033,37 @@ func migrationBackfillAllowedModelsWildcard(ctx context.Context, db *gorm.DB) er return nil } +// migrationAddMCPClientAllowedExtraHeadersJSONColumn adds the allowed_extra_headers_json column to the mcp_client table +func migrationAddMCPClientAllowedExtraHeadersJSONColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_client_allowed_extra_headers_json_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "allowed_extra_headers_json") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "allowed_extra_headers_json"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableMCPClient{}, "allowed_extra_headers_json") { + if err := migrator.DropColumn(&tables.TableMCPClient{}, "allowed_extra_headers_json"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running add_mcp_client_allowed_extra_headers_json_column migration: %s", err.Error()) + } + return nil +} + // migrationAddPluginOrderColumns adds placement and exec_order columns to config_plugins table func migrationAddPluginOrderColumns(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index af7ee85def..3c5d138f59 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -870,26 +870,22 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, // This will never happen, but just in case. clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) for i, dbClient := range dbMCPClients { - // Dereference IsPingAvailable pointer, defaulting to true if nil - isPingAvailable := true - if dbClient.IsPingAvailable != nil { - isPingAvailable = *dbClient.IsPingAvailable - } clientConfigs[i] = &schemas.MCPClientConfig{ - ID: dbClient.ClientID, - Name: dbClient.Name, - IsCodeModeClient: dbClient.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), - ConnectionString: dbClient.ConnectionString, - StdioConfig: dbClient.StdioConfig, - AuthType: schemas.MCPAuthType(dbClient.AuthType), - OauthConfigID: dbClient.OauthConfigID, - ToolsToExecute: dbClient.ToolsToExecute, - ToolsToAutoExecute: dbClient.ToolsToAutoExecute, - Headers: dbClient.Headers, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, - ToolPricing: dbClient.ToolPricing, + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: dbClient.ConnectionString, + StdioConfig: dbClient.StdioConfig, + AuthType: schemas.MCPAuthType(dbClient.AuthType), + OauthConfigID: dbClient.OauthConfigID, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: dbClient.Headers, + AllowedExtraHeaders: dbClient.AllowedExtraHeaders, + IsPingAvailable: dbClient.IsPingAvailable, + ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + ToolPricing: dbClient.ToolPricing, } } return &schemas.MCPConfig{ @@ -910,26 +906,22 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, } clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) for i, dbClient := range dbMCPClients { - // Dereference IsPingAvailable pointer, defaulting to true if nil - isPingAvailable := true - if dbClient.IsPingAvailable != nil { - isPingAvailable = *dbClient.IsPingAvailable - } clientConfigs[i] = &schemas.MCPClientConfig{ - ID: dbClient.ClientID, - Name: dbClient.Name, - IsCodeModeClient: dbClient.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), - ConnectionString: dbClient.ConnectionString, - StdioConfig: dbClient.StdioConfig, - AuthType: schemas.MCPAuthType(dbClient.AuthType), - OauthConfigID: dbClient.OauthConfigID, - ToolsToExecute: dbClient.ToolsToExecute, - ToolsToAutoExecute: dbClient.ToolsToAutoExecute, - Headers: dbClient.Headers, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, - ToolPricing: dbClient.ToolPricing, + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: dbClient.ConnectionString, + StdioConfig: dbClient.StdioConfig, + AuthType: schemas.MCPAuthType(dbClient.AuthType), + OauthConfigID: dbClient.OauthConfigID, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: dbClient.Headers, + AllowedExtraHeaders: dbClient.AllowedExtraHeaders, + IsPingAvailable: dbClient.IsPingAvailable, + ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + ToolPricing: dbClient.ToolPricing, } } return &schemas.MCPConfig{ @@ -1014,19 +1006,20 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig } // Create new client dbClient := tables.TableMCPClient{ - ClientID: clientConfigCopy.ID, - Name: clientConfigCopy.Name, - IsCodeModeClient: clientConfigCopy.IsCodeModeClient, - ConnectionType: string(clientConfigCopy.ConnectionType), - ConnectionString: clientConfigCopy.ConnectionString, - StdioConfig: clientConfigCopy.StdioConfig, - AuthType: string(clientConfigCopy.AuthType), - OauthConfigID: clientConfigCopy.OauthConfigID, - ToolsToExecute: clientConfigCopy.ToolsToExecute, - ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, - Headers: clientConfigCopy.Headers, - IsPingAvailable: &clientConfigCopy.IsPingAvailable, - ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()), + ClientID: clientConfigCopy.ID, + Name: clientConfigCopy.Name, + IsCodeModeClient: clientConfigCopy.IsCodeModeClient, + ConnectionType: string(clientConfigCopy.ConnectionType), + ConnectionString: clientConfigCopy.ConnectionString, + StdioConfig: clientConfigCopy.StdioConfig, + AuthType: string(clientConfigCopy.AuthType), + OauthConfigID: clientConfigCopy.OauthConfigID, + ToolsToExecute: clientConfigCopy.ToolsToExecute, + ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, + Headers: clientConfigCopy.Headers, + AllowedExtraHeaders: clientConfigCopy.AllowedExtraHeaders, + IsPingAvailable: clientConfigCopy.IsPingAvailable, + ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()), } if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil { return s.parseGormError(err) @@ -1085,6 +1078,13 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c if err != nil { return fmt.Errorf("failed to marshal headers: %w", err) } + if clientConfigCopy.AllowedExtraHeaders == nil { + clientConfigCopy.AllowedExtraHeaders = []string{} + } + allowedExtraHeadersJSON, err := json.Marshal(clientConfigCopy.AllowedExtraHeaders) + if err != nil { + return fmt.Errorf("failed to marshal allowed_extra_headers: %w", err) + } if clientConfigCopy.ToolPricing == nil { clientConfigCopy.ToolPricing = map[string]float64{} @@ -1111,6 +1111,7 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c "tools_to_execute_json": string(toolsToExecuteJSON), "tools_to_auto_execute_json": string(toolsToAutoExecuteJSON), "headers_json": headersJSONStr, + "allowed_extra_headers_json": string(allowedExtraHeadersJSON), "tool_pricing_json": string(toolPricingJSON), "tool_sync_interval": clientConfigCopy.ToolSyncInterval, "updated_at": time.Now(), diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index fedaf0f65a..bdecb80b3f 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -13,19 +13,20 @@ import ( // TableMCPClient represents an MCP client configuration in the database type TableMCPClient struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. - ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` - Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` - IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client - ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType - ConnectionString *schemas.EnvVar `gorm:"type:text" json:"connection_string,omitempty"` - StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig - ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - IsPingAvailable *bool `gorm:"default:true" json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks - ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64 - ToolSyncInterval int `gorm:"default:0" json:"tool_sync_interval"` // Per-client tool sync interval in minutes (0 = use global, -1 = disabled) + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. + ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` + Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` + IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client + ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType + ConnectionString *schemas.EnvVar `gorm:"type:text" json:"connection_string,omitempty"` + StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig + ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + AllowedExtraHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + IsPingAvailable *bool `gorm:"default:true" json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks + ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64 + ToolSyncInterval int `gorm:"default:0" json:"tool_sync_interval"` // Per-client tool sync interval in minutes (0 = use global, -1 = disabled) // OAuth authentication fields AuthType string `gorm:"type:varchar(20);default:'headers'" json:"auth_type"` // "none", "headers", "oauth" @@ -42,11 +43,12 @@ type TableMCPClient struct { UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` // Virtual fields for runtime use (not stored in DB) - StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` - ToolsToExecute schemas.WhiteList `gorm:"-" json:"tools_to_execute"` - ToolsToAutoExecute schemas.WhiteList `gorm:"-" json:"tools_to_auto_execute"` - Headers map[string]schemas.EnvVar `gorm:"-" json:"headers"` - ToolPricing map[string]float64 `gorm:"-" json:"tool_pricing"` + StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` + ToolsToExecute schemas.WhiteList `gorm:"-" json:"tools_to_execute"` + ToolsToAutoExecute schemas.WhiteList `gorm:"-" json:"tools_to_auto_execute"` + Headers map[string]schemas.EnvVar `gorm:"-" json:"headers"` + AllowedExtraHeaders schemas.WhiteList `gorm:"-" json:"allowed_extra_headers"` + ToolPricing map[string]float64 `gorm:"-" json:"tool_pricing"` } // TableName sets the table name for each model @@ -111,6 +113,19 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { c.HeadersJSON = "{}" } + if c.AllowedExtraHeaders != nil { + if err := c.AllowedExtraHeaders.Validate(); err != nil { + return fmt.Errorf("invalid allowed_extra_headers: %w", err) + } + data, err := json.Marshal(c.AllowedExtraHeaders) + if err != nil { + return err + } + c.AllowedExtraHeadersJSON = string(data) + } else { + c.AllowedExtraHeadersJSON = "[]" + } + if c.ToolPricing != nil { data, err := json.Marshal(c.ToolPricing) if err != nil { @@ -189,7 +204,11 @@ func (c *TableMCPClient) AfterFind(tx *gorm.DB) error { return err } } - + if c.AllowedExtraHeadersJSON != "" { + if err := sonic.Unmarshal([]byte(c.AllowedExtraHeadersJSON), &c.AllowedExtraHeaders); err != nil { + return err + } + } if c.ToolPricingJSON != "" { if err := json.Unmarshal([]byte(c.ToolPricingJSON), &c.ToolPricing); err != nil { return err diff --git a/transports/bifrost-http/handlers/asyncinference.go b/transports/bifrost-http/handlers/asyncinference.go index 5d6d8a0626..010a74a110 100644 --- a/transports/bifrost-http/handlers/asyncinference.go +++ b/transports/bifrost-http/handlers/asyncinference.go @@ -108,7 +108,7 @@ func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -146,7 +146,7 @@ func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -184,7 +184,7 @@ func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -218,7 +218,7 @@ func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -256,7 +256,7 @@ func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -294,7 +294,7 @@ func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -332,7 +332,7 @@ func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -370,7 +370,7 @@ func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -403,7 +403,7 @@ func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -436,7 +436,7 @@ func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") return @@ -473,7 +473,7 @@ func (h *AsyncHandler) getJob(operationType schemas.RequestType) fasthttp.Reques } // Get the requesting user's VK for auth check - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 24c4961352..326ba7d342 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -682,7 +682,7 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { provider := string(ctx.QueryArgs().Peek("provider")) // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -809,7 +809,7 @@ func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -920,7 +920,7 @@ func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1014,7 +1014,7 @@ func (h *CompletionHandler) responses(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1088,7 +1088,7 @@ func (h *CompletionHandler) embeddings(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1181,7 +1181,7 @@ func (h *CompletionHandler) rerank(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1253,7 +1253,7 @@ func (h *CompletionHandler) speech(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1380,7 +1380,7 @@ func (h *CompletionHandler) transcription(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1420,7 +1420,7 @@ func (h *CompletionHandler) countTokens(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1780,7 +1780,7 @@ func (h *CompletionHandler) imageGeneration(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { cancel() SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1993,7 +1993,7 @@ func (h *CompletionHandler) imageEdit(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -2136,7 +2136,7 @@ func (h *CompletionHandler) imageVariation(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -2205,7 +2205,7 @@ func (h *CompletionHandler) videoGeneration(ctx *fasthttp.RequestCtx) { Fallbacks: fallbacks, } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { cancel() SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2259,7 +2259,7 @@ func (h *CompletionHandler) videoRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2317,7 +2317,7 @@ func (h *CompletionHandler) videoDownload(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2379,7 +2379,7 @@ func (h *CompletionHandler) videoList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2430,7 +2430,7 @@ func (h *CompletionHandler) videoDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2507,7 +2507,7 @@ func (h *CompletionHandler) videoRemix(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2575,7 +2575,7 @@ func (h *CompletionHandler) batchCreate(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2635,7 +2635,7 @@ func (h *CompletionHandler) batchList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2681,7 +2681,7 @@ func (h *CompletionHandler) batchRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2727,7 +2727,7 @@ func (h *CompletionHandler) batchCancel(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2773,7 +2773,7 @@ func (h *CompletionHandler) batchResults(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2862,7 +2862,7 @@ func (h *CompletionHandler) fileUpload(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2928,7 +2928,7 @@ func (h *CompletionHandler) fileList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2974,7 +2974,7 @@ func (h *CompletionHandler) fileRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3020,7 +3020,7 @@ func (h *CompletionHandler) fileDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3066,7 +3066,7 @@ func (h *CompletionHandler) fileContent(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3127,7 +3127,7 @@ func (h *CompletionHandler) containerCreate(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3186,7 +3186,7 @@ func (h *CompletionHandler) containerList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3233,7 +3233,7 @@ func (h *CompletionHandler) containerRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3280,7 +3280,7 @@ func (h *CompletionHandler) containerDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3377,7 +3377,7 @@ func (h *CompletionHandler) containerFileCreate(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3437,7 +3437,7 @@ func (h *CompletionHandler) containerFileList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3492,7 +3492,7 @@ func (h *CompletionHandler) containerFileRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3547,7 +3547,7 @@ func (h *CompletionHandler) containerFileContent(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3602,7 +3602,7 @@ func (h *CompletionHandler) containerFileDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 6332536fac..0caf6231c4 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -209,7 +209,7 @@ func (h *MCPHandler) getMCPClientsPaginated(ctx *fasthttp.RequestCtx, limitStr, ToolsToExecute: dbClient.ToolsToExecute, ToolsToAutoExecute: dbClient.ToolsToAutoExecute, Headers: dbClient.Headers, - IsPingAvailable: isPingAvailable, + IsPingAvailable: &isPingAvailable, ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, ToolPricing: dbClient.ToolPricing, } @@ -314,6 +314,10 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) return } + if err := validateAllowedExtraHeaders(req.AllowedExtraHeaders); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid allowed_extra_headers: %v", err)) + return + } // Check if OAuth flow is needed if req.AuthType == "oauth" { @@ -375,27 +379,23 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { } } - isPingAvailable := true - if req.IsPingAvailable != nil { - isPingAvailable = *req.IsPingAvailable - } - // Store MCP client config in OAuth provider memory (not in database) // It will be stored in database only after OAuth completion pendingConfig := schemas.MCPClientConfig{ - ID: req.ClientID, - Name: req.Name, - IsCodeModeClient: req.IsCodeModeClient, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: toolSyncInterval, - ConnectionType: schemas.MCPConnectionType(req.ConnectionType), - ConnectionString: req.ConnectionString, - StdioConfig: req.StdioConfig, - AuthType: schemas.MCPAuthType(req.AuthType), - OauthConfigID: &flowInitiation.OauthConfigID, - ToolsToExecute: req.ToolsToExecute, - ToolsToAutoExecute: req.ToolsToAutoExecute, - Headers: req.Headers, + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ConnectionType: schemas.MCPConnectionType(req.ConnectionType), + ConnectionString: req.ConnectionString, + StdioConfig: req.StdioConfig, + AuthType: schemas.MCPAuthType(req.AuthType), + OauthConfigID: &flowInitiation.OauthConfigID, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, + Headers: req.Headers, + AllowedExtraHeaders: req.AllowedExtraHeaders, } // Store pending config in database (associated with oauth_config_id for multi-instance support) @@ -432,26 +432,22 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { } // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) - // Dereference IsPingAvailable pointer, defaulting to true if nil (new clients default to ping available) - isPingAvailable := true - if req.IsPingAvailable != nil { - isPingAvailable = *req.IsPingAvailable - } schemasConfig := &schemas.MCPClientConfig{ - ID: req.ClientID, - Name: req.Name, - IsCodeModeClient: req.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(req.ConnectionType), - ConnectionString: req.ConnectionString, - StdioConfig: req.StdioConfig, - ToolsToExecute: req.ToolsToExecute, - ToolsToAutoExecute: req.ToolsToAutoExecute, - Headers: req.Headers, - AuthType: schemas.MCPAuthType(req.AuthType), - OauthConfigID: req.OauthConfigID, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: toolSyncInterval, - ToolPricing: req.ToolPricing, + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(req.ConnectionType), + ConnectionString: req.ConnectionString, + StdioConfig: req.StdioConfig, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, + Headers: req.Headers, + AllowedExtraHeaders: req.AllowedExtraHeaders, + AuthType: schemas.MCPAuthType(req.AuthType), + OauthConfigID: req.OauthConfigID, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ToolPricing: req.ToolPricing, } // Creating MCP client config in config store @@ -518,6 +514,10 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) return } + if err := validateAllowedExtraHeaders(req.AllowedExtraHeaders); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid allowed_extra_headers: %v", err)) + return + } // Get existing config to handle redacted values var existingConfig *schemas.MCPClientConfig if h.store.MCPConfig != nil { @@ -566,25 +566,22 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { } } // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) - isPingAvailable := true - if req.IsPingAvailable != nil { - isPingAvailable = *req.IsPingAvailable - } schemasConfig := &schemas.MCPClientConfig{ - ID: req.ClientID, - Name: req.Name, - IsCodeModeClient: req.IsCodeModeClient, - ConnectionType: existingConfig.ConnectionType, - ConnectionString: existingConfig.ConnectionString, - StdioConfig: existingConfig.StdioConfig, - ToolsToExecute: req.ToolsToExecute, - ToolsToAutoExecute: req.ToolsToAutoExecute, - Headers: req.Headers, - AuthType: existingConfig.AuthType, - OauthConfigID: existingConfig.OauthConfigID, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: toolSyncInterval, - ToolPricing: req.ToolPricing, + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + ConnectionType: existingConfig.ConnectionType, + ConnectionString: existingConfig.ConnectionString, + StdioConfig: existingConfig.StdioConfig, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, + Headers: req.Headers, + AllowedExtraHeaders: req.AllowedExtraHeaders, + AuthType: existingConfig.AuthType, + OauthConfigID: existingConfig.OauthConfigID, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ToolPricing: req.ToolPricing, } // Update MCP client in memory if err := h.mcpManager.UpdateMCPClient(ctx, id, schemasConfig); err != nil { @@ -654,6 +651,13 @@ func validateToolsToExecute(toolsToExecute schemas.WhiteList) error { return nil } +func validateAllowedExtraHeaders(allowedExtraHeaders schemas.WhiteList) error { + if err := allowedExtraHeaders.Validate(); err != nil { + return fmt.Errorf("invalid allowed_extra_headers: %w", err) + } + return nil +} + func validateToolsToAutoExecute(toolsToAutoExecute schemas.WhiteList, toolsToExecute schemas.WhiteList) error { if err := toolsToAutoExecute.Validate(); err != nil { return fmt.Errorf("invalid tools_to_auto_execute: %w", err) @@ -738,7 +742,11 @@ func mergeMCPRedactedValues(incoming *configstoreTables.TableMCPClient, oldRaw, // Preserve IsPingAvailable if not explicitly set in incoming request // This prevents the zero-value (false) from overwriting the existing DB value if incoming.IsPingAvailable == nil { - merged.IsPingAvailable = bifrost.Ptr(oldRaw.IsPingAvailable) + merged.IsPingAvailable = oldRaw.IsPingAvailable + } + // Preserve AllowedExtraHeaders if not explicitly set in incoming request + if incoming.AllowedExtraHeaders == nil { + merged.AllowedExtraHeaders = oldRaw.AllowedExtraHeaders } return merged diff --git a/transports/bifrost-http/handlers/mcpinference.go b/transports/bifrost-http/handlers/mcpinference.go index 80856dd8a3..4e80e18d5d 100644 --- a/transports/bifrost-http/handlers/mcpinference.go +++ b/transports/bifrost-http/handlers/mcpinference.go @@ -14,14 +14,14 @@ import ( type MCPInferenceHandler struct { client *bifrost.Bifrost - store *lib.Config + config *lib.Config } // NewMCPInferenceHandler creates a new MCP inference handler instance -func NewMCPInferenceHandler(client *bifrost.Bifrost, store *lib.Config) *MCPInferenceHandler { +func NewMCPInferenceHandler(client *bifrost.Bifrost, config *lib.Config) *MCPInferenceHandler { return &MCPInferenceHandler{ client: client, - store: store, + config: config, } } @@ -60,7 +60,7 @@ func (h *MCPInferenceHandler) executeChatMCPTool(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.store.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -93,7 +93,7 @@ func (h *MCPInferenceHandler) executeResponsesMCPTool(ctx *fasthttp.RequestCtx) } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.store.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") diff --git a/transports/bifrost-http/handlers/mcpserver.go b/transports/bifrost-http/handlers/mcpserver.go index 2ab0dddc6f..2d49b9b379 100644 --- a/transports/bifrost-http/handlers/mcpserver.go +++ b/transports/bifrost-http/handlers/mcpserver.go @@ -84,7 +84,7 @@ func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() // Use mcp-go server to handle the request @@ -123,7 +123,7 @@ func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { ctx.Response.Header.Set("Connection", "keep-alive") // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) // Use streaming response writer ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { diff --git a/transports/bifrost-http/integrations/bedrock_test.go b/transports/bifrost-http/integrations/bedrock_test.go index b2050471a1..16ab7ad4d5 100644 --- a/transports/bifrost-http/integrations/bedrock_test.go +++ b/transports/bifrost-http/integrations/bedrock_test.go @@ -16,9 +16,10 @@ import ( // mockHandlerStore implements lib.HandlerStore for testing type mockHandlerStore struct { - allowDirectKeys bool - headerMatcher *lib.HeaderMatcher - availableProviders []schemas.ModelProvider + allowDirectKeys bool + headerMatcher *lib.HeaderMatcher + availableProviders []schemas.ModelProvider + mcpHeaderCombinedAllowlist schemas.WhiteList } func (m *mockHandlerStore) ShouldAllowDirectKeys() bool { @@ -49,6 +50,10 @@ func (m *mockHandlerStore) GetKVStore() *kvstore.Store { return nil } +func (m *mockHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { + return m.mcpHeaderCombinedAllowlist +} + // Ensure mockHandlerStore implements lib.HandlerStore var _ lib.HandlerStore = (*mockHandlerStore)(nil) diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index e1f76d3652..491161444c 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -611,7 +611,7 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle var rawBody []byte // Execute the request through Bifrost - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher(), g.handlerStore.GetMCPHeaderCombinedAllowlist()) // Set integration type to context bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, string(config.Type)) @@ -2653,7 +2653,7 @@ func (g *GenericRouter) handlePassthrough(ctx *fasthttp.RequestCtx) { return true }) - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher(), g.handlerStore.GetMCPHeaderCombinedAllowlist()) if directKey := ctx.UserValue(string(schemas.BifrostContextKeyDirectKey)); directKey != nil { if key, ok := directKey.(schemas.Key); ok { bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 802c36ca54..de53b1e8b0 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -74,6 +74,8 @@ type HandlerStore interface { // GetKVStore returns the shared in-memory kvstore instance. // Returns nil if not initialized. GetKVStore() *kvstore.Store + // GetMCPHeaderCombinedAllowlist returns the combined allowlist for MCP headers + GetMCPHeaderCombinedAllowlist() schemas.WhiteList } // Retry backoff constants for validation @@ -1726,17 +1728,18 @@ func mergePluginsFromFile(ctx context.Context, config *Config, configData *Confi // convertSchemasMCPClientConfigToTable converts schemas.MCPClientConfig to tables.TableMCPClient func convertSchemasMCPClientConfigToTable(clientConfig *schemas.MCPClientConfig) *configstoreTables.TableMCPClient { return &configstoreTables.TableMCPClient{ - ClientID: clientConfig.ID, - Name: clientConfig.Name, - IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: string(clientConfig.ConnectionType), - ConnectionString: clientConfig.ConnectionString, - StdioConfig: clientConfig.StdioConfig, - ToolsToExecute: clientConfig.ToolsToExecute, - ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, - Headers: clientConfig.Headers, - AuthType: string(clientConfig.AuthType), - OauthConfigID: clientConfig.OauthConfigID, + ClientID: clientConfig.ID, + Name: clientConfig.Name, + IsCodeModeClient: clientConfig.IsCodeModeClient, + ConnectionType: string(clientConfig.ConnectionType), + ConnectionString: clientConfig.ConnectionString, + StdioConfig: clientConfig.StdioConfig, + ToolsToExecute: clientConfig.ToolsToExecute, + ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, + Headers: clientConfig.Headers, + AllowedExtraHeaders: clientConfig.AllowedExtraHeaders, + AuthType: string(clientConfig.AuthType), + OauthConfigID: clientConfig.OauthConfigID, } } @@ -2526,6 +2529,29 @@ func (c *Config) SetHeaderMatcher(m *HeaderMatcher) { c.headerMatcher.Store(m) } +// GetMCPHeaderCombinedAllowlist returns the combined allowlist for MCP headers across all MCP clients. +// This method acquires a muMCP read lock and is safe for concurrent access from hot paths. +func (c *Config) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { + c.muMCP.RLock() + defer c.muMCP.RUnlock() + + if c.MCPConfig == nil || len(c.MCPConfig.ClientConfigs) == 0 { + return schemas.WhiteList{} + } + + allowlist := schemas.WhiteList{} + for _, mcpClient := range c.MCPConfig.ClientConfigs { + if mcpClient == nil { + continue + } + if mcpClient.AllowedExtraHeaders.IsUnrestricted() { + return schemas.WhiteList{"*"} + } + allowlist = append(allowlist, mcpClient.AllowedExtraHeaders...) + } + return allowlist +} + // GetPluginOrder returns the names of all base plugins in their sorted placement order. // This method is lock-free and safe for concurrent access from hot paths. // Do not modify the returned slice; it is a shared snapshot and must be treated read-only. @@ -3398,6 +3424,7 @@ func (c *Config) UpdateMCPClient(ctx context.Context, id string, updatedConfig * c.MCPConfig.ClientConfigs[configIndex].Headers = updatedConfig.Headers c.MCPConfig.ClientConfigs[configIndex].ToolsToExecute = updatedConfig.ToolsToExecute c.MCPConfig.ClientConfigs[configIndex].ToolsToAutoExecute = updatedConfig.ToolsToAutoExecute + c.MCPConfig.ClientConfigs[configIndex].AllowedExtraHeaders = updatedConfig.AllowedExtraHeaders c.MCPConfig.ClientConfigs[configIndex].ToolPricing = updatedConfig.ToolPricing c.MCPConfig.ClientConfigs[configIndex].IsPingAvailable = updatedConfig.IsPingAvailable c.MCPConfig.ClientConfigs[configIndex].ToolSyncInterval = updatedConfig.ToolSyncInterval diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 70037d0a50..74cbd3d2e7 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -397,7 +397,7 @@ func NewMockConfigStore() *MockConfigStore { func (m *MockConfigStore) Ping(ctx context.Context) error { return nil } func (m *MockConfigStore) EncryptPlaintextRows(ctx context.Context) error { return nil } func (m *MockConfigStore) Close(ctx context.Context) error { return nil } -func (m *MockConfigStore) DB() *gorm.DB { return nil } +func (m *MockConfigStore) DB() *gorm.DB { return nil } func (m *MockConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error { return fn(nil) } @@ -489,30 +489,32 @@ func (m *MockConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, if m.mcpConfig.ClientConfigs[i].ID == id { // Found the entry, update it with the new config m.mcpConfig.ClientConfigs[i] = &schemas.MCPClientConfig{ - ID: clientConfig.ClientID, - Name: clientConfig.Name, - IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), - ConnectionString: clientConfig.ConnectionString, - StdioConfig: clientConfig.StdioConfig, - Headers: clientConfig.Headers, - ToolsToExecute: clientConfig.ToolsToExecute, - ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, + ID: clientConfig.ClientID, + Name: clientConfig.Name, + IsCodeModeClient: clientConfig.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), + ConnectionString: clientConfig.ConnectionString, + StdioConfig: clientConfig.StdioConfig, + Headers: clientConfig.Headers, + ToolsToExecute: clientConfig.ToolsToExecute, + ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, + AllowedExtraHeaders: clientConfig.AllowedExtraHeaders, } return nil } } // If not found, create a new entry (similar to CreateMCPClientConfig behavior) m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, &schemas.MCPClientConfig{ - ID: clientConfig.ClientID, - Name: clientConfig.Name, - IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), - ConnectionString: clientConfig.ConnectionString, - StdioConfig: clientConfig.StdioConfig, - Headers: clientConfig.Headers, - ToolsToExecute: clientConfig.ToolsToExecute, - ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, + ID: clientConfig.ClientID, + Name: clientConfig.Name, + IsCodeModeClient: clientConfig.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), + ConnectionString: clientConfig.ConnectionString, + StdioConfig: clientConfig.StdioConfig, + Headers: clientConfig.Headers, + ToolsToExecute: clientConfig.ToolsToExecute, + ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, + AllowedExtraHeaders: clientConfig.AllowedExtraHeaders, }) return nil @@ -12293,17 +12295,17 @@ func TestGenerateClientConfigHash(t *testing.T) { initTestLogger() cc1 := configstore.ClientConfig{ - DropExcessRequests: true, - InitialPoolSize: 300, - PrometheusLabels: []string{"label1", "label2"}, - EnableLogging: true, - DisableContentLogging: false, - LogRetentionDays: 30, + DropExcessRequests: true, + InitialPoolSize: 300, + PrometheusLabels: []string{"label1", "label2"}, + EnableLogging: true, + DisableContentLogging: false, + LogRetentionDays: 30, EnforceAuthOnInference: false, AllowDirectKeys: true, - AllowedOrigins: []string{"http://localhost:3000"}, - MaxRequestBodySizeMB: 100, - EnableLiteLLMFallbacks: false, + AllowedOrigins: []string{"http://localhost:3000"}, + MaxRequestBodySizeMB: 100, + EnableLiteLLMFallbacks: false, } hash1, err := cc1.GenerateClientConfigHash() @@ -13342,30 +13344,30 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { labels := []string{"provider", "model", "status"} ccToSave := tables.TableClientConfig{ - DropExcessRequests: true, - InitialPoolSize: 300, - PrometheusLabels: labels, - EnableLogging: true, - DisableContentLogging: false, - LogRetentionDays: 30, + DropExcessRequests: true, + InitialPoolSize: 300, + PrometheusLabels: labels, + EnableLogging: true, + DisableContentLogging: false, + LogRetentionDays: 30, EnforceAuthOnInference: false, - AllowDirectKeys: true, - MaxRequestBodySizeMB: 100, - EnableLiteLLMFallbacks: false, + AllowDirectKeys: true, + MaxRequestBodySizeMB: 100, + EnableLiteLLMFallbacks: false, } // Generate hash from config clientConfig := configstore.ClientConfig{ - DropExcessRequests: ccToSave.DropExcessRequests, - InitialPoolSize: ccToSave.InitialPoolSize, - PrometheusLabels: ccToSave.PrometheusLabels, - EnableLogging: ccToSave.EnableLogging, - DisableContentLogging: ccToSave.DisableContentLogging, - LogRetentionDays: ccToSave.LogRetentionDays, + DropExcessRequests: ccToSave.DropExcessRequests, + InitialPoolSize: ccToSave.InitialPoolSize, + PrometheusLabels: ccToSave.PrometheusLabels, + EnableLogging: ccToSave.EnableLogging, + DisableContentLogging: ccToSave.DisableContentLogging, + LogRetentionDays: ccToSave.LogRetentionDays, EnforceAuthOnInference: ccToSave.EnforceAuthOnInference, - AllowDirectKeys: ccToSave.AllowDirectKeys, - MaxRequestBodySizeMB: ccToSave.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: ccToSave.EnableLiteLLMFallbacks, + AllowDirectKeys: ccToSave.AllowDirectKeys, + MaxRequestBodySizeMB: ccToSave.MaxRequestBodySizeMB, + EnableLiteLLMFallbacks: ccToSave.EnableLiteLLMFallbacks, } hashBeforeSave, _ := clientConfig.GenerateClientConfigHash() @@ -13375,16 +13377,16 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { db.Where("id = ?", ccToSave.ID).First(&ccFromDB) clientConfigFromDB := configstore.ClientConfig{ - DropExcessRequests: ccFromDB.DropExcessRequests, - InitialPoolSize: ccFromDB.InitialPoolSize, - PrometheusLabels: ccFromDB.PrometheusLabels, - EnableLogging: ccFromDB.EnableLogging, - DisableContentLogging: ccFromDB.DisableContentLogging, - LogRetentionDays: ccFromDB.LogRetentionDays, + DropExcessRequests: ccFromDB.DropExcessRequests, + InitialPoolSize: ccFromDB.InitialPoolSize, + PrometheusLabels: ccFromDB.PrometheusLabels, + EnableLogging: ccFromDB.EnableLogging, + DisableContentLogging: ccFromDB.DisableContentLogging, + LogRetentionDays: ccFromDB.LogRetentionDays, EnforceAuthOnInference: ccFromDB.EnforceAuthOnInference, - AllowDirectKeys: ccFromDB.AllowDirectKeys, - MaxRequestBodySizeMB: ccFromDB.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: ccFromDB.EnableLiteLLMFallbacks, + AllowDirectKeys: ccFromDB.AllowDirectKeys, + MaxRequestBodySizeMB: ccFromDB.MaxRequestBodySizeMB, + EnableLiteLLMFallbacks: ccFromDB.EnableLiteLLMFallbacks, } hashAfterLoad, _ := clientConfigFromDB.GenerateClientConfigHash() @@ -15407,13 +15409,13 @@ var enterpriseSchemaPaths = map[string]bool{ var excludedGoFields = map[string]map[string]bool{ // ClientConfig - MCP fields are managed at MCP level, not client level "configstore.ClientConfig": { - "ConfigHash": true, - "allowed_headers": true, // Internal use - "mcp_agent_depth": true, // Managed via MCP config - "mcp_code_mode_binding_level": true, - "mcp_tool_execution_timeout": true, - "mcp_tool_sync_interval": true, - "mcp_disable_auto_tool_inject": true, + "ConfigHash": true, + "allowed_headers": true, // Internal use + "mcp_agent_depth": true, // Managed via MCP config + "mcp_code_mode_binding_level": true, + "mcp_tool_execution_timeout": true, + "mcp_tool_sync_interval": true, + "mcp_disable_auto_tool_inject": true, }, "configstore.ProviderConfig": {"ConfigHash": true}, // GovernanceConfig - some fields are internal/enterprise diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index 7131faab23..ccc67cd9f2 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -92,7 +92,7 @@ const ( // // Maxim tracing data, MCP filters, governance keys, API keys, cache settings, // // session stickiness, and extra headers -func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, matcher *HeaderMatcher) (*schemas.BifrostContext, context.CancelFunc) { +func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, matcher *HeaderMatcher, mcpHeaderCombinedAllowlist schemas.WhiteList) (*schemas.BifrostContext, context.CancelFunc) { // Reuse a shared request-scoped context when available. var bifrostCtx *schemas.BifrostContext var cancel context.CancelFunc @@ -141,6 +141,8 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat maximTags := make(map[string]string) // Initialize extra headers map for headers prefixed with x-bf-eh- extraHeaders := make(map[string][]string) + // Initialize extra headers map for headers in the mcp header combined allowlist + mcpExtraHeaders := make(map[string][]string) // Security denylist of header names that should never be accepted (case-insensitive) // This denylist is always enforced regardless of user configuration securityDenylist := map[string]bool{ @@ -377,6 +379,11 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat return true } } + // Handle MCP extra headers + if mcpHeaderCombinedAllowlist.IsAllowed(keyStr) { + mcpExtraHeaders[keyStr] = append(mcpExtraHeaders[keyStr], string(value)) + return true + } // Send back raw response header if keyStr == "x-bf-send-back-raw-response" { if valueStr := string(value); valueStr == "true" { @@ -411,6 +418,11 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat bifrostCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders) } + // Store collected MCP extra headers in the context if any were found + if len(mcpExtraHeaders) > 0 { + bifrostCtx.SetValue(schemas.BifrostContextKeyMCPExtraHeaders, mcpExtraHeaders) + } + // Collect all request headers for downstream use (e.g., governance required headers check) // Keys are lowercased for case-insensitive lookup allHeaders := make(map[string]string) diff --git a/transports/bifrost-http/lib/ctx_test.go b/transports/bifrost-http/lib/ctx_test.go index 3f522a3548..396f7a57f8 100644 --- a/transports/bifrost-http/lib/ctx_test.go +++ b/transports/bifrost-http/lib/ctx_test.go @@ -16,7 +16,7 @@ func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) { base.SetValue(schemas.BifrostContextKeyRequestID, "req-shared") ctx.SetUserValue(FastHTTPUserValueBifrostContext, base) - converted, cancel := ConvertToBifrostContext(ctx, false, nil) + converted, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) defer cancel() if converted == nil { @@ -36,13 +36,13 @@ func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) { func TestConvertToBifrostContext_SecondCallReturnsSameSharedContext(t *testing.T) { ctx := &fasthttp.RequestCtx{} - first, cancelFirst := ConvertToBifrostContext(ctx, false, nil) + first, cancelFirst := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) defer cancelFirst() if first == nil { t.Fatal("expected first context to be non-nil") } - second, cancelSecond := ConvertToBifrostContext(ctx, false, nil) + second, cancelSecond := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) defer cancelSecond() if second == nil { t.Fatal("expected second context to be non-nil") @@ -69,7 +69,7 @@ func TestConvertToBifrostContext_StarAllowlistSecurityHeadersBlocked(t *testing. ctx.Request.Header.Set("x-bf-eh-connection", "should-be-blocked") ctx.Request.Header.Set("x-bf-eh-proxy-authorization", "should-be-blocked") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -103,7 +103,7 @@ func TestConvertToBifrostContext_StarAllowlistDirectForwardingSecurityBlocked(t // Security headers sent directly — should be blocked ctx.Request.Header.Set("proxy-authorization", "should-be-blocked") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -140,7 +140,7 @@ func TestConvertToBifrostContext_PrefixWildcardDirectForwarding(t *testing.T) { // Header not matching the pattern ctx.Request.Header.Set("openai-version", "should-not-forward") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -168,7 +168,7 @@ func TestConvertToBifrostContext_WildcardAllowlistFiltering(t *testing.T) { ctx.Request.Header.Set("x-bf-eh-anthropic-version", "2024-01-01") ctx.Request.Header.Set("x-bf-eh-openai-version", "should-be-blocked") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -196,7 +196,7 @@ func TestConvertToBifrostContext_WildcardDenylistBlocking(t *testing.T) { ctx.Request.Header.Set("x-bf-eh-x-internal-secret", "blocked-value") ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -217,7 +217,7 @@ func TestConvertToBifrostContext_NilMatcher(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) diff --git a/transports/changelog.md b/transports/changelog.md index 0a42a6b57e..e03eadeb53 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -1,4 +1,5 @@ - feat: VK provider config key_ids now supports ["*"] wildcard to allow all keys; empty key_ids denies all; handler resolves wildcard to AllowAllKeys flag without DB key lookups - feat: add option to disable automatic MCP tool injection per request - feat: virtual key MCP configs now act as an execution-time allow-list — tools not permitted by the VK are blocked at inference and MCP tool execution -- refactor: standardize empty array conventions in bifrost. Empty array means no tools/keys are allowed, ["*"] means all tools/keys are allowed. \ No newline at end of file +- refactor: standardize empty array conventions in bifrost. Empty array means no tools/keys are allowed, ["*"] means all tools/keys are allowed. +- feat: add support for request level extra headers in MCP tool execution. \ No newline at end of file diff --git a/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx b/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx index ee042c0414..8ea6a99ad6 100644 --- a/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx +++ b/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx @@ -71,6 +71,7 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], tool_pricing: mcpClient.config.tool_pricing || {}, tool_sync_interval: toolSyncIntervalToMinutes(mcpClient.config.tool_sync_interval), + allowed_extra_headers: mcpClient.config.allowed_extra_headers || [], }, }); @@ -85,6 +86,7 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], tool_pricing: mcpClient.config.tool_pricing || {}, tool_sync_interval: toolSyncIntervalToMinutes(mcpClient.config.tool_sync_interval), + allowed_extra_headers: mcpClient.config.allowed_extra_headers || [], }); }, [form, mcpClient]); @@ -101,6 +103,7 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: tools_to_auto_execute: data.tools_to_auto_execute, tool_pricing: data.tool_pricing, tool_sync_interval: data.tool_sync_interval ?? 0, + allowed_extra_headers: data.allowed_extra_headers, }, }).unwrap(); @@ -381,6 +384,48 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: )} /> + ( + +
+ Allowed Extra Headers + + + + + + +

+ Allowlist of headers that callers can forward to this MCP server at request time. +

+
+
+
+
+ + { + const raw = e.target.value; + const parsed = raw.trim() ? raw.split(",").map((h) => h.trim()).filter(Boolean) : []; + field.onChange(parsed); + }} + onBlur={field.onBlur} + /> + +

+ Comma-separated header names, or * to allow all. Leave empty to block all extra headers. +

+ +
+ )} + /> {/* Client Configuration */}
diff --git a/ui/lib/types/mcp.ts b/ui/lib/types/mcp.ts index 14ec3996df..77288daff4 100644 --- a/ui/lib/types/mcp.ts +++ b/ui/lib/types/mcp.ts @@ -40,6 +40,7 @@ export interface MCPClientConfig { is_ping_available?: boolean; tool_pricing?: Record; tool_sync_interval?: number; // Per-client override in minutes (0 = use global, -1 = disabled) + allowed_extra_headers?: string[]; // Allowlist of x-bf-eh-* headers forwarded to this MCP server. ["*"] = allow all. } export interface MCPClient { @@ -90,6 +91,7 @@ export interface UpdateMCPClientRequest { is_ping_available?: boolean; tool_pricing?: Record; tool_sync_interval?: number; // Per-client override in minutes (0 = use global, -1 = disabled) + allowed_extra_headers?: string[]; // Allowlist of x-bf-eh-* headers forwarded to this MCP server. ["*"] = allow all. } // Pagination params for MCP clients list diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 8569c92a4c..fb7dfc178c 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -935,6 +935,17 @@ export const mcpClientUpdateSchema = z.object({ ), tool_pricing: z.record(z.string(), z.number().min(0, "Cost must be non-negative")).optional(), tool_sync_interval: z.number().optional(), // -1 = disabled, 0 = use global, >0 = custom interval in minutes + allowed_extra_headers: z + .array(z.string()) + .optional() + .refine( + (headers) => { + if (!headers || headers.length === 0) return true; + const hasWildcard = headers.includes("*"); + return !hasWildcard || headers.length === 1; + }, + { message: "Wildcard '*' cannot be combined with specific header names" }, + ), }); // Global proxy type schema