Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
- feat: add DisableAutoToolInject to MCPToolManagerConfig to suppress automatic MCP tool injection per request
- feat: add BifrostContextKeyMCPAddedTools to context to track MCP tools added to the request
- refactor: standardize empty array conventions in bifrost. Empty array means deny all, ["*"] means allow all for models/tools/keys.
- refactor: standardize empty array conventions in bifrost. Empty array means deny all, ["*"] means allow all for models/tools/keys.
- feat: add support for request-level extra headers in MCP tool execution using BifrostContextKeyMCPExtraHeaders key in context.
21 changes: 13 additions & 8 deletions core/mcp/clientmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,14 @@ func (m *MCPManager) UpdateClient(id string, updatedConfig *schemas.MCPClientCon
ConfigHash: client.ExecutionConfig.ConfigHash,
ToolPricing: maps.Clone(client.ExecutionConfig.ToolPricing),
// Updatable fields - copy from updated config with proper cloning
Name: updatedConfig.Name,
IsCodeModeClient: updatedConfig.IsCodeModeClient,
Headers: maps.Clone(updatedConfig.Headers),
ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute),
ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute),
IsPingAvailable: updatedConfig.IsPingAvailable,
ToolSyncInterval: updatedConfig.ToolSyncInterval,
Name: updatedConfig.Name,
IsCodeModeClient: updatedConfig.IsCodeModeClient,
Headers: maps.Clone(updatedConfig.Headers),
ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute),
ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute),
AllowedExtraHeaders: slices.Clone(updatedConfig.AllowedExtraHeaders),
IsPingAvailable: updatedConfig.IsPingAvailable,
ToolSyncInterval: updatedConfig.ToolSyncInterval,
}

// Atomically replace the config pointer
Expand Down Expand Up @@ -663,7 +664,11 @@ func (m *MCPManager) connectToMCPClient(config *schemas.MCPClientConfig) error {
}

// Start health monitoring for the client
monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, config.IsPingAvailable, m.logger)
isPingAvailable := true
if config.IsPingAvailable != nil {
isPingAvailable = *config.IsPingAvailable
}
monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, isPingAvailable, m.logger)
m.healthMonitorManager.StartMonitoring(monitor)

// Start tool syncing for the client (skip for internal bifrost client)
Expand Down
3 changes: 1 addition & 2 deletions core/mcp/codemode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package mcp

import (
"context"
"sync"
"time"

Expand Down Expand Up @@ -31,7 +30,7 @@ type CodeMode interface {

// ExecuteTool handles a code mode tool call by name.
// Returns the response message and any error that occurred.
ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error)
ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error)

// IsCodeModeTool returns true if the given tool name is a code mode tool.
IsCodeModeTool(toolName string) bool
Expand Down
95 changes: 17 additions & 78 deletions core/mcp/codemode/starlark/executecode.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ package starlark
import (
"context"
"fmt"
"net/http"
"strings"
"time"

"github.com/bytedance/sonic"
"github.com/mark3labs/mcp-go/mcp"

codemcp "github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/mcp/utils"
"github.com/maximhq/bifrost/core/schemas"
"go.starlark.net/starlark"
"go.starlark.net/starlarkstruct"
Expand Down Expand Up @@ -103,7 +103,7 @@ func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool {
}

// handleExecuteToolCode handles the executeToolCode tool call.
func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
func (s *StarlarkCodeMode) handleExecuteToolCode(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
toolName := "unknown"
if toolCall.Function.Name != nil {
toolName = *toolCall.Function.Name
Expand Down Expand Up @@ -197,7 +197,7 @@ func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall s
}

// executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings.
func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) ExecutionResult {
func (s *StarlarkCodeMode) executeCode(ctx *schemas.BifrostContext, code string) ExecutionResult {
logs := []string{}

s.logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix)
Expand Down Expand Up @@ -372,7 +372,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi
}

// callMCPTool calls an MCP tool and returns the result.
func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) {
func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) {
// Get available tools per client
availableToolsPerClient := s.clientManager.GetToolPerClient(ctx)

Expand Down Expand Up @@ -400,29 +400,25 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName
// Strip the client name prefix from tool name before calling MCP server
originalToolName := stripClientPrefix(toolName, clientName)

// Get BifrostContext for plugin pipeline
var bifrostCtx *schemas.BifrostContext
var ok bool
if bifrostCtx, ok = ctx.(*schemas.BifrostContext); !ok {
return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog)
originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string)
if !ok {
originalRequestID = ""
}

originalRequestID, _ := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string)

// Generate new request ID for this nested tool call
var newRequestID string
if s.fetchNewRequestIDFunc != nil {
newRequestID = s.fetchNewRequestIDFunc(bifrostCtx)
newRequestID = s.fetchNewRequestIDFunc(ctx)
} else {
newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName)
}

// Create new child context
deadline, hasDeadline := bifrostCtx.Deadline()
deadline, hasDeadline := ctx.Deadline()
if !hasDeadline {
deadline = schemas.NoDeadline
}
nestedCtx := schemas.NewBifrostContext(bifrostCtx, deadline)
nestedCtx := schemas.NewBifrostContext(ctx, deadline)
nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID)
if originalRequestID != "" {
nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID)
Expand Down Expand Up @@ -451,13 +447,17 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName

// Check if plugin pipeline is available
if s.pluginPipelineProvider == nil {
return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog)
// Should never happen, but just in case
s.logger.Warn("%s Plugin pipeline provider is nil", codemcp.CodeModeLogPrefix)
return nil, fmt.Errorf("plugin pipeline provider is nil")
}

// Get plugin pipeline and run hooks
pipeline := s.pluginPipelineProvider()
if pipeline == nil {
return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog)
// Should never happen, but just in case
s.logger.Warn("%s Plugin pipeline is nil", codemcp.CodeModeLogPrefix)
return nil, fmt.Errorf("plugin pipeline is nil")
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
defer s.releasePluginPipeline(pipeline)

Expand Down Expand Up @@ -515,14 +515,7 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName
Name: toolNameToCall,
Arguments: args,
},
}

if client.ExecutionConfig.Headers != nil {
headers := make(http.Header)
for key, value := range client.ExecutionConfig.Headers {
headers.Add(key, value.GetValue())
}
callRequest.Header = headers
Header: utils.GetHeadersForToolExecution(nestedCtx, client),
}

toolExecutionTimeout := s.getToolExecutionTimeout()
Expand Down Expand Up @@ -604,57 +597,3 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName

return nil, fmt.Errorf("plugin post-hooks returned invalid response")
}

// callMCPToolDirect executes an MCP tool call directly without plugin hooks.
func (s *StarlarkCodeMode) callMCPToolDirect(ctx context.Context, client *schemas.MCPClientState, originalToolName, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) {
callRequest := mcp.CallToolRequest{
Request: mcp.Request{
Method: string(mcp.MethodToolsCall),
},
Params: mcp.CallToolParams{
Name: originalToolName,
Arguments: args,
},
}

if client.ExecutionConfig.Headers != nil {
headers := make(http.Header)
for key, value := range client.ExecutionConfig.Headers {
headers.Add(key, value.GetValue())
}
callRequest.Header = headers
}

toolExecutionTimeout := s.getToolExecutionTimeout()
toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
defer cancel()

logToolName := stripClientPrefix(toolName, clientName)
logToolName = strings.ReplaceAll(logToolName, "-", "_")

toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest)
if callErr != nil {
s.logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, logToolName, callErr)
appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, logToolName, callErr))
return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, logToolName, callErr)
}

rawResult := extractTextFromMCPResponse(toolResponse, toolName)

if after, ok := strings.CutPrefix(rawResult, "Error: "); ok {
errorMsg := after
s.logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, logToolName, errorMsg)
appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, logToolName, errorMsg))
return nil, fmt.Errorf("%s", errorMsg)
}

var finalResult interface{}
if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil {
finalResult = rawResult
}

resultStr := formatResultForLog(finalResult)
appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr))

return finalResult, nil
}
3 changes: 1 addition & 2 deletions core/mcp/codemode/starlark/starlark.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package starlark

import (
"context"
"fmt"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -111,7 +110,7 @@ func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool {
// Returns:
// - *schemas.ChatMessage: The tool response message
// - error: Any error that occurred during execution
func (s *StarlarkCodeMode) ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
func (s *StarlarkCodeMode) ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
if toolCall.Function.Name == nil {
return nil, fmt.Errorf("tool call missing function name")
}
Expand Down
6 changes: 5 additions & 1 deletion core/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider
manager.clientMap[clientConfig.ID].State = schemas.MCPConnectionStateDisconnected
}
manager.mu.Unlock()
monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, clientConfig.IsPingAvailable, manager.logger)
isPingAvailable := true
if clientConfig.IsPingAvailable != nil {
isPingAvailable = *clientConfig.IsPingAvailable
}
monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, isPingAvailable, manager.logger)
manager.healthMonitorManager.StartMonitoring(monitor)
}
}(clientConfig)
Expand Down
11 changes: 2 additions & 9 deletions core/mcp/toolmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync/atomic"
"time"

"github.com/mark3labs/mcp-go/mcp"
"github.com/maximhq/bifrost/core/mcp/utils"
"github.com/maximhq/bifrost/core/schemas"
)

Expand Down Expand Up @@ -553,14 +553,7 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall
Name: originalMCPToolName,
Arguments: arguments,
},
}

if client.ExecutionConfig.Headers != nil {
headers := make(http.Header)
for key, value := range client.ExecutionConfig.Headers {
headers.Add(key, value.GetValue())
}
callRequest.Header = headers
Header: utils.GetHeadersForToolExecution(ctx, client),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

// Create timeout context for tool execution
Expand Down
49 changes: 49 additions & 0 deletions core/mcp/utils/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package utils

import (
"net/http"

"github.com/maximhq/bifrost/core/schemas"
)

// GetHeadersForToolExecution sets additional headers for tool execution.
// It returns the headers for the tool execution.
func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header {
if ctx == nil || client == nil || client.ExecutionConfig == nil {
return make(http.Header)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
headers := make(http.Header)
if client.ExecutionConfig.Headers != nil {
for key, value := range client.ExecutionConfig.Headers {
headers.Add(key, value.GetValue())
}
}
// Give priority to extra headers in the context
if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyMCPExtraHeaders).(map[string][]string); ok {
filteredHeaders := make(http.Header)
for key, values := range extraHeaders {
if client.ExecutionConfig.AllowedExtraHeaders.IsAllowed(key) {
for i, value := range values {
if i == 0 {
filteredHeaders.Set(key, value)
} else {
filteredHeaders.Add(key, value)
}
}
}
}
// Add the filtered headers to the headers
if len(filteredHeaders) > 0 {
for k, values := range filteredHeaders {
for i, v := range values {
if i == 0 {
headers.Set(k, v)
} else {
headers.Add(k, v)
}
}
}
}
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
return headers
}
1 change: 1 addition & 0 deletions core/schemas/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ const (
BifrostContextKeySSEReaderFactory BifrostContextKey = "bifrost-sse-reader-factory" // *providerUtils.SSEReaderFactory (set by enterprise — replaces default bufio.Scanner SSE readers with streaming readers)
BifrostContextKeySessionID BifrostContextKey = "bifrost-session-id" // string session ID for the request (session stickiness)
BifrostContextKeySessionTTL BifrostContextKey = "bifrost-session-ttl" // time.Duration session TTL for the request (session stickiness)
BifrostContextKeyMCPExtraHeaders BifrostContextKey = "bifrost-mcp-extra-headers" // map[string][]string (these headers are forwarded only to the MCP while tool execution if they are in the allowlist of the MCP client)
)

const (
Expand Down
27 changes: 14 additions & 13 deletions core/schemas/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,19 @@ const (

// MCPClientConfig defines tool filtering for an MCP client.
type MCPClientConfig struct {
ID string `json:"client_id"` // Client ID
Name string `json:"name"` // Client name
IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client
ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess)
ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections)
StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections)
AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth)
OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table)
State string `json:"state,omitempty"` // Connection state (connected, disconnected, error)
Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type)
InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only)
ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list.
ID string `json:"client_id"` // Client ID
Name string `json:"name"` // Client name
IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client
ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess)
ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections)
StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections)
AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth)
OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table)
State string `json:"state,omitempty"` // Connection state (connected, disconnected, error)
Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type)
AllowedExtraHeaders WhiteList `json:"allowed_extra_headers,omitempty"` // Allowlist of request-level headers that callers may forward to this MCP server at execution time
InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only)
ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list.
// ToolsToExecute semantics:
// - ["*"] => all tools are included
// - [] => no tools are included (deny-by-default)
Expand All @@ -101,7 +102,7 @@ type MCPClientConfig struct {
// - nil/omitted => treated as [] (no tools)
// - ["tool1", "tool2"] => auto-execute only the specified tools
// Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped.
IsPingAvailable bool `json:"is_ping_available"` // Whether the MCP server supports ping for health checks (default: true). If false, uses listTools for health checks.
IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true.
ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled)
ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution)
ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized)
Expand Down
Loading