diff --git a/core/bifrost.go b/core/bifrost.go index 0bd3718711..d5d80f2381 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/providers/anthropic" "github.com/maximhq/bifrost/core/providers/azure" "github.com/maximhq/bifrost/core/providers/bedrock" @@ -63,7 +64,8 @@ type Bifrost struct { pluginPipelinePool sync.Pool // Pool for PluginPipeline objects bifrostRequestPool sync.Pool // Pool for BifrostRequest objects logger schemas.Logger // logger instance, default logger is used if not provided - mcpManager *MCPManager // MCP integration manager (nil if MCP not configured) + mcpManager *mcp.MCPManager // MCP integration manager (nil if MCP not configured) + mcpInitOnce sync.Once // Ensures MCP manager is initialized only once dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. keySelector schemas.KeySelector // Custom key selector function } @@ -176,13 +178,10 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { // Initialize MCP manager if configured if config.MCPConfig != nil { - mcpManager, err := newMCPManager(bifrostCtx, *config.MCPConfig, bifrost.logger) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("failed to initialize MCP manager: %v", err)) - } else { - bifrost.mcpManager = mcpManager + bifrost.mcpInitOnce.Do(func() { + bifrost.mcpManager = mcp.NewMCPManager(bifrostCtx, *config.MCPConfig, bifrost.logger) bifrost.logger.Info("MCP integration initialized successfully") - } + }) } // Create buffered channels for each provider and start workers @@ -492,8 +491,7 @@ func (bifrost *Bifrost) TextCompletionStreamRequest(ctx context.Context, req *sc return bifrost.handleStreamRequest(ctx, bifrostReq) } -// ChatCompletionRequest sends a chat completion request to the specified provider. -func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) makeChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -519,10 +517,35 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. if err != nil { return nil, err } - //TODO: Release the response + return response.ChatResponse, nil } +// ChatCompletionRequest sends a chat completion request to the specified provider. +func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + // If ctx is nil, use the bifrost context (defensive check for mcp agent mode) + if ctx == nil { + ctx = bifrost.ctx + } + + response, err := bifrost.makeChatCompletionRequest(ctx, req) + if err != nil { + return nil, err + } + + // Check if we should enter agent mode + if bifrost.mcpManager != nil { + return bifrost.mcpManager.CheckAndExecuteAgentForChatRequest( + &ctx, + req, + response, + bifrost.makeChatCompletionRequest, + ) + } + + return response, nil +} + // ChatCompletionStreamRequest sends a chat completion stream request to the specified provider. func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req == nil { @@ -549,8 +572,7 @@ func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *sc return bifrost.handleStreamRequest(ctx, bifrostReq) } -// ResponsesRequest sends a responses request to the specified provider. -func (bifrost *Bifrost) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) makeResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -576,10 +598,34 @@ func (bifrost *Bifrost) ResponsesRequest(ctx context.Context, req *schemas.Bifro if err != nil { return nil, err } - //TODO: Release the response return response.ResponsesResponse, nil } +// ResponsesRequest sends a responses request to the specified provider. +func (bifrost *Bifrost) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + // If ctx is nil, use the bifrost context (defensive check for mcp agent mode) + if ctx == nil { + ctx = bifrost.ctx + } + + response, err := bifrost.makeResponsesRequest(ctx, req) + if err != nil { + return nil, err + } + + // Check if we should enter agent mode + if bifrost.mcpManager != nil { + return bifrost.mcpManager.CheckAndExecuteAgentForResponsesRequest( + &ctx, + req, + response, + bifrost.makeResponsesRequest, + ) + } + + return response, nil +} + // ResponsesStreamRequest sends a responses stream request to the specified provider. func (bifrost *Bifrost) ResponsesStreamRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req == nil { @@ -1089,7 +1135,7 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.registerTool(name, description, handler, toolSchema) + return bifrost.mcpManager.RegisterTool(name, description, handler, toolSchema) } // ExecuteMCPTool executes an MCP tool call and returns the result as a tool message. @@ -1112,13 +1158,12 @@ func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.Cha } } - result, err := bifrost.mcpManager.executeTool(ctx, toolCall) + result, err := bifrost.mcpManager.ExecuteTool(ctx, toolCall) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: err.Error(), - Error: err, }, } } @@ -1141,12 +1186,9 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { return nil, fmt.Errorf("MCP is not configured in this Bifrost instance") } - clients, err := bifrost.mcpManager.GetClients() - if err != nil { - return nil, err - } - + clients := bifrost.mcpManager.GetClients() clientsInConfig := make([]schemas.MCPClient, 0, len(clients)) + for _, client := range clients { tools := make([]schemas.ChatToolFunction, 0, len(client.ToolMap)) for _, tool := range client.ToolMap { @@ -1192,13 +1234,17 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { // }) func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { if bifrost.mcpManager == nil { - manager := &MCPManager{ - ctx: bifrost.ctx, - clientMap: make(map[string]*MCPClient), - logger: bifrost.logger, - } + // Use sync.Once to ensure thread-safe initialization + bifrost.mcpInitOnce.Do(func() { + bifrost.mcpManager = mcp.NewMCPManager(bifrost.ctx, schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{config}, + }, bifrost.logger) + }) + } - bifrost.mcpManager = manager + // Handle case where initialization succeeded elsewhere but manager is still nil + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP manager is not initialized") } return bifrost.mcpManager.AddClient(config) @@ -1266,6 +1312,20 @@ func (bifrost *Bifrost) ReconnectMCPClient(id string) error { return bifrost.mcpManager.ReconnectClient(id) } +// UpdateToolManagerConfig updates the tool manager config for the MCP manager. +// This allows for hot-reloading of the tool manager config at runtime. +func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int) error { + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP is not configured in this Bifrost instance") + } + + bifrost.mcpManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ + MaxAgentDepth: maxAgentDepth, + ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, + }) + return nil +} + // PROVIDER MANAGEMENT // createBaseProvider creates a provider based on the base provider type @@ -1764,11 +1824,8 @@ func (bifrost *Bifrost) tryRequest(ctx context.Context, req *schemas.BifrostRequ } // Add MCP tools to request if MCP is configured and requested - if req.RequestType != schemas.EmbeddingRequest && - req.RequestType != schemas.SpeechRequest && - req.RequestType != schemas.TranscriptionRequest && - bifrost.mcpManager != nil { - req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) + if bifrost.mcpManager != nil { + req = bifrost.mcpManager.AddToolsToRequest(ctx, req) } pipeline := bifrost.getPluginPipeline() @@ -1854,7 +1911,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.Bifro // Add MCP tools to request if MCP is configured and requested if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.mcpManager != nil { - req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) + req = bifrost.mcpManager.AddToolsToRequest(ctx, req) } pipeline := bifrost.getPluginPipeline() @@ -2596,7 +2653,7 @@ func (bifrost *Bifrost) Shutdown() { // Cleanup MCP manager if bifrost.mcpManager != nil { - err := bifrost.mcpManager.cleanup() + err := bifrost.mcpManager.Cleanup() if err != nil { bifrost.logger.Warn(fmt.Sprintf("Error cleaning up MCP manager: %s", err.Error())) } diff --git a/core/chatbot_test.go b/core/chatbot_test.go index 9f5ad7679b..6a832b7897 100644 --- a/core/chatbot_test.go +++ b/core/chatbot_test.go @@ -515,7 +515,7 @@ func (s *ChatSession) SendMessage(message string) (string, error) { s.history = append(s.history, *assistantMessage) // Check if assistant wants to use tools - if assistantMessage.ToolCalls != nil && len(assistantMessage.ToolCalls) > 0 { + if len(assistantMessage.ToolCalls) > 0 { return s.handleToolCalls(*assistantMessage) } @@ -523,7 +523,7 @@ func (s *ChatSession) SendMessage(message string) (string, error) { var responseText string if assistantMessage.Content.ContentStr != nil { responseText = *assistantMessage.Content.ContentStr - } else if assistantMessage.Content.ContentBlocks != nil && len(assistantMessage.Content.ContentBlocks) > 0 { + } else if len(assistantMessage.Content.ContentBlocks) > 0 { var textParts []string for _, block := range assistantMessage.Content.ContentBlocks { if block.Text != nil { @@ -633,7 +633,7 @@ func (s *ChatSession) synthesizeToolResults() (string, error) { synthesisRequest := &schemas.BifrostChatRequest{ Provider: s.config.Provider, Model: s.config.Model, - Input: conversationWithSynthesis, + Input: conversationWithSynthesis, Params: &schemas.ChatParameters{ Temperature: s.config.Temperature, MaxCompletionTokens: s.config.MaxTokens, diff --git a/core/go.mod b/core/go.mod index c9421accb4..de9bad610b 100644 --- a/core/go.mod +++ b/core/go.mod @@ -8,6 +8,8 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.31.13 github.com/aws/smithy-go v1.23.1 github.com/bytedance/sonic v1.14.1 + github.com/clarkmcc/go-typescript v0.7.0 + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 github.com/google/uuid v1.6.0 github.com/hajimehoshi/go-mp3 v0.3.4 github.com/mark3labs/mcp-go v0.41.1 @@ -37,6 +39,9 @@ require ( github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/klauspost/compress v1.18.1 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect diff --git a/core/go.sum b/core/go.sum index 51836b2d67..a3a4077fa0 100644 --- a/core/go.sum +++ b/core/go.sum @@ -1,5 +1,7 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.39.5 h1:e/SXuia3rkFtapghJROrydtQpfQaaUgd1cUvyO1mp2w= @@ -40,6 +42,8 @@ github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7 github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -47,11 +51,19 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 h1:e+8XbKB6IMn8A4OAyZccO4pYfB3s7bt6azNIPE7AnPg= +github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -129,6 +141,8 @@ golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/core/mcp.go b/core/mcp.go deleted file mode 100644 index d2b611a3fe..0000000000 --- a/core/mcp.go +++ /dev/null @@ -1,1171 +0,0 @@ -package bifrost - -import ( - "context" - "encoding/json" - "fmt" - "maps" - "os" - "slices" - "strings" - "sync" - "time" - - "github.com/maximhq/bifrost/core/schemas" - - "github.com/mark3labs/mcp-go/client" - "github.com/mark3labs/mcp-go/client/transport" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" -) - -// ============================================================================ -// CONSTANTS -// ============================================================================ - -const ( - // MCP defaults and identifiers - BifrostMCPVersion = "1.0.0" // Version identifier for Bifrost - BifrostMCPClientName = "BifrostClient" // Name for internal Bifrost MCP client - BifrostMCPClientKey = "bifrost-internal" // Key for internal Bifrost client in clientMap - MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix - MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment - - // Context keys for client filtering in requests - // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). - // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. - // Request context filtering takes priority over client config - context can override client exclusions. - MCPContextKeyIncludeClients schemas.BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering - MCPContextKeyIncludeTools schemas.BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName/toolName" format) -) - -// ============================================================================ -// TYPE DEFINITIONS -// ============================================================================ - -// MCPManager manages MCP integration for Bifrost core. -// It provides a bridge between Bifrost and various MCP servers, supporting -// both local tool hosting and external MCP server connections. -type MCPManager struct { - ctx context.Context - server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based) - clientMap map[string]*MCPClient // Map of MCP client names to their configurations - mu sync.RWMutex // Read-write mutex for thread-safe operations - serverRunning bool // Track whether local MCP server is running - logger schemas.Logger // Logger instance for structured logging -} - -// MCPClient represents a connected MCP client with its configuration and tools. -type MCPClient struct { - // Name string // Unique name for this client - Conn *client.Client // Active MCP client connection - ExecutionConfig schemas.MCPClientConfig // Tool filtering settings - ToolMap map[string]schemas.ChatTool // Available tools mapped by name - ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management - cancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) -} - -// MCPClientConnectionInfo stores metadata about how a client is connected. -type MCPClientConnectionInfo struct { - Type schemas.MCPConnectionType `json:"type"` // Connection type (HTTP, STDIO, SSE, or InProcess) - ConnectionURL *string `json:"connection_url,omitempty"` // HTTP/SSE endpoint URL (for HTTP/SSE connections) - StdioCommandString *string `json:"stdio_command_string,omitempty"` // Command string for display (for STDIO connections) -} - -// MCPToolHandler is a generic function type for handling tool calls with typed arguments. -// T represents the expected argument structure for the tool. -type MCPToolHandler[T any] func(args T) (string, error) - -// ============================================================================ -// CONSTRUCTOR AND INITIALIZATION -// ============================================================================ - -// newMCPManager creates and initializes a new MCP manager instance. -// -// Parameters: -// - config: MCP configuration including server port and client configs -// - logger: Logger instance for structured logging (uses default if nil) -// -// Returns: -// - *MCPManager: Initialized manager instance -// - error: Any initialization error -func newMCPManager(ctx context.Context, config schemas.MCPConfig, logger schemas.Logger) (*MCPManager, error) { - // Creating new instance - manager := &MCPManager{ - ctx: ctx, - clientMap: make(map[string]*MCPClient), - logger: logger, - } - // Process client configs: create client map entries and establish connections - for _, clientConfig := range config.ClientConfigs { - if err := manager.AddClient(clientConfig); err != nil { - manager.logger.Warn(fmt.Sprintf("%s Failed to add MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err)) - } - } - manager.logger.Info(MCPLogPrefix + " MCP Manager initialized") - return manager, nil -} - -// GetClients returns all MCP clients managed by the manager. -// -// Returns: -// - []*MCPClient: List of all MCP clients -// - error: Any retrieval error -func (m *MCPManager) GetClients() ([]MCPClient, error) { - m.mu.RLock() - defer m.mu.RUnlock() - - clients := make([]MCPClient, 0, len(m.clientMap)) - for _, client := range m.clientMap { - clients = append(clients, *client) - } - - return clients, nil -} - -// ReconnectClient attempts to reconnect an MCP client if it is disconnected. -func (m *MCPManager) ReconnectClient(id string) error { - m.mu.Lock() - - client, ok := m.clientMap[id] - if !ok { - m.mu.Unlock() - return fmt.Errorf("client %s not found", id) - } - - m.mu.Unlock() - - // connectToMCPClient handles locking internally - err := m.connectToMCPClient(client.ExecutionConfig) - if err != nil { - return fmt.Errorf("failed to connect to MCP client %s: %w", id, err) - } - - return nil -} - -// AddClient adds a new MCP client to the manager. -// It validates the client configuration and establishes a connection. -// -// Parameters: -// - config: MCP client configuration -// -// Returns: -func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { - if err := validateMCPClientConfig(&config); err != nil { - return fmt.Errorf("invalid MCP client configuration: %w", err) - } - - // Make a copy of the config to use after unlocking - configCopy := config - - m.mu.Lock() - - if _, ok := m.clientMap[config.ID]; ok { - m.mu.Unlock() - return fmt.Errorf("client %s already exists", config.Name) - } - - // Create placeholder entry - m.clientMap[config.ID] = &MCPClient{ - ExecutionConfig: config, - ToolMap: make(map[string]schemas.ChatTool), - } - - // Temporarily unlock for the connection attempt - // This is to avoid deadlocks when the connection attempt is made - m.mu.Unlock() - - // Connect using the copied config - if err := m.connectToMCPClient(configCopy); err != nil { - // Re-lock to clean up the failed entry - m.mu.Lock() - delete(m.clientMap, config.ID) - m.mu.Unlock() - return fmt.Errorf("failed to connect to MCP client %s: %w", config.Name, err) - } - - return nil -} - -// RemoveClient removes an MCP client from the manager. -// It handles cleanup for all transport types (HTTP, STDIO, SSE). -// -// Parameters: -// - id: ID of the client to remove -func (m *MCPManager) RemoveClient(id string) error { - m.mu.Lock() - defer m.mu.Unlock() - - return m.removeClientUnsafe(id) -} - -func (m *MCPManager) removeClientUnsafe(id string) error { - client, ok := m.clientMap[id] - if !ok { - return fmt.Errorf("client %s not found", id) - } - - m.logger.Info(fmt.Sprintf("%s Disconnecting MCP client: %s", MCPLogPrefix, client.ExecutionConfig.Name)) - - // Cancel SSE context if present (required for proper SSE cleanup) - if client.cancelFunc != nil { - client.cancelFunc() - client.cancelFunc = nil - } - - // Close the client transport connection - // This handles cleanup for all transport types (HTTP, STDIO, SSE) - if client.Conn != nil { - if err := client.Conn.Close(); err != nil { - m.logger.Error("%s Failed to close MCP client %s: %v", MCPLogPrefix, client.ExecutionConfig.Name, err) - } - client.Conn = nil - } - - // Clear client tool map - client.ToolMap = make(map[string]schemas.ChatTool) - - delete(m.clientMap, id) - return nil -} - -func (m *MCPManager) EditClient(id string, updatedConfig schemas.MCPClientConfig) error { - m.mu.Lock() - defer m.mu.Unlock() - - client, ok := m.clientMap[id] - if !ok { - return fmt.Errorf("client %s not found", id) - } - - // Update the client's execution config with new tool filters - config := client.ExecutionConfig - config.Name = updatedConfig.Name - config.Headers = updatedConfig.Headers - config.ToolsToExecute = updatedConfig.ToolsToExecute - - // Store the updated config - client.ExecutionConfig = config - - if client.Conn == nil { - return nil // Client is not connected, so no tools to update - } - - // Clear current tool map - client.ToolMap = make(map[string]schemas.ChatTool) - - // Temporarily unlock for the network call - m.mu.Unlock() - - // Retrieve tools with updated configuration - tools, err := m.retrieveExternalTools(m.ctx, client.Conn, config) - - // Re-lock to update the tool map - m.mu.Lock() - - // Verify client still exists - if _, ok := m.clientMap[id]; !ok { - return fmt.Errorf("client %s was removed during tool update", id) - } - - if err != nil { - return fmt.Errorf("failed to retrieve external tools: %w", err) - } - - // Store discovered tools - maps.Copy(client.ToolMap, tools) - - return nil -} - -// ============================================================================ -// TOOL REGISTRATION AND DISCOVERY -// ============================================================================ - -// getAvailableTools returns all tools from connected MCP clients. -// Applies client filtering if specified in the context. -func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.ChatTool { - m.mu.RLock() - defer m.mu.RUnlock() - - var includeClients []string - - // Extract client filtering from request context - if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { - includeClients = existingIncludeClients - } - - tools := make([]schemas.ChatTool, 0) - for id, client := range m.clientMap { - // Apply client filtering logic - if !m.shouldIncludeClient(id, includeClients) { - m.logger.Debug(fmt.Sprintf("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, id)) - continue - } - - m.logger.Debug(fmt.Sprintf("Checking tools for MCP client %s with tools to execute: %v", id, client.ExecutionConfig.ToolsToExecute)) - - // Add all tools from this client - for toolName, tool := range client.ToolMap { - // Check if tool should be skipped based on client configuration - if m.shouldSkipToolForConfig(toolName, client.ExecutionConfig) { - m.logger.Debug(fmt.Sprintf("%s Skipping MCP tool %s: not in tools to execute list", MCPLogPrefix, toolName)) - continue - } - - // Check if tool should be skipped based on request context - if m.shouldSkipToolForRequest(id, toolName, ctx) { - m.logger.Debug(fmt.Sprintf("%s Skipping MCP tool %s: not in include tools list", MCPLogPrefix, toolName)) - continue - } - - tools = append(tools, tool) - } - } - return tools -} - -// registerTool registers a typed tool handler with the local MCP server. -// This is a convenience function that handles the conversion between typed Go -// handlers and the MCP protocol. -// -// Type Parameters: -// - T: The expected argument type for the tool (must be JSON-deserializable) -// -// Parameters: -// - name: Unique tool name -// - description: Human-readable tool description -// - handler: Typed function that handles tool execution -// - toolSchema: Bifrost tool schema for function calling -// -// Returns: -// - error: Any registration error -// -// Example: -// -// type EchoArgs struct { -// Message string `json:"message"` -// } -// -// err := bifrost.RegisterMCPTool("echo", "Echo a message", -// func(args EchoArgs) (string, error) { -// return args.Message, nil -// }, toolSchema) -func (m *MCPManager) registerTool(name, description string, handler MCPToolHandler[any], toolSchema schemas.ChatTool) error { - // Ensure local server is set up - if err := m.setupLocalHost(); err != nil { - return fmt.Errorf("failed to setup local host: %w", err) - } - - // Verify internal client exists - if _, ok := m.clientMap[BifrostMCPClientKey]; !ok { - return fmt.Errorf("bifrost client not found") - } - - m.mu.Lock() - defer m.mu.Unlock() - - // Check if tool name already exists to prevent silent overwrites - if _, exists := m.clientMap[BifrostMCPClientKey].ToolMap[name]; exists { - return fmt.Errorf("tool '%s' is already registered", name) - } - - m.logger.Info(fmt.Sprintf("%s Registering typed tool: %s", MCPLogPrefix, name)) - - // Create MCP handler wrapper that converts between typed and MCP interfaces - mcpHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // Extract arguments from the request using the request's methods - args := request.GetArguments() - result, err := handler(args) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Error: %s", err.Error())), nil - } - return mcp.NewToolResultText(result), nil - } - - // Register the tool with the local MCP server using AddTool - if m.server != nil { - tool := mcp.NewTool(name, mcp.WithDescription(description)) - m.server.AddTool(tool, mcpHandler) - } - - // Store tool definition for Bifrost integration - m.clientMap[BifrostMCPClientKey].ToolMap[name] = toolSchema - - return nil -} - -// setupLocalHost initializes the local MCP server and client if not already running. -// This creates a STDIO-based server for local tool hosting and a corresponding client. -// This is called automatically when tools are registered or when the server is needed. -// -// Returns: -// - error: Any setup error -func (m *MCPManager) setupLocalHost() error { - // Check if server is already running - if m.server != nil && m.serverRunning { - return nil - } - - // Create and configure local MCP server (STDIO-based) - server, err := m.createLocalMCPServer() - if err != nil { - return fmt.Errorf("failed to create local MCP server: %w", err) - } - m.server = server - - // Create and configure local MCP client (STDIO-based) - client, err := m.createLocalMCPClient() - if err != nil { - return fmt.Errorf("failed to create local MCP client: %w", err) - } - m.clientMap[BifrostMCPClientKey] = client - - // Start the server and initialize client connection - return m.startLocalMCPServer() -} - -// createLocalMCPServer creates a new local MCP server instance with STDIO transport. -// This server will host tools registered via RegisterTool function. -// -// Returns: -// - *server.MCPServer: Configured MCP server instance -// - error: Any creation error -func (m *MCPManager) createLocalMCPServer() (*server.MCPServer, error) { - // Create MCP server - mcpServer := server.NewMCPServer( - "Bifrost-MCP-Server", - "1.0.0", - server.WithToolCapabilities(true), - ) - - return mcpServer, nil -} - -// createLocalMCPClient creates a placeholder client entry for the local MCP server. -// The actual in-process client connection will be established in startLocalMCPServer. -// -// Returns: -// - *MCPClient: Placeholder client for local server -// - error: Any creation error -func (m *MCPManager) createLocalMCPClient() (*MCPClient, error) { - // Don't create the actual client connection here - it will be created - // after the server is ready using NewInProcessClient - return &MCPClient{ - ExecutionConfig: schemas.MCPClientConfig{ - Name: BifrostMCPClientName, - }, - ToolMap: make(map[string]schemas.ChatTool), - ConnectionInfo: MCPClientConnectionInfo{ - Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport - }, - }, nil -} - -// startLocalMCPServer creates an in-process connection between the local server and client. -// -// Returns: -// - error: Any startup error -func (m *MCPManager) startLocalMCPServer() error { - m.mu.Lock() - defer m.mu.Unlock() - - // Check if server is already running - if m.server != nil && m.serverRunning { - return nil - } - - if m.server == nil { - return fmt.Errorf("server not initialized") - } - - // Create in-process client directly connected to the server - inProcessClient, err := client.NewInProcessClient(m.server) - if err != nil { - return fmt.Errorf("failed to create in-process MCP client: %w", err) - } - - // Update the client connection - clientEntry, ok := m.clientMap[BifrostMCPClientKey] - if !ok { - return fmt.Errorf("bifrost client not found") - } - clientEntry.Conn = inProcessClient - - // Initialize the in-process client - ctx, cancel := context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) - defer cancel() - - // Create proper initialize request with correct structure - initRequest := mcp.InitializeRequest{ - Params: mcp.InitializeParams{ - ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, - Capabilities: mcp.ClientCapabilities{}, - ClientInfo: mcp.Implementation{ - Name: BifrostMCPClientName, - Version: BifrostMCPVersion, - }, - }, - } - - _, err = inProcessClient.Initialize(ctx, initRequest) - if err != nil { - return fmt.Errorf("failed to initialize MCP client: %w", err) - } - - // Mark server as running - m.serverRunning = true - - return nil -} - -// executeTool executes a tool call and returns the result as a tool message. -// -// Parameters: -// - ctx: Execution context -// - toolCall: The tool call to execute (from assistant message) -// -// Returns: -// - schemas.ChatMessage: Tool message with execution result -// - error: Any execution error -func (m *MCPManager) executeTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { - if toolCall.Function.Name == nil { - return nil, fmt.Errorf("tool call missing function name") - } - toolName := *toolCall.Function.Name - - // Parse tool arguments - var arguments map[string]interface{} - if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { - return nil, fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err) - } - - // Find which client has this tool - client := m.findMCPClientForTool(toolName) - if client == nil { - return nil, fmt.Errorf("tool '%s' not found in any connected MCP client", toolName) - } - - if client.Conn == nil { - return nil, fmt.Errorf("client '%s' has no active connection", client.ExecutionConfig.Name) - } - - // Call the tool via MCP client -> MCP server - callRequest := mcp.CallToolRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodToolsCall), - }, - Params: mcp.CallToolParams{ - Name: toolName, - Arguments: arguments, - }, - } - - m.logger.Debug(fmt.Sprintf("%s Starting tool execution: %s via client: %s", MCPLogPrefix, toolName, client.ExecutionConfig.Name)) - - toolResponse, callErr := client.Conn.CallTool(ctx, callRequest) - if callErr != nil { - m.logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr) - return nil, fmt.Errorf("MCP tool call failed: %v", callErr) - } - - m.logger.Debug(fmt.Sprintf("%s Tool execution completed: %s", MCPLogPrefix, toolName)) - - // Extract text from MCP response - responseText := m.extractTextFromMCPResponse(toolResponse, toolName) - - // Create tool response message - return m.createToolResponseMessage(toolCall, responseText), nil -} - -// ============================================================================ -// EXTERNAL MCP CONNECTION MANAGEMENT -// ============================================================================ - -// connectToMCPClient establishes a connection to an external MCP server and -// registers its available tools with the manager. -func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { - // First lock: Initialize or validate client entry - m.mu.Lock() - - // Initialize or validate client entry - if existingClient, exists := m.clientMap[config.ID]; exists { - // Client entry exists from config, check for existing connection, if it does then close - if existingClient.cancelFunc != nil { - existingClient.cancelFunc() - existingClient.cancelFunc = nil - } - if existingClient.Conn != nil { - existingClient.Conn.Close() - } - // Update connection type for this connection attempt - existingClient.ConnectionInfo.Type = config.ConnectionType - } - // Create new client entry with configuration - m.clientMap[config.ID] = &MCPClient{ - ExecutionConfig: config, - ToolMap: make(map[string]schemas.ChatTool), - ConnectionInfo: MCPClientConnectionInfo{ - Type: config.ConnectionType, - }, - } - m.mu.Unlock() - - // Heavy operations performed outside lock - var externalClient *client.Client - var connectionInfo MCPClientConnectionInfo - var err error - - // Create appropriate transport based on connection type - switch config.ConnectionType { - case schemas.MCPConnectionTypeHTTP: - externalClient, connectionInfo, err = m.createHTTPConnection(config) - case schemas.MCPConnectionTypeSTDIO: - externalClient, connectionInfo, err = m.createSTDIOConnection(config) - case schemas.MCPConnectionTypeSSE: - externalClient, connectionInfo, err = m.createSSEConnection(config) - case schemas.MCPConnectionTypeInProcess: - externalClient, connectionInfo, err = m.createInProcessConnection(config) - default: - return fmt.Errorf("unknown connection type: %s", config.ConnectionType) - } - - if err != nil { - return fmt.Errorf("failed to create connection: %w", err) - } - - // Initialize the external client with timeout - // For SSE connections, we need a long-lived context, for others we can use timeout - var ctx context.Context - var cancel context.CancelFunc - - if config.ConnectionType == schemas.MCPConnectionTypeSSE { - // SSE connections need a long-lived context for the persistent stream - ctx, cancel = context.WithCancel(m.ctx) - // Don't defer cancel here - SSE needs the context to remain active - } else { - // Other connection types can use timeout context - ctx, cancel = context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) - defer cancel() - } - - // Start the transport first (required for STDIO and SSE clients) - if err := externalClient.Start(ctx); err != nil { - if config.ConnectionType == schemas.MCPConnectionTypeSSE { - cancel() // Cancel SSE context only on error - } - return fmt.Errorf("failed to start MCP client transport %s: %v", config.Name, err) - } - - // Create proper initialize request for external client - extInitRequest := mcp.InitializeRequest{ - Params: mcp.InitializeParams{ - ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, - Capabilities: mcp.ClientCapabilities{}, - ClientInfo: mcp.Implementation{ - Name: fmt.Sprintf("Bifrost-%s", config.Name), - Version: "1.0.0", - }, - }, - } - - _, err = externalClient.Initialize(ctx, extInitRequest) - if err != nil { - if config.ConnectionType == schemas.MCPConnectionTypeSSE { - cancel() // Cancel SSE context only on error - } - return fmt.Errorf("failed to initialize MCP client %s: %v", config.Name, err) - } - - // Retrieve tools from the external server (this also requires network I/O) - tools, err := m.retrieveExternalTools(ctx, externalClient, config) - if err != nil { - m.logger.Warn(fmt.Sprintf("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err)) - // Continue with connection even if tool retrieval fails - tools = make(map[string]schemas.ChatTool) - } - - // Second lock: Update client with final connection details and tools - m.mu.Lock() - defer m.mu.Unlock() - - // Verify client still exists (could have been cleaned up during heavy operations) - if client, exists := m.clientMap[config.ID]; exists { - // Store the external client connection and details - client.Conn = externalClient - client.ConnectionInfo = connectionInfo - - // Store cancel function for SSE connections to enable proper cleanup - if config.ConnectionType == schemas.MCPConnectionTypeSSE { - client.cancelFunc = cancel - } - - // Store discovered tools - for toolName, tool := range tools { - client.ToolMap[toolName] = tool - } - - m.logger.Info(fmt.Sprintf("%s Connected to MCP client: %s", MCPLogPrefix, config.Name)) - } else { - return fmt.Errorf("client %s was removed during connection setup", config.Name) - } - - return nil -} - -// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks. -func (m *MCPManager) retrieveExternalTools(ctx context.Context, client *client.Client, config schemas.MCPClientConfig) (map[string]schemas.ChatTool, error) { - // Get available tools from external server - listRequest := mcp.ListToolsRequest{ - PaginatedRequest: mcp.PaginatedRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodToolsList), - }, - }, - } - - toolsResponse, err := client.ListTools(ctx, listRequest) - if err != nil { - return nil, fmt.Errorf("failed to list tools: %v", err) - } - - if toolsResponse == nil { - return make(map[string]schemas.ChatTool), nil // No tools available - } - - m.logger.Debug(fmt.Sprintf("%s Retrieved %d tools from %s", MCPLogPrefix, len(toolsResponse.Tools), config.Name)) - - tools := make(map[string]schemas.ChatTool) - - // toolsResponse is already a ListToolsResult - for _, mcpTool := range toolsResponse.Tools { - // Convert MCP tool schema to Bifrost format - bifrostTool := m.convertMCPToolToBifrostSchema(&mcpTool) - tools[mcpTool.Name] = bifrostTool - } - - return tools, nil -} - -// shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap). -func (m *MCPManager) shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bool { - // If ToolsToExecute is specified (not nil), apply filtering - if config.ToolsToExecute != nil { - // Handle empty array [] - means no tools are allowed - if len(config.ToolsToExecute) == 0 { - return true // No tools allowed - } - - // Handle wildcard "*" - if present, all tools are allowed - if slices.Contains(config.ToolsToExecute, "*") { - return false // All tools allowed - } - - // Check if specific tool is in the allowed list - for _, allowedTool := range config.ToolsToExecute { - if allowedTool == toolName { - return false // Tool is allowed - } - } - return true // Tool not in allowed list - } - - return true // Tool is skipped (nil is treated as [] - no tools) -} - -// shouldSkipToolForRequest checks if a tool should be skipped based on the request context. -func (m *MCPManager) shouldSkipToolForRequest(clientID, toolName string, ctx context.Context) bool { - includeTools := ctx.Value(MCPContextKeyIncludeTools) - - if includeTools != nil { - // Try []string first (preferred type) - if includeToolsList, ok := includeTools.([]string); ok { - // Handle empty array [] - means no tools are included - if len(includeToolsList) == 0 { - return true // No tools allowed - } - - // Handle wildcard "clientName/*" - if present, all tools are included for this client - if slices.Contains(includeToolsList, fmt.Sprintf("%s/*", clientID)) { - return false // All tools allowed - } - - // Check if specific tool is in the list (format: clientName/toolName) - fullToolName := fmt.Sprintf("%s/%s", clientID, toolName) - if slices.Contains(includeToolsList, fullToolName) { - return false // Tool is explicitly allowed - } - - // If includeTools is specified but this tool is not in it, skip it - return true - } - } - - return false // Tool is allowed (default when no filtering specified) -} - -// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format. -func (m *MCPManager) convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.ChatTool { - return schemas.ChatTool{ - Type: schemas.ChatToolTypeFunction, - Function: &schemas.ChatToolFunction{ - Name: mcpTool.Name, - Description: Ptr(mcpTool.Description), - Parameters: &schemas.ToolFunctionParameters{ - Type: mcpTool.InputSchema.Type, - Properties: Ptr(mcpTool.InputSchema.Properties), - Required: mcpTool.InputSchema.Required, - }, - }, - } -} - -// extractTextFromMCPResponse extracts text content from an MCP tool response. -func (m *MCPManager) extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string { - if toolResponse == nil { - return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) - } - - var result strings.Builder - for _, contentBlock := range toolResponse.Content { - // Handle typed content - switch content := contentBlock.(type) { - case mcp.TextContent: - result.WriteString(content.Text) - case mcp.ImageContent: - result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) - case mcp.AudioContent: - result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) - case mcp.EmbeddedResource: - result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type)) - default: - // Fallback: try to extract from map structure - if jsonBytes, err := json.Marshal(contentBlock); err == nil { - var contentMap map[string]interface{} - if json.Unmarshal(jsonBytes, &contentMap) == nil { - if text, ok := contentMap["text"].(string); ok { - result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text)) - continue - } - } - // Final fallback: serialize as JSON - result.WriteString(string(jsonBytes)) - } - } - } - - if result.Len() > 0 { - return strings.TrimSpace(result.String()) - } - return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) -} - -// createToolResponseMessage creates a tool response message with the execution result. -func (m *MCPManager) createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage { - return &schemas.ChatMessage{ - Role: schemas.ChatMessageRoleTool, - Content: &schemas.ChatMessageContent{ - ContentStr: &responseText, - }, - ChatToolMessage: &schemas.ChatToolMessage{ - ToolCallID: toolCall.ID, - }, - } -} - -func (m *MCPManager) addMCPToolsToBifrostRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { - mcpTools := m.getAvailableTools(ctx) - if len(mcpTools) > 0 { - m.logger.Debug(fmt.Sprintf("%s Adding %d MCP tools to request", MCPLogPrefix, len(mcpTools))) - switch req.RequestType { - case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: - // Only allocate new Params if it's nil to preserve caller-supplied settings - if req.ChatRequest.Params == nil { - req.ChatRequest.Params = &schemas.ChatParameters{} - } - - tools := req.ChatRequest.Params.Tools - - // Create a map of existing tool names for O(1) lookup - existingToolsMap := make(map[string]bool) - for _, tool := range tools { - if tool.Function != nil && tool.Function.Name != "" { - existingToolsMap[tool.Function.Name] = true - } - } - - // Add MCP tools that are not already present - for _, mcpTool := range mcpTools { - // Skip tools with nil Function or empty Name - if mcpTool.Function == nil || mcpTool.Function.Name == "" { - continue - } - - if !existingToolsMap[mcpTool.Function.Name] { - tools = append(tools, mcpTool) - // Update the map to prevent duplicates within MCP tools as well - existingToolsMap[mcpTool.Function.Name] = true - } - } - req.ChatRequest.Params.Tools = tools - case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: - // Only allocate new Params if it's nil to preserve caller-supplied settings - if req.ResponsesRequest.Params == nil { - req.ResponsesRequest.Params = &schemas.ResponsesParameters{} - } - - tools := req.ResponsesRequest.Params.Tools - - // Create a map of existing tool names for O(1) lookup - existingToolsMap := make(map[string]bool) - for _, tool := range tools { - if tool.Name != nil { - existingToolsMap[*tool.Name] = true - } - } - - // Add MCP tools that are not already present - for _, mcpTool := range mcpTools { - // Skip tools with nil Function or empty Name - if mcpTool.Function == nil || mcpTool.Function.Name == "" { - continue - } - - if !existingToolsMap[mcpTool.Function.Name] { - responsesTool := mcpTool.ToResponsesTool() - // Skip if the converted tool has nil Name - if responsesTool.Name == nil { - continue - } - - tools = append(tools, *responsesTool) - // Update the map to prevent duplicates within MCP tools as well - existingToolsMap[*responsesTool.Name] = true - } - } - req.ResponsesRequest.Params.Tools = tools - } - } - return req -} - -func validateMCPClientConfig(config *schemas.MCPClientConfig) error { - if strings.TrimSpace(config.ID) == "" { - return fmt.Errorf("id is required for MCP client config") - } - - if strings.TrimSpace(config.Name) == "" { - return fmt.Errorf("name is required for MCP client config") - } - - if config.ConnectionType == "" { - return fmt.Errorf("connection type is required for MCP client config") - } - - switch config.ConnectionType { - case schemas.MCPConnectionTypeHTTP: - if config.ConnectionString == nil { - return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name) - } - case schemas.MCPConnectionTypeSSE: - if config.ConnectionString == nil { - return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name) - } - case schemas.MCPConnectionTypeSTDIO: - if config.StdioConfig == nil { - return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name) - } - case schemas.MCPConnectionTypeInProcess: - // InProcess requires a server instance to be provided programmatically - // This cannot be validated from JSON config - the server must be set when using the Go package - if config.InProcessServer == nil { - return fmt.Errorf("InProcessServer is required for InProcess connection type in client '%s' (Go package only)", config.Name) - } - default: - return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name) - } - - return nil -} - -// ============================================================================ -// HELPER METHODS -// ============================================================================ - -// findMCPClientForTool safely finds a client that has the specified tool. -func (m *MCPManager) findMCPClientForTool(toolName string) *MCPClient { - m.mu.RLock() - defer m.mu.RUnlock() - - for _, client := range m.clientMap { - if _, exists := client.ToolMap[toolName]; exists { - return client - } - } - return nil -} - -// shouldIncludeClient determines if a client should be included based on filtering rules. -func (m *MCPManager) shouldIncludeClient(clientID string, includeClients []string) bool { - // If includeClients is specified (not nil), apply whitelist filtering - if includeClients != nil { - // Handle empty array [] - means no clients are included - if len(includeClients) == 0 { - return false // No clients allowed - } - - // Handle wildcard "*" - if present, all clients are included - if slices.Contains(includeClients, "*") { - return true // All clients allowed - } - - // Check if specific client is in the list - return slices.Contains(includeClients, clientID) - } - - // Default: include all clients when no filtering specified (nil case) - return true -} - -// createHTTPConnection creates an HTTP-based MCP client connection without holding locks. -func (m *MCPManager) createHTTPConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { - if config.ConnectionString == nil { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("HTTP connection string is required") - } - - // Prepare connection info - connectionInfo := MCPClientConnectionInfo{ - Type: config.ConnectionType, - ConnectionURL: config.ConnectionString, - } - - // Create StreamableHTTP transport - httpTransport, err := transport.NewStreamableHTTP(*config.ConnectionString, transport.WithHTTPHeaders(config.Headers)) - if err != nil { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create HTTP transport: %w", err) - } - - client := client.NewClient(httpTransport) - - return client, connectionInfo, nil -} - -// createSTDIOConnection creates a STDIO-based MCP client connection without holding locks. -func (m *MCPManager) createSTDIOConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { - if config.StdioConfig == nil { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("stdio config is required") - } - - // Prepare STDIO command info for display - cmdString := fmt.Sprintf("%s %s", config.StdioConfig.Command, strings.Join(config.StdioConfig.Args, " ")) - - // Check if environment variables are set - for _, env := range config.StdioConfig.Envs { - if os.Getenv(env) == "" { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) - } - } - - // Create STDIO transport - stdioTransport := transport.NewStdio( - config.StdioConfig.Command, - config.StdioConfig.Envs, - config.StdioConfig.Args..., - ) - - // Prepare connection info - connectionInfo := MCPClientConnectionInfo{ - Type: config.ConnectionType, - StdioCommandString: &cmdString, - } - - client := client.NewClient(stdioTransport) - - // Return nil for cmd since mark3labs/mcp-go manages the process internally - return client, connectionInfo, nil -} - -// createSSEConnection creates a SSE-based MCP client connection without holding locks. -func (m *MCPManager) createSSEConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { - if config.ConnectionString == nil { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("SSE connection string is required") - } - - // Prepare connection info - connectionInfo := MCPClientConnectionInfo{ - Type: config.ConnectionType, - ConnectionURL: config.ConnectionString, // Reuse HTTPConnectionURL field for SSE URL display - } - - // Create SSE transport - sseTransport, err := transport.NewSSE(*config.ConnectionString, transport.WithHeaders(config.Headers)) - if err != nil { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create SSE transport: %w", err) - } - - client := client.NewClient(sseTransport) - - return client, connectionInfo, nil -} - -// createInProcessConnection creates an in-process MCP client connection without holding locks. -// This allows direct connection to an MCP server running in the same process, providing -// the lowest latency and highest performance for tool execution. -func (m *MCPManager) createInProcessConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { - if config.InProcessServer == nil { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") - } - - // Type assert to ensure we have a proper MCP server - mcpServer, ok := config.InProcessServer.(*server.MCPServer) - if !ok { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("InProcessServer must be a *server.MCPServer instance") - } - - // Create in-process client directly connected to the provided server - inProcessClient, err := client.NewInProcessClient(mcpServer) - if err != nil { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create in-process client: %w", err) - } - - // Prepare connection info - connectionInfo := MCPClientConnectionInfo{ - Type: config.ConnectionType, - } - - return inProcessClient, connectionInfo, nil -} - -// cleanup performs cleanup of all MCP resources including clients and local server. -// This function safely disconnects all MCP clients (HTTP, STDIO, and SSE) and -// cleans up the local MCP server. It handles proper cancellation of SSE contexts -// and closes all transport connections. -// -// Returns: -// - error: Always returns nil, but maintains error interface for consistency -func (m *MCPManager) cleanup() error { - m.mu.Lock() - defer m.mu.Unlock() - - // Disconnect all external MCP clients - for id := range m.clientMap { - if err := m.removeClientUnsafe(id); err != nil { - m.logger.Error("%s Failed to remove MCP client %s: %v", MCPLogPrefix, id, err) - } - } - - // Clear the client map - m.clientMap = make(map[string]*MCPClient) - - // Clear local server reference - // Note: mark3labs/mcp-go STDIO server cleanup is handled automatically - if m.server != nil { - m.logger.Info(MCPLogPrefix + " Clearing local MCP server reference") - m.server = nil - m.serverRunning = false - } - - m.logger.Info(MCPLogPrefix + " MCP cleanup completed") - return nil -} diff --git a/core/mcp/agent.go b/core/mcp/agent.go new file mode 100644 index 0000000000..bd2b6da6d3 --- /dev/null +++ b/core/mcp/agent.go @@ -0,0 +1,473 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +// ExecuteAgentForChatRequest handles the agent mode execution loop for Chat API. +// It orchestrates iterative tool execution up to the maximum depth, handling +// auto-executable and non-auto-executable tools appropriately. +// +// Parameters: +// - ctx: Context for agent execution +// - maxAgentDepth: Maximum number of agent iterations allowed +// - originalReq: The original chat request +// - initialResponse: The initial chat response containing tool calls +// - makeReq: Function to make subsequent chat requests during agent execution +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration +// - executeToolFunc: Function to execute individual tool calls +// - clientManager: Client manager for accessing MCP clients and tools +// +// Returns: +// - *schemas.BifrostChatResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func ExecuteAgentForChatRequest( + ctx *context.Context, + maxAgentDepth int, + originalReq *schemas.BifrostChatRequest, + initialResponse *schemas.BifrostChatResponse, + makeReq func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), + fetchNewRequestIDFunc func(ctx context.Context) string, + executeToolFunc func(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error), + clientManager ClientManager, +) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + // Create adapter for Chat API + adapter := &chatAPIAdapter{ + originalReq: originalReq, + initialResponse: initialResponse, + makeReq: makeReq, + } + + result, err := executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager) + if err != nil { + return nil, err + } + + chatResponse, ok := result.(*schemas.BifrostChatResponse) + // Should never happen, but just in case + if !ok { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "Failed to convert result to schemas.BifrostChatResponse", + }, + } + } + + return chatResponse, nil +} + +// ExecuteAgentForResponsesRequest handles the agent mode execution loop for Responses API. +// It orchestrates iterative tool execution up to the maximum depth, handling +// auto-executable and non-auto-executable tools appropriately. +// +// Parameters: +// - ctx: Context for agent execution +// - maxAgentDepth: Maximum number of agent iterations allowed +// - originalReq: The original responses request +// - initialResponse: The initial responses response containing tool calls +// - makeReq: Function to make subsequent responses requests during agent execution +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration +// - executeToolFunc: Function to execute individual tool calls +// - clientManager: Client manager for accessing MCP clients and tools +// +// Returns: +// - *schemas.BifrostResponsesResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func ExecuteAgentForResponsesRequest( + ctx *context.Context, + maxAgentDepth int, + originalReq *schemas.BifrostResponsesRequest, + initialResponse *schemas.BifrostResponsesResponse, + makeReq func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), + fetchNewRequestIDFunc func(ctx context.Context) string, + executeToolFunc func(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error), + clientManager ClientManager, +) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + // Create adapter for Responses API + adapter := &responsesAPIAdapter{ + originalReq: originalReq, + initialResponse: initialResponse, + makeReq: makeReq, + } + + result, err := executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager) + if err != nil { + return nil, err + } + + responsesResponse, ok := result.(*schemas.BifrostResponsesResponse) + // Should never happen, but just in case + if !ok { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "Failed to convert result to schemas.BifrostResponsesResponse", + }, + } + } + + return responsesResponse, nil +} + +// executeAgent handles the generic agent mode execution loop using an API adapter pattern. +// It iteratively executes tools, separates auto-executable from non-auto-executable tools, +// executes auto-executable tools in parallel, and continues the loop until no more tool +// calls are present or the maximum depth is reached. +// +// Parameters: +// - ctx: Context for agent execution (may be modified to add request IDs) +// - maxAgentDepth: Maximum number of agent iterations allowed +// - adapter: API adapter that abstracts differences between Chat and Responses APIs +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration +// - executeToolFunc: Function to execute individual tool calls +// - clientManager: Client manager for accessing MCP clients and tools +// +// Returns: +// - interface{}: The final response after agent execution (type depends on adapter) +// - *schemas.BifrostError: Any error that occurred during agent execution +func executeAgent( + ctx *context.Context, + maxAgentDepth int, + adapter agentAPIAdapter, + fetchNewRequestIDFunc func(ctx context.Context) string, + executeToolFunc func(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error), + clientManager ClientManager, +) (interface{}, *schemas.BifrostError) { + logger.Debug("Entering agent mode - detected tool calls in response") + + // Get initial response from adapter + currentResponse := adapter.getInitialResponse() + + // Create conversation history starting with original messages + conversationHistory := adapter.getConversationHistory() + + depth := 0 + + // Track all executed tool results and tool calls across all iterations + allExecutedToolResults := make([]*schemas.ChatMessage, 0) + allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + + originalRequestID, ok := (*ctx).Value(schemas.BifrostContextKeyRequestID).(string) + if ok { + *ctx = context.WithValue(*ctx, schemas.BifrostMCPAgentOriginalRequestID, originalRequestID) + } + + for depth < maxAgentDepth { + depth++ + toolCalls := adapter.extractToolCalls(currentResponse) + if len(toolCalls) == 0 { + logger.Debug("No more tool calls found, exiting agent mode") + break + } + + logger.Debug(fmt.Sprintf("Agent mode depth %d: executing %d tool calls", depth, len(toolCalls))) + + // Separate tools into auto-executable and non-auto-executable groups + var autoExecutableTools []schemas.ChatAssistantMessageToolCall + var nonAutoExecutableTools []schemas.ChatAssistantMessageToolCall + + for _, toolCall := range toolCalls { + if toolCall.Function.Name == nil { + // Skip tools without names + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + toolName := *toolCall.Function.Name + client := clientManager.GetClientForTool(toolName) + if client == nil { + // Allow code mode list and read tool tools + if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile { + autoExecutableTools = append(autoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s can be auto-executed", toolName)) + continue + } else if toolName == ToolTypeExecuteToolCode { + // Build allowed auto-execution tools map for code mode validation + allClientNames, allowedAutoExecutionTools := buildAllowedAutoExecutionTools(*ctx, clientManager) + + // Parse tool arguments + var arguments map[string]interface{} + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + logger.Debug(fmt.Sprintf("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + code, ok := arguments["code"].(string) + if !ok || code == "" { + logger.Debug(fmt.Sprintf("%s Code parameter missing or empty", CodeModeLogPrefix)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + // Step 1: Convert literal \n escape sequences to actual newlines for parsing + codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") + if len(codeWithNewlines) != len(code) { + logger.Debug(fmt.Sprintf("%s Converted literal \\n escape sequences to actual newlines", CodeModeLogPrefix)) + } + + // Step 2: Extract tool calls from code during AST formation + extractedToolCalls, err := extractToolCallsFromCode(codeWithNewlines) + if err != nil { + logger.Debug(fmt.Sprintf("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + logger.Debug(fmt.Sprintf("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls))) + + // Step 3: Validate all tool calls against allowedAutoExecutionTools + canAutoExecute := true + if len(extractedToolCalls) > 0 { + // If there are tool calls, we need allowedAutoExecutionTools to validate them + if len(allowedAutoExecutionTools) == 0 { + logger.Debug(fmt.Sprintf("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix)) + canAutoExecute = false + } else { + logger.Debug(fmt.Sprintf("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools))) + + // Validate each tool call + for _, extractedToolCall := range extractedToolCalls { + isAllowed := isToolCallAllowedForCodeMode(extractedToolCall.serverName, extractedToolCall.toolName, allClientNames, allowedAutoExecutionTools) + if !isAllowed { + logger.Debug(fmt.Sprintf("%s Tool call %s.%s: allowed=%v", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName, isAllowed)) + logger.Debug(fmt.Sprintf("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName)) + canAutoExecute = false + break + } + } + if canAutoExecute { + logger.Debug(fmt.Sprintf("%s All tool calls validated successfully", CodeModeLogPrefix)) + } + } + } else { + logger.Debug(fmt.Sprintf("%s No tool calls found in code, skipping validation", CodeModeLogPrefix)) + } + + // Add to appropriate list based on validation result + if canAutoExecute { + autoExecutableTools = append(autoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s can be auto-executed (validation passed)", toolName)) + } else { + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s cannot be auto-executed (validation failed)", toolName)) + } + continue + } + // Else, if client not found, treat as non-auto-executable (can be a manually passed tool) + logger.Debug(fmt.Sprintf("Client not found for tool %s, treating as non-auto-executable", toolName)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + // Check if tool can be auto-executed + if canAutoExecuteTool(toolName, client.ExecutionConfig) { + autoExecutableTools = append(autoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s can be auto-executed", toolName)) + } else { + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s cannot be auto-executed", toolName)) + } + } + + logger.Debug(fmt.Sprintf("Auto-executable tools: %d", len(autoExecutableTools))) + logger.Debug(fmt.Sprintf("Non-auto-executable tools: %d", len(nonAutoExecutableTools))) + + // Execute auto-executable tools first + var executedToolResults []*schemas.ChatMessage + if len(autoExecutableTools) > 0 { + // Add assistant message with auto-executable tool calls to conversation + conversationHistory = adapter.addAssistantMessage(conversationHistory, currentResponse) + + // Execute all auto-executable tool calls parallelly + wg := sync.WaitGroup{} + wg.Add(len(autoExecutableTools)) + channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools)) + for _, toolCall := range autoExecutableTools { + go func(toolCall schemas.ChatAssistantMessageToolCall) { + defer wg.Done() + toolResult, toolErr := executeToolFunc(*ctx, toolCall) + if toolErr != nil { + logger.Warn(fmt.Sprintf("Tool execution failed: %v", toolErr)) + channelToolResults <- createToolResultMessage(toolCall, "", toolErr) + } else { + channelToolResults <- toolResult + } + }(toolCall) + } + wg.Wait() + close(channelToolResults) + + // Collect tool results + executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools)) + for toolResult := range channelToolResults { + executedToolResults = append(executedToolResults, toolResult) + } + + // Track executed tool results and calls across all iterations + allExecutedToolResults = append(allExecutedToolResults, executedToolResults...) + allExecutedToolCalls = append(allExecutedToolCalls, autoExecutableTools...) + + // Add tool results to conversation history + conversationHistory = adapter.addToolResults(conversationHistory, executedToolResults) + } + + // If there are non-auto-executable tools, return them immediately without continuing the loop + if len(nonAutoExecutableTools) > 0 { + logger.Debug(fmt.Sprintf("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools))) + // Return as is if its the first iteration + if depth == 1 && len(allExecutedToolResults) == 0 { + return currentResponse, nil + } + // Create response with all executed tool results from all iterations, and non-auto-executable tool calls + return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil + } + + // Create new request with updated conversation history + newReq := adapter.createNewRequest(conversationHistory) + + if fetchNewRequestIDFunc != nil { + newID := fetchNewRequestIDFunc(*ctx) + if newID != "" { + *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyRequestID, newID) + } + } + + // Make new LLM request + response, err := adapter.makeLLMCall(*ctx, newReq) + if err != nil { + logger.Error("Agent mode: LLM request failed: %v", err) + return nil, err + } + + currentResponse = response + } + + logger.Debug(fmt.Sprintf("Agent mode completed after %d iterations", depth)) + return currentResponse, nil +} + +// extractToolCalls extracts all tool calls from a chat response. +// It iterates through all choices in the response and collects tool calls +// from assistant messages. +// +// Parameters: +// - response: The chat response to extract tool calls from +// +// Returns: +// - []schemas.ChatAssistantMessageToolCall: List of extracted tool calls, or nil if none found +func extractToolCalls(response *schemas.BifrostChatResponse) []schemas.ChatAssistantMessageToolCall { + if !hasToolCallsForChatResponse(response) { + return nil + } + + var toolCalls []schemas.ChatAssistantMessageToolCall + for _, choice := range response.Choices { + if choice.ChatNonStreamResponseChoice != nil && + choice.ChatNonStreamResponseChoice.Message != nil && + choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil { + toolCalls = append(toolCalls, choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls...) + } + } + + return toolCalls +} + +// createToolResultMessage creates a tool result message from tool execution. +// It formats the result or error into a chat message with the appropriate tool call ID. +// +// Parameters: +// - toolCall: The original tool call that was executed +// - result: The successful execution result (ignored if err is not nil) +// - err: Any error that occurred during tool execution +// +// Returns: +// - *schemas.ChatMessage: A tool message containing the execution result or error +func createToolResultMessage(toolCall schemas.ChatAssistantMessageToolCall, result string, err error) *schemas.ChatMessage { + var content string + if err != nil { + content = fmt.Sprintf("Error executing tool %s: %s", + func() string { + if toolCall.Function.Name != nil { + return *toolCall.Function.Name + } + return "unknown" + }(), err.Error()) + } else { + content = result + } + + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: &content, + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: toolCall.ID, + }, + } +} + +// buildAllowedAutoExecutionTools builds a map of client names to their auto-executable tools. +// It processes code mode clients and parses their ToolsToAutoExecute configuration to create +// a map of allowed tools. Tool names are parsed to match their appearance in JavaScript code. +// +// Parameters: +// - ctx: Context for accessing client tools +// - clientManager: Client manager for accessing MCP clients +// +// Returns: +// - []string: List of all client names +// - map[string][]string: Map of client names to their auto-executable tool names (as they appear in code) +func buildAllowedAutoExecutionTools(ctx context.Context, clientManager ClientManager) ([]string, map[string][]string) { + allowedTools := make(map[string][]string) + availableToolsPerClient := clientManager.GetToolPerClient(ctx) + allClientNames := []string{} + + for clientName := range availableToolsPerClient { + client := clientManager.GetClientByName(clientName) + if client == nil { + continue + } + allClientNames = append(allClientNames, clientName) + + // Only include code mode clients + if !client.ExecutionConfig.IsCodeModeClient { + continue + } + + // Get auto-executable tools from config + toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute + if len(toolsToAutoExecute) == 0 { + // No auto-executable tools configured for this client + continue + } + + // Parse tool names (as they appear in JavaScript code) + autoExecutableTools := []string{} + for _, originalToolName := range toolsToAutoExecute { + // Handle wildcard "*" - means all tools are auto-executable + if originalToolName == "*" { + autoExecutableTools = append(autoExecutableTools, "*") + continue + } + // Use parsed tool name (as it appears in code) + parsedToolName := parseToolName(originalToolName) + autoExecutableTools = append(autoExecutableTools, parsedToolName) + } + + // Add to map if there are auto-executable tools + if len(autoExecutableTools) > 0 { + allowedTools[clientName] = autoExecutableTools + } + } + + return allClientNames, allowedTools +} diff --git a/core/mcp/agent_adaptors.go b/core/mcp/agent_adaptors.go new file mode 100644 index 0000000000..9aa99b31f8 --- /dev/null +++ b/core/mcp/agent_adaptors.go @@ -0,0 +1,529 @@ +package mcp + +import ( + "context" + "fmt" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +// agentAPIAdapter defines the interface for API-specific operations +type agentAPIAdapter interface { + // Extract conversation history from the original request + getConversationHistory() []interface{} + + // Get original request + getOriginalRequest() interface{} + + // Get initial response + getInitialResponse() interface{} + + // Check if response has tool calls + hasToolCalls(response interface{}) bool + + // Extract tool calls from response + extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall + + // Add assistant message with tool calls to conversation + addAssistantMessage(conversation []interface{}, response interface{}) []interface{} + + // Add tool results to conversation + addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} + + // Create new request with updated conversation + createNewRequest(conversation []interface{}) interface{} + + // Make LLM call + makeLLMCall(ctx context.Context, request interface{}) (interface{}, *schemas.BifrostError) + + // Create response with executed tools and non-auto-executable calls + createResponseWithExecutedTools( + response interface{}, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, + ) interface{} +} + +// chatAPIAdapter implements agentAPIAdapter for Chat API +type chatAPIAdapter struct { + originalReq *schemas.BifrostChatRequest + initialResponse *schemas.BifrostChatResponse + makeReq func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) +} + +// responsesAPIAdapter implements agentAPIAdapter for Responses API +type responsesAPIAdapter struct { + originalReq *schemas.BifrostResponsesRequest + initialResponse *schemas.BifrostResponsesResponse + makeReq func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) +} + +// Chat API adapter implementations +func (c *chatAPIAdapter) getConversationHistory() []interface{} { + history := make([]interface{}, 0) + if c.originalReq.Input != nil { + for _, msg := range c.originalReq.Input { + history = append(history, msg) + } + } + return history +} + +func (c *chatAPIAdapter) getOriginalRequest() interface{} { + return c.originalReq +} + +func (c *chatAPIAdapter) getInitialResponse() interface{} { + return c.initialResponse +} + +func (c *chatAPIAdapter) hasToolCalls(response interface{}) bool { + chatResponse := response.(*schemas.BifrostChatResponse) + return hasToolCallsForChatResponse(chatResponse) +} + +func (c *chatAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall { + chatResponse := response.(*schemas.BifrostChatResponse) + return extractToolCalls(chatResponse) +} + +func (c *chatAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} { + chatResponse := response.(*schemas.BifrostChatResponse) + for _, choice := range chatResponse.Choices { + if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil { + conversation = append(conversation, *choice.ChatNonStreamResponseChoice.Message) + } + } + return conversation +} + +func (c *chatAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} { + for _, toolResult := range toolResults { + conversation = append(conversation, *toolResult) + } + return conversation +} + +func (c *chatAPIAdapter) createNewRequest(conversation []interface{}) interface{} { + // Convert conversation back to ChatMessage slice + chatMessages := make([]schemas.ChatMessage, 0, len(conversation)) + for _, msg := range conversation { + chatMessages = append(chatMessages, msg.(schemas.ChatMessage)) + } + + return &schemas.BifrostChatRequest{ + Provider: c.originalReq.Provider, + Model: c.originalReq.Model, + Fallbacks: c.originalReq.Fallbacks, + Params: c.originalReq.Params, + Input: chatMessages, + } +} + +func (c *chatAPIAdapter) makeLLMCall(ctx context.Context, request interface{}) (interface{}, *schemas.BifrostError) { + chatRequest := request.(*schemas.BifrostChatRequest) + return c.makeReq(ctx, chatRequest) +} + +func (c *chatAPIAdapter) createResponseWithExecutedTools( + response interface{}, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) interface{} { + chatResponse := response.(*schemas.BifrostChatResponse) + return createChatResponseWithExecutedToolsAndNonAutoExecutableCalls( + chatResponse, + executedToolResults, + executedToolCalls, + nonAutoExecutableToolCalls, + ) +} + +// createChatResponseWithExecutedToolsAndNonAutoExecutableCalls creates a chat response +// that includes executed tool results and non-auto-executable tool calls. The response +// contains a formatted text summary of executed tool results and includes the non-auto-executable +// tool calls for the caller to handle. The finish reason is set to "stop" to prevent +// further agent loop iterations. +// +// Parameters: +// - originalResponse: The original chat response to copy metadata from +// - executedToolResults: List of tool execution results from auto-executable tools +// - executedToolCalls: List of tool calls that were executed +// - nonAutoExecutableToolCalls: List of tool calls that require manual execution +// +// Returns: +// - *schemas.BifrostChatResponse: A new chat response with executed results and pending tool calls +func createChatResponseWithExecutedToolsAndNonAutoExecutableCalls( + originalResponse *schemas.BifrostChatResponse, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) *schemas.BifrostChatResponse { + // Start with a copy of the original response metadata + response := &schemas.BifrostChatResponse{ + ID: originalResponse.ID, + Object: originalResponse.Object, + Created: originalResponse.Created, + Model: originalResponse.Model, + Choices: make([]schemas.BifrostResponseChoice, 0), + ServiceTier: originalResponse.ServiceTier, + SystemFingerprint: originalResponse.SystemFingerprint, + Usage: originalResponse.Usage, + ExtraFields: originalResponse.ExtraFields, + SearchResults: originalResponse.SearchResults, + Videos: originalResponse.Videos, + Citations: originalResponse.Citations, + } + + // Build a map from tool call ID to tool name for easy lookup + toolCallIDToName := make(map[string]string) + for _, toolCall := range executedToolCalls { + if toolCall.ID != nil && toolCall.Function.Name != nil { + toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name + } + } + + // Build content text showing executed tool results + var contentText string + if len(executedToolResults) > 0 { + // Format tool results as JSON-like structure + toolResultsMap := make(map[string]interface{}) + for _, toolResult := range executedToolResults { + // Get tool name from tool call ID mapping + var toolName string + if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil { + toolCallID := *toolResult.ChatToolMessage.ToolCallID + if name, ok := toolCallIDToName[toolCallID]; ok { + toolName = name + } else { + toolName = toolCallID // Fallback to tool call ID if name not found + } + } else { + toolName = "unknown_tool" + } + + // Extract output from tool result + var output interface{} + if toolResult.Content != nil { + if toolResult.Content.ContentStr != nil { + output = *toolResult.Content.ContentStr + } else if toolResult.Content.ContentBlocks != nil { + // Convert content blocks to a readable format + blocks := make([]map[string]interface{}, 0) + for _, block := range toolResult.Content.ContentBlocks { + blockMap := make(map[string]interface{}) + blockMap["type"] = string(block.Type) + if block.Text != nil { + blockMap["text"] = *block.Text + } + blocks = append(blocks, blockMap) + } + output = blocks + } + } + toolResultsMap[toolName] = output + } + + // Convert to JSON string for display + jsonBytes, err := sonic.Marshal(toolResultsMap) + if err != nil { + // Fallback to simple string representation + contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap) + } else { + contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes)) + } + } else { + contentText = "Now I shall call these tools next..." + } + + // Create content with the formatted text + content := &schemas.ChatMessageContent{ + ContentStr: &contentText, + } + + // Determine finish reason + // Note: We set finish_reason to "stop" (not "tool_calls") for non-auto-executable tools + // to prevent the agent loop from retrying. The tool calls are still included in the response + // for the caller to handle, but setting finish_reason to "stop" ensures hasToolCalls returns false + // and the agent loop exits properly. + finishReason := "stop" + + // Create a single choice with the formatted content and non-auto-executable tool calls + response.Choices = append(response.Choices, schemas.BifrostResponseChoice{ + Index: 0, + FinishReason: &finishReason, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: content, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: nonAutoExecutableToolCalls, + }, + }, + }, + }) + + return response +} + +// Responses API adapter implementations +func (r *responsesAPIAdapter) getConversationHistory() []interface{} { + history := make([]interface{}, 0) + if r.originalReq.Input != nil { + for _, msg := range r.originalReq.Input { + history = append(history, msg) + } + } + return history +} + +func (r *responsesAPIAdapter) getOriginalRequest() interface{} { + return r.originalReq +} + +func (r *responsesAPIAdapter) getInitialResponse() interface{} { + return r.initialResponse +} + +func (r *responsesAPIAdapter) hasToolCalls(response interface{}) bool { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + return hasToolCallsForResponsesResponse(responsesResponse) +} + +func (r *responsesAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + // Convert to Chat format and extract tool calls using existing logic + chatResponse := responsesResponse.ToBifrostChatResponse() + return extractToolCalls(chatResponse) +} + +func (r *responsesAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + for _, output := range responsesResponse.Output { + conversation = append(conversation, output) + } + return conversation +} + +func (r *responsesAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} { + for _, toolResult := range toolResults { + // Convert using existing converter + responsesMessages := toolResult.ToResponsesMessages() + for _, respMsg := range responsesMessages { + conversation = append(conversation, respMsg) + } + } + return conversation +} + +func (r *responsesAPIAdapter) createNewRequest(conversation []interface{}) interface{} { + // Convert conversation back to ResponsesMessage slice + responsesMessages := make([]schemas.ResponsesMessage, 0, len(conversation)) + for _, msg := range conversation { + responsesMessages = append(responsesMessages, msg.(schemas.ResponsesMessage)) + } + + return &schemas.BifrostResponsesRequest{ + Provider: r.originalReq.Provider, + Model: r.originalReq.Model, + Fallbacks: r.originalReq.Fallbacks, + Params: r.originalReq.Params, + Input: responsesMessages, + } +} + +func (r *responsesAPIAdapter) makeLLMCall(ctx context.Context, request interface{}) (interface{}, *schemas.BifrostError) { + responsesRequest := request.(*schemas.BifrostResponsesRequest) + return r.makeReq(ctx, responsesRequest) +} + +func (r *responsesAPIAdapter) createResponseWithExecutedTools( + response interface{}, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) interface{} { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + + // Create response with executed tools directly on Responses schema + return createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls( + responsesResponse, + executedToolResults, + executedToolCalls, + nonAutoExecutableToolCalls, + ) +} + +// createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls creates a responses response +// that includes executed tool results and non-auto-executable tool calls. The response +// contains a formatted text summary of executed tool results and includes the non-auto-executable +// tool calls for the caller to handle. All Response-specific fields are preserved. +// +// Parameters: +// - originalResponse: The original responses response to copy metadata from +// - executedToolResults: List of tool execution results from auto-executable tools +// - executedToolCalls: List of tool calls that were executed +// - nonAutoExecutableToolCalls: List of tool calls that require manual execution +// +// Returns: +// - *schemas.BifrostResponsesResponse: A new responses response with executed results and pending tool calls +func createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls( + originalResponse *schemas.BifrostResponsesResponse, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) *schemas.BifrostResponsesResponse { + // Start with a copy of the original response, preserving all Response-specific fields + response := &schemas.BifrostResponsesResponse{ + ID: originalResponse.ID, + Background: originalResponse.Background, + Conversation: originalResponse.Conversation, + CreatedAt: originalResponse.CreatedAt, + Error: originalResponse.Error, + Include: originalResponse.Include, + IncompleteDetails: originalResponse.IncompleteDetails, + Instructions: originalResponse.Instructions, + MaxOutputTokens: originalResponse.MaxOutputTokens, + MaxToolCalls: originalResponse.MaxToolCalls, + Metadata: originalResponse.Metadata, + ParallelToolCalls: originalResponse.ParallelToolCalls, + PreviousResponseID: originalResponse.PreviousResponseID, + Prompt: originalResponse.Prompt, + PromptCacheKey: originalResponse.PromptCacheKey, + Reasoning: originalResponse.Reasoning, + SafetyIdentifier: originalResponse.SafetyIdentifier, + ServiceTier: originalResponse.ServiceTier, + StreamOptions: originalResponse.StreamOptions, + Store: originalResponse.Store, + Temperature: originalResponse.Temperature, + Text: originalResponse.Text, + TopLogProbs: originalResponse.TopLogProbs, + TopP: originalResponse.TopP, + ToolChoice: originalResponse.ToolChoice, + Tools: originalResponse.Tools, + Truncation: originalResponse.Truncation, + Usage: originalResponse.Usage, + ExtraFields: originalResponse.ExtraFields, + // Perplexity-specific fields + SearchResults: originalResponse.SearchResults, + Videos: originalResponse.Videos, + Citations: originalResponse.Citations, + Output: make([]schemas.ResponsesMessage, 0), + } + + // Build a map from tool call ID to tool name for easy lookup + toolCallIDToName := make(map[string]string) + for _, toolCall := range executedToolCalls { + if toolCall.ID != nil && toolCall.Function.Name != nil { + toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name + } + } + + // Build content text showing executed tool results + var contentText string + if len(executedToolResults) > 0 { + // Format tool results as JSON-like structure + toolResultsMap := make(map[string]interface{}) + for _, toolResult := range executedToolResults { + // Get tool name from tool call ID mapping + var toolName string + if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil { + toolCallID := *toolResult.ChatToolMessage.ToolCallID + if name, ok := toolCallIDToName[toolCallID]; ok { + toolName = name + } else { + toolName = toolCallID // Fallback to tool call ID if name not found + } + } else { + toolName = "unknown_tool" + } + + // Extract output from tool result + var output interface{} + if toolResult.Content != nil { + if toolResult.Content.ContentStr != nil { + output = *toolResult.Content.ContentStr + } else if toolResult.Content.ContentBlocks != nil { + // Convert content blocks to a readable format + blocks := make([]map[string]interface{}, 0) + for _, block := range toolResult.Content.ContentBlocks { + blockMap := make(map[string]interface{}) + blockMap["type"] = string(block.Type) + if block.Text != nil { + blockMap["text"] = *block.Text + } + blocks = append(blocks, blockMap) + } + output = blocks + } + } + toolResultsMap[toolName] = output + } + + // Convert to JSON string for display + jsonBytes, err := sonic.Marshal(toolResultsMap) + if err != nil { + // Fallback to simple string representation + contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap) + } else { + contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes)) + } + } else { + contentText = "Now I shall call these tools next..." + } + + // Create assistant message with the formatted text content + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + assistantMessage := schemas.ResponsesMessage{ + Type: &messageType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &contentText, + }, + }, + }, + } + response.Output = append(response.Output, assistantMessage) + + // Add non-auto-executable tool calls as separate function_call messages + for _, toolCall := range nonAutoExecutableToolCalls { + functionCallType := schemas.ResponsesMessageTypeFunctionCall + assistantRole := schemas.ResponsesInputMessageRoleAssistant + + var callID *string + if toolCall.ID != nil && *toolCall.ID != "" { + callID = toolCall.ID + } + + var namePtr *string + if toolCall.Function.Name != nil && *toolCall.Function.Name != "" { + namePtr = toolCall.Function.Name + } + + var argumentsPtr *string + if toolCall.Function.Arguments != "" { + argumentsPtr = &toolCall.Function.Arguments + } + + toolCallMessage := schemas.ResponsesMessage{ + Type: &functionCallType, + Role: &assistantRole, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: callID, + Name: namePtr, + Arguments: argumentsPtr, + }, + } + + response.Output = append(response.Output, toolCallMessage) + } + + return response +} diff --git a/core/mcp/agent_test.go b/core/mcp/agent_test.go new file mode 100644 index 0000000000..97ba0c630d --- /dev/null +++ b/core/mcp/agent_test.go @@ -0,0 +1,480 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +// MockLLMCaller implements schemas.BifrostLLMCaller for testing +type MockLLMCaller struct { + chatResponses []*schemas.BifrostChatResponse + responsesResponses []*schemas.BifrostResponsesResponse + chatCallCount int + responsesCallCount int +} + +func (m *MockLLMCaller) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + if m.chatCallCount >= len(m.chatResponses) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "no more mock chat responses available", + }, + } + } + + response := m.chatResponses[m.chatCallCount] + m.chatCallCount++ + return response, nil +} + +func (m *MockLLMCaller) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + if m.responsesCallCount >= len(m.responsesResponses) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "no more mock responses api responses available", + }, + } + } + + response := m.responsesResponses[m.responsesCallCount] + m.responsesCallCount++ + return response, nil +} + +// MockLogger implements schemas.Logger for testing +type MockLogger struct{} + +func (m *MockLogger) Debug(msg string, args ...any) {} +func (m *MockLogger) Info(msg string, args ...any) {} +func (m *MockLogger) Warn(msg string, args ...any) {} +func (m *MockLogger) Error(msg string, args ...any) {} +func (m *MockLogger) Fatal(msg string, args ...any) {} +func (m *MockLogger) SetLevel(level schemas.LogLevel) {} +func (m *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {} + +// MockClientManager implements ClientManager for testing +type MockClientManager struct{} + +func (m *MockClientManager) GetClientForTool(toolName string) *schemas.MCPClientState { + return nil // Return nil to simulate no client found +} + +func (m *MockClientManager) GetClientByName(clientName string) *schemas.MCPClientState { + return nil +} + +func (m *MockClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { + return make(map[string][]schemas.ChatTool) +} + +func TestHasToolCallsForChatResponse(t *testing.T) { + // Test nil response + if hasToolCallsForChatResponse(nil) { + t.Error("Should return false for nil response") + } + + // Test empty choices + emptyResponse := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{}, + } + if hasToolCallsForChatResponse(emptyResponse) { + t.Error("Should return false for response with empty choices") + } + + // Test response with tool_calls finish reason + toolCallsResponse := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("tool_calls"), + }, + }, + } + if !hasToolCallsForChatResponse(toolCallsResponse) { + t.Error("Should return true for response with tool_calls finish reason") + } + + // Test response with actual tool calls + responseWithToolCalls := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("test_tool"), + }, + }, + }, + }, + }, + }, + }, + }, + } + if !hasToolCallsForChatResponse(responseWithToolCalls) { + t.Error("Should return true for response with tool calls in message") + } + + // Test response with stop finish reason (should return false even with tool calls) + responseWithStopReason := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("stop"), + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("test_tool"), + }, + }, + }, + }, + }, + }, + }, + }, + } + if hasToolCallsForChatResponse(responseWithStopReason) { + t.Error("Should return false for response with stop finish reason even with tool calls") + } +} + +func TestExtractToolCalls(t *testing.T) { + // Test response without tool calls + responseNoTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("stop"), + }, + }, + } + + toolCalls := extractToolCalls(responseNoTools) + if len(toolCalls) != 0 { + t.Error("Should return empty slice for response without tool calls") + } + + // Test response with tool calls + expectedToolCalls := []schemas.ChatAssistantMessageToolCall{ + { + ID: schemas.Ptr("call_123"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("test_tool"), + Arguments: `{"param": "value"}`, + }, + }, + } + + responseWithTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: expectedToolCalls, + }, + }, + }, + }, + }, + } + + actualToolCalls := extractToolCalls(responseWithTools) + if len(actualToolCalls) != 1 { + t.Errorf("Expected 1 tool call, got %d", len(actualToolCalls)) + } + + if actualToolCalls[0].Function.Name == nil || *actualToolCalls[0].Function.Name != "test_tool" { + t.Error("Tool call name mismatch") + } +} + +func TestExecuteAgentForChatRequest(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Test with response that has no tool calls - should return immediately + responseNoTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("stop"), + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Hello, how can I help you?"), + }, + }, + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return llmCaller.ChatCompletionRequest(ctx, req) + } + originalReq := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + } + + ctx := context.Background() + + result, err := ExecuteAgentForChatRequest(&ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{}) + if err != nil { + t.Errorf("Expected no error for response without tool calls, got: %v", err) + } + if result != responseNoTools { + t.Error("Expected same response to be returned for response without tool calls") + } +} + +func TestExecuteAgentForChatRequest_WithNonAutoExecutableTools(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Create a response with tool calls that will NOT be auto-executed + responseWithNonAutoTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("tool_calls"), + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("I need to call a tool"), + }, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + ID: schemas.Ptr("call_123"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("non_auto_executable_tool"), + Arguments: `{"param": "value"}`, + }, + }, + }, + }, + }, + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return llmCaller.ChatCompletionRequest(ctx, req) + } + originalReq := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Test message"), + }, + }, + }, + } + + ctx := context.Background() + + // Execute agent mode - should return immediately with non-auto-executable tools + result, err := ExecuteAgentForChatRequest(&ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{}) + + // Should not return error for non-auto-executable tools + if err != nil { + t.Errorf("Expected no error for non-auto-executable tools, got: %v", err) + } + + // Should return a response with the non-auto-executable tool calls + if result == nil { + t.Error("Expected result to be returned for non-auto-executable tools") + } + + // Verify that no LLM calls were made (since tools are non-auto-executable) + if llmCaller.chatCallCount != 0 { + t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.chatCallCount) + } +} + +func TestHasToolCallsForResponsesResponse(t *testing.T) { + // Test nil response + if hasToolCallsForResponsesResponse(nil) { + t.Error("Should return false for nil response") + } + + // Test empty output + emptyResponse := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{}, + } + if hasToolCallsForResponsesResponse(emptyResponse) { + t.Error("Should return false for response with empty output") + } + + // Test response with function call + responseWithFunctionCall := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call_123"), + Name: schemas.Ptr("test_tool"), + }, + }, + }, + } + if !hasToolCallsForResponsesResponse(responseWithFunctionCall) { + t.Error("Should return true for response with function call") + } + + // Test response with function call but no ResponsesToolMessage + responseWithoutToolMessage := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + // No ResponsesToolMessage + }, + }, + } + if hasToolCallsForResponsesResponse(responseWithoutToolMessage) { + t.Error("Should return false for response with function call type but no ResponsesToolMessage") + } + + // Test response with regular message + responseWithRegularMessage := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + } + if hasToolCallsForResponsesResponse(responseWithRegularMessage) { + t.Error("Should return false for response with regular message") + } +} + +func TestExecuteAgentForResponsesRequest(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Test with response that has no tool calls - should return immediately + responseNoTools := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello, how can I help you?"), + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return llmCaller.ResponsesRequest(ctx, req) + } + originalReq := &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + } + + ctx := context.Background() + + result, err := ExecuteAgentForResponsesRequest(&ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{}) + if err != nil { + t.Errorf("Expected no error for response without tool calls, got: %v", err) + } + if result != responseNoTools { + t.Error("Expected same response to be returned for response without tool calls") + } +} + +func TestExecuteAgentForResponsesRequest_WithNonAutoExecutableTools(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Create a response with tool calls that will NOT be auto-executed + responseWithNonAutoTools := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call_123"), + Name: schemas.Ptr("non_auto_executable_tool"), + Arguments: schemas.Ptr(`{"param": "value"}`), + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return llmCaller.ResponsesRequest(ctx, req) + } + originalReq := &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Test message"), + }, + }, + }, + } + + ctx := context.Background() + + // Execute agent mode - should return immediately with non-auto-executable tools + result, err := ExecuteAgentForResponsesRequest(&ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{}) + + // Should not return error for non-auto-executable tools + if err != nil { + t.Errorf("Expected no error for non-auto-executable tools, got: %v", err) + } + + // Should return a response with the non-auto-executable tool calls + if result == nil { + t.Error("Expected result to be returned for non-auto-executable tools") + } + + // Verify that no LLM calls were made (since tools are non-auto-executable) + if llmCaller.responsesCallCount != 0 { + t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.responsesCallCount) + } +} diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go new file mode 100644 index 0000000000..b0456c5f94 --- /dev/null +++ b/core/mcp/clientmanager.go @@ -0,0 +1,679 @@ +package mcp + +import ( + "context" + "fmt" + "maps" + "os" + "strings" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/maximhq/bifrost/core/schemas" +) + +// GetClients returns all MCP clients managed by the manager. +// +// Returns: +// - []*schemas.MCPClientState: List of all MCP clients +func (m *MCPManager) GetClients() []schemas.MCPClientState { + m.mu.RLock() + defer m.mu.RUnlock() + + clients := make([]schemas.MCPClientState, 0, len(m.clientMap)) + for _, client := range m.clientMap { + snapshot := *client + if client.ToolMap != nil { + snapshot.ToolMap = make(map[string]schemas.ChatTool, len(client.ToolMap)) + maps.Copy(snapshot.ToolMap, client.ToolMap) + } + clients = append(clients, snapshot) + } + + return clients +} + +// ReconnectClient attempts to reconnect an MCP client if it is disconnected. +// It validates that the client exists and then establishes a new connection using +// the client's existing configuration. +// +// Parameters: +// - id: ID of the client to reconnect +// +// Returns: +// - error: Any error that occurred during reconnection +func (m *MCPManager) ReconnectClient(id string) error { + m.mu.Lock() + client, ok := m.clientMap[id] + if !ok { + m.mu.Unlock() + return fmt.Errorf("client %s not found", id) + } + config := client.ExecutionConfig + m.mu.Unlock() + + // connectToMCPClient handles locking internally + err := m.connectToMCPClient(config) + if err != nil { + return fmt.Errorf("failed to connect to MCP client %s: %w", id, err) + } + + return nil +} + +// AddClient adds a new MCP client to the manager. +// It validates the client configuration and establishes a connection. +// If connection fails, the client entry is automatically cleaned up. +// +// Parameters: +// - config: MCP client configuration +// +// Returns: +// - error: Any error that occurred during client addition or connection +func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { + if err := validateMCPClientConfig(&config); err != nil { + return fmt.Errorf("invalid MCP client configuration: %w", err) + } + + // Make a copy of the config to use after unlocking + configCopy := config + + m.mu.Lock() + + if _, ok := m.clientMap[config.ID]; ok { + m.mu.Unlock() + return fmt.Errorf("client %s already exists", config.Name) + } + + // Create placeholder entry + m.clientMap[config.ID] = &schemas.MCPClientState{ + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + } + + // Temporarily unlock for the connection attempt + // This is to avoid deadlocks when the connection attempt is made + m.mu.Unlock() + + // Connect using the copied config + if err := m.connectToMCPClient(configCopy); err != nil { + // Re-lock to clean up the failed entry + m.mu.Lock() + delete(m.clientMap, config.ID) + m.mu.Unlock() + return fmt.Errorf("failed to connect to MCP client %s: %w", config.Name, err) + } + + return nil +} + +// RemoveClient removes an MCP client from the manager. +// It handles cleanup for all transport types (HTTP, STDIO, SSE). +// +// Parameters: +// - id: ID of the client to remove +func (m *MCPManager) RemoveClient(id string) error { + m.mu.Lock() + defer m.mu.Unlock() + + return m.removeClientUnsafe(id) +} + +// removeClientUnsafe removes an MCP client from the manager without acquiring locks. +// This is an internal method that should only be called when the caller already holds +// the appropriate lock. It handles cleanup for all transport types including cancellation +// of SSE contexts and closing of transport connections. +// +// Parameters: +// - id: ID of the client to remove +// +// Returns: +// - error: Any error that occurred during client removal +func (m *MCPManager) removeClientUnsafe(id string) error { + client, ok := m.clientMap[id] + if !ok { + return fmt.Errorf("client %s not found", id) + } + + logger.Info(fmt.Sprintf("%s Disconnecting MCP client: %s", MCPLogPrefix, client.ExecutionConfig.Name)) + + // Cancel SSE context if present (required for proper SSE cleanup) + if client.CancelFunc != nil { + client.CancelFunc() + client.CancelFunc = nil + } + + // Close the client transport connection + // This handles cleanup for all transport types (HTTP, STDIO, SSE) + if client.Conn != nil { + if err := client.Conn.Close(); err != nil { + logger.Error("%s Failed to close MCP client %s: %v", MCPLogPrefix, client.ExecutionConfig.Name, err) + } + client.Conn = nil + } + + // Clear client tool map + client.ToolMap = make(map[string]schemas.ChatTool) + + delete(m.clientMap, id) + return nil +} + +// EditClient updates an existing MCP client's configuration and refreshes its tool list. +// It updates the client's execution config with new settings and retrieves updated tools +// from the MCP server if the client is connected. +// This method does not refresh the client's tool list. +// To refresh the client's tool list, use the ReconnectClient method. +// +// Parameters: +// - id: ID of the client to edit +// - updatedConfig: Updated client configuration with new settings +// +// Returns: +// - error: Any error that occurred during client update or tool retrieval +func (m *MCPManager) EditClient(id string, updatedConfig schemas.MCPClientConfig) error { + m.mu.Lock() + defer m.mu.Unlock() + + client, ok := m.clientMap[id] + if !ok { + return fmt.Errorf("client %s not found", id) + } + + if err := validateMCPClientName(updatedConfig.Name); err != nil { + return fmt.Errorf("invalid MCP client configuration: %w", err) + } + + // Update the client's execution config with new tool filters + config := client.ExecutionConfig + config.Name = updatedConfig.Name + config.IsCodeModeClient = updatedConfig.IsCodeModeClient + config.Headers = updatedConfig.Headers + config.ToolsToExecute = updatedConfig.ToolsToExecute + config.ToolsToAutoExecute = updatedConfig.ToolsToAutoExecute + + // Store the updated config + client.ExecutionConfig = config + return nil +} + +// registerTool registers a typed tool handler with the local MCP server. +// This is a convenience function that handles the conversion between typed Go +// handlers and the MCP protocol. +// +// Type Parameters: +// - T: The expected argument type for the tool (must be JSON-deserializable) +// +// Parameters: +// - name: Unique tool name +// - description: Human-readable tool description +// - handler: Typed function that handles tool execution +// - toolSchema: Bifrost tool schema for function calling +// +// Returns: +// - error: Any registration error +// +// Example: +// +// type EchoArgs struct { +// Message string `json:"message"` +// } +// +// err := bifrost.RegisterMCPTool("echo", "Echo a message", +// func(args EchoArgs) (string, error) { +// return args.Message, nil +// }, toolSchema) +func (m *MCPManager) RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error { + // Ensure local server is set up + if err := m.setupLocalHost(); err != nil { + return fmt.Errorf("failed to setup local host: %w", err) + } + + // Validate tool name + if strings.TrimSpace(name) == "" { + return fmt.Errorf("tool name is required") + } + if strings.Contains(name, "-") { + return fmt.Errorf("tool name cannot contain hyphens") + } + if strings.Contains(name, " ") { + return fmt.Errorf("tool name cannot contain spaces") + } + if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { + return fmt.Errorf("tool name cannot start with a number") + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Verify internal client exists + internalClient, ok := m.clientMap[BifrostMCPClientKey] + if !ok { + return fmt.Errorf("bifrost client not found") + } + + // Check if tool name already exists to prevent silent overwrites + if _, exists := internalClient.ToolMap[name]; exists { + return fmt.Errorf("tool '%s' is already registered", name) + } + + logger.Info(fmt.Sprintf("%s Registering typed tool: %s", MCPLogPrefix, name)) + + // Create MCP handler wrapper that converts between typed and MCP interfaces + mcpHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments from the request using the request's methods + args := request.GetArguments() + result, err := toolFunction(args) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Error: %s", err.Error())), nil + } + return mcp.NewToolResultText(result), nil + } + + // Register the tool with the local MCP server using AddTool + if m.server != nil { + tool := mcp.NewTool(name, mcp.WithDescription(description)) + m.server.AddTool(tool, mcpHandler) + } + + // Store tool definition for Bifrost integration + internalClient.ToolMap[name] = toolSchema + + return nil +} + +// ============================================================================ +// CONNECTION HELPER METHODS +// ============================================================================ + +// connectToMCPClient establishes a connection to an external MCP server and +// registers its available tools with the manager. +func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { + // First lock: Initialize or validate client entry + m.mu.Lock() + + // Initialize or validate client entry + if existingClient, exists := m.clientMap[config.ID]; exists { + // Client entry exists from config, check for existing connection, if it does then close + if existingClient.CancelFunc != nil { + existingClient.CancelFunc() + existingClient.CancelFunc = nil + } + if existingClient.Conn != nil { + existingClient.Conn.Close() + } + // Update connection type for this connection attempt + existingClient.ConnectionInfo.Type = config.ConnectionType + } + // Create new client entry with configuration + m.clientMap[config.ID] = &schemas.MCPClientState{ + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + ConnectionInfo: schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + }, + } + m.mu.Unlock() + + // Heavy operations performed outside lock + var externalClient *client.Client + var connectionInfo schemas.MCPClientConnectionInfo + var err error + + // Create appropriate transport based on connection type + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + externalClient, connectionInfo, err = m.createHTTPConnection(config) + case schemas.MCPConnectionTypeSTDIO: + externalClient, connectionInfo, err = m.createSTDIOConnection(config) + case schemas.MCPConnectionTypeSSE: + externalClient, connectionInfo, err = m.createSSEConnection(config) + case schemas.MCPConnectionTypeInProcess: + externalClient, connectionInfo, err = m.createInProcessConnection(config) + default: + return fmt.Errorf("unknown connection type: %s", config.ConnectionType) + } + + if err != nil { + return fmt.Errorf("failed to create connection: %w", err) + } + + // Initialize the external client with timeout + // For SSE connections, we need a long-lived context, for others we can use timeout + var ctx context.Context + var cancel context.CancelFunc + + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + // SSE connections need a long-lived context for the persistent stream + ctx, cancel = context.WithCancel(m.ctx) + // Don't defer cancel here - SSE needs the context to remain active + } else { + // Other connection types can use timeout context + ctx, cancel = context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + } + + // Start the transport first (required for STDIO and SSE clients) + if err := externalClient.Start(ctx); err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to start MCP client transport %s: %v", config.Name, err) + } + + // Create proper initialize request for external client + extInitRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s", config.Name), + Version: "1.0.0", + }, + }, + } + + _, err = externalClient.Initialize(ctx, extInitRequest) + if err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to initialize MCP client %s: %v", config.Name, err) + } + + // Retrieve tools from the external server (this also requires network I/O) + tools, err := retrieveExternalTools(ctx, externalClient, config.Name) + if err != nil { + logger.Warn(fmt.Sprintf("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err)) + // Continue with connection even if tool retrieval fails + tools = make(map[string]schemas.ChatTool) + } + + // Second lock: Update client with final connection details and tools + m.mu.Lock() + defer m.mu.Unlock() + + // Verify client still exists (could have been cleaned up during heavy operations) + if client, exists := m.clientMap[config.ID]; exists { + // Store the external client connection and details + client.Conn = externalClient + client.ConnectionInfo = connectionInfo + + // Store cancel function for SSE connections to enable proper cleanup + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + client.CancelFunc = cancel + } + + // Store discovered tools + for toolName, tool := range tools { + client.ToolMap[toolName] = tool + } + + logger.Info(fmt.Sprintf("%s Connected to MCP client: %s", MCPLogPrefix, config.Name)) + } else { + // Clean up resources before returning error: client was removed during connection setup + // Cancel SSE context if it was created + if config.ConnectionType == schemas.MCPConnectionTypeSSE && cancel != nil { + cancel() + } + // Close external client connection to prevent transport/goroutine leaks + if externalClient != nil { + if err := externalClient.Close(); err != nil { + logger.Warn(fmt.Sprintf("%s Failed to close external client during cleanup: %v", MCPLogPrefix, err)) + } + } + return fmt.Errorf("client %s was removed during connection setup", config.Name) + } + + return nil +} + +// createHTTPConnection creates an HTTP-based MCP client connection without holding locks. +func (m *MCPManager) createHTTPConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("HTTP connection string is required") + } + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, + } + + // Create StreamableHTTP transport + httpTransport, err := transport.NewStreamableHTTP(*config.ConnectionString, transport.WithHTTPHeaders(config.Headers)) + if err != nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + client := client.NewClient(httpTransport) + + return client, connectionInfo, nil +} + +// createSTDIOConnection creates a STDIO-based MCP client connection without holding locks. +func (m *MCPManager) createSTDIOConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.StdioConfig == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("stdio config is required") + } + + // Prepare STDIO command info for display + cmdString := fmt.Sprintf("%s %s", config.StdioConfig.Command, strings.Join(config.StdioConfig.Args, " ")) + + // Check if environment variables are set + for _, env := range config.StdioConfig.Envs { + if os.Getenv(env) == "" { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) + } + } + + // Create STDIO transport + stdioTransport := transport.NewStdio( + config.StdioConfig.Command, + config.StdioConfig.Envs, + config.StdioConfig.Args..., + ) + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + StdioCommandString: &cmdString, + } + + client := client.NewClient(stdioTransport) + + // Return nil for cmd since mark3labs/mcp-go manages the process internally + return client, connectionInfo, nil +} + +// createSSEConnection creates a SSE-based MCP client connection without holding locks. +func (m *MCPManager) createSSEConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("SSE connection string is required") + } + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, // Reuse HTTPConnectionURL field for SSE URL display + } + + // Create SSE transport + sseTransport, err := transport.NewSSE(*config.ConnectionString, transport.WithHeaders(config.Headers)) + if err != nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create SSE transport: %w", err) + } + + client := client.NewClient(sseTransport) + + return client, connectionInfo, nil +} + +// createInProcessConnection creates an in-process MCP client connection without holding locks. +// This allows direct connection to an MCP server running in the same process, providing +// the lowest latency and highest performance for tool execution. +func (m *MCPManager) createInProcessConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.InProcessServer == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") + } + + // Create in-process client directly connected to the provided server + inProcessClient, err := client.NewInProcessClient(config.InProcessServer) + if err != nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create in-process client: %w", err) + } + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + } + + return inProcessClient, connectionInfo, nil +} + +// ============================================================================ +// LOCAL MCP SERVER AND CLIENT MANAGEMENT +// ============================================================================ + +// setupLocalHost initializes the local MCP server and client if not already running. +// This creates a STDIO-based server for local tool hosting and a corresponding client. +// This is called automatically when tools are registered or when the server is needed. +// +// Returns: +// - error: Any setup error +func (m *MCPManager) setupLocalHost() error { + // First check: fast path if already initialized + m.mu.Lock() + if m.server != nil && m.serverRunning { + m.mu.Unlock() + return nil + } + m.mu.Unlock() + + // Create server and client into local variables (outside lock to avoid + // holding lock during object creation, even though it's lightweight) + server, err := m.createLocalMCPServer() + if err != nil { + return fmt.Errorf("failed to create local MCP server: %w", err) + } + + client, err := m.createLocalMCPClient() + if err != nil { + return fmt.Errorf("failed to create local MCP client: %w", err) + } + + // Second check and assignment: hold lock for atomic check-and-set + m.mu.Lock() + // Double-check: another goroutine might have initialized while we were creating + if m.server != nil && m.serverRunning { + m.mu.Unlock() + return nil + } + + // Assign server and client atomically while holding the lock + m.server = server + m.clientMap[BifrostMCPClientKey] = client + m.mu.Unlock() + + // Start the server and initialize client connection + // (startLocalMCPServer already locks internally) + return m.startLocalMCPServer() +} + +// createLocalMCPServer creates a new local MCP server instance with STDIO transport. +// This server will host tools registered via RegisterTool function. +// +// Returns: +// - *server.MCPServer: Configured MCP server instance +// - error: Any creation error +func (m *MCPManager) createLocalMCPServer() (*server.MCPServer, error) { + // Create MCP server + mcpServer := server.NewMCPServer( + "Bifrost-MCP-Server", + "1.0.0", + server.WithToolCapabilities(true), + ) + + return mcpServer, nil +} + +// createLocalMCPClient creates a placeholder client entry for the local MCP server. +// The actual in-process client connection will be established in startLocalMCPServer. +// +// Returns: +// - *schemas.MCPClientState: Placeholder client for local server +// - error: Any creation error +func (m *MCPManager) createLocalMCPClient() (*schemas.MCPClientState, error) { + // Don't create the actual client connection here - it will be created + // after the server is ready using NewInProcessClient + return &schemas.MCPClientState{ + ExecutionConfig: schemas.MCPClientConfig{ + ID: BifrostMCPClientKey, + Name: BifrostMCPClientName, + ToolsToExecute: []string{"*"}, // Allow all tools for internal client + }, + ToolMap: make(map[string]schemas.ChatTool), + ConnectionInfo: schemas.MCPClientConnectionInfo{ + Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport + }, + }, nil +} + +// startLocalMCPServer creates an in-process connection between the local server and client. +// +// Returns: +// - error: Any startup error +func (m *MCPManager) startLocalMCPServer() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Check if server is already running + if m.server != nil && m.serverRunning { + return nil + } + + if m.server == nil { + return fmt.Errorf("server not initialized") + } + + // Create in-process client directly connected to the server + inProcessClient, err := client.NewInProcessClient(m.server) + if err != nil { + return fmt.Errorf("failed to create in-process MCP client: %w", err) + } + + // Update the client connection + clientEntry, ok := m.clientMap[BifrostMCPClientKey] + if !ok { + return fmt.Errorf("bifrost client not found") + } + clientEntry.Conn = inProcessClient + + // Initialize the in-process client + ctx, cancel := context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + + // Create proper initialize request with correct structure + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: BifrostMCPClientName, + Version: BifrostMCPVersion, + }, + }, + } + + _, err = inProcessClient.Initialize(ctx, initRequest) + if err != nil { + return fmt.Errorf("failed to initialize MCP client: %w", err) + } + + // Mark server as running + m.serverRunning = true + + return nil +} diff --git a/core/mcp/codemode_executecode.go b/core/mcp/codemode_executecode.go new file mode 100644 index 0000000000..3d6b3360f8 --- /dev/null +++ b/core/mcp/codemode_executecode.go @@ -0,0 +1,1035 @@ +package mcp + +import ( + "context" + "fmt" + "regexp" + "strings" + "time" + + "github.com/bytedance/sonic" + "github.com/clarkmcc/go-typescript" + "github.com/dop251/goja" + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +// toolBinding represents a tool binding for the VM +type toolBinding struct { + toolName string + clientName string +} + +// toolCallInfo represents a tool call extracted from code +type toolCallInfo struct { + serverName string + toolName string +} + +// ExecutionResult represents the result of code execution +type ExecutionResult struct { + Result interface{} `json:"result"` + Logs []string `json:"logs"` + Errors *ExecutionError `json:"errors,omitempty"` + Environment ExecutionEnvironment `json:"environment"` +} + +type ExecutionErrorType string + +const ( + ExecutionErrorTypeCompile ExecutionErrorType = "compile" + ExecutionErrorTypeTypescript ExecutionErrorType = "typescript" + ExecutionErrorTypeRuntime ExecutionErrorType = "runtime" +) + +// ExecutionError represents an error during code execution +type ExecutionError struct { + Kind ExecutionErrorType `json:"kind"` // "compile", "typescript", or "runtime" + Message string `json:"message"` + Hints []string `json:"hints"` +} + +// ExecutionEnvironment contains information about the execution environment +type ExecutionEnvironment struct { + ServerKeys []string `json:"serverKeys"` + ImportsStripped bool `json:"importsStripped"` + StrippedLines []int `json:"strippedLines"` + TypeScriptUsed bool `json:"typescriptUsed"` +} + +const ( + CodeModeLogPrefix = "[CODE MODE]" +) + +// createExecuteToolCodeTool creates the executeToolCode tool definition for code mode. +// This tool allows executing TypeScript code in a sandboxed VM with access to MCP server tools. +// +// Returns: +// - schemas.ChatTool: The tool definition for executing tool code +func (m *ToolsManager) createExecuteToolCodeTool() schemas.ChatTool { + executeToolCodeProps := map[string]interface{}{ + "code": map[string]interface{}{ + "type": "string", + "description": "TypeScript code to execute. The code will be transpiled to JavaScript and validated before execution. Import/export statements will be stripped. You can use async/await syntax for async operations. For simple use cases, directly return results. Check keys and value types only for debugging. Do not print entire outputs in console logs - only print structure (keys, types) when debugging. ALWAYS retry if code fails. Example (simple): const result = await serverName.toolName({arg: 'value'}); return result; Example (debugging): const result = await serverName.toolName({arg: 'value'}); const getStruct = (o, d=0) => d>2 ? '...' : o===null ? 'null' : Array.isArray(o) ? `Array[${o.length}]` : typeof o !== 'object' ? typeof o : Object.keys(o).reduce((a,k) => (a[k]=getStruct(o[k],d+1), a), {}); console.log('Structure:', getStruct(result)); return result;", + }, + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: ToolTypeExecuteToolCode, + Description: schemas.Ptr( + "Executes TypeScript code inside a sandboxed goja-based VM with access to all connected MCP servers' tools. " + + "TypeScript code is automatically transpiled to JavaScript and validated before execution, providing type checking and validation. " + + "All connected servers are exposed as global objects named after their configuration keys, and each server " + + "provides async (Promise-returning) functions for every tool available on that server. The canonical usage " + + "pattern is: const result = await .({ ...args }); Both and " + + "should be discovered using listToolFiles and readToolFile. " + + + "IMPORTANT WORKFLOW: Always follow this order — first use listToolFiles to see available servers and tools, " + + "then use readToolFile to understand the tool definitions and their parameters, and finally use executeToolCode " + + "to execute your code. Check listToolFiles whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + + + "LOGGING GUIDELINES: For simple use cases, you can directly return results without logging. Check for keys and value types only " + + "for debugging purposes when you need to understand the response structure. Do not print the entire output in console logs. " + + "When debugging, use console logs to print just the output structure to understand its type. For nested objects, use a recursive helper to show types at all levels. " + + "For example: const getStruct = (o, d=0) => d>2 ? '...' : o===null ? 'null' : Array.isArray(o) ? `Array[${o.length}]` : typeof o !== 'object' ? typeof o : Object.keys(o).reduce((a,k) => (a[k]=getStruct(o[k],d+1), a), {}); " + + "console.log('Structure:', getStruct(result)); Only print the entire data if absolutely necessary for debugging. " + + "This helps understand the response structure without cluttering the output with full object contents. " + + + "RETRY POLICY: ALWAYS retry if a code block fails. If execution produces an error or unexpected result, analyze the error, " + + "adjust your code accordingly for better results or debugging, and retry the execution. Do not give up after a single failure — iterate and improve your code until it succeeds. " + + + "The environment is intentionally minimal and has several constraints: " + + "• ES modules are not supported — any leading import/export statements are automatically stripped and imported symbols will not exist. " + + "• Browser and Node APIs such as fetch, XMLHttpRequest, axios, require, setTimeout, setInterval, window, and document do not exist. " + + "• async/await syntax is supported and automatically transpiled to Promise chains compatible with goja. " + + "• Using undefined server names or tool names will result in reference or function errors. " + + "• The VM does not emulate a browser or Node.js environment — no DOM, timers, modules, or network APIs are available. " + + "• Only ES5.1+ features supported by goja are guaranteed to work. " + + "• TypeScript type checking occurs during transpilation — type errors will prevent execution. " + + + "If you want a value returned from the code, write a top-level 'return '; otherwise the return value will be null. " + + "Console output (log, error, warn, info) is captured and returned. " + + "Long-running or blocked operations are interrupted via execution timeout. " + + "This tool is designed specifically for orchestrating MCP tool calls and lightweight TypeScript computation.", + ), + + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &executeToolCodeProps, + Required: []string{"code"}, + }, + }, + } +} + +// handleExecuteToolCode handles the executeToolCode tool call. +// It parses the code argument, executes it in a sandboxed VM, and formats the response +// with execution results, logs, errors, and environment information. +// +// Parameters: +// - ctx: Context for code execution +// - toolCall: The tool call request containing the TypeScript code to execute +// +// Returns: +// - *schemas.ChatMessage: A tool response message containing execution results +// - error: Any error that occurred during processing +func (m *ToolsManager) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + toolName := "unknown" + if toolCall.Function.Name != nil { + toolName = *toolCall.Function.Name + } + logger.Debug(fmt.Sprintf("%s Handling executeToolCode tool call: %s", CodeModeLogPrefix, toolName)) + + // Parse tool arguments + var arguments map[string]interface{} + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + logger.Debug(fmt.Sprintf("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)) + return nil, fmt.Errorf("failed to parse tool arguments: %v", err) + } + + code, ok := arguments["code"].(string) + if !ok || code == "" { + logger.Debug(fmt.Sprintf("%s Code parameter missing or empty", CodeModeLogPrefix)) + return nil, fmt.Errorf("code parameter is required and must be a non-empty string") + } + + logger.Debug(fmt.Sprintf("%s Starting code execution", CodeModeLogPrefix)) + result := m.executeCode(ctx, code) + logger.Debug(fmt.Sprintf("%s Code execution completed. Success: %v, Has errors: %v, Log count: %d", CodeModeLogPrefix, result.Errors == nil, result.Errors != nil, len(result.Logs))) + + // Format response text + var responseText string + var executionSuccess bool = true // Track if execution was successful (has data) + if result.Errors != nil { + logger.Debug(fmt.Sprintf("%s Formatting error response. Error kind: %s, Message length: %d, Hints count: %d", CodeModeLogPrefix, result.Errors.Kind, len(result.Errors.Message), len(result.Errors.Hints))) + logsText := "" + if len(result.Logs) > 0 { + logsText = fmt.Sprintf("\n\nConsole/Log Output:\n%s\n", + strings.Join(result.Logs, "\n")) + } + errorKindLabel := result.Errors.Kind + + responseText = fmt.Sprintf( + "Execution %s error:\n\n%s\n\nHints:\n%s%s\n\nEnvironment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", + errorKindLabel, + result.Errors.Message, + strings.Join(result.Errors.Hints, "\n"), + logsText, + strings.Join(result.Environment.ServerKeys, ", "), + map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], + map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped], + ) + if len(result.Environment.StrippedLines) > 0 { + strippedStr := make([]string, len(result.Environment.StrippedLines)) + for i, line := range result.Environment.StrippedLines { + strippedStr[i] = fmt.Sprintf("%d", line) + } + responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) + } + logger.Debug(fmt.Sprintf("%s Error response formatted. Response length: %d chars", CodeModeLogPrefix, len(responseText))) + } else { + // Success case - check if execution produced any data + hasLogs := len(result.Logs) > 0 + hasResult := result.Result != nil + logger.Debug(fmt.Sprintf("%s Formatting success response. Has logs: %v, Has result: %v", CodeModeLogPrefix, hasLogs, hasResult)) + + // If execution completed but produced no data (no logs, no return value), treat as failure + if !hasLogs && !hasResult { + executionSuccess = false + logger.Debug(fmt.Sprintf("%s Execution completed with no data (no logs, no result), marking as failure", CodeModeLogPrefix)) + hints := []string{ + "Add console.log() statements throughout your code to debug and see what's happening at each step", + "Ensure your code has a top-level return statement if you want to return a value", + "Check that your tool calls are actually executing and returning data", + "Verify that async operations (like await) are properly handled", + } + responseText = fmt.Sprintf( + "Execution completed but produced no data:\n\n"+ + "The code executed without errors but returned no output (no console logs and no return value).\n\n"+ + "Hints:\n%s\n\n"+ + "Environment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", + strings.Join(hints, "\n"), + strings.Join(result.Environment.ServerKeys, ", "), + map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], + map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped], + ) + if len(result.Environment.StrippedLines) > 0 { + strippedStr := make([]string, len(result.Environment.StrippedLines)) + for i, line := range result.Environment.StrippedLines { + strippedStr[i] = fmt.Sprintf("%d", line) + } + responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) + } + logger.Debug(fmt.Sprintf("%s No-data failure response formatted. Response length: %d chars", CodeModeLogPrefix, len(responseText))) + } else { + // Normal success case with data + if hasLogs { + responseText = fmt.Sprintf("Console output:\n%s\n\nExecution completed successfully.", + strings.Join(result.Logs, "\n")) + } else { + responseText = "Execution completed successfully." + } + if hasResult { + resultJSON, err := sonic.MarshalIndent(result.Result, "", " ") + if err == nil { + responseText += fmt.Sprintf("\nReturn value: %s", string(resultJSON)) + logger.Debug(fmt.Sprintf("%s Added return value to response (JSON length: %d chars)", CodeModeLogPrefix, len(resultJSON))) + } else { + logger.Debug(fmt.Sprintf("%s Failed to marshal result to JSON: %v", CodeModeLogPrefix, err)) + } + } + + // Add environment information for successful executions + responseText += fmt.Sprintf("\n\nEnvironment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", + strings.Join(result.Environment.ServerKeys, ", "), + map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], + map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped]) + if len(result.Environment.StrippedLines) > 0 { + strippedStr := make([]string, len(result.Environment.StrippedLines)) + for i, line := range result.Environment.StrippedLines { + strippedStr[i] = fmt.Sprintf("%d", line) + } + responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) + } + responseText += "\nNote: Browser APIs like fetch, setTimeout are not available. Use MCP tools for external interactions." + logger.Debug(fmt.Sprintf("%s Success response formatted. Response length: %d chars, Server keys: %v", CodeModeLogPrefix, len(responseText), result.Environment.ServerKeys)) + } + } + + logger.Debug(fmt.Sprintf("%s Returning tool response message. Execution success: %v", CodeModeLogPrefix, executionSuccess)) + return createToolResponseMessage(toolCall, responseText), nil +} + +// executeCode executes TypeScript code in a sandboxed VM with MCP tool bindings. +// It handles code preprocessing (stripping imports/exports), TypeScript transpilation, +// VM setup with tool bindings, and promise-based async execution with timeout handling. +// +// Parameters: +// - ctx: Context for code execution (used for timeout and tool access) +// - code: TypeScript code string to execute +// +// Returns: +// - ExecutionResult: Result containing execution output, logs, errors, and environment info +func (m *ToolsManager) executeCode(ctx context.Context, code string) ExecutionResult { + logs := []string{} + strippedLines := []int{} + + logger.Debug(fmt.Sprintf("%s Starting TypeScript code execution", CodeModeLogPrefix)) + + // Step 1: Convert literal \n escape sequences to actual newlines first + // This ensures multiline code and import/export stripping work correctly + codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") + + // Step 2: Strip import/export statements + cleanedCode, strippedLineNumbers := stripImportsAndExports(codeWithNewlines) + strippedLines = append(strippedLines, strippedLineNumbers...) + if len(strippedLineNumbers) > 0 { + logger.Debug(fmt.Sprintf("%s Stripped %d import/export lines", CodeModeLogPrefix, len(strippedLineNumbers))) + } + + // Step 3: Handle empty code after stripping (in case stripping made it empty) + trimmedCode := strings.TrimSpace(cleanedCode) + if trimmedCode == "" { + // Empty code should return null - return early without VM execution + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: nil, + Environment: ExecutionEnvironment{ + ServerKeys: []string{}, // Will be populated below if needed, but empty code doesn't need tools + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } + } + + // Step 4: Wrap code in async function for proper await transpilation + // TypeScript needs an async function context to properly transpile await expressions + // Check if code is already an async IIFE - if so, await it + trimmedLower := strings.ToLower(strings.TrimSpace(trimmedCode)) + isAsyncIIFE := strings.HasPrefix(trimmedLower, "(async") && strings.Contains(trimmedCode, ")()") + + var codeToTranspile string + if isAsyncIIFE { + // Code is already an async IIFE - await it to get the result + codeToTranspile = fmt.Sprintf("async function __execute__() {\nreturn await %s\n}", trimmedCode) + } else { + // Regular code - wrap in async function + codeToTranspile = fmt.Sprintf("async function __execute__() {\n%s\n}", trimmedCode) + } + + // Step 5: Transpile TypeScript to JavaScript with validation + // Configure TypeScript compiler to transpile async/await to Promise chains (ES5 compatible) + logger.Debug(fmt.Sprintf("%s Transpiling TypeScript code", CodeModeLogPrefix)) + compileOptions := map[string]interface{}{ + "target": "ES5", // Target ES5 for goja compatibility + "module": "None", // No module system + "lib": []string{}, // No lib (minimal environment) + "downlevelIteration": true, // Support async/await transpilation + } + jsCode, transpileErr := typescript.TranspileString(codeToTranspile, typescript.WithCompileOptions(compileOptions)) + if transpileErr != nil { + logger.Debug(fmt.Sprintf("%s TypeScript transpilation failed: %v", CodeModeLogPrefix, transpileErr)) + // Build bindings to get server keys for error hints + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + serverKeys := make([]string, 0, len(availableToolsPerClient)) + for clientName := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient { + continue + } + serverKeys = append(serverKeys, clientName) + } + + errorMessage := transpileErr.Error() + hints := generateTypeScriptErrorHints(errorMessage, serverKeys) + + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: &ExecutionError{ + Kind: ExecutionErrorTypeTypescript, + Message: fmt.Sprintf("TypeScript compilation error: %s", errorMessage), + Hints: hints, + }, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } + } + + logger.Debug(fmt.Sprintf("%s TypeScript transpiled successfully", CodeModeLogPrefix)) + + // Step 5: Create timeout context early so goroutines can use it + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + // Step 6: Build bindings for all connected servers + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + bindings := make(map[string]map[string]toolBinding) + serverKeys := make([]string, 0, len(availableToolsPerClient)) + + for clientName, tools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { + continue + } + serverKeys = append(serverKeys, clientName) + + toolFunctions := make(map[string]toolBinding) + + // Create a function for each tool + for _, tool := range tools { + if tool.Function == nil || tool.Function.Name == "" { + continue + } + + originalToolName := tool.Function.Name + // Parse tool name for property name compatibility (used as property name in the runtime) + parsedToolName := parseToolName(originalToolName) + + // Store tool binding + toolFunctions[parsedToolName] = toolBinding{ + toolName: originalToolName, + clientName: clientName, + } + } + + bindings[clientName] = toolFunctions + } + + if len(serverKeys) > 0 { + logger.Debug(fmt.Sprintf("%s Bound %d servers with tools", CodeModeLogPrefix, len(serverKeys))) + } + + // Step 7: Wrap transpiled code to execute the async function and return its result + // The transpiled code contains an async function __execute__() that we need to call + // Trim trailing newlines to avoid issues when wrapping + codeToWrap := strings.TrimRight(jsCode, "\n\r") + // Wrap in IIFE that calls the transpiled async function and returns the promise + wrappedCode := fmt.Sprintf("(function() {\n%s\nreturn __execute__();\n})()", codeToWrap) + + // Step 8: Create goja runtime + vm := goja.New() + + // Step 9: Set up thread-safe logging + appendLog := func(msg string) { + m.logMu.Lock() + defer m.logMu.Unlock() + logs = append(logs, msg) + } + + // Step 10: Set up console + consoleObj := vm.NewObject() + consoleObj.Set("log", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(message) + }) + consoleObj.Set("error", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(fmt.Sprintf("[ERROR] %s", message)) + }) + consoleObj.Set("warn", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(fmt.Sprintf("[WARN] %s", message)) + }) + consoleObj.Set("info", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(fmt.Sprintf("[INFO] %s", message)) + }) + vm.Set("console", consoleObj) + + // Step 11: Set up server bindings + for serverKey, tools := range bindings { + serverObj := vm.NewObject() + for toolName, binding := range tools { + // Capture variables for closure + toolNameFinal := binding.toolName + clientNameFinal := binding.clientName + + serverObj.Set(toolName, func(call goja.FunctionCall) goja.Value { + args := call.Argument(0).Export() + + // Convert args to map[string]interface{} + argsMap, ok := args.(map[string]interface{}) + if !ok { + logger.Debug(fmt.Sprintf("%s Invalid args type for %s.%s: expected object, got %T", + CodeModeLogPrefix, clientNameFinal, toolNameFinal, args)) + // Return rejected promise for invalid args + promise, _, reject := vm.NewPromise() + err := fmt.Errorf("expected object argument, got %T", args) + reject(vm.ToValue(err)) + return vm.ToValue(promise) + } + + // Create promise on VM goroutine (thread-safe) + promise, resolve, reject := vm.NewPromise() + + // Define result struct for channel communication + type toolResult struct { + result interface{} + err error + } + + // Create buffered channel for worker communication + resultChan := make(chan toolResult, 1) + + // Call tool asynchronously with timeout context and panic recovery + // Worker goroutine - NO VM calls allowed here + go func() { + defer func() { + if r := recover(); r != nil { + logger.Debug(fmt.Sprintf("%s Panic in tool call goroutine for %s.%s: %v", + CodeModeLogPrefix, clientNameFinal, toolNameFinal, r)) + // Send panic as error through channel (no VM calls in worker) + select { + case resultChan <- toolResult{nil, fmt.Errorf("tool call panic: %v", r)}: + case <-timeoutCtx.Done(): + // Context cancelled, ignore + } + } + }() + + // Check if context is already cancelled before starting + select { + case <-timeoutCtx.Done(): + // Send timeout error through channel (no VM calls in worker) + select { + case resultChan <- toolResult{nil, fmt.Errorf("execution timeout")}: + case <-timeoutCtx.Done(): + // Already cancelled, ignore + } + return + default: + } + + result, err := m.callMCPTool(timeoutCtx, clientNameFinal, toolNameFinal, argsMap, appendLog) + + // Check if context was cancelled during execution + select { + case <-timeoutCtx.Done(): + // Send timeout error through channel (no VM calls in worker) + select { + case resultChan <- toolResult{nil, fmt.Errorf("execution timeout")}: + case <-timeoutCtx.Done(): + // Already cancelled, ignore + } + return + default: + } + + // Send result through channel (no VM calls in worker) + select { + case resultChan <- toolResult{result, err}: + case <-timeoutCtx.Done(): + // Context cancelled, ignore + } + }() + + // Process result synchronously on VM goroutine to ensure thread safety + // This blocks the VM goroutine until the tool call completes, but ensures + // all VM operations (vm.ToValue, resolve, reject) happen on the correct thread + select { + case res := <-resultChan: + if res.err != nil { + logger.Debug(fmt.Sprintf("%s Tool call failed: %s.%s - %v", + CodeModeLogPrefix, clientNameFinal, toolNameFinal, res.err)) + reject(vm.ToValue(res.err)) + } else { + resolve(vm.ToValue(res.result)) + } + case <-timeoutCtx.Done(): + reject(vm.ToValue(fmt.Errorf("execution timeout"))) + } + + return vm.ToValue(promise) + }) + } + vm.Set(serverKey, serverObj) + } + + // Step 12: Set up environment info + envObj := vm.NewObject() + envObj.Set("serverKeys", serverKeys) + envObj.Set("version", "1.0.0") + vm.Set("__MCP_ENV__", envObj) + + // Step 13: Execute code with timeout + + // Set up interrupt handler + interruptDone := make(chan struct{}) + go func() { + select { + case <-timeoutCtx.Done(): + logger.Debug(fmt.Sprintf("%s Execution timeout reached", CodeModeLogPrefix)) + vm.Interrupt("execution timeout") + case <-interruptDone: + } + }() + + var result interface{} + var executionErr error + + func() { + defer close(interruptDone) + val, err := vm.RunString(wrappedCode) + if err != nil { + logger.Debug(fmt.Sprintf("%s VM execution error: %v", CodeModeLogPrefix, err)) + executionErr = err + return + } + + // Check if the result is a promise by checking its type + // First check if val is nil or undefined (these can't be converted to objects) + if val == nil || val == goja.Undefined() { + result = nil + return + } + + // Try to convert to object to check if it's a promise + // Use recover to safely handle null values that can't be converted to objects + var valObj *goja.Object + func() { + defer func() { + if r := recover(); r != nil { + // Value is null or can't be converted to object, just export it + valObj = nil + } + }() + valObj = val.ToObject(vm) + }() + + if valObj != nil { + // Check if it has a 'then' method (Promise-like) + if then := valObj.Get("then"); then != nil && then != goja.Undefined() { + // It's a promise, we need to await it + // Use buffered channels to prevent blocking if handlers are called after timeout + resultChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + + // Set up promise handlers + thenFunc, ok := goja.AssertFunction(then) + if ok { + // Call then with resolve and reject handlers + _, err := thenFunc(val, + vm.ToValue(func(res goja.Value) { + select { + case resultChan <- res.Export(): + case <-timeoutCtx.Done(): + // Timeout already occurred, ignore result + } + }), + vm.ToValue(func(err goja.Value) { + var errMsg string + if err == nil || err == goja.Undefined() { + errMsg = "unknown error" + } else { + // Try to get error message from Error object + if errObj := err.ToObject(vm); errObj != nil { + if msg := errObj.Get("message"); msg != nil && msg != goja.Undefined() { + errMsg = msg.String() + } else if name := errObj.Get("name"); name != nil && name != goja.Undefined() { + errMsg = name.String() + } else { + errMsg = err.String() + } + } else { + // Fallback to string conversion + errMsg = err.String() + } + } + select { + case errChan <- fmt.Errorf("%s", errMsg): + case <-timeoutCtx.Done(): + // Timeout already occurred, ignore error + } + }), + ) + if err != nil { + executionErr = err + return + } + + // Wait for result or error with timeout + select { + case res := <-resultChan: + result = res + case err := <-errChan: + logger.Debug(fmt.Sprintf("%s Promise rejected: %v", CodeModeLogPrefix, err)) + executionErr = err + case <-timeoutCtx.Done(): + logger.Debug(fmt.Sprintf("%s Promise timeout while waiting for result", CodeModeLogPrefix)) + executionErr = fmt.Errorf("execution timeout") + } + } else { + result = val.Export() + } + } else { + result = val.Export() + } + } else { + // Not an object (or null/undefined), just export the value + result = val.Export() + } + }() + + if executionErr != nil { + errorMessage := executionErr.Error() + hints := generateErrorHints(errorMessage, serverKeys) + logger.Debug(fmt.Sprintf("%s Execution failed: %s", CodeModeLogPrefix, errorMessage)) + + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: &ExecutionError{ + Kind: ExecutionErrorTypeRuntime, + Message: errorMessage, + Hints: hints, + }, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } + } + + logger.Debug(fmt.Sprintf("%s Execution completed successfully", CodeModeLogPrefix)) + return ExecutionResult{ + Result: result, + Logs: logs, + Errors: nil, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } +} + +// callMCPTool calls an MCP tool and returns the result. +// It locates the client by name, constructs the MCP tool call request, executes it +// with timeout handling, and parses the response as JSON or returns it as a string. +// +// Parameters: +// - ctx: Context for tool execution (used for timeout) +// - clientName: Name of the MCP client/server to call +// - toolName: Name of the tool to execute +// - args: Tool arguments as a map +// - appendLog: Function to append log messages during execution +// +// Returns: +// - interface{}: Parsed tool result (JSON object or string) +// - error: Any error that occurred during tool execution +func (m *ToolsManager) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { + // Get available tools per client + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + + // Find the client by name + tools, exists := availableToolsPerClient[clientName] + if !exists || len(tools) == 0 { + return nil, fmt.Errorf("client not found for server name: %s", clientName) + } + + // Get client using a tool from this client + // Find the first tool with a valid Function to use for client lookup + var client *schemas.MCPClientState + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name != "" { + client = m.clientManager.GetClientForTool(tool.Function.Name) + if client != nil { + break + } + } + } + + if client == nil { + return nil, fmt.Errorf("client not found for server name: %s", clientName) + } + + // Strip the client name prefix from tool name before calling MCP server + // The MCP server expects the original tool name, not the prefixed version + originalToolName := stripClientPrefix(toolName, clientName) + + // Call the tool via MCP client + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: originalToolName, + Arguments: args, + }, + } + + // Create timeout context + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + if callErr != nil { + logger.Debug(fmt.Sprintf("%s Tool call failed: %s.%s - %v", CodeModeLogPrefix, clientName, toolName, callErr)) + appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, toolName, callErr)) + return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, toolName, callErr) + } + + // Extract result + rawResult := extractTextFromMCPResponse(toolResponse, toolName) + + // Check if this is an error result (from NewToolResultError) + // Error results start with "Error: " prefix + if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { + errorMsg := after + logger.Debug(fmt.Sprintf("%s Tool returned error result: %s.%s - %s", CodeModeLogPrefix, clientName, toolName, errorMsg)) + appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, toolName, errorMsg)) + return nil, fmt.Errorf("%s", errorMsg) + } + + // Try to parse as JSON, otherwise use as string + var finalResult interface{} + if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { + // Not JSON, use as string + finalResult = rawResult + } + + // Log the result + resultStr := formatResultForLog(finalResult) + appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, toolName, resultStr)) + + return finalResult, nil +} + +// HELPER FUNCTIONS + +// formatResultForLog formats a result value for logging purposes. +// It attempts to marshal to JSON for structured output, falling back to string representation. +// +// Parameters: +// - result: The result value to format +// +// Returns: +// - string: Formatted string representation of the result +func formatResultForLog(result interface{}) string { + var resultStr string + if result == nil { + resultStr = "null" + } else if resultBytes, err := sonic.Marshal(result); err == nil { + resultStr = string(resultBytes) + } else { + resultStr = fmt.Sprintf("%v", result) + } + return resultStr +} + +// formatConsoleArgs formats console arguments for logging. +// It formats each argument as JSON if possible, otherwise uses string representation. +// +// Parameters: +// - args: Array of console arguments to format +// +// Returns: +// - string: Formatted string with all arguments joined by spaces +func formatConsoleArgs(args []interface{}) string { + parts := make([]string, len(args)) + for i, arg := range args { + if argBytes, err := sonic.MarshalIndent(arg, "", " "); err == nil { + parts[i] = string(argBytes) + } else { + parts[i] = fmt.Sprintf("%v", arg) + } + } + return strings.Join(parts, " ") +} + +// stripImportsAndExports strips import and export statements from code. +// It removes lines that start with import or export keywords and returns +// the cleaned code along with 1-based line numbers of stripped lines. +// +// Parameters: +// - code: Source code string to process +// +// Returns: +// - string: Code with import/export statements removed +// - []int: 1-based line numbers of stripped lines +func stripImportsAndExports(code string) (string, []int) { + lines := strings.Split(code, "\n") + keptLines := []string{} + strippedLineNumbers := []int{} + + importExportRegex := regexp.MustCompile(`^\s*(import|export)\b`) + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + + // Skip empty lines + if trimmed == "" { + keptLines = append(keptLines, line) + continue + } + + // Check if this is an import or export statement + isImportOrExport := importExportRegex.MatchString(line) + + if isImportOrExport { + strippedLineNumbers = append(strippedLineNumbers, i+1) // 1-based line numbers + continue // Skip import/export lines + } + + // Keep comment lines and all other non-import/export lines + keptLines = append(keptLines, line) + } + + return strings.Join(keptLines, "\n"), strippedLineNumbers +} + +// generateTypeScriptErrorHints generates helpful hints for TypeScript compilation errors. +// It analyzes the error message and provides context-specific guidance based on error patterns. +// +// Parameters: +// - errorMessage: The TypeScript compilation error message +// - serverKeys: List of available MCP server keys for context +// +// Returns: +// - []string: Array of helpful hint messages +func generateTypeScriptErrorHints(errorMessage string, serverKeys []string) []string { + hints := []string{} + + // TypeScript-specific error patterns + if strings.Contains(errorMessage, "Cannot find name") || strings.Contains(errorMessage, "is not defined") { + hints = append(hints, "TypeScript compilation error: undefined variable or identifier.") + hints = append(hints, "Check that all variables are properly declared and typed.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, "Use server keys to access MCP tools: .(args)") + } + } else if strings.Contains(errorMessage, "Type") && (strings.Contains(errorMessage, "is not assignable") || strings.Contains(errorMessage, "does not exist")) { + hints = append(hints, "TypeScript type error detected.") + hints = append(hints, "Check that variable types match their usage.") + hints = append(hints, "Ensure function arguments match the expected types.") + } else if strings.Contains(errorMessage, "Expected") { + hints = append(hints, "TypeScript syntax error detected.") + hints = append(hints, "Check for missing parentheses, brackets, or semicolons.") + hints = append(hints, "Ensure all code blocks are properly closed.") + } else if strings.Contains(errorMessage, "async") || strings.Contains(errorMessage, "await") { + hints = append(hints, "async/await syntax should be supported. If you see this error, it may be a TypeScript compilation issue.") + hints = append(hints, "Ensure async functions are properly declared: async function myFunction() { ... }") + hints = append(hints, "Example: const result = await serverName.toolName({...});") + } else { + hints = append(hints, "TypeScript compilation error detected.") + hints = append(hints, "Review the error message above for specific details.") + hints = append(hints, "Ensure your TypeScript code follows valid syntax and type rules.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + } + + return hints +} + +// generateErrorHints generates helpful hints based on runtime error messages. +// It analyzes common runtime error patterns (undefined variables, missing functions, etc.) +// and provides context-specific guidance including available server keys and usage examples. +// +// Parameters: +// - errorMessage: The runtime error message +// - serverKeys: List of available MCP server keys for context +// +// Returns: +// - []string: Array of helpful hint messages +func generateErrorHints(errorMessage string, serverKeys []string) []string { + hints := []string{} + + if strings.Contains(errorMessage, "is not defined") { + re := regexp.MustCompile(`(\w+)\s+is not defined`) + if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar := match[1] + + // Special handling for common browser/Node.js APIs + if undefinedVar == "fetch" { + hints = append(hints, "The 'fetch' API is not available in this runtime environment.") + hints = append(hints, "Instead of using fetch for HTTP requests, use the available MCP tools.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, fmt.Sprintf("Example: const result = await %s.({ url: 'https://example.com' });", serverKeys[0])) + } + hints = append(hints, "MCP tools handle HTTP requests, file operations, and other external interactions.") + return hints + } else if undefinedVar == "XMLHttpRequest" || undefinedVar == "axios" { + hints = append(hints, fmt.Sprintf("The '%s' API is not available in this runtime environment.", undefinedVar)) + hints = append(hints, "Use MCP tools instead for HTTP requests and external API calls.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + return hints + } else if undefinedVar == "setTimeout" || undefinedVar == "setInterval" { + hints = append(hints, fmt.Sprintf("The '%s' API is not available in this runtime environment.", undefinedVar)) + hints = append(hints, "This is a sandboxed environment focused on MCP tool interactions.") + hints = append(hints, "Use Promise chains with MCP tools instead of timing functions.") + return hints + } else if undefinedVar == "require" || undefinedVar == "import" { + hints = append(hints, "Module imports are not supported in this runtime environment.") + hints = append(hints, "Use the available MCP tools for external functionality.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + return hints + } + + // Generic undefined variable handling + hints = append(hints, fmt.Sprintf("Variable or identifier '%s' is not defined.", undefinedVar)) + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Use one of the available server keys as the object name: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, "Then access tools using: .(args)") + hints = append(hints, fmt.Sprintf("For example: const result = await %s.({ ... });", serverKeys[0])) + } + } + } else if strings.Contains(errorMessage, "is not a function") { + re := regexp.MustCompile(`(\w+(?:\.\w+)?)\s+is not a function`) + if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { + notFunction := match[1] + hints = append(hints, fmt.Sprintf("'%s' is not a function.", notFunction)) + hints = append(hints, "Ensure you're using the correct server key and tool name.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + hints = append(hints, "To see available tools for a server, use listToolFiles and readToolFile.") + } + } else if strings.Contains(errorMessage, "Cannot read property") || + strings.Contains(errorMessage, "Cannot read properties") || + strings.Contains(errorMessage, "is not an object") { + hints = append(hints, "You're trying to access a property that doesn't exist or is undefined.") + hints = append(hints, "The tool response structure might be different than expected.") + hints = append(hints, "Check the console logs above to see the actual response structure from the tool.") + hints = append(hints, "Add console.log() statements to inspect the response before accessing properties.") + hints = append(hints, "Example: console.log('searchResults:', searchResults);") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + } else { + hints = append(hints, "Check the error message above for details.") + hints = append(hints, "Check the console logs above to see tool responses and debug the issue.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + hints = append(hints, "Ensure you're using the correct syntax: const result = await .({ ...args });") + } + + return hints +} diff --git a/core/mcp/codemode_listfiles.go b/core/mcp/codemode_listfiles.go new file mode 100644 index 0000000000..a285d2bb99 --- /dev/null +++ b/core/mcp/codemode_listfiles.go @@ -0,0 +1,83 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// createListToolFilesTool creates the listToolFiles tool definition for code mode. +// This tool allows listing all available virtual .d.ts declaration files for connected MCP servers. +// +// Returns: +// - schemas.ChatTool: The tool definition for listing tool files +func (m *ToolsManager) createListToolFilesTool() schemas.ChatTool { + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: ToolTypeListToolFiles, + Description: schemas.Ptr( + "Returns a tree structure listing all virtual .d.ts declaration files available for connected MCP servers. " + + "Each connected server has a corresponding virtual file that can be read using readToolFile. " + + "The filenames follow the pattern .d.ts where serverDisplayName is the human-readable " + + "name reported by each connected server. Note that the code-level bindings (used in executeToolCode) use " + + "configuration keys from SERVER_CONFIGS, which may differ from these display names. " + + "This tool is generic and works with any set of servers connected at runtime. " + + "Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + + "If you have even the SLIGHTEST DOUBT that the current tools might not be useful for the task, check listToolFiles to discover all available tools.", + ), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{}, + Required: []string{}, + }, + }, + } +} + +// handleListToolFiles handles the listToolFiles tool call. +// It builds a tree structure listing all virtual .d.ts files available for code mode clients. +// +// Parameters: +// - ctx: Context for accessing client tools +// - toolCall: The tool call request containing no arguments +// +// Returns: +// - *schemas.ChatMessage: A tool response message containing the file tree structure +// - error: Any error that occurred during processing +func (m *ToolsManager) handleListToolFiles(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + + if len(availableToolsPerClient) == 0 { + responseText := "No servers are currently connected. There are no virtual .d.ts files available. " + + "Please ensure servers are connected before using this tool." + return createToolResponseMessage(toolCall, responseText), nil + } + + // Build tree structure + treeLines := []string{"servers/"} + codeModeServerCount := 0 + for clientName := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient { + continue + } + codeModeServerCount++ + treeLines = append(treeLines, fmt.Sprintf(" %s.d.ts", clientName)) + } + + if codeModeServerCount == 0 { + responseText := "Servers are connected but none are configured for code mode. " + + "There are no virtual .d.ts files available." + return createToolResponseMessage(toolCall, responseText), nil + } + + responseText := strings.Join(treeLines, "\n") + return createToolResponseMessage(toolCall, responseText), nil +} diff --git a/core/mcp/codemode_readfile.go b/core/mcp/codemode_readfile.go new file mode 100644 index 0000000000..60fdb2b025 --- /dev/null +++ b/core/mcp/codemode_readfile.go @@ -0,0 +1,396 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// createReadToolFileTool creates the readToolFile tool definition for code mode. +// This tool allows reading virtual .d.ts declaration files for specific MCP servers, +// generating TypeScript type definitions from the server's tool schemas. +// +// Returns: +// - schemas.ChatTool: The tool definition for reading tool files +func (m *ToolsManager) createReadToolFileTool() schemas.ChatTool { + readToolFileProps := map[string]interface{}{ + "fileName": map[string]interface{}{ + "type": "string", + "description": "The virtual filename (e.g., 'calculator-server.d.ts') from listToolFiles", + }, + "startLine": map[string]interface{}{ + "type": "number", + "description": "Optional 1-based starting line number for partial file read (inclusive). Note: Line numbers start at 1, not 0. The first line is line 1.", + }, + "endLine": map[string]interface{}{ + "type": "number", + "description": "Optional 1-based ending line number for partial file read (inclusive)", + }, + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: ToolTypeReadToolFile, + Description: schemas.Ptr( + "Reads a virtual .d.ts declaration file for a specific MCP server, generating TypeScript type definitions " + + "from the server's tool schemas. The fileName should match one of the virtual files listed by listToolFiles. " + + "The function removes the .d.ts extension and performs case-insensitive matching against both the server's " + + "display name and configuration key. Optionally, you can specify startLine and endLine (1-based, inclusive) " + + "to read only a portion of the file. IMPORTANT: Line numbers are 1-based, not 0-based. The first line is line 1, not line 0. " + + "This tool generates pseudo declaration files that describe the available " + + "tools and their argument types, enabling code-mode execution as described in the MCP code execution pattern. " + + "The generated file includes interfaces for each tool's arguments and corresponding function declarations. " + + "Always follow this workflow: first use listToolFiles to see available servers, then use readToolFile to understand " + + "the tool definitions, and finally use executeToolCode to execute your code.", + ), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &readToolFileProps, + Required: []string{"fileName"}, + }, + }, + } +} + +// handleReadToolFile handles the readToolFile tool call. +// It reads a virtual .d.ts file for a specific MCP server, generates TypeScript type definitions, +// and optionally returns a portion of the file based on line range parameters. +// +// Parameters: +// - ctx: Context for accessing client tools +// - toolCall: The tool call request containing fileName and optional startLine/endLine +// +// Returns: +// - *schemas.ChatMessage: A tool response message containing the TypeScript definitions +// - error: Any error that occurred during processing +func (m *ToolsManager) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments: %v", err) + } + + fileName, ok := arguments["fileName"].(string) + if !ok || fileName == "" { + return nil, fmt.Errorf("fileName parameter is required and must be a string") + } + + // Get available tools per client + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + + // Remove .d.ts extension and normalize to lowercase for matching + baseName := strings.ToLower(strings.TrimSuffix(fileName, ".d.ts")) + + // Find matching client + var matchedClientName string + var matchedTools []schemas.ChatTool + for clientName, tools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { + continue + } + clientNameLower := strings.ToLower(clientName) + if clientNameLower == baseName { + if matchedClientName != "" { + // Multiple matches found + availableFiles := make([]string, 0, len(availableToolsPerClient)) + for name := range availableToolsPerClient { + availableFiles = append(availableFiles, fmt.Sprintf("%s.d.ts", name)) + } + errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName) + for name := range availableToolsPerClient { + if strings.ToLower(name) == baseName { + errorMsg += fmt.Sprintf(" - %s\n", name) + } + } + errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity." + return createToolResponseMessage(toolCall, errorMsg), nil + } + matchedClientName = clientName + matchedTools = tools + } + } + + if matchedClientName == "" { + availableFiles := make([]string, 0, len(availableToolsPerClient)) + for name := range availableToolsPerClient { + availableFiles = append(availableFiles, fmt.Sprintf("%s.d.ts", name)) + } + errorMsg := fmt.Sprintf("No server found matching filename '%s'. Available virtual files are:\n", fileName) + for _, f := range availableFiles { + errorMsg += fmt.Sprintf(" - %s\n", f) + } + errorMsg += "\nPlease use one of the exact filenames listed above. The matching is case-insensitive and works with both display names and configuration keys." + return createToolResponseMessage(toolCall, errorMsg), nil + } + + // Generate TypeScript definitions + fileContent := generateTypeDefinitions(matchedClientName, matchedTools) + lines := strings.Split(fileContent, "\n") + totalLines := len(lines) + + // Handle line slicing if provided + var startLine, endLine *int + if sl, ok := arguments["startLine"].(float64); ok { + slInt := int(sl) + startLine = &slInt + } + if el, ok := arguments["endLine"].(float64); ok { + elInt := int(el) + endLine = &elInt + } + + if startLine != nil || endLine != nil { + start := 1 + if startLine != nil { + start = *startLine + } + end := totalLines + if endLine != nil { + end = *endLine + } + + // Validate line numbers + if start < 1 || start > totalLines { + errorMsg := fmt.Sprintf("Invalid startLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%v, totalLines=%d", + start, totalLines, start, endLine, totalLines) + return createToolResponseMessage(toolCall, errorMsg), nil + } + if end < 1 || end > totalLines { + errorMsg := fmt.Sprintf("Invalid endLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%d, totalLines=%d", + end, totalLines, start, end, totalLines) + return createToolResponseMessage(toolCall, errorMsg), nil + } + if start > end { + errorMsg := fmt.Sprintf("Invalid line range: startLine (%d) must be less than or equal to endLine (%d). Total lines in file: %d", + start, end, totalLines) + return createToolResponseMessage(toolCall, errorMsg), nil + } + + // Slice lines (convert to 0-based indexing) + selectedLines := lines[start-1 : end] + fileContent = strings.Join(selectedLines, "\n") + } + + return createToolResponseMessage(toolCall, fileContent), nil +} + +// HELPER FUNCTIONS + +// generateTypeDefinitions generates TypeScript type definitions from ChatTool schemas +// with comprehensive comments to help LLMs understand how to use the tools. +// It creates interfaces for tool inputs and responses, along with function declarations. +// +// Parameters: +// - clientName: Name of the MCP client/server +// - tools: List of chat tools to generate definitions for +// +// Returns: +// - string: Complete TypeScript declaration file content +func generateTypeDefinitions(clientName string, tools []schemas.ChatTool) string { + var sb strings.Builder + + // Write comprehensive header comment + sb.WriteString("// ============================================================================\n") + sb.WriteString(fmt.Sprintf("// Type definitions for %s MCP server\n", clientName)) + sb.WriteString("// ============================================================================\n") + sb.WriteString("//\n") + sb.WriteString("// This file contains TypeScript type definitions for all tools available on this MCP server.\n") + sb.WriteString("// These definitions enable code-mode execution as described in the MCP code execution pattern.\n") + sb.WriteString("//\n") + sb.WriteString("// USAGE INSTRUCTIONS:\n") + sb.WriteString("// 1. Each tool has an input interface (e.g., ToolNameInput) that defines the required parameters\n") + sb.WriteString("// 2. Each tool has a function declaration showing how to call it\n") + sb.WriteString("// 3. To use these tools in executeToolCode, you would call them like:\n") + sb.WriteString("// const result = await .({ ...args });\n") + sb.WriteString("//\n") + sb.WriteString("// NOTE: The server name used in executeToolCode is the same as the display name shown here.\n") + sb.WriteString("// ============================================================================\n\n") + + // Generate interfaces and function declarations for each tool + for _, tool := range tools { + if tool.Function == nil || tool.Function.Name == "" { + continue + } + + originalToolName := tool.Function.Name + // Parse tool name for property name compatibility (used in virtual TypeScript files) + toolName := parseToolName(originalToolName) + description := "" + if tool.Function.Description != nil { + description = *tool.Function.Description + } + + // Generate input interface with detailed comments + inputInterfaceName := toPascalCase(toolName) + "Input" + sb.WriteString("// ----------------------------------------------------------------------------\n") + sb.WriteString(fmt.Sprintf("// Tool: %s\n", toolName)) + sb.WriteString("// ----------------------------------------------------------------------------\n") + if description != "" { + sb.WriteString(fmt.Sprintf("// Description: %s\n", description)) + } + sb.WriteString(fmt.Sprintf("// Input interface for %s\n", toolName)) + sb.WriteString(fmt.Sprintf("// This interface defines all parameters that can be passed to the %s tool.\n", toolName)) + sb.WriteString(fmt.Sprintf("interface %s {\n", inputInterfaceName)) + + if tool.Function.Parameters != nil && tool.Function.Parameters.Properties != nil { + props := *tool.Function.Parameters.Properties + required := make(map[string]bool) + if tool.Function.Parameters.Required != nil { + for _, req := range tool.Function.Parameters.Required { + required[req] = true + } + } + + // Sort properties for consistent output + propNames := make([]string, 0, len(props)) + for name := range props { + propNames = append(propNames, name) + } + // Simple alphabetical sort + for i := 0; i < len(propNames)-1; i++ { + for j := i + 1; j < len(propNames); j++ { + if propNames[i] > propNames[j] { + propNames[i], propNames[j] = propNames[j], propNames[i] + } + } + } + + for _, propName := range propNames { + prop := props[propName] + propMap, ok := prop.(map[string]interface{}) + if !ok { + continue + } + + tsType := jsonSchemaToTypeScript(propMap) + optional := "" + if !required[propName] { + optional = "?" + } + + propDesc := "" + if desc, ok := propMap["description"].(string); ok && desc != "" { + propDesc = fmt.Sprintf(" // %s", desc) + } else { + propDesc = fmt.Sprintf(" // %s parameter", propName) + } + + requiredNote := "" + if required[propName] { + requiredNote = " (required)" + } else { + requiredNote = " (optional)" + } + + sb.WriteString(fmt.Sprintf(" %s%s: %s;%s%s\n", propName, optional, tsType, propDesc, requiredNote)) + } + } + + sb.WriteString("}\n\n") + + // Generate response interface with helpful comments + responseInterfaceName := toPascalCase(toolName) + "Response" + sb.WriteString(fmt.Sprintf("// Response interface for %s\n", toolName)) + sb.WriteString("// The actual response structure depends on the tool implementation.\n") + sb.WriteString("// This is a placeholder interface - the actual response may contain different fields.\n") + sb.WriteString(fmt.Sprintf("interface %s {\n", responseInterfaceName)) + sb.WriteString(" // Response structure depends on the tool implementation\n") + sb.WriteString(" // Common fields may include: result, error, data, etc.\n") + sb.WriteString(" [key: string]: any;\n") + sb.WriteString("}\n\n") + + // Generate function declaration with usage example + sb.WriteString(fmt.Sprintf("// Function declaration for %s\n", toolName)) + if description != "" { + sb.WriteString(fmt.Sprintf("// %s\n", description)) + } + sb.WriteString("//\n") + sb.WriteString("// Usage example in executeToolCode:\n") + sb.WriteString(fmt.Sprintf("// const result = await .%s({ ... });\n", toolName)) + sb.WriteString("// // Replace with the actual server name/ID\n") + sb.WriteString(fmt.Sprintf("// // Replace { ... } with the appropriate %sInput object\n", inputInterfaceName)) + sb.WriteString(fmt.Sprintf("export async function %s(input: %s): Promise<%s>;\n\n", toolName, inputInterfaceName, responseInterfaceName)) + } + + return sb.String() +} + +// jsonSchemaToTypeScript converts a JSON Schema type definition to a TypeScript type string. +// It handles basic types, arrays, enums, and defaults to "any" for unknown types. +// +// Parameters: +// - prop: JSON Schema property definition map +// +// Returns: +// - string: TypeScript type string representation +func jsonSchemaToTypeScript(prop map[string]interface{}) string { + // Check for explicit type + if typeVal, ok := prop["type"].(string); ok { + switch typeVal { + case "string": + return "string" + case "number", "integer": + return "number" + case "boolean": + return "boolean" + case "array": + itemsType := "any" + if items, ok := prop["items"].(map[string]interface{}); ok { + itemsType = jsonSchemaToTypeScript(items) + } + return fmt.Sprintf("%s[]", itemsType) + case "object": + return "object" + case "null": + return "null" + } + } + + // Check for enum + if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 { + enumStrs := make([]string, 0, len(enum)) + for _, e := range enum { + enumStrs = append(enumStrs, fmt.Sprintf("%q", e)) + } + return strings.Join(enumStrs, " | ") + } + + // Default to any + return "any" +} + +// toPascalCase converts a string to PascalCase format. +// It splits on underscores, hyphens, and spaces, then capitalizes the first letter +// of each word and lowercases the rest. +// +// Parameters: +// - s: Input string to convert +// +// Returns: +// - string: PascalCase formatted string +func toPascalCase(s string) string { + if s == "" { + return s + } + parts := strings.FieldsFunc(s, func(r rune) bool { + return r == '_' || r == '-' || r == ' ' + }) + result := "" + for _, part := range parts { + if len(part) > 0 { + result += strings.ToUpper(part[:1]) + strings.ToLower(part[1:]) + } + } + if result == "" { + return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) + } + return result +} diff --git a/core/mcp/init.go b/core/mcp/init.go new file mode 100644 index 0000000000..d0eb389c18 --- /dev/null +++ b/core/mcp/init.go @@ -0,0 +1,9 @@ +package mcp + +import "github.com/maximhq/bifrost/core/schemas" + +var logger schemas.Logger + +func SetLogger(l schemas.Logger) { + logger = l +} diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go new file mode 100644 index 0000000000..848f3bf6a0 --- /dev/null +++ b/core/mcp/mcp.go @@ -0,0 +1,232 @@ +package mcp + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" + + "github.com/mark3labs/mcp-go/server" +) + +// ============================================================================ +// CONSTANTS +// ============================================================================ + +const ( + // MCP defaults and identifiers + BifrostMCPVersion = "1.0.0" // Version identifier for Bifrost + BifrostMCPClientName = "BifrostClient" // Name for internal Bifrost MCP client + BifrostMCPClientKey = "bifrostInternal" // Key for internal Bifrost client in clientMap + MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix + MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment + + // Context keys for client filtering in requests + // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). + // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. + // Request context filtering takes priority over client config - context can override client exclusions. + MCPContextKeyIncludeClients schemas.BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering + MCPContextKeyIncludeTools schemas.BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName/toolName" format) +) + +// ============================================================================ +// TYPE DEFINITIONS +// ============================================================================ + +// MCPManager manages MCP integration for Bifrost core. +// It provides a bridge between Bifrost and various MCP servers, supporting +// both local tool hosting and external MCP server connections. +type MCPManager struct { + ctx context.Context + toolsHandler *ToolsManager // Handler for MCP tools + server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based) + clientMap map[string]*schemas.MCPClientState // Map of MCP client names to their configurations + mu sync.RWMutex // Read-write mutex for thread-safe operations + serverRunning bool // Track whether local MCP server is running +} + +// MCPToolFunction is a generic function type for handling tool calls with typed arguments. +// T represents the expected argument structure for the tool. +type MCPToolFunction[T any] func(args T) (string, error) + +// ============================================================================ +// CONSTRUCTOR AND INITIALIZATION +// ============================================================================ + +// NewMCPManager creates and initializes a new MCP manager instance. +// +// Parameters: +// - config: MCP configuration including server port and client configs +// - logger: Logger instance for structured logging (uses default if nil) +// +// Returns: +// - *MCPManager: Initialized manager instance +// - error: Any initialization error +func NewMCPManager(ctx context.Context, config schemas.MCPConfig, logger schemas.Logger) *MCPManager { + SetLogger(logger) + // Set default values + if config.ToolManagerConfig == nil { + config.ToolManagerConfig = &schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + MaxAgentDepth: schemas.DefaultMaxAgentDepth, + } + } + // Creating new instance + manager := &MCPManager{ + ctx: ctx, + clientMap: make(map[string]*schemas.MCPClientState), + } + manager.toolsHandler = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc) + // Process client configs: create client map entries and establish connections + for _, clientConfig := range config.ClientConfigs { + if err := manager.AddClient(clientConfig); err != nil { + logger.Warn(fmt.Sprintf("%s Failed to add MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err)) + } + } + logger.Info(MCPLogPrefix + " MCP Manager initialized") + return manager +} + +// AddToolsToRequest parses available MCP tools from the context and adds them to the request. +// It respects context-based filtering for clients and tools, and returns the modified request +// with tools attached. +// +// Parameters: +// - ctx: Context containing optional client/tool filtering keys +// - req: The Bifrost request to add tools to +// +// Returns: +// - *schemas.BifrostRequest: The request with tools added +func (m *MCPManager) AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { + return m.toolsHandler.ParseAndAddToolsToRequest(ctx, req) +} + +// ExecuteTool executes a single tool call from a chat assistant message. +// It handles tool execution, error handling, and returns the result as a chat message. +// +// Parameters: +// - ctx: Context for the tool execution +// - toolCall: The tool call to execute, containing tool name and arguments +// +// Returns: +// - *schemas.ChatMessage: The result message containing tool execution output +// - error: Any error that occurred during tool execution +func (m *MCPManager) ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + return m.toolsHandler.ExecuteTool(ctx, toolCall) +} + +// UpdateToolManagerConfig updates the configuration for the tool manager. +// This allows runtime updates to settings like execution timeout and max agent depth. +// +// Parameters: +// - config: The new tool manager configuration to apply +func (m *MCPManager) UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) { + m.toolsHandler.UpdateConfig(config) +} + +// CheckAndExecuteAgentForChatRequest checks if the chat response contains tool calls, +// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls +// are present, it returns the original response unchanged. +// +// Parameters: +// - ctx: Context for the agent execution +// - req: The original chat request +// - response: The initial chat response that may contain tool calls +// - makeReq: Function to make subsequent chat requests during agent execution +// +// Returns: +// - *schemas.BifrostChatResponse: The final response after agent execution (or original if no tool calls) +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *MCPManager) CheckAndExecuteAgentForChatRequest( + ctx *context.Context, + req *schemas.BifrostChatRequest, + response *schemas.BifrostChatResponse, + makeReq func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), +) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + if makeReq == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "makeReq is required to execute agent mode", + }, + } + } + // Check if initial response has tool calls + if !hasToolCallsForChatResponse(response) { + logger.Debug("No tool calls detected, returning response") + return response, nil + } + // Execute agent mode + return m.toolsHandler.ExecuteAgentForChatRequest(ctx, req, response, makeReq) +} + +// CheckAndExecuteAgentForResponsesRequest checks if the responses response contains tool calls, +// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls +// are present, it returns the original response unchanged. +// +// Parameters: +// - ctx: Context for the agent execution +// - req: The original responses request +// - response: The initial responses response that may contain tool calls +// - makeReq: Function to make subsequent responses requests during agent execution +// +// Returns: +// - *schemas.BifrostResponsesResponse: The final response after agent execution (or original if no tool calls) +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *MCPManager) CheckAndExecuteAgentForResponsesRequest( + ctx *context.Context, + req *schemas.BifrostResponsesRequest, + response *schemas.BifrostResponsesResponse, + makeReq func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), +) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + if makeReq == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "makeReq is required to execute agent mode", + }, + } + } + // Check if initial response has tool calls + if !hasToolCallsForResponsesResponse(response) { + logger.Debug("No tool calls detected, returning response") + return response, nil + } + // Execute agent mode + return m.toolsHandler.ExecuteAgentForResponsesRequest(ctx, req, response, makeReq) +} + +// Cleanup performs cleanup of all MCP resources including clients and local server. +// This function safely disconnects all MCP clients (HTTP, STDIO, and SSE) and +// cleans up the local MCP server. It handles proper cancellation of SSE contexts +// and closes all transport connections. +// +// Returns: +// - error: Always returns nil, but maintains error interface for consistency +func (m *MCPManager) Cleanup() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Disconnect all external MCP clients + for id := range m.clientMap { + if err := m.removeClientUnsafe(id); err != nil { + logger.Error("%s Failed to remove MCP client %s: %v", MCPLogPrefix, id, err) + } + } + + // Clear the client map + m.clientMap = make(map[string]*schemas.MCPClientState) + + // Clear local server reference + // Note: mark3labs/mcp-go STDIO server cleanup is handled automatically + if m.server != nil { + logger.Info(MCPLogPrefix + " Clearing local MCP server reference") + m.server = nil + m.serverRunning = false + } + + logger.Info(MCPLogPrefix + " MCP cleanup completed") + return nil +} diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go new file mode 100644 index 0000000000..78fbde2e85 --- /dev/null +++ b/core/mcp/toolmanager.go @@ -0,0 +1,386 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +type ClientManager interface { + GetClientByName(clientName string) *schemas.MCPClientState + GetClientForTool(toolName string) *schemas.MCPClientState + GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool +} + +type ToolsManager struct { + toolExecutionTimeout atomic.Value + maxAgentDepth atomic.Int32 + clientManager ClientManager + logMu sync.Mutex // Protects concurrent access to logs slice in codemode execution + + // Function to fetch a new request ID for each tool call result message in agent mode, + // this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user. + // This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode. + // If not provider, same request ID is used for all tool call result messages without any overrides. + fetchNewRequestIDFunc func(ctx context.Context) string +} + +const ( + ToolTypeListToolFiles string = "listToolFiles" + ToolTypeReadToolFile string = "readToolFile" + ToolTypeExecuteToolCode string = "executeToolCode" +) + +// NewToolsManager creates and initializes a new tools manager instance. +// It validates the configuration, sets defaults if needed, and initializes atomic values +// for thread-safe configuration updates. +// +// Parameters: +// - config: Tool manager configuration with execution timeout and max agent depth +// - clientManager: Client manager interface for accessing MCP clients and tools +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for agent mode +// +// Returns: +// - *ToolsManager: Initialized tools manager instance +func NewToolsManager(config *schemas.MCPToolManagerConfig, clientManager ClientManager, fetchNewRequestIDFunc func(ctx context.Context) string) *ToolsManager { + if config == nil { + config = &schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + MaxAgentDepth: schemas.DefaultMaxAgentDepth, + } + } + if config.MaxAgentDepth <= 0 { + config.MaxAgentDepth = schemas.DefaultMaxAgentDepth + } + if config.ToolExecutionTimeout <= 0 { + config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout + } + manager := &ToolsManager{ + clientManager: clientManager, + fetchNewRequestIDFunc: fetchNewRequestIDFunc, + } + // Initialize atomic values + manager.toolExecutionTimeout.Store(config.ToolExecutionTimeout) + manager.maxAgentDepth.Store(int32(config.MaxAgentDepth)) + + logger.Info(fmt.Sprintf("%s tool manager initialized with tool execution timeout: %v and max agent depth: %d", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth)) + return manager +} + +// ParseAndAddToolsToRequest parses the available tools per client and adds them to the Bifrost request. +// +// Parameters: +// - ctx: Execution context +// - req: Bifrost request +// - availableToolsPerClient: Map of client name to its available tools +// +// Returns: +// - *schemas.BifrostRequest: Bifrost request with MCP tools added +func (m *ToolsManager) ParseAndAddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { + // MCP is only supported for chat and responses requests + if req.ChatRequest == nil && req.ResponsesRequest == nil { + return req + } + + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + // Flatten tools from all clients into a single slice, avoiding duplicates + var availableTools []schemas.ChatTool + var includeCodeModeTools bool + // Track tool names to prevent duplicates + seenToolNames := make(map[string]bool) + + for clientName, clientTools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if client.ExecutionConfig.IsCodeModeClient { + includeCodeModeTools = true + } else { + // Add tools from this client, checking for duplicates + for _, tool := range clientTools { + if tool.Function != nil && tool.Function.Name != "" { + if !seenToolNames[tool.Function.Name] { + availableTools = append(availableTools, tool) + seenToolNames[tool.Function.Name] = true + } + } + } + } + } + + if includeCodeModeTools { + codeModeTools := []schemas.ChatTool{ + m.createListToolFilesTool(), + m.createReadToolFileTool(), + m.createExecuteToolCodeTool(), + } + // Add code mode tools, checking for duplicates + for _, tool := range codeModeTools { + if tool.Function != nil && tool.Function.Name != "" { + if !seenToolNames[tool.Function.Name] { + availableTools = append(availableTools, tool) + seenToolNames[tool.Function.Name] = true + } + } + } + } + + if len(availableTools) > 0 { + logger.Debug(fmt.Sprintf("%s Adding %d MCP tools to request from %d clients", MCPLogPrefix, len(availableTools), len(availableToolsPerClient))) + switch req.RequestType { + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + // Only allocate new Params if it's nil to preserve caller-supplied settings + if req.ChatRequest.Params == nil { + req.ChatRequest.Params = &schemas.ChatParameters{} + } + + tools := req.ChatRequest.Params.Tools + + // Create a map of existing tool names for O(1) lookup + existingToolsMap := make(map[string]bool) + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name != "" { + existingToolsMap[tool.Function.Name] = true + } + } + + // Add MCP tools that are not already present + for _, mcpTool := range availableTools { + // Skip tools with nil Function or empty Name + if mcpTool.Function == nil || mcpTool.Function.Name == "" { + continue + } + + if !existingToolsMap[mcpTool.Function.Name] { + tools = append(tools, mcpTool) + // Update the map to prevent duplicates within MCP tools as well + existingToolsMap[mcpTool.Function.Name] = true + } + } + req.ChatRequest.Params.Tools = tools + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + // Only allocate new Params if it's nil to preserve caller-supplied settings + if req.ResponsesRequest.Params == nil { + req.ResponsesRequest.Params = &schemas.ResponsesParameters{} + } + + tools := req.ResponsesRequest.Params.Tools + + // Create a map of existing tool names for O(1) lookup + existingToolsMap := make(map[string]bool) + for _, tool := range tools { + if tool.Name != nil { + existingToolsMap[*tool.Name] = true + } + } + + // Add MCP tools that are not already present + for _, mcpTool := range availableTools { + // Skip tools with nil Function or empty Name + if mcpTool.Function == nil || mcpTool.Function.Name == "" { + continue + } + + if !existingToolsMap[mcpTool.Function.Name] { + responsesTool := mcpTool.ToResponsesTool() + // Skip if the converted tool has nil Name + if responsesTool.Name == nil { + continue + } + + tools = append(tools, *responsesTool) + // Update the map to prevent duplicates within MCP tools as well + existingToolsMap[*responsesTool.Name] = true + } + } + req.ResponsesRequest.Params.Tools = tools + } + } + return req +} + +// ============================================================================ +// TOOL REGISTRATION AND DISCOVERY +// ============================================================================ + +// executeTool executes a tool call and returns the result as a tool message. +// +// Parameters: +// - ctx: Execution context +// - toolCall: The tool call to execute (from assistant message) +// +// Returns: +// - schemas.ChatMessage: Tool message with execution result +// - error: Any execution error +func (m *ToolsManager) ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + if toolCall.Function.Name == nil { + return nil, fmt.Errorf("tool call missing function name") + } + toolName := *toolCall.Function.Name + + // Handle code mode tools + switch toolName { + case ToolTypeListToolFiles: + return m.handleListToolFiles(ctx, toolCall) + case ToolTypeReadToolFile: + return m.handleReadToolFile(ctx, toolCall) + case ToolTypeExecuteToolCode: + return m.handleExecuteToolCode(ctx, toolCall) + default: + // Check if the user has permission to execute the tool call + availableTools := m.clientManager.GetToolPerClient(ctx) + toolFound := false + for _, tools := range availableTools { + for _, mcpTool := range tools { + if mcpTool.Function != nil && mcpTool.Function.Name == toolName { + toolFound = true + break + } + } + if toolFound { + break + } + } + + if !toolFound { + return nil, fmt.Errorf("tool '%s' is not available or not permitted", toolName) + } + + client := m.clientManager.GetClientForTool(toolName) + if client == nil { + return nil, fmt.Errorf("client not found for tool %s", toolName) + } + + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err) + } + + // Strip the client name prefix from tool name before calling MCP server + // The MCP server expects the original tool name, not the prefixed version + originalToolName := stripClientPrefix(toolName, client.ExecutionConfig.Name) + + // Call the tool via MCP client -> MCP server + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: originalToolName, + Arguments: arguments, + }, + } + + logger.Debug(fmt.Sprintf("%s Starting tool execution: %s via client: %s", MCPLogPrefix, toolName, client.ExecutionConfig.Name)) + + // Create timeout context for tool execution + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + if callErr != nil { + // Check if it was a timeout error + if toolCtx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) + } + logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr) + return nil, fmt.Errorf("MCP tool call failed: %v", callErr) + } + + logger.Debug(fmt.Sprintf("%s Tool execution completed: %s", MCPLogPrefix, toolName)) + + // Extract text from MCP response + responseText := extractTextFromMCPResponse(toolResponse, toolName) + + // Create tool response message + return createToolResponseMessage(toolCall, responseText), nil + } +} + +// ExecuteAgentForChatRequest executes agent mode for a chat request, handling +// iterative tool calls up to the configured maximum depth. It delegates to the +// shared agent execution logic with the manager's configuration and dependencies. +// +// Parameters: +// - ctx: Context for agent execution +// - req: The original chat request +// - resp: The initial chat response containing tool calls +// - makeReq: Function to make subsequent chat requests during agent execution +// +// Returns: +// - *schemas.BifrostChatResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *ToolsManager) ExecuteAgentForChatRequest( + ctx *context.Context, + req *schemas.BifrostChatRequest, + resp *schemas.BifrostChatResponse, + makeReq func(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), +) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return ExecuteAgentForChatRequest( + ctx, + int(m.maxAgentDepth.Load()), + req, + resp, + makeReq, + m.fetchNewRequestIDFunc, + m.ExecuteTool, + m.clientManager, + ) +} + +// ExecuteAgentForResponsesRequest executes agent mode for a responses request, handling +// iterative tool calls up to the configured maximum depth. It delegates to the +// shared agent execution logic with the manager's configuration and dependencies. +// +// Parameters: +// - ctx: Context for agent execution +// - req: The original responses request +// - resp: The initial responses response containing tool calls +// - makeReq: Function to make subsequent responses requests during agent execution +// +// Returns: +// - *schemas.BifrostResponsesResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *ToolsManager) ExecuteAgentForResponsesRequest( + ctx *context.Context, + req *schemas.BifrostResponsesRequest, + resp *schemas.BifrostResponsesResponse, + makeReq func(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), +) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return ExecuteAgentForResponsesRequest( + ctx, + int(m.maxAgentDepth.Load()), + req, + resp, + makeReq, + m.fetchNewRequestIDFunc, + m.ExecuteTool, + m.clientManager, + ) +} + +// UpdateConfig updates both tool execution timeout and max agent depth atomically. +// This method is safe to call concurrently from multiple goroutines. +func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) { + if config == nil { + return + } + if config.ToolExecutionTimeout > 0 { + m.toolExecutionTimeout.Store(config.ToolExecutionTimeout) + } + if config.MaxAgentDepth > 0 { + m.maxAgentDepth.Store(int32(config.MaxAgentDepth)) + } + + logger.Info(fmt.Sprintf("%s tool manager configuration updated with tool execution timeout: %v and max agent depth: %d", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth)) +} diff --git a/core/mcp/utils.go b/core/mcp/utils.go new file mode 100644 index 0000000000..cfed7993f0 --- /dev/null +++ b/core/mcp/utils.go @@ -0,0 +1,562 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "regexp" + "slices" + "strings" + "unicode" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +// GetClientForTool safely finds a client that has the specified tool. +// Returns a copy of the client state to avoid data races. Callers should be aware +// that fields like Conn and ToolMap are still shared references and may be modified +// by other goroutines, but the struct itself is safe from concurrent modification. +func (m *MCPManager) GetClientForTool(toolName string) *schemas.MCPClientState { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, client := range m.clientMap { + if _, exists := client.ToolMap[toolName]; exists { + // Return a copy to prevent TOCTOU race conditions + // The caller receives a snapshot of the client state at this point in time + clientCopy := *client + return &clientCopy + } + } + return nil +} + +// GetToolPerClient returns all tools from connected MCP clients. +// Applies client filtering if specified in the context. +// Returns a map of client name to its available tools. +// Parameters: +// - ctx: Execution context +// +// Returns: +// - map[string][]schemas.ChatTool: Map of client name to its available tools +func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { + m.mu.RLock() + defer m.mu.RUnlock() + + var includeClients []string + + // Extract client filtering from request context + if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { + includeClients = existingIncludeClients + } + + tools := make(map[string][]schemas.ChatTool) + for _, client := range m.clientMap { + // Use client name as the key (not ID) + clientName := client.ExecutionConfig.Name + + // Apply client filtering logic + if !shouldIncludeClient(clientName, includeClients) { + logger.Debug(fmt.Sprintf("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, clientName)) + continue + } + + logger.Debug(fmt.Sprintf("Checking tools for MCP client %s with tools to execute: %v", clientName, client.ExecutionConfig.ToolsToExecute)) + + // Add all tools from this client + for toolName, tool := range client.ToolMap { + // Check if tool should be skipped based on client configuration + if shouldSkipToolForConfig(toolName, client.ExecutionConfig) { + logger.Debug(fmt.Sprintf("%s Skipping MCP tool %s: not in tools to execute list", MCPLogPrefix, toolName)) + continue + } + + // Check if tool should be skipped based on request context + if shouldSkipToolForRequest(ctx, clientName, toolName) { + logger.Debug(fmt.Sprintf("%s Skipping MCP tool %s: not in include tools list", MCPLogPrefix, toolName)) + continue + } + + tools[clientName] = append(tools[clientName], tool) + } + if len(tools[clientName]) > 0 { + logger.Debug(fmt.Sprintf("%s Added %d tools for MCP client %s", MCPLogPrefix, len(tools[clientName]), clientName)) + } + } + return tools +} + +// GetClientByName returns a client by name. +// +// Parameters: +// - clientName: Name of the client to get +// +// Returns: +// - *schemas.MCPClientState: Client state if found, nil otherwise +func (m *MCPManager) GetClientByName(clientName string) *schemas.MCPClientState { + m.mu.RLock() + defer m.mu.RUnlock() + for _, client := range m.clientMap { + if client.ExecutionConfig.Name == clientName { + // Return a copy to prevent TOCTOU race conditions + // The caller receives a snapshot of the client state at this point in time + clientCopy := *client + return &clientCopy + } + } + return nil +} + +// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks. +func retrieveExternalTools(ctx context.Context, client *client.Client, clientName string) (map[string]schemas.ChatTool, error) { + // Get available tools from external server + listRequest := mcp.ListToolsRequest{ + PaginatedRequest: mcp.PaginatedRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsList), + }, + }, + } + + toolsResponse, err := client.ListTools(ctx, listRequest) + if err != nil { + return nil, fmt.Errorf("failed to list tools: %v", err) + } + + if toolsResponse == nil { + return make(map[string]schemas.ChatTool), nil // No tools available + } + + tools := make(map[string]schemas.ChatTool) + + // toolsResponse is already a ListToolsResult + for _, mcpTool := range toolsResponse.Tools { + // Convert MCP tool schema to Bifrost format + bifrostTool := convertMCPToolToBifrostSchema(&mcpTool) + // Prefix tool name with client name to make it permanent + prefixedToolName := fmt.Sprintf("%s_%s", clientName, mcpTool.Name) + // Update the tool's function name to match the prefixed name + if bifrostTool.Function != nil { + bifrostTool.Function.Name = prefixedToolName + } + tools[prefixedToolName] = bifrostTool + } + + return tools, nil +} + +// shouldIncludeClient determines if a client should be included based on filtering rules. +func shouldIncludeClient(clientName string, includeClients []string) bool { + // If includeClients is specified (not nil), apply whitelist filtering + if includeClients != nil { + // Handle empty array [] - means no clients are included + if len(includeClients) == 0 { + return false // No clients allowed + } + + // Handle wildcard "*" - if present, all clients are included + if slices.Contains(includeClients, "*") { + return true // All clients allowed + } + + // Check if specific client is in the list + return slices.Contains(includeClients, clientName) + } + + // Default: include all clients when no filtering specified (nil case) + return true +} + +// shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap). +func shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bool { + // If ToolsToExecute is specified (not nil), apply filtering + if config.ToolsToExecute != nil { + // Handle empty array [] - means no tools are allowed + if len(config.ToolsToExecute) == 0 { + return true // No tools allowed + } + + // Handle wildcard "*" - if present, all tools are allowed + if slices.Contains(config.ToolsToExecute, "*") { + return false // All tools allowed + } + + // Check if specific tool is in the allowed list + return !slices.Contains(config.ToolsToExecute, toolName) // Tool not in allowed list + } + + return true // Tool is skipped (nil is treated as [] - no tools) +} + +// canAutoExecuteTool checks if a tool can be auto-executed based on client configuration. +// Returns true if the tool can be auto-executed, false otherwise. +func canAutoExecuteTool(toolName string, config schemas.MCPClientConfig) bool { + // First check if tool is in ToolsToExecute (must be executable first) + if shouldSkipToolForConfig(toolName, config) { + return false // Tool is not in ToolsToExecute, so it cannot be auto-executed + } + + // If ToolsToAutoExecute is specified (not nil), apply filtering + if config.ToolsToAutoExecute != nil { + // Handle empty array [] - means no tools are auto-executed + if len(config.ToolsToAutoExecute) == 0 { + return false // No tools auto-executed + } + + // Handle wildcard "*" - if present, all tools are auto-executed + if slices.Contains(config.ToolsToAutoExecute, "*") { + return true // All tools auto-executed + } + + // Check if specific tool is in the auto-execute list + return slices.Contains(config.ToolsToAutoExecute, toolName) + } + + return false // Tool is not auto-executed (nil is treated as [] - no tools) +} + +// shouldSkipToolForRequest checks if a tool should be skipped based on the request context. +func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool { + includeTools := ctx.Value(MCPContextKeyIncludeTools) + + logger.Debug(fmt.Sprintf("%s Checking if tool %s should be skipped for request: %v", MCPLogPrefix, toolName, includeTools)) + + if includeTools != nil { + // Try []string first (preferred type) + if includeToolsList, ok := includeTools.([]string); ok { + // Handle empty array [] - means no tools are included + if len(includeToolsList) == 0 { + return true // No tools allowed + } + + // Handle wildcard "clientName/*" - if present, all tools are included for this client + if slices.Contains(includeToolsList, fmt.Sprintf("%s/*", clientName)) { + return false // All tools allowed + } + + // Check if specific tool is in the list (format: clientName/toolName) + fullToolName := fmt.Sprintf("%s/%s", clientName, toolName) + if slices.Contains(includeToolsList, fullToolName) { + return false // Tool is explicitly allowed + } + + // If includeTools is specified but this tool is not in it, skip it + return true + } + } + + return false // Tool is allowed (default when no filtering specified) +} + +// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format. +func convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.ChatTool { + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: mcpTool.Name, + Description: schemas.Ptr(mcpTool.Description), + Parameters: &schemas.ToolFunctionParameters{ + Type: mcpTool.InputSchema.Type, + Properties: schemas.Ptr(mcpTool.InputSchema.Properties), + Required: mcpTool.InputSchema.Required, + }, + }, + } +} + +// extractTextFromMCPResponse extracts text content from an MCP tool response. +func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string { + if toolResponse == nil { + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) + } + + var result strings.Builder + for _, contentBlock := range toolResponse.Content { + // Handle typed content + switch content := contentBlock.(type) { + case mcp.TextContent: + result.WriteString(content.Text) + case mcp.ImageContent: + result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.AudioContent: + result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.EmbeddedResource: + result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type)) + default: + // Fallback: try to extract from map structure + if jsonBytes, err := json.Marshal(contentBlock); err == nil { + var contentMap map[string]interface{} + if json.Unmarshal(jsonBytes, &contentMap) == nil { + if text, ok := contentMap["text"].(string); ok { + result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text)) + continue + } + } + // Final fallback: serialize as JSON + result.WriteString(string(jsonBytes)) + } + } + } + + if result.Len() > 0 { + return strings.TrimSpace(result.String()) + } + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) +} + +// createToolResponseMessage creates a tool response message with the execution result. +func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage { + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: &responseText, + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: toolCall.ID, + }, + } +} + +// validateMCPClientConfig validates an MCP client configuration. +func validateMCPClientConfig(config *schemas.MCPClientConfig) error { + if strings.TrimSpace(config.ID) == "" { + return fmt.Errorf("id is required for MCP client config") + } + if err := validateMCPClientName(config.Name); err != nil { + return fmt.Errorf("invalid name for MCP client: %w", err) + } + if config.ConnectionType == "" { + return fmt.Errorf("connection type is required for MCP client config") + } + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSSE: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSTDIO: + if config.StdioConfig == nil { + return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeInProcess: + // InProcess requires a server instance to be provided programmatically + // This cannot be validated from JSON config - the server must be set when using the Go package + if config.InProcessServer == nil { + return fmt.Errorf("InProcessServer is required for InProcess connection type in client '%s' (Go package only)", config.Name) + } + default: + return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name) + } + return nil +} + +func validateMCPClientName(name string) error { + if strings.TrimSpace(name) == "" { + return fmt.Errorf("name is required for MCP client") + } + for _, r := range name { + if r > 127 { // non-ASCII + return fmt.Errorf("name must contain only ASCII characters") + } + } + if strings.Contains(name, "-") { + return fmt.Errorf("name cannot contain hyphens") + } + if strings.Contains(name, " ") { + return fmt.Errorf("name cannot contain spaces") + } + if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { + return fmt.Errorf("name cannot start with a number") + } + return nil +} + +// parseToolName parses the tool name to be JavaScript-compatible. +// It converts spaces and hyphens to underscores, removes invalid characters, and ensures +// the name starts with a valid JavaScript identifier character. +func parseToolName(toolName string) string { + if toolName == "" { + return "" + } + + var result strings.Builder + runes := []rune(toolName) + + // Process first character - must be letter, underscore, or dollar sign + if len(runes) > 0 { + first := runes[0] + if unicode.IsLetter(first) || first == '_' || first == '$' { + result.WriteRune(unicode.ToLower(first)) + } else { + // If first char is invalid, prefix with underscore + result.WriteRune('_') + if unicode.IsDigit(first) { + result.WriteRune(first) + } + } + } + + // Process remaining characters + for i := 1; i < len(runes); i++ { + r := runes[i] + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' { + result.WriteRune(unicode.ToLower(r)) + } else if unicode.IsSpace(r) || r == '-' { + // Replace spaces and hyphens with single underscore + // Avoid consecutive underscores + if result.Len() > 0 && result.String()[result.Len()-1] != '_' { + result.WriteRune('_') + } + } + // Skip other invalid characters + } + + parsed := result.String() + + // Remove trailing underscores + parsed = strings.TrimRight(parsed, "_") + + // Ensure we have at least one character + // Should never happen, but just in case + if parsed == "" { + return "tool" + } + + return parsed +} + +// extractToolCallsFromCode extracts tool calls from TypeScript code +// Tool calls are in the format: serverName.toolName(...) or await serverName.toolName(...) +func extractToolCallsFromCode(code string) ([]toolCallInfo, error) { + toolCalls := []toolCallInfo{} + + // Regex pattern to match tool calls: + // - Optional "await" keyword + // - Server name (identifier) + // - Dot + // - Tool name (identifier) + // - Opening parenthesis + // This pattern matches: await serverName.toolName( or serverName.toolName( + toolCallPattern := regexp.MustCompile(`(?:await\s+)?([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\.\s*([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\(`) + + // Find all matches + matches := toolCallPattern.FindAllStringSubmatch(code, -1) + for _, match := range matches { + if len(match) >= 3 { + serverName := match[1] + toolName := match[2] + toolCalls = append(toolCalls, toolCallInfo{ + serverName: serverName, + toolName: toolName, + }) + } + } + + return toolCalls, nil +} + +// isToolCallAllowedForCodeMode checks if a tool call is allowed based on allowedAutoExecutionTools map +func isToolCallAllowedForCodeMode(serverName, toolName string, allClientNames []string, allowedAutoExecutionTools map[string][]string) bool { + // Check if the server name is in the list of all client names + if !slices.Contains(allClientNames, serverName) { + // It can be a built-in JavaScript/TypeScript object, if not then downstream execution will fail with a runtime error. + return true + } + + // Get allowed tools for this server + allowedTools, exists := allowedAutoExecutionTools[serverName] + if !exists { + // Server not in allowed list, return false to prevent downstream execution. + return false + } + + // Check if wildcard "*" is present (all tools allowed) + if slices.Contains(allowedTools, "*") { + return true + } + + // Check if specific tool is in the allowed list + if slices.Contains(allowedTools, toolName) { + return true + } + + return false // Tool not in allowed list +} + +// hasToolCalls checks if a chat response contains tool calls that need to be executed +func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool { + if response == nil || len(response.Choices) == 0 { + return false + } + + choice := response.Choices[0] + + // If finish_reason is "stop", this indicates non-auto-executable tools that require user approval. + // Don't return true even if tool calls are present, as the agent loop should not process them. + if choice.FinishReason != nil && *choice.FinishReason == "stop" { + return false + } + + // Check finish reason + if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { + return true + } + + // Check if message has tool calls + if choice.ChatNonStreamResponseChoice != nil && + choice.ChatNonStreamResponseChoice.Message != nil && + choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil && + len(choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls) > 0 { + return true + } + + return false +} + +func hasToolCallsForResponsesResponse(response *schemas.BifrostResponsesResponse) bool { + if response == nil || len(response.Output) == 0 { + return false + } + + // Check if any output message is a tool call + for _, output := range response.Output { + if output.Type == nil { + continue + } + + // Check for tool call types + switch *output.Type { + case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeCustomToolCall: + // Verify that ResponsesToolMessage is actually set + if output.ResponsesToolMessage != nil { + return true + } + } + } + + return false +} + +// stripClientPrefix removes the client name prefix from a tool name. +// Tool names are stored with format "{clientName}_{toolName}", but when calling +// the MCP server, we need the original tool name without the prefix. +// +// Parameters: +// - prefixedToolName: Tool name with client prefix (e.g., "calculator_add") +// - clientName: Client name to strip (e.g., "calculator") +// +// Returns: +// - string: Original tool name without prefix (e.g., "add") +func stripClientPrefix(prefixedToolName, clientName string) string { + prefix := clientName + "_" + if strings.HasPrefix(prefixedToolName, prefix) { + return strings.TrimPrefix(prefixedToolName, prefix) + } + // If prefix doesn't match, return as-is (shouldn't happen, but be safe) + return prefixedToolName +} diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index fdd82cc0ea..8a64309616 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -1820,6 +1820,8 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, key schemas.K return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) } + //TODO: add HandleProviderResponse here + // Parse raw response for RawResponse field var rawResponse interface{} if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index c7a9346339..b2b8c15e94 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -106,17 +106,18 @@ const ( BifrostContextKeyRequestID BifrostContextKey = "request-id" // string BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct - BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost)) - BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost)) - BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost)) - BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost)) 0 for primary, 1 for first fallback, etc. - BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost) + BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. + BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string]string BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool - BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost) + BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostMCPAgentOriginalRequestID BifrostContextKey = "bifrost-mcp-agent-original-request-id" // string (to store the original request ID for MCP agent mode) ) // NOTE: for custom plugin implementation dealing with streaming short circuit, @@ -358,7 +359,7 @@ type BifrostError struct { Error *ErrorField `json:"error"` AllowFallbacks *bool `json:"-"` // Optional: Controls fallback behavior (nil = true by default) StreamControl *StreamControl `json:"-"` // Optional: Controls stream behavior - ExtraFields BifrostErrorExtraFields `json:"extra_fields,omitempty"` + ExtraFields BifrostErrorExtraFields `json:"extra_fields"` } // StreamControl represents stream control options. diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index e26409e122..53ae358eb0 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -1,32 +1,60 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -// MCPServerInstance represents an MCP server instance for InProcess connections. -// This should be a *github.com/mark3labs/mcp-go/server.MCPServer instance. -// We use interface{} to avoid creating a dependency on the mcp-go package in schemas. -type MCPServerInstance interface{} +import ( + "context" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/server" +) // MCPConfig represents the configuration for MCP integration in Bifrost. // It enables tool auto-discovery and execution from local and external MCP servers. type MCPConfig struct { - ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations + ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations + ToolManagerConfig *MCPToolManagerConfig `json:"tool_manager_config,omitempty"` // MCP tool manager configuration + + // Function to fetch a new request ID for each tool call result message in agent mode, + // this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user. + // This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode. + // If not provider, same request ID is used for all tool call result messages without any overrides. + FetchNewRequestIDFunc func(ctx context.Context) string `json:"-"` } +type MCPToolManagerConfig struct { + ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` + MaxAgentDepth int `json:"max_agent_depth"` +} + +const ( + DefaultMaxAgentDepth = 10 + DefaultToolExecutionTimeout = 30 * time.Second +) + // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { ID string `json:"id"` // Client ID Name string `json:"name"` // Client name + IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) ConnectionString *string `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) Headers map[string]string `json:"headers,omitempty"` // Headers to send with the request - InProcessServer MCPServerInstance `json:"-"` // MCP server instance for in-process connections (Go package only) + InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Include-only list. // ToolsToExecute semantics: // - ["*"] => all tools are included // - [] => no tools are included (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => include only the specified tools + ToolsToAutoExecute []string `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. + // ToolsToAutoExecute semantics: + // - ["*"] => all tools are auto-executed + // - [] => no tools are auto-executed (deny-by-default) + // - nil/omitted => treated as [] (no tools) + // - ["tool1", "tool2"] => auto-execute only the specified tools + // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. } // MCPConnectionType defines the communication protocol for MCP connections @@ -54,9 +82,27 @@ const ( MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used ) +// MCPClientState represents a connected MCP client with its configuration and tools. +// It is used internally by the MCP manager to track the state of a connected MCP client. +type MCPClientState struct { + Name string // Unique name for this client + Conn *client.Client // Active MCP client connection + ExecutionConfig MCPClientConfig // Tool filtering settings + ToolMap map[string]ChatTool // Available tools mapped by name + ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management + CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) +} + +// MCPClientConnectionInfo stores metadata about how a client is connected. +type MCPClientConnectionInfo struct { + Type MCPConnectionType `json:"type"` // Connection type (HTTP, STDIO, SSE, or InProcess) + ConnectionURL *string `json:"connection_url,omitempty"` // HTTP/SSE endpoint URL (for HTTP/SSE connections) + StdioCommandString *string `json:"stdio_command_string,omitempty"` // Command string for display (for STDIO connections) +} + // MCPClient represents a connected MCP client with its configuration and tools, // and connection information, after it has been initialized. -// It is returned by GetMCPClients() method. +// It is returned by GetMCPClients() method in bifrost. type MCPClient struct { Config MCPClientConfig `json:"config"` // Tool filtering settings Tools []ChatToolFunction `json:"tools"` // Available tools diff --git a/core/schemas/mux.go b/core/schemas/mux.go index f90c1c9bed..fd3c4d252c 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -324,14 +324,17 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { role = ResponsesInputMessageRoleSystem case ChatMessageRoleTool: messageType = ResponsesMessageTypeFunctionCallOutput - role = ResponsesInputMessageRoleUser // Tool messages are typically user role in responses + role = "" // tool call output messages don't include a role field case ChatMessageRoleDeveloper: role = ResponsesInputMessageRoleDeveloper } rm := ResponsesMessage{ Type: &messageType, - Role: &role, + } + + if role != "" { + rm.Role = &role } // Handle refusal content specifically - use content blocks with ResponsesOutputMessageContentRefusal @@ -347,7 +350,10 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { } } else if cm.Content != nil && cm.Content.ContentStr != nil { // Convert regular string content (if input message then ContentStr, else ContentBlocks) - if cm.Role == ChatMessageRoleAssistant { + // Skip setting content for function_call_output - content should only be in output field + if messageType == ResponsesMessageTypeFunctionCallOutput { + // Don't set content for function_call_output - it will be set in ResponsesToolMessage.Output + } else if cm.Role == ChatMessageRoleAssistant { rm.Content = &ResponsesMessageContent{ ContentBlocks: []ResponsesMessageContentBlock{ {Type: ResponsesOutputMessageContentTypeText, Text: cm.Content.ContentStr}, @@ -360,57 +366,62 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { } } else if cm.Content != nil && cm.Content.ContentBlocks != nil { // Convert content blocks - responseBlocks := make([]ResponsesMessageContentBlock, len(cm.Content.ContentBlocks)) - for i, block := range cm.Content.ContentBlocks { - blockType := ResponsesMessageContentBlockType(block.Type) - - switch block.Type { - case ChatContentBlockTypeText: - if cm.Role == ChatMessageRoleAssistant { - blockType = ResponsesOutputMessageContentTypeText - } else { - blockType = ResponsesInputMessageContentBlockTypeText + // Skip setting content blocks for function_call_output + if messageType == ResponsesMessageTypeFunctionCallOutput { + // Don't set content for function_call_output - it will be set in ResponsesToolMessage.Output + } else { + responseBlocks := make([]ResponsesMessageContentBlock, len(cm.Content.ContentBlocks)) + for i, block := range cm.Content.ContentBlocks { + blockType := ResponsesMessageContentBlockType(block.Type) + + switch block.Type { + case ChatContentBlockTypeText: + if cm.Role == ChatMessageRoleAssistant { + blockType = ResponsesOutputMessageContentTypeText + } else { + blockType = ResponsesInputMessageContentBlockTypeText + } + case ChatContentBlockTypeImage: + blockType = ResponsesInputMessageContentBlockTypeImage + case ChatContentBlockTypeFile: + blockType = ResponsesInputMessageContentBlockTypeFile + case ChatContentBlockTypeInputAudio: + blockType = ResponsesInputMessageContentBlockTypeAudio } - case ChatContentBlockTypeImage: - blockType = ResponsesInputMessageContentBlockTypeImage - case ChatContentBlockTypeFile: - blockType = ResponsesInputMessageContentBlockTypeFile - case ChatContentBlockTypeInputAudio: - blockType = ResponsesInputMessageContentBlockTypeAudio - } - responseBlocks[i] = ResponsesMessageContentBlock{ - Type: blockType, - Text: block.Text, - } - - // Convert specific block types - if block.ImageURLStruct != nil { - responseBlocks[i].ResponsesInputMessageContentBlockImage = &ResponsesInputMessageContentBlockImage{ - ImageURL: &block.ImageURLStruct.URL, - Detail: block.ImageURLStruct.Detail, + responseBlocks[i] = ResponsesMessageContentBlock{ + Type: blockType, + Text: block.Text, } - } - if block.File != nil { - responseBlocks[i].ResponsesInputMessageContentBlockFile = &ResponsesInputMessageContentBlockFile{ - FileData: block.File.FileData, - Filename: block.File.Filename, + + // Convert specific block types + if block.ImageURLStruct != nil { + responseBlocks[i].ResponsesInputMessageContentBlockImage = &ResponsesInputMessageContentBlockImage{ + ImageURL: &block.ImageURLStruct.URL, + Detail: block.ImageURLStruct.Detail, + } } - responseBlocks[i].FileID = block.File.FileID - } - if block.InputAudio != nil { - format := "" - if block.InputAudio.Format != nil { - format = *block.InputAudio.Format + if block.File != nil { + responseBlocks[i].ResponsesInputMessageContentBlockFile = &ResponsesInputMessageContentBlockFile{ + FileData: block.File.FileData, + Filename: block.File.Filename, + } + responseBlocks[i].FileID = block.File.FileID } - responseBlocks[i].Audio = &ResponsesInputMessageContentBlockAudio{ - Data: block.InputAudio.Data, - Format: format, + if block.InputAudio != nil { + format := "" + if block.InputAudio.Format != nil { + format = *block.InputAudio.Format + } + responseBlocks[i].Audio = &ResponsesInputMessageContentBlockAudio{ + Data: block.InputAudio.Data, + Format: format, + } } } - } - rm.Content = &ResponsesMessageContent{ - ContentBlocks: responseBlocks, + rm.Content = &ResponsesMessageContent{ + ContentBlocks: responseBlocks, + } } } @@ -422,9 +433,18 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { } // If tool output content exists, add it to function_call_output - if rm.Content != nil && rm.Content.ContentStr != nil && *rm.Content.ContentStr != "" { + // For function_call_output, get content from cm.Content since rm.Content is not set + var outputContent *string + if messageType == ResponsesMessageTypeFunctionCallOutput { + // Get content directly from ChatMessage for function_call_output + if cm.Content != nil && cm.Content.ContentStr != nil && *cm.Content.ContentStr != "" { + outputContent = cm.Content.ContentStr + } + } + + if outputContent != nil { rm.ResponsesToolMessage.Output = &ResponsesToolMessageOutputStruct{ - ResponsesToolCallOutputStr: rm.Content.ContentStr, + ResponsesToolCallOutputStr: outputContent, } } } diff --git a/core/schemas/utils.go b/core/schemas/utils.go index 406edb49c0..ab680ba440 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -702,6 +702,91 @@ func deepCopyChatContentBlock(original ChatContentBlock) ChatContentBlock { return copy } +// DeepCopyChatTool creates a deep copy of a ChatTool +// to prevent shared data mutation between different plugin accumulators +func DeepCopyChatTool(original ChatTool) ChatTool { + copyTool := ChatTool{ + Type: original.Type, + } + + // Deep copy Function if present + if original.Function != nil { + copyTool.Function = &ChatToolFunction{ + Name: original.Function.Name, + } + + if original.Function.Description != nil { + copyDescription := *original.Function.Description + copyTool.Function.Description = ©Description + } + + if original.Function.Parameters != nil { + copyParams := &ToolFunctionParameters{ + Type: original.Function.Parameters.Type, + } + + if original.Function.Parameters.Description != nil { + copyParamDesc := *original.Function.Parameters.Description + copyParams.Description = ©ParamDesc + } + + if original.Function.Parameters.Required != nil { + copyParams.Required = make([]string, len(original.Function.Parameters.Required)) + copy(copyParams.Required, original.Function.Parameters.Required) + } + + if original.Function.Parameters.Properties != nil { + // Deep copy the map + copyProps := make(map[string]interface{}, len(*original.Function.Parameters.Properties)) + for k, v := range *original.Function.Parameters.Properties { + copyProps[k] = DeepCopy(v) + } + copyParams.Properties = ©Props + } + + if original.Function.Parameters.Enum != nil { + copyParams.Enum = make([]string, len(original.Function.Parameters.Enum)) + copy(copyParams.Enum, original.Function.Parameters.Enum) + } + + if original.Function.Parameters.AdditionalProperties != nil { + copyAdditionalProps := *original.Function.Parameters.AdditionalProperties + copyParams.AdditionalProperties = ©AdditionalProps + } + + copyTool.Function.Parameters = copyParams + } + + if original.Function.Strict != nil { + copyStrict := *original.Function.Strict + copyTool.Function.Strict = ©Strict + } + } + + // Deep copy Custom if present + if original.Custom != nil { + copyTool.Custom = &ChatToolCustom{} + + if original.Custom.Format != nil { + copyFormat := &ChatToolCustomFormat{ + Type: original.Custom.Format.Type, + } + + if original.Custom.Format.Grammar != nil { + copyGrammar := &ChatToolCustomGrammarFormat{ + Definition: original.Custom.Format.Grammar.Definition, + Syntax: original.Custom.Format.Grammar.Syntax, + } + copyFormat.Grammar = copyGrammar + } + + copyTool.Custom.Format = copyFormat + } + } + + return copyTool +} + // DeepCopyResponsesMessage creates a deep copy of a ResponsesMessage // to prevent shared data mutation between different plugin accumulators func DeepCopyResponsesMessage(original ResponsesMessage) ResponsesMessage { diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index c89d5a393f..3b9469378c 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -28,18 +28,20 @@ type EnvKeyInfo struct { // ClientConfig represents the core configuration for Bifrost HTTP transport and the Bifrost Client. // It includes settings for excess request handling, Prometheus metrics, and initial pool size. type ClientConfig struct { - DropExcessRequests bool `json:"drop_excess_requests"` // Drop excess requests if the provider queue is full - InitialPoolSize int `json:"initial_pool_size"` // The initial pool size for the bifrost client - PrometheusLabels []string `json:"prometheus_labels"` // The labels to be used for prometheus metrics + DropExcessRequests bool `json:"drop_excess_requests"` // Drop excess requests if the provider queue is full + InitialPoolSize int `json:"initial_pool_size"` // The initial pool size for the bifrost client + PrometheusLabels []string `json:"prometheus_labels"` // The labels to be used for prometheus metrics EnableLogging bool `json:"enable_logging"` // Enable logging of requests and responses DisableContentLogging bool `json:"disable_content_logging"` // Disable logging of content LogRetentionDays int `json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) - EnableGovernance bool `json:"enable_governance"` // Enable governance on all requests - EnforceGovernanceHeader bool `json:"enforce_governance_header"` // Enforce governance on all requests - AllowDirectKeys bool `json:"allow_direct_keys"` // Allow direct keys to be used for requests - AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) - MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB - EnableLiteLLMFallbacks bool `json:"enable_litellm_fallbacks"` // Enable litellm-specific fallbacks for text completion for Groq + EnableGovernance bool `json:"enable_governance"` // Enable governance on all requests + EnforceGovernanceHeader bool `json:"enforce_governance_header"` // Enforce governance on all requests + AllowDirectKeys bool `json:"allow_direct_keys"` // Allow direct keys to be used for requests + AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) + MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB + EnableLiteLLMFallbacks bool `json:"enable_litellm_fallbacks"` // Enable litellm-specific fallbacks for text completion for Groq + MCPAgentDepth int `json:"mcp_agent_depth"` // The maximum depth for MCP agent mode tool execution + MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds } // ProviderConfig represents the configuration for a specific AI model provider. @@ -55,10 +57,10 @@ type ProviderConfig struct { // AuthConfig represents configured auth config for Bifrost dashboard type AuthConfig struct { - AdminUserName string `json:"admin_username"` - AdminPassword string `json:"admin_password"` - IsEnabled bool `json:"is_enabled"` - DisableAuthOnInference bool `json:"disable_auth_on_inference"` + AdminUserName string `json:"admin_username"` + AdminPassword string `json:"admin_password"` + IsEnabled bool `json:"is_enabled"` + DisableAuthOnInference bool `json:"disable_auth_on_inference"` } // ConfigMap maps provider names to their configurations. diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 2474f98007..724860305f 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -3,7 +3,10 @@ package configstore import ( "context" "fmt" + "log" "strconv" + "strings" + "unicode" "github.com/google/uuid" "github.com/maximhq/bifrost/framework/configstore/tables" @@ -76,12 +79,24 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationMissingProviderColumnInKeyTable(ctx, db); err != nil { return err } + if err := migrationAddToolsToAutoExecuteJSONColumn(ctx, db); err != nil { + return err + } + if err := migrationAddIsCodeModeClientColumn(ctx, db); err != nil { + return err + } if err := migrationAddLogRetentionDaysColumn(ctx, db); err != nil { return err } if err := migrationAddBatchAndCachePricingColumns(ctx, db); err != nil { return err } + if err := migrationAddMCPAgentDepthAndMCPToolExecutionTimeoutColumns(ctx, db); err != nil { + return err + } + if err := migrationNormalizeMCPClientNames(ctx, db); err != nil { + return err + } if err := migrationMoveKeysToProviderConfig(ctx, db); err != nil { return err } @@ -1043,6 +1058,74 @@ func migrationMissingProviderColumnInKeyTable(ctx context.Context, db *gorm.DB) return nil } +// migrationAddToolsToAutoExecuteJSONColumn adds the tools_to_auto_execute_json column to the mcp_client table +func migrationAddToolsToAutoExecuteJSONColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_tools_to_auto_execute_json_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "tools_to_auto_execute_json") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "tools_to_auto_execute_json"); err != nil { + return err + } + // Initialize existing rows with empty array + if err := tx.Exec("UPDATE config_mcp_clients SET tools_to_auto_execute_json = '[]' WHERE tools_to_auto_execute_json IS NULL OR tools_to_auto_execute_json = ''").Error; err != nil { + return fmt.Errorf("failed to initialize tools_to_auto_execute_json: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableMCPClient{}, "tools_to_auto_execute_json"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +// migrationAddIsCodeModeClientColumn adds the is_code_mode_client column to the config_mcp_clients table +func migrationAddIsCodeModeClientColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_is_code_mode_client_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "is_code_mode_client") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "is_code_mode_client"); err != nil { + return err + } + // Initialize existing rows with false (default value) + if err := tx.Exec("UPDATE config_mcp_clients SET is_code_mode_client = false WHERE is_code_mode_client IS NULL").Error; err != nil { + return fmt.Errorf("failed to initialize is_code_mode_client: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableMCPClient{}, "is_code_mode_client"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + // migrationAddLogRetentionDaysColumn adds the log_retention_days column to the client config table func migrationAddLogRetentionDaysColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ @@ -1123,6 +1206,176 @@ func migrationAddBatchAndCachePricingColumns(ctx context.Context, db *gorm.DB) e return m.Migrate() } +func migrationAddMCPAgentDepthAndMCPToolExecutionTimeoutColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_agent_depth_and_mcp_tool_execution_timeout_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableClientConfig{}, "mcp_agent_depth") { + if err := migrator.AddColumn(&tables.TableClientConfig{}, "mcp_agent_depth"); err != nil { + return err + } + } + if !migrator.HasColumn(&tables.TableClientConfig{}, "mcp_tool_execution_timeout") { + if err := migrator.AddColumn(&tables.TableClientConfig{}, "mcp_tool_execution_timeout"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableClientConfig{}, "mcp_agent_depth"); err != nil { + return err + } + if err := migrator.DropColumn(&tables.TableClientConfig{}, "mcp_tool_execution_timeout"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +// normalizeMCPClientName normalizes an MCP client name by: +// 1. Replacing hyphens and spaces with underscores +// 2. Removing leading digits +// 3. Using a default name if the result is empty +func normalizeMCPClientName(name string) string { + // Replace hyphens and spaces with underscores + normalized := strings.ReplaceAll(name, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + + // Remove leading digits + normalized = strings.TrimLeftFunc(normalized, func(r rune) bool { + return unicode.IsDigit(r) + }) + + // If name becomes empty after normalization, use a default name + if normalized == "" { + normalized = "mcp_client" + } + + return normalized +} + +// migrationNormalizeMCPClientNames normalizes MCP client names by: +// 1. Replacing hyphens and spaces with underscores +// 2. Removing leading digits +// 3. Adding number suffix if name already exists +func migrationNormalizeMCPClientNames(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "normalize_mcp_client_names", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + + // Fetch all MCP clients + var mcpClients []tables.TableMCPClient + if err := tx.Find(&mcpClients).Error; err != nil { + return fmt.Errorf("failed to fetch MCP clients: %w", err) + } + + // Track assigned names in memory to avoid transaction visibility issues + // and ensure we see all updates made during this migration + assignedNames := make(map[string]bool) + + // Helper function to find a unique name + findUniqueName := func(baseName string, originalName string, excludeID uint, tx *gorm.DB, assignedNames map[string]bool) (string, error) { + // First check if base name is already assigned in this migration + if !assignedNames[baseName] { + // Also check database for existing names (excluding current client) + var existing tables.TableMCPClient + err := tx.Where("name = ? AND id != ?", baseName, excludeID).First(&existing).Error + if err == gorm.ErrRecordNotFound { + // Name is available + assignedNames[baseName] = true + // Log normalization even when no collision + if originalName != baseName { + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, baseName) + } + return baseName, nil + } else if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + } + + // Name exists (either assigned in this migration or in database), try with number suffix starting from 2 + // (base name is conceptually "1", so collisions start from "2") + suffix := 2 + const maxSuffix = 1000 // Safety limit to prevent infinite loops + for { + if suffix > maxSuffix { + return "", fmt.Errorf("could not find unique name after %d attempts for base name: %s", maxSuffix, baseName) + } + candidateName := baseName + strconv.Itoa(suffix) + + // Check both in-memory map and database + if !assignedNames[candidateName] { + var existing tables.TableMCPClient + err := tx.Where("name = ? AND id != ?", candidateName, excludeID).First(&existing).Error + if err == gorm.ErrRecordNotFound { + // Found available name - log the transformation + assignedNames[candidateName] = true + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, candidateName) + return candidateName, nil + } else if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + } + suffix++ + } + } + + // Process each client + for _, client := range mcpClients { + originalName := client.Name + needsUpdate := false + + // Check if name needs normalization + if strings.Contains(originalName, "-") || strings.Contains(originalName, " ") { + needsUpdate = true + } else if len(originalName) > 0 && unicode.IsDigit(rune(originalName[0])) { + needsUpdate = true + } + + if needsUpdate { + // Normalize the name + normalizedName := normalizeMCPClientName(originalName) + + // Find a unique name (pass assignedNames map to track names in this migration) + uniqueName, err := findUniqueName(normalizedName, originalName, client.ID, tx, assignedNames) + if err != nil { + return fmt.Errorf("failed to find unique name for client %d (original: %s): %w", client.ID, originalName, err) + } + + // Update the client name + if err := tx.Model(&client).Update("name", uniqueName).Error; err != nil { + return fmt.Errorf("failed to update MCP client %d name from %s to %s: %w", client.ID, originalName, uniqueName, err) + } + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + // Rollback is not possible as we don't store the original names + // This migration is one-way + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running MCP client name normalization migration: %s", err.Error()) + } + return nil +} + // migrationMoveKeysToProviderConfig migrates keys from virtual key level to provider config level func migrationMoveKeysToProviderConfig(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ @@ -1131,18 +1384,18 @@ func migrationMoveKeysToProviderConfig(ctx context.Context, db *gorm.DB) error { tx = tx.WithContext(ctx) gormMigrator := tx.Migrator() - // Step 1: Create the new join table for provider config -> keys relationship - // Setup the join table so GORM knows about the custom structure - if err := tx.SetupJoinTable(&tables.TableVirtualKeyProviderConfig{}, "Keys", &tables.TableVirtualKeyProviderConfigKey{}); err != nil { - return fmt.Errorf("failed to setup join table for provider config keys: %w", err) - } + // Step 1: Create the new join table for provider config -> keys relationship + // Setup the join table so GORM knows about the custom structure + if err := tx.SetupJoinTable(&tables.TableVirtualKeyProviderConfig{}, "Keys", &tables.TableVirtualKeyProviderConfigKey{}); err != nil { + return fmt.Errorf("failed to setup join table for provider config keys: %w", err) + } - // Create the join table if it doesn't exist - if !gormMigrator.HasTable(&tables.TableVirtualKeyProviderConfigKey{}) { - if err := gormMigrator.CreateTable(&tables.TableVirtualKeyProviderConfigKey{}); err != nil { - return fmt.Errorf("failed to create join table for provider config keys: %w", err) + // Create the join table if it doesn't exist + if !gormMigrator.HasTable(&tables.TableVirtualKeyProviderConfigKey{}) { + if err := gormMigrator.CreateTable(&tables.TableVirtualKeyProviderConfigKey{}); err != nil { + return fmt.Errorf("failed to create join table for provider config keys: %w", err) + } } - } // Step 2: Migrate existing key associations from virtual key to provider config level // Check if old join table exists diff --git a/framework/configstore/migrations_test.go b/framework/configstore/migrations_test.go new file mode 100644 index 0000000000..cada594b38 --- /dev/null +++ b/framework/configstore/migrations_test.go @@ -0,0 +1,539 @@ +package configstore + +import ( + "bytes" + "context" + "fmt" + "log" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// setupTestDB creates an in-memory SQLite database for testing +func setupTestDB(t *testing.T) *gorm.DB { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + require.NoError(t, err, "Failed to create test database") + + // Create the MCP clients table + err = db.AutoMigrate(&tables.TableMCPClient{}) + require.NoError(t, err, "Failed to migrate test database") + + return db +} + +// captureLogOutput captures log output during a function execution +func captureLogOutput(fn func()) string { + var buf bytes.Buffer + log.SetOutput(&buf) + defer log.SetOutput(os.Stderr) + + fn() + return buf.String() +} + +func TestNormalizeName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "hyphen to underscore", + input: "my-tool", + expected: "my_tool", + }, + { + name: "space to underscore", + input: "my tool", + expected: "my_tool", + }, + { + name: "multiple hyphens", + input: "my-super-tool", + expected: "my_super_tool", + }, + { + name: "multiple spaces", + input: "my super tool", + expected: "my_super_tool", + }, + { + name: "leading digits removed", + input: "123tool", + expected: "tool", + }, + { + name: "leading digits with hyphen", + input: "123my-tool", + expected: "my_tool", + }, + { + name: "empty after normalization", + input: "123", + expected: "mcp_client", + }, + { + name: "no change needed", + input: "my_tool", + expected: "my_tool", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + normalized := normalizeMCPClientName(tt.input) + assert.Equal(t, tt.expected, normalized, "normalizeMCPClientName should produce expected output") + }) + } +} + +func TestFindUniqueName_NoCollision(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create a test client with a unique name + client := &tables.TableMCPClient{ + Name: "existing_client", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + + // Test findUniqueName with a different base name (no collision) + logOutput := captureLogOutput(func() { + uniqueName, err := findUniqueNameForTest("new_client", "new_client", 999, db.WithContext(ctx)) + require.NoError(t, err) + assert.Equal(t, "new_client", uniqueName, "Should return base name when no collision") + }) + + // Should not log anything when there's no collision + assert.Empty(t, logOutput, "Should not log when name is available without suffix") +} + +func TestFindUniqueName_WithCollision(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create existing clients that will cause collisions + // First client with base name + client1 := &tables.TableMCPClient{ + Name: "my_tool", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client1).Error + require.NoError(t, err) + + // Second client with first suffix + client2 := &tables.TableMCPClient{ + Name: "my_tool1", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = db.WithContext(ctx).Create(client2).Error + require.NoError(t, err) + + // Test findUniqueName with collision - should find "my_tool2" + // excludeID is set to a non-existent ID (999) so all existing clients are considered + var uniqueName string + logOutput := captureLogOutput(func() { + uniqueName, err = findUniqueNameForTest("my_tool", "my-tool", 999, db.WithContext(ctx)) + }) + + require.NoError(t, err) + assert.Equal(t, "my_tool2", uniqueName, "Should return name with suffix when collision occurs") + assert.Contains(t, logOutput, "MCP Client Name Normalized: 'my-tool' -> 'my_tool2'", "Should log the transformation") +} + +func TestFindUniqueName_MultipleCollisions(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create existing clients that will cause multiple collisions + client1 := &tables.TableMCPClient{ + Name: "test_tool", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client1).Error + require.NoError(t, err) + + client2 := &tables.TableMCPClient{ + Name: "test_tool1", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = db.WithContext(ctx).Create(client2).Error + require.NoError(t, err) + + client3 := &tables.TableMCPClient{ + Name: "test_tool2", + ClientID: "client-3", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = db.WithContext(ctx).Create(client3).Error + require.NoError(t, err) + + // Test findUniqueName with multiple collisions - should find "test_tool3" + var uniqueName string + logOutput := captureLogOutput(func() { + uniqueName, err = findUniqueNameForTest("test_tool", "test tool", 999, db.WithContext(ctx)) + }) + + require.NoError(t, err) + assert.Equal(t, "test_tool3", uniqueName, "Should return name with correct suffix after multiple collisions") + assert.Contains(t, logOutput, "MCP Client Name Normalized: 'test tool' -> 'test_tool3'", "Should log the transformation") +} + +func TestFindUniqueName_NormalizationAndCollision(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create existing client with normalized name + client := &tables.TableMCPClient{ + Name: "my_tool", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + + // Test that "my-tool" normalizes to "my_tool" and then collides, requiring suffix + var uniqueName string + logOutput := captureLogOutput(func() { + uniqueName, err = findUniqueNameForTest("my_tool", "my-tool", 999, db.WithContext(ctx)) + }) + + require.NoError(t, err) + assert.Equal(t, "my_tool2", uniqueName, "Should handle normalization and collision") + assert.Contains(t, logOutput, "MCP Client Name Normalized: 'my-tool' -> 'my_tool2'", "Should log the full transformation") +} + +func TestFindUniqueName_MultipleNormalizationsToSameBase(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Test case: 3 entries that normalize to the same base name: + // "mcp client" -> "mcp_client" + // "mcp-client" -> "mcp_client" (collision, becomes "mcp_client2") + // "1mcp-client" -> "mcp_client" (collision, becomes "mcp_client3") + // Note: In the actual migration, names are processed sequentially and each checks + // against all previously created names. To simulate this, we need to create clients + // with the original names first, then normalize them in sequence. + + // Helper function to normalize (same logic as in migrations.go) + normalizeName := func(name string) string { + normalized := strings.ReplaceAll(name, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + normalized = strings.TrimLeftFunc(normalized, func(r rune) bool { + return r >= '0' && r <= '9' + }) + if normalized == "" { + normalized = "mcp_client" + } + return normalized + } + + // Create three clients with original names (simulating pre-migration state) + clients := []*tables.TableMCPClient{ + { + Name: "mcp client", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "mcp-client", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "1mcp-client", + ClientID: "client-3", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + } + + for _, client := range clients { + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + } + + // Now simulate the migration: process each client sequentially + // First: "mcp client" -> "mcp_client" (no collision) + client1 := clients[0] + normalizedName1 := normalizeName(client1.Name) + var uniqueName1 string + var err error + logOutput1 := captureLogOutput(func() { + uniqueName1, err = findUniqueNameForTest(normalizedName1, client1.Name, client1.ID, db.WithContext(ctx)) + }) + require.NoError(t, err) + assert.Equal(t, "mcp_client", uniqueName1, "First normalization should use base name") + assert.Empty(t, logOutput1, "Should not log when name is available without suffix") + + // Update first client + err = db.WithContext(ctx).Model(client1).Update("name", uniqueName1).Error + require.NoError(t, err) + + // Second: "mcp-client" -> "mcp_client" (collision with "mcp_client", becomes "mcp_client2") + // Note: We need to check that "mcp_client" exists (from client1), so it should skip to "mcp_client2" + client2 := clients[1] + normalizedName2 := normalizeName(client2.Name) + var uniqueName2 string + logOutput2 := captureLogOutput(func() { + uniqueName2, err = findUniqueNameForTest(normalizedName2, client2.Name, client2.ID, db.WithContext(ctx)) + }) + require.NoError(t, err) + // With the updated implementation, suffixes start from 2 when base name exists + // So "mcp-client" normalizes to "mcp_client" which collides, becomes "mcp_client2" + assert.Equal(t, "mcp_client2", uniqueName2, "Second normalization should get suffix 2 (skipping 1)") + assert.Contains(t, logOutput2, "MCP Client Name Normalized: 'mcp-client' -> 'mcp_client2'", "Should log the transformation") + + // Update second client + err = db.WithContext(ctx).Model(client2).Update("name", uniqueName2).Error + require.NoError(t, err) + + // Third: "1mcp-client" -> "mcp_client" (collision with "mcp_client" and "mcp_client2", becomes "mcp_client3") + client3 := clients[2] + normalizedName3 := normalizeName(client3.Name) + var uniqueName3 string + logOutput3 := captureLogOutput(func() { + uniqueName3, err = findUniqueNameForTest(normalizedName3, client3.Name, client3.ID, db.WithContext(ctx)) + }) + require.NoError(t, err) + // Third normalization finds "mcp_client" and "mcp_client2" exist, so becomes "mcp_client3" + assert.Equal(t, "mcp_client3", uniqueName3, "Third normalization should get suffix 3") + assert.Contains(t, logOutput3, "MCP Client Name Normalized: '1mcp-client' -> 'mcp_client3'", "Should log the transformation") + + // Update third client + err = db.WithContext(ctx).Model(client3).Update("name", uniqueName3).Error + require.NoError(t, err) + + // Final verification: all three should exist with correct names + var finalClients []tables.TableMCPClient + err = db.WithContext(ctx).Find(&finalClients).Error + require.NoError(t, err) + assert.Len(t, finalClients, 3, "Should have all 3 clients") + + names := make([]string, len(finalClients)) + for i, c := range finalClients { + names[i] = c.Name + } + assert.Contains(t, names, "mcp_client", "Should contain mcp_client") + assert.Contains(t, names, "mcp_client2", "Should contain mcp_client2") + assert.Contains(t, names, "mcp_client3", "Should contain mcp_client3") +} + +func TestFindUniqueName_MigrationScenarioWithInMemoryTracking(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // This test simulates the exact migration scenario where clients are processed in a loop + // and we need to track assigned names in memory to avoid transaction visibility issues + + // Create three clients with original names (simulating pre-migration state) + clients := []*tables.TableMCPClient{ + { + Name: "mcp client", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "mcp-client", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "1mcp-client", + ClientID: "client-3", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + } + + for _, client := range clients { + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + } + + // Simulate the migration: process clients in a loop with in-memory tracking + assignedNames := make(map[string]bool) + normalizeName := func(name string) string { + normalized := strings.ReplaceAll(name, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + normalized = strings.TrimLeftFunc(normalized, func(r rune) bool { + return r >= '0' && r <= '9' + }) + if normalized == "" { + normalized = "mcp_client" + } + return normalized + } + + var logOutputs []string + for _, client := range clients { + originalName := client.Name + needsUpdate := strings.Contains(originalName, "-") || strings.Contains(originalName, " ") || + (len(originalName) > 0 && originalName[0] >= '0' && originalName[0] <= '9') + + if needsUpdate { + normalizedName := normalizeName(originalName) + uniqueName, err := findUniqueNameForTestWithTracking(normalizedName, originalName, client.ID, db.WithContext(ctx), assignedNames) + require.NoError(t, err) + + // Capture log output + logOutput := captureLogOutput(func() { + // Log if name changed + if originalName != uniqueName { + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, uniqueName) + } + }) + if logOutput != "" { + logOutputs = append(logOutputs, logOutput) + } + + // Update client + err = db.WithContext(ctx).Model(client).Update("name", uniqueName).Error + require.NoError(t, err) + } + } + + // Verify all three clients have correct names + var finalClients []tables.TableMCPClient + err := db.WithContext(ctx).Find(&finalClients).Error + require.NoError(t, err) + assert.Len(t, finalClients, 3, "Should have all 3 clients") + + names := make([]string, len(finalClients)) + for i, c := range finalClients { + names[i] = c.Name + } + assert.Contains(t, names, "mcp_client", "Should contain mcp_client") + assert.Contains(t, names, "mcp_client2", "Should contain mcp_client2") + assert.Contains(t, names, "mcp_client3", "Should contain mcp_client3") + + // Verify logging: should log all three transformations + allLogs := strings.Join(logOutputs, "") + assert.Contains(t, allLogs, "MCP Client Name Normalized: 'mcp client' -> 'mcp_client'", "Should log first normalization") + assert.Contains(t, allLogs, "MCP Client Name Normalized: 'mcp-client' -> 'mcp_client2'", "Should log second normalization") + assert.Contains(t, allLogs, "MCP Client Name Normalized: '1mcp-client' -> 'mcp_client3'", "Should log third normalization") +} + +// findUniqueNameForTestWithTracking is a test helper that tracks assigned names in memory +func findUniqueNameForTestWithTracking(baseName string, originalName string, excludeID uint, tx *gorm.DB, assignedNames map[string]bool) (string, error) { + // First check if base name is already assigned in this migration + if !assignedNames[baseName] { + // Also check database for existing names (excluding current client) + var count int64 + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", baseName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Name is available + assignedNames[baseName] = true + // Log normalization even when no collision + if originalName != baseName { + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, baseName) + } + return baseName, nil + } + } + + // Name exists (either assigned in this migration or in database), try with number suffix starting from 2 + suffix := 2 + const maxSuffix = 1000 + for { + if suffix > maxSuffix { + return "", fmt.Errorf("could not find unique name after %d attempts for base name: %s", maxSuffix, baseName) + } + candidateName := baseName + strconv.Itoa(suffix) + + // Check both in-memory map and database + if !assignedNames[candidateName] { + var count int64 + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", candidateName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Found available name + assignedNames[candidateName] = true + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, candidateName) + return candidateName, nil + } + } + suffix++ + } +} + +// findUniqueNameForTest is a test helper that extracts the findUniqueName logic +// This mirrors the implementation in migrations.go for testing +func findUniqueNameForTest(baseName string, originalName string, excludeID uint, tx *gorm.DB) (string, error) { + // First, try the base name + var count int64 + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", baseName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Name is available + return baseName, nil + } + + // Name exists, try with number suffix starting from 2 + // (base name is conceptually "1", so collisions start from "2") + suffix := 2 + const maxSuffix = 1000 // Safety limit to prevent infinite loops + for { + if suffix > maxSuffix { + return "", fmt.Errorf("could not find unique name after %d attempts for base name: %s", maxSuffix, baseName) + } + candidateName := baseName + strconv.Itoa(suffix) + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", candidateName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Found available name - log the transformation + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, candidateName) + return candidateName, nil + } + suffix++ + } +} diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index c78acab9db..c9d158f2d7 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strings" + "time" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" @@ -39,6 +40,8 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC AllowedOrigins: config.AllowedOrigins, MaxRequestBodySizeMB: config.MaxRequestBodySizeMB, EnableLiteLLMFallbacks: config.EnableLiteLLMFallbacks, + MCPAgentDepth: config.MCPAgentDepth, + MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, } // Delete existing client config and create new one in a transaction return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { @@ -156,7 +159,6 @@ func (s *RDBConfigStore) UpdateFrameworkConfig(ctx context.Context, config *tabl } return tx.Create(config).Error }) - } // GetFrameworkConfig retrieves the framework configuration from the database. @@ -193,6 +195,8 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er AllowedOrigins: dbConfig.AllowedOrigins, MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, EnableLiteLLMFallbacks: dbConfig.EnableLiteLLMFallbacks, + MCPAgentDepth: dbConfig.MCPAgentDepth, + MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, }, nil } @@ -666,17 +670,39 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, } clientConfigs[i] = schemas.MCPClientConfig{ - ID: dbClient.ClientID, - Name: dbClient.Name, - ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), - ConnectionString: processedConnectionString, - StdioConfig: dbClient.StdioConfig, - ToolsToExecute: dbClient.ToolsToExecute, - Headers: processedHeaders, + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: processedConnectionString, + StdioConfig: dbClient.StdioConfig, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: processedHeaders, + } + } + var clientConfig tables.TableClientConfig + if err := s.db.WithContext(ctx).First(&clientConfig).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + // Return MCP config with default ToolManagerConfig if no client config exists + // This will never happen, but just in case. + return &schemas.MCPConfig{ + ClientConfigs: clientConfigs, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: 30 * time.Second, // default from TableClientConfig + MaxAgentDepth: 10, // default from TableClientConfig + }, + }, nil } + return nil, err + } + toolManagerConfig := schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: time.Duration(clientConfig.MCPToolExecutionTimeout) * time.Second, + MaxAgentDepth: clientConfig.MCPAgentDepth, } return &schemas.MCPConfig{ - ClientConfigs: clientConfigs, + ClientConfigs: clientConfigs, + ToolManagerConfig: &toolManagerConfig, }, nil } @@ -702,17 +728,20 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig } // Substitute environment variables back to their original form - substituteMCPClientEnvVars(&clientConfigCopy, envKeys) + // For create operations, no existing headers to restore from + substituteMCPClientEnvVars(&clientConfigCopy, envKeys, nil) // Create new client dbClient := tables.TableMCPClient{ - ClientID: clientConfigCopy.ID, - Name: clientConfigCopy.Name, - ConnectionType: string(clientConfigCopy.ConnectionType), - ConnectionString: clientConfigCopy.ConnectionString, - StdioConfig: clientConfigCopy.StdioConfig, - ToolsToExecute: clientConfigCopy.ToolsToExecute, - Headers: clientConfigCopy.Headers, + ClientID: clientConfigCopy.ID, + Name: clientConfigCopy.Name, + IsCodeModeClient: clientConfigCopy.IsCodeModeClient, + ConnectionType: string(clientConfigCopy.ConnectionType), + ConnectionString: clientConfigCopy.ConnectionString, + StdioConfig: clientConfigCopy.StdioConfig, + ToolsToExecute: clientConfigCopy.ToolsToExecute, + ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, + Headers: clientConfigCopy.Headers, } if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil { @@ -741,17 +770,20 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c } // Substitute environment variables back to their original form - substituteMCPClientEnvVars(&clientConfigCopy, envKeys) + // Pass existing headers to restore redacted plain values + substituteMCPClientEnvVars(&clientConfigCopy, envKeys, existingClient.Headers) // Update existing client existingClient.Name = clientConfigCopy.Name - existingClient.ConnectionType = string(clientConfigCopy.ConnectionType) - existingClient.ConnectionString = clientConfigCopy.ConnectionString - existingClient.StdioConfig = clientConfigCopy.StdioConfig + existingClient.IsCodeModeClient = clientConfigCopy.IsCodeModeClient existingClient.ToolsToExecute = clientConfigCopy.ToolsToExecute + existingClient.ToolsToAutoExecute = clientConfigCopy.ToolsToAutoExecute existingClient.Headers = clientConfigCopy.Headers - if err := tx.WithContext(ctx).Updates(&existingClient).Error; err != nil { + // Use Select to explicitly include IsCodeModeClient even when it's false (zero value) + // GORM's Updates() skips zero values by default, so we need to explicitly select fields + // Using struct field names - GORM will convert them to column names automatically + if err := tx.WithContext(ctx).Select("name", "is_code_mode_client", "tools_to_execute_json", "tools_to_auto_execute_json", "headers_json", "updated_at").Updates(&existingClient).Error; err != nil { return s.parseGormError(err) } return nil diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index 154af51b60..7e4c412e53 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -15,12 +15,14 @@ type TableClientConfig struct { AllowedOriginsJSON string `gorm:"type:text" json:"-"` // JSON serialized []string InitialPoolSize int `gorm:"default:300" json:"initial_pool_size"` EnableLogging bool `gorm:"" json:"enable_logging"` - DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged - LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) + DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged + LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) EnableGovernance bool `gorm:"" json:"enable_governance"` EnforceGovernanceHeader bool `gorm:"" json:"enforce_governance_header"` AllowDirectKeys bool `gorm:"" json:"allow_direct_keys"` MaxRequestBodySizeMB int `gorm:"default:100" json:"max_request_body_size_mb"` + MCPAgentDepth int `gorm:"default:10" json:"mcp_agent_depth"` + MCPToolExecutionTimeout int `gorm:"default:30" json:"mcp_tool_execution_timeout"` // Timeout for individual tool execution in seconds (default: 30) // LiteLLM fallback flag EnableLiteLLMFallbacks bool `gorm:"column:enable_litellm_fallbacks;default:false" json:"enable_litellm_fallbacks"` @@ -29,7 +31,7 @@ type TableClientConfig struct { // Virtual fields for runtime use (not stored in DB) PrometheusLabels []string `gorm:"-" json:"prometheus_labels"` - AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"` + AllowedOrigins []string `gorm:"-" json:"allowed_origins,omitempty"` } // TableName sets the table name for each model diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index 17a6f4f9dc..74af59d763 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -10,21 +10,24 @@ import ( // TableMCPClient represents an MCP client configuration in the database type TableMCPClient struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. - ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` - Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` - ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType - ConnectionString *string `gorm:"type:text" json:"connection_string,omitempty"` - StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig - ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - CreatedAt time.Time `gorm:"index;not null" json:"created_at"` - UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. + ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` + Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` + IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client + ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType + ConnectionString *string `gorm:"type:text" json:"connection_string,omitempty"` + StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig + ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` // Virtual fields for runtime use (not stored in DB) - StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` - ToolsToExecute []string `gorm:"-" json:"tools_to_execute"` - Headers map[string]string `gorm:"-" json:"headers"` + StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` + ToolsToExecute []string `gorm:"-" json:"tools_to_execute"` + ToolsToAutoExecute []string `gorm:"-" json:"tools_to_auto_execute"` + Headers map[string]string `gorm:"-" json:"headers"` } // TableName sets the table name for each model @@ -52,6 +55,16 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { c.ToolsToExecuteJSON = "[]" } + if c.ToolsToAutoExecute != nil { + data, err := json.Marshal(c.ToolsToAutoExecute) + if err != nil { + return err + } + c.ToolsToAutoExecuteJSON = string(data) + } else { + c.ToolsToAutoExecuteJSON = "[]" + } + if c.Headers != nil { data, err := json.Marshal(c.Headers) if err != nil { @@ -81,6 +94,12 @@ func (c *TableMCPClient) AfterFind(tx *gorm.DB) error { } } + if c.ToolsToAutoExecuteJSON != "" { + if err := json.Unmarshal([]byte(c.ToolsToAutoExecuteJSON), &c.ToolsToAutoExecute); err != nil { + return err + } + } + if c.HeadersJSON != "" { if err := json.Unmarshal([]byte(c.HeadersJSON), &c.Headers); err != nil { return err diff --git a/framework/configstore/utils.go b/framework/configstore/utils.go index 78f0e133eb..24f1f636cc 100644 --- a/framework/configstore/utils.go +++ b/framework/configstore/utils.go @@ -183,32 +183,59 @@ func substituteMCPEnvVars(config *schemas.MCPConfig, envKeys map[string][]EnvKey } // substituteMCPClientEnvVars replaces resolved environment variable values with their original env.VAR_NAME references for a single MCP client config -func substituteMCPClientEnvVars(clientConfig *schemas.MCPClientConfig, envKeys map[string][]EnvKeyInfo) { +// If existingHeaders is provided, it will restore redacted plain header values from the existing headers before substitution +func substituteMCPClientEnvVars(clientConfig *schemas.MCPClientConfig, envKeys map[string][]EnvKeyInfo, existingHeaders map[string]string) { + // First, restore redacted plain header values from existing headers if provided + // This handles the case where UI sends redacted headers that aren't env vars + if existingHeaders != nil && clientConfig.Headers != nil { + for header, value := range clientConfig.Headers { + // Check if the value is redacted (contains **** pattern) and not an env var + if strings.Contains(value, "****") && !strings.HasPrefix(value, "env.") { + // If header exists in existing headers and wasn't an env var, restore it + if oldHeaderValue, exists := existingHeaders[header]; exists { + if !strings.HasPrefix(oldHeaderValue, "env.") { + clientConfig.Headers[header] = oldHeaderValue + } + } + } + } + } + // Find the environment variable for this client's connection string and headers for envVar, keyInfos := range envKeys { for _, keyInfo := range keyInfos { // For MCP connection strings if keyInfo.KeyType == "connection_string" { - // Extract client name from config path like "mcp.client_configs.clientName.connection_string" + // Extract client ID from config path like "mcp.client_configs.clientID.connection_string" pathParts := strings.Split(keyInfo.ConfigPath, ".") if len(pathParts) >= 3 && pathParts[0] == "mcp" && pathParts[1] == "client_configs" { - clientName := pathParts[2] - // If this environment variable is for the current client - if clientName == clientConfig.Name && clientConfig.ConnectionString != nil { + clientID := pathParts[2] + // If this environment variable is for the current client (match by ID) + if clientID == clientConfig.ID && clientConfig.ConnectionString != nil { clientConfig.ConnectionString = &[]string{fmt.Sprintf("env.%s", envVar)}[0] } } } // For MCP headers if keyInfo.KeyType == "mcp_header" { - // Extract client name and header name from config path like "mcp.client_configs.clientName.headers.headerName" + // Extract client ID and header name from config path like "mcp.client_configs.clientID.headers.headerName" pathParts := strings.Split(keyInfo.ConfigPath, ".") if len(pathParts) >= 5 && pathParts[0] == "mcp" && pathParts[1] == "client_configs" && pathParts[3] == "headers" { - clientName := pathParts[2] + clientID := pathParts[2] headerName := pathParts[4] - // If this environment variable is for the current client - if clientName == clientConfig.Name && clientConfig.Headers != nil { - clientConfig.Headers[headerName] = fmt.Sprintf("env.%s", envVar) + // If this environment variable is for the current client (match by ID) + if clientID == clientConfig.ID && clientConfig.Headers != nil { + if headerValue, exists := clientConfig.Headers[headerName]; exists { + // If it's already in env.VAR format, update to use the correct env var + if strings.HasPrefix(headerValue, "env.") { + clientConfig.Headers[headerName] = fmt.Sprintf("env.%s", envVar) + } else if strings.Contains(headerValue, "****") { + // If it's redacted (contains ****), restore to env.VAR format + // This handles the case where UI sends redacted headers back for env vars + clientConfig.Headers[headerName] = fmt.Sprintf("env.%s", envVar) + } + // If it's a plain value (not env. and not redacted), leave it as-is + } } } } diff --git a/framework/logstore/rdb.go b/framework/logstore/rdb.go index 316641a74d..fbd7d8ddfd 100644 --- a/framework/logstore/rdb.go +++ b/framework/logstore/rdb.go @@ -257,7 +257,7 @@ func (s *RDBLogStore) HasLogs(ctx context.Context) (bool, error) { return false, nil } return false, err - } + } return true, nil } diff --git a/tests/core-mcp/README.md b/tests/core-mcp/README.md new file mode 100644 index 0000000000..b2e6745de2 --- /dev/null +++ b/tests/core-mcp/README.md @@ -0,0 +1,230 @@ +# MCP Test Suite + +This directory contains comprehensive tests for the MCP (Model Context Protocol) functionality in Bifrost, covering code mode and non-code mode clients, auto-execute and non-auto-execute tools, and their various combinations. + +## Overview + +The test suite is organized into multiple test files covering different aspects of MCP: + +1. **Client Configuration Tests** (`client_config_test.go`) + - Single and multiple code mode clients + - Single and multiple non-code mode clients + - Mixed code mode + non-code mode clients + - Client connection states + - Client configuration updates + +2. **Tool Execution Tests** (`tool_execution_test.go`) + - Non-code mode tool execution (direct) + - Code mode tool execution (`executeToolCode`) + - Code mode calling code mode client tools + - Code mode calling multiple servers + - `listToolFiles` and `readToolFile` functionality + +3. **Auto-Execute Configuration Tests** (`auto_execute_config_test.go`) + - Tools in `ToolsToExecute` but not in `ToolsToAutoExecute` + - Tools in both lists (auto-execute) + - Tools in `ToolsToAutoExecute` but not in `ToolsToExecute` (should be skipped) + - Wildcard configurations + - Empty and nil configurations + - Mixed auto-execute configurations + +4. **Code Mode Auto-Execute Validation Tests** (`codemode_auto_execute_test.go`) + - `executeToolCode` with code calling only auto-execute tools + - `executeToolCode` with code calling non-auto-execute tools + - `executeToolCode` with code calling mixed auto/non-auto tools + - `executeToolCode` with no tool calls + - `executeToolCode` with `listToolFiles`/`readToolFile` calls + +5. **Agent Mode Tests** (`agent_mode_test.go`) + - Agent mode configuration validation + - Max depth configuration + - Note: Full agent mode flow testing requires LLM integration (see `integration_test.go`) + +6. **Edge Cases & Error Handling** (`edge_cases_test.go`) + - Code mode client calling non-code mode client tool (runtime error) + - Tool not in `ToolsToExecute` (should not be available) + - Tool execution timeout + - Tool execution error propagation + - Empty code execution + - Code with syntax errors + - Code with TypeScript compilation errors + - Code with runtime errors + - Code calling tools with invalid arguments + - Code mode tools always auto-executable + +7. **Integration Tests** (`integration_test.go`) + - Full workflow: `listToolFiles` → `readToolFile` → `executeToolCode` + - Multiple code mode clients with different auto-execute configs + - Tool filtering with code mode + - Code mode and non-code mode tools in same request + - Complex code execution scenarios + - Error handling in code execution + +8. **Basic MCP Connection Tests** (`mcp_connection_test.go`) + - MCP manager initialization + - Local tool registration + - Tool discovery and execution + - Multiple servers + - Tool execution timeout and errors + +## MCP Architecture + +### Client Types + +- **Code Mode Clients** (`IsCodeModeClient=true`): + - Enable code mode tools: `listToolFiles`, `readToolFile`, `executeToolCode` + - Tools accessible via TypeScript code execution in sandboxed VM + - Only code mode clients appear in `listToolFiles` output + +- **Non-Code Mode Clients** (`IsCodeModeClient=false`): + - Tools exposed directly as function-calling tools + - Cannot be called from `executeToolCode` code + +### Tool Execution Modes + +- **Auto-Execute Tools** (`ToolsToAutoExecute`): + - Automatically executed in agent mode without user approval + - Must also be in `ToolsToExecute` list + - For `executeToolCode`: validates all tool calls within code against auto-execute list + +- **Non-Auto-Execute Tools**: + - Require explicit user approval in agent mode + - Agent loop stops and returns these tools for user decision + +### Agent Mode Behavior + +When agent mode receives tool calls: + +- **All auto-execute tools**: Executes all tools, makes new LLM call, continues loop +- **All non-auto-execute tools**: Stops immediately, returns tool calls in `tool_calls` field +- **Mixed scenario** (e.g., 3 auto-execute, 2 non-auto-execute): + - Executes all auto-executable tools (3 in example) + - Adds executed tool results to message content (formatted as JSON) + - Includes non-auto-executable tool calls (2 in example) in `tool_calls` field + - Sets `finish_reason` to "stop" (not "tool_calls") to prevent loop continuation + - Returns immediately without making another LLM call + +Agent mode respects `maxAgentDepth` limit and returns an error if exceeded. + +## Test Structure + +### Setup Files + +- `setup.go` - Test setup utilities for initializing Bifrost and configuring clients + - `setupTestBifrost()` - Basic Bifrost instance + - `setupTestBifrostWithCodeMode()` - Bifrost with code mode enabled + - `setupTestBifrostWithMCPConfig()` - Bifrost with custom MCP config + - `setupCodeModeClient()` - Helper to create code mode client config + - `setupNonCodeModeClient()` - Helper to create non-code mode client config + - `setupClientWithAutoExecute()` - Helper to create client with auto-execute config + - `registerTestTools()` - Registers test tools (echo, add, multiply, etc.) + +- `fixtures.go` - Sample TypeScript code snippets and expected results + - Basic expressions and tool calls + - Auto-execute validation scenarios + - Mixed client scenarios + - Edge case scenarios + +- `utils.go` - Test helper functions for assertions and validation + - `createToolCall()` - Creates tool call messages + - `assertExecutionResult()` - Validates execution results + - `assertAgentModeResponse()` - Validates agent mode response structure + - `extractExecutedToolResults()` - Extracts executed tool results from agent mode response + - `canAutoExecuteTool()` - Checks if a tool can be auto-executed + - `createMCPClientConfig()` - Creates MCP client configs + +## Running Tests + +### Run all tests: +```bash +cd tests/core-mcp +go test -v ./... +``` + +### Run specific test file: +```bash +go test -v -run TestClientConfig ./... +``` + +### Run specific test: +```bash +go test -v -run TestSingleCodeModeClient +``` + +### Run with coverage: +```bash +go test -v -cover ./... +``` + +### Run tests by category: +```bash +# Client configuration tests +go test -v -run "^Test.*Client.*" ./... + +# Tool execution tests +go test -v -run "^Test.*Tool.*" ./... + +# Auto-execute tests +go test -v -run "^Test.*Auto.*" ./... + +# Edge case tests +go test -v -run "^Test.*Error|^Test.*Timeout|^Test.*Empty" ./... + +# Integration tests +go test -v -run "^Test.*Workflow|^Test.*Integration" ./... +``` + +## Test Tools + +The test suite registers several test tools: + +1. **echo** - Simple echo that returns input +2. **add** - Adds two numbers +3. **multiply** - Multiplies two numbers +4. **get_data** - Returns structured data (object/array) +5. **error_tool** - Tool that always returns an error +6. **slow_tool** - Tool that takes time to execute +7. **complex_args_tool** - Tool that accepts complex nested arguments + +## Key Test Scenarios + +### Scenario 1: Mixed Auto-Execute and Non-Auto-Execute Tools (Critical) + +When agent mode receives 5 tool calls: 3 auto-execute, 2 non-auto-execute: +- Agent executes the 3 auto-execute tools +- Adds their results to message content (JSON formatted) +- Includes the 2 non-auto-execute tool calls in `tool_calls` field +- Sets `finish_reason` to "stop" +- Stops immediately (no further LLM call) +- Response structure validated correctly + +### Scenario 2: Code Mode Client + Auto-Execute Tools + +- Setup: Code mode client with tools configured for auto-execute +- Test: `executeToolCode` with code calling these tools should auto-execute in agent mode + +### Scenario 3: Mixed Client Types + +- Setup: One code mode client + one non-code mode client +- Test: Code mode tools only see code mode client, non-code mode tools available separately + +### Scenario 4: Auto-Execute Validation in Code + +- Setup: Code mode client with mixed auto-execute config +- Test: `executeToolCode` validates all tool calls in code against auto-execute list + +### Scenario 5: Code Mode Tools Always Auto-Execute + +- Setup: Code mode enabled +- Test: `listToolFiles` and `readToolFile` always auto-execute regardless of config + +## Notes + +- All tests use a timeout context to prevent hanging +- Tests are designed to be independent and can run in parallel +- The test suite uses the `bifrostInternal` server for local tool registration +- Code mode tests verify that TypeScript code is transpiled and executes correctly in the sandboxed goja VM +- TypeScript compilation errors are caught and reported with helpful hints +- Async/await syntax is automatically transpiled to Promise chains compatible with goja +- Error handling tests verify that helpful error hints are provided for both runtime and TypeScript compilation errors +- Agent mode tests verify the critical mixed auto-execute/non-auto-execute scenario where some tools are executed and others are returned for user approval diff --git a/tests/core-mcp/agent_mode_test.go b/tests/core-mcp/agent_mode_test.go new file mode 100644 index 0000000000..8f3b00453e --- /dev/null +++ b/tests/core-mcp/agent_mode_test.go @@ -0,0 +1,77 @@ +package mcp + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Note: Full agent mode testing requires integration with LLM calls. +// These tests verify the configuration and tool execution aspects that can be tested directly. +// For full agent mode flow testing, see integration_test.go + +// TestAgentModeConfiguration tests the configuration aspects of agent mode +// Full agent mode flow testing requires LLM integration (see integration_test.go) +func TestAgentModeConfiguration(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Test configuration: echo auto-execute, add non-auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // Only echo is auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + + // Verify configuration + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("add", bifrostClient.Config), "add should not be auto-executable") + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should not be auto-executable") +} + +func TestAgentModeMaxDepthConfiguration(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Create Bifrost with max depth of 2 + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 2, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx context.Context) string { + return "test-request-id" + }, + } + b, err := setupTestBifrostWithMCPConfig(ctx, mcpConfig) + require.NoError(t, err) + + // Verify max depth is configured + clients, err := b.GetMCPClients() + require.NoError(t, err) + assert.NotNil(t, clients, "Should have clients") +} diff --git a/tests/core-mcp/auto_execute_config_test.go b/tests/core-mcp/auto_execute_config_test.go new file mode 100644 index 0000000000..a0bff19d53 --- /dev/null +++ b/tests/core-mcp/auto_execute_config_test.go @@ -0,0 +1,322 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToolInToolsToExecuteButNotInToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure echo in ToolsToExecute but not in ToolsToAutoExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo"}, + ToolsToAutoExecute: []string{}, // Empty - no auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Empty(t, bifrostClient.Config.ToolsToAutoExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") +} + +func TestToolInBothToolsToExecuteAndToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure echo in both lists + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo"}, + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Contains(t, bifrostClient.Config.ToolsToAutoExecute, "echo") + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") +} + +func TestToolInToolsToAutoExecuteButNotInToolsToExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure echo in ToolsToAutoExecute but not in ToolsToExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"add"}, // echo not in this list + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + // echo should not be auto-executable because it's not in ToolsToExecute + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (not in ToolsToExecute)") +} + +func TestWildcardInToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure wildcard in ToolsToAutoExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToAutoExecute, "*") + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable with wildcard") + assert.True(t, canAutoExecuteTool("add", bifrostClient.Config), "add should be auto-executable with wildcard") +} + +func TestEmptyToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure empty ToolsToAutoExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{}, // Empty - no auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Empty(t, bifrostClient.Config.ToolsToAutoExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") +} + +func TestNilToolsToAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure nil ToolsToAutoExecute (omitted) + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + // ToolsToAutoExecute omitted (nil) + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + // nil should be treated as empty + if bifrostClient.Config.ToolsToAutoExecute == nil { + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (nil treated as empty)") + } else { + assert.Empty(t, bifrostClient.Config.ToolsToAutoExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") + } +} + +func TestMultipleToolsWithMixedAutoExecuteConfigs(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure mixed: echo auto-execute, add non-auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo", "add", "multiply"}, + ToolsToAutoExecute: []string{"echo", "multiply"}, // add not in auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("add", bifrostClient.Config), "add should not be auto-executable") + assert.True(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should be auto-executable") +} + +func TestToolsToExecuteEmptyList(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure empty ToolsToExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{}, // Empty - no tools allowed + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Empty(t, bifrostClient.Config.ToolsToExecute) + // Even with wildcard in ToolsToAutoExecute, tools not in ToolsToExecute should not be auto-executable + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (not in ToolsToExecute)") +} + +func TestToolsToExecuteNil(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure nil ToolsToExecute (omitted) + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + // ToolsToExecute omitted (nil) + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + // nil ToolsToExecute should be treated as empty + if bifrostClient.Config.ToolsToExecute == nil { + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (nil ToolsToExecute treated as empty)") + } else { + assert.Empty(t, bifrostClient.Config.ToolsToExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") + } +} diff --git a/tests/core-mcp/client_config_test.go b/tests/core-mcp/client_config_test.go new file mode 100644 index 0000000000..7b7b9851d5 --- /dev/null +++ b/tests/core-mcp/client_config_test.go @@ -0,0 +1,346 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSingleCodeModeClient(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // Find bifrostInternal client + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient, "bifrostInternal client should exist") + assert.True(t, bifrostClient.Config.IsCodeModeClient, "bifrostInternal should be code mode client") + assert.Equal(t, schemas.MCPConnectionStateConnected, bifrostClient.State) +} + +func TestSingleNonCodeModeClient(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Note: For in-process clients, we need to register tools first + err = registerTestTools(b) + require.NoError(t, err) + + // Update bifrostInternal to be non-code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.False(t, bifrostClient.Config.IsCodeModeClient, "bifrostInternal should be non-code mode client") +} + +func TestMultipleCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + codeModeCount := 0 + for _, client := range clients { + if client.Config.IsCodeModeClient { + codeModeCount++ + } + } + + assert.GreaterOrEqual(t, codeModeCount, 1, "Should have at least one code mode client") +} + +func TestMultipleNonCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + nonCodeModeCount := 0 + for _, client := range clients { + if !client.Config.IsCodeModeClient { + nonCodeModeCount++ + } + } + + assert.GreaterOrEqual(t, nonCodeModeCount, 1, "Should have at least one non-code mode client") +} + +func TestMixedCodeModeAndNonCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + codeModeCount := 0 + + for _, client := range clients { + if client.Config.IsCodeModeClient { + codeModeCount++ + } + } + + // At minimum, we should have bifrostInternal as code mode + assert.GreaterOrEqual(t, codeModeCount, 1, "Should have at least one code mode client") +} + +func TestClientConnectionStates(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // All clients should be connected + for _, client := range clients { + assert.Equal(t, schemas.MCPConnectionStateConnected, client.State, "Client %s should be connected", client.Config.ID) + } +} + +func TestClientWithNoTools(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Don't register any tools - bifrostInternal client should still exist but with no tools + clients, err := b.GetMCPClients() + require.NoError(t, err) + + // bifrostInternal client is created when MCP is initialized, but won't have tools until registered + // This test verifies the client exists even without tools + assert.NotNil(t, clients, "Clients list should exist") + + // Find bifrostInternal client + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient, "bifrostInternal client should exist") + assert.Empty(t, bifrostClient.Tools, "bifrostInternal client should have no tools") +} + +func TestClientWithEmptyToolLists(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set ToolsToExecute to empty list + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Equal(t, []string{}, bifrostClient.Config.ToolsToExecute, "ToolsToExecute should be empty") +} + +func TestClientConfigUpdate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Initially, bifrostInternal should not be code mode (default) + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + initialIsCodeMode := bifrostClient.Config.IsCodeModeClient + + // Update to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + }) + require.NoError(t, err) + + // Verify update + clients, err = b.GetMCPClients() + require.NoError(t, err) + + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.NotEqual(t, initialIsCodeMode, bifrostClient.Config.IsCodeModeClient, "IsCodeModeClient should have changed") + assert.True(t, bifrostClient.Config.IsCodeModeClient, "Should now be code mode") +} + +func TestClientWithToolsToExecuteWildcard(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set ToolsToExecute to wildcard + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "*", "Should contain wildcard") +} + +func TestClientWithSpecificToolsToExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set ToolsToExecute to specific tools + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo", "add"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "add") + assert.Len(t, bifrostClient.Config.ToolsToExecute, 2) +} diff --git a/tests/core-mcp/codemode_auto_execute_test.go b/tests/core-mcp/codemode_auto_execute_test.go new file mode 100644 index 0000000000..d73c68cc4d --- /dev/null +++ b/tests/core-mcp/codemode_auto_execute_test.go @@ -0,0 +1,233 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +func TestExecuteToolCodeWithAutoExecuteTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Configure echo as auto-execute - preserve existing config + clients, err := b.GetMCPClients() + require.NoError(t, err) + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + require.NotNil(t, currentConfig) + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: currentConfig.IsCodeModeClient, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + // Test executeToolCode with code calling auto-execute tool + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithAutoExecuteTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithNonAutoExecuteTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Configure multiply as non-auto-execute - preserve existing config + clients, err := b.GetMCPClients() + require.NoError(t, err) + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + require.NotNil(t, currentConfig) + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: currentConfig.IsCodeModeClient, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // multiply not in auto-execute + }) + require.NoError(t, err) + + // Test executeToolCode with code calling non-auto-execute tool + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithNonAutoExecuteTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithMixedAutoExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Configure echo as auto-execute, multiply as non-auto-execute - preserve existing config + clients, err := b.GetMCPClients() + require.NoError(t, err) + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + require.NotNil(t, currentConfig) + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: currentConfig.IsCodeModeClient, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // multiply not in auto-execute + }) + require.NoError(t, err) + + // Test executeToolCode with code calling mixed tools + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithMixedAutoExecute, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithNoToolCalls(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode with no tool calls + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithNoToolCalls, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithListToolFiles(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // listToolFiles should always be auto-executable + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithListToolFiles, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // listToolFiles and readToolFile are code mode meta-tools and cannot be called from within executeToolCode + // They're only available as direct tool calls, not from within code execution + // So this will fail with a runtime error + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestExecuteToolCodeWithReadToolFile(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // readToolFile should always be auto-executable + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithReadToolFile, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // listToolFiles and readToolFile are code mode meta-tools and cannot be called from within executeToolCode + // They're only available as direct tool calls, not from within code execution + // So this will fail with a runtime error + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestExecuteToolCodeWithUndefinedServer(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode with undefined server + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithUndefinedServer, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestExecuteToolCodeWithUndefinedTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode with undefined tool + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithUndefinedTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} diff --git a/tests/core-mcp/edge_cases_test.go b/tests/core-mcp/edge_cases_test.go new file mode 100644 index 0000000000..dfc3e780c8 --- /dev/null +++ b/tests/core-mcp/edge_cases_test.go @@ -0,0 +1,299 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCodeModeClientCallingNonCodeModeClientTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code trying to call non-code mode client tool + // This should fail at runtime since non-code mode clients aren't available in code execution + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingNonCodeModeTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error - tool call succeeds but code execution fails + requireNoBifrostError(t, bifrostErr, "Tool call should succeed") + require.NotNil(t, result, "Result should be present") + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestNonCodeModeClientToolCalledFromExecuteToolCode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Code mode can only call code mode client tools + // Non-code mode tools are not available in executeToolCode context + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": `const result = await NonExistentClient.tool({}); return result`, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error - tool call succeeds but code execution fails + requireNoBifrostError(t, bifrostErr, "Tool call should succeed") + require.NotNil(t, result, "Result should be present") + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestToolNotInToolsToExecute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure only echo in ToolsToExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo"}, // add not in list + }) + require.NoError(t, err) + + // Try to execute add tool (not in ToolsToExecute) + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(1), + "b": float64(2), + }) + _, bifrostErr := b.ExecuteMCPTool(ctx, addCall) + + // Should fail - tool not available + assert.NotNil(t, bifrostErr, "Should fail when tool not in ToolsToExecute") +} + +func TestToolExecutionTimeoutEdgeCase(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Test slow tool with timeout + slowCall := createToolCall("slow_tool", map[string]interface{}{ + "delay_ms": float64(100), + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, slowCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Completed", "Should complete execution") +} + +func TestToolExecutionErrorPropagation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Test error tool + errorCall := createToolCall("error_tool", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, errorCall) + + // Tool execution should succeed (no bifrostErr), but result should contain error message + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Error:", "Result should contain error message") + assert.Contains(t, responseText, "this tool always fails", "Result should contain the error text") +} + +func TestEmptyCodeExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.EmptyCode, + }) + + _, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Empty code should return an error + require.NotNil(t, bifrostErr, "Empty code should return an error") + assert.Contains(t, bifrostErr.Error.Message, "code parameter is required", "Error should mention code parameter") +} + +func TestCodeWithSyntaxErrors(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.SyntaxError, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // Syntax errors are caught during JavaScript execution (runtime), not TypeScript compilation + // The error will be a runtime SyntaxError + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestCodeWithTypeScriptCompilationErrors(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Invalid TypeScript code + invalidCode := `const x: string = 123; return x` + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": invalidCode, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // TypeScript type errors might not be caught - the code might execute successfully + // This is acceptable behavior if type checking is disabled + // Just verify the execution completed (either with error or success) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) +} + +func TestCodeWithRuntimeErrors(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.RuntimeError, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestCodeCallingToolsWithInvalidArguments(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Code calling tool with invalid arguments + invalidArgsCode := `const result = await BifrostClient.echo({invalid: "arg"}); return result` + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": invalidArgsCode, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail - tool expects "message" parameter + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "") +} + +func TestCodeModeToolsAlwaysAutoExecutable(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{}, // Empty - no auto-execute configured + }) + require.NoError(t, err) + + // listToolFiles and readToolFile should always be auto-executable + // This is tested in integration tests that verify agent mode behavior + // For now, verify they can be executed directly + listCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) +} + +func TestCommentsOnlyCode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CommentsOnly, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // Comments-only code should execute (return null) + assertExecutionResult(t, result, true, nil, "") +} + +func TestUndefinedVariableError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.UndefinedVariable, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} diff --git a/tests/core-mcp/fixtures.go b/tests/core-mcp/fixtures.go new file mode 100644 index 0000000000..fe8b5a82e5 --- /dev/null +++ b/tests/core-mcp/fixtures.go @@ -0,0 +1,311 @@ +package mcp + +// CodeFixtures contains sample TypeScript code snippets for testing +var CodeFixtures = struct { + // Basic expressions + SimpleExpression string + SimpleString string + VariableAssignment string + ConsoleLogging string + ExplicitReturn string + AutoReturnExpression string + + // MCP tool calls + SingleToolCall string + ToolCallWithPromise string + ToolCallChain string + ToolCallErrorHandling string + MultipleServerToolCalls string + ToolCallWithComplexArgs string + + // Import/Export + ImportStatement string + ExportStatement string + MultipleImportExport string + ImportExportWithComments string + + // Expression analysis + FunctionCallExpression string + PromiseChainExpression string + ObjectLiteralExpression string + AssignmentStatement string + ControlFlowStatement string + TopLevelReturn string + + // Error cases + UndefinedVariable string + UndefinedServer string + UndefinedTool string + SyntaxError string + RuntimeError string + + // Edge cases + NestedPromiseChains string + PromiseErrorHandling string + ComplexDataStructures string + MultiLineExpression string + EmptyCode string + CommentsOnly string + FunctionDefinition string + + // Environment tests + AsyncAwaitTest string + EnvironmentTest string + + // Long code test + LongCodeExecution string + + // Auto-execute validation tests + CodeWithAutoExecuteTool string + CodeWithNonAutoExecuteTool string + CodeWithMixedAutoExecute string + CodeWithMultipleClients string + CodeWithNoToolCalls string + CodeWithListToolFiles string + CodeWithReadToolFile string + + // Mixed client scenarios + CodeCallingCodeModeTool string + CodeCallingNonCodeModeTool string + CodeCallingMultipleServers string + CodeWithUndefinedServer string + CodeWithUndefinedTool string + + // Agent mode scenarios + CodeForAgentModeAutoExecute string + CodeForAgentModeNonAutoExecute string +}{ + SimpleExpression: `return 1 + 1`, + SimpleString: `return "hello"`, + VariableAssignment: `var x = 5; return x`, + ConsoleLogging: `console.log("test"); return "logged"`, + ExplicitReturn: `return 42`, + AutoReturnExpression: `return 2 + 2`, // Note: Now requires explicit return + + SingleToolCall: `const result = await BifrostClient.echo({message: "hello"}); return result`, + ToolCallWithPromise: `const result = await BifrostClient.echo({message: "test"}); console.log(result); return result`, + ToolCallChain: `const result1 = await BifrostClient.add({a: 1, b: 2}); const result2 = await BifrostClient.multiply({a: result1, b: 3}); return result2`, + ToolCallErrorHandling: `try { await BifrostClient.error_tool({}); } catch (err) { console.error(err); return "handled"; }`, + MultipleServerToolCalls: `const r1 = await BifrostClient.echo({message: "test"}); const r2 = await BifrostClient.add({a: 1, b: 2}); return r2`, + ToolCallWithComplexArgs: `return await BifrostClient.complex_args_tool({data: {nested: {value: 42}}})`, + + ImportStatement: `import { something } from "module"; return 1 + 1`, + ExportStatement: `export const x = 5; return x`, + MultipleImportExport: `import a from "a"; import b from "b"; export const c = 1; return 2 + 2`, + ImportExportWithComments: `// comment\nimport x from "x";\n// another comment\nreturn 2 + 2`, + + FunctionCallExpression: `return Math.max(1, 2)`, // Note: Now requires explicit return + PromiseChainExpression: `return Promise.resolve(1).then(x => x + 1)`, // Note: Now requires explicit return + ObjectLiteralExpression: `return {a: 1, b: 2}`, // Note: Now requires explicit return + AssignmentStatement: `var x = 5`, // Assignment statements don't return values + ControlFlowStatement: `if (true) { return 1; } else { return 2; }`, // Note: Now requires explicit return + TopLevelReturn: `return 42`, + + UndefinedVariable: `return undefinedVar`, // Will cause runtime error + UndefinedServer: `return nonexistentServer.tool({})`, // Will cause runtime error + UndefinedTool: `return BifrostClient.nonexistentTool({})`, // Will cause runtime error + SyntaxError: `var x = `, // Syntax error - no return needed + RuntimeError: `return null.someProperty`, // Will cause runtime error + + NestedPromiseChains: `return Promise.resolve(1).then(x => Promise.resolve(x + 1).then(y => y + 1))`, // Note: Now requires explicit return + PromiseErrorHandling: `return Promise.reject("error").catch(err => "handled")`, // Note: Now requires explicit return + ComplexDataStructures: `return [{a: 1}, {b: 2}].map(x => x.a || x.b)`, // Note: Now requires explicit return + MultiLineExpression: `const result = await BifrostClient.echo({message: "test"});\n return result`, // Note: Now requires explicit return + EmptyCode: ``, + CommentsOnly: `// comment\n/* another */`, + FunctionDefinition: `function test() { return 1; } return test()`, // Note: Now requires explicit return for function call + + AsyncAwaitTest: `async function test() { const result = await Promise.resolve(1); return result; } return test()`, + EnvironmentTest: `return __MCP_ENV__.serverKeys`, + + LongCodeExecution: `// Long and complex code execution test with extensive operations\n` + + `(async function() {\n` + + ` var results = [];\n` + + ` var sum = 0;\n` + + ` var processedData = [];\n` + + ` var executionLog = [];\n` + + ` \n` + + ` // Initialize execution context\n` + + ` var context = {\n` + + ` startTime: Date.now(),\n` + + ` steps: 0,\n` + + ` errors: [],\n` + + ` warnings: []\n` + + ` };\n` + + ` \n` + + ` try {\n` + + ` // Step 1: Initial echo call\n` + + ` const result1 = await BifrostClient.echo({message: "step1"});\n` + + ` console.log("Step 1 completed:", result1);\n` + + ` results.push(result1);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 1, action: "echo", result: result1});\n` + + ` \n` + + ` // Step 2: Add operation\n` + + ` const result2 = await BifrostClient.add({a: 10, b: 20});\n` + + ` console.log("Step 2 completed:", result2);\n` + + ` results.push(result2);\n` + + ` sum += result2;\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 2, action: "add", result: result2, sum: sum});\n` + + ` \n` + + ` // Conditional logic based on result\n` + + ` let result3;\n` + + ` if (result2 > 25) {\n` + + ` console.log("Result is greater than 25, proceeding with multiplication");\n` + + ` result3 = await BifrostClient.multiply({a: result2, b: 2});\n` + + ` } else {\n` + + ` console.log("Result is less than or equal to 25, using add again");\n` + + ` result3 = await BifrostClient.add({a: result2, b: 5});\n` + + ` }\n` + + ` console.log("Step 3 completed:", result3);\n` + + ` results.push(result3);\n` + + ` sum += result3;\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 3, action: "math", result: result3, sum: sum});\n` + + ` \n` + + ` // Step 4: Echo call\n` + + ` const result4 = await BifrostClient.echo({message: "step4"});\n` + + ` console.log("Step 4 completed:", result4);\n` + + ` results.push(result4);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 4, action: "echo", result: result4});\n` + + ` \n` + + ` // Complex loop with nested operations\n` + + ` for (var i = 0; i < 20; i++) {\n` + + ` sum += i;\n` + + ` if (i % 3 === 0) {\n` + + ` processedData.push({\n` + + ` index: i,\n` + + ` value: i * 2,\n` + + ` isMultipleOfThree: true\n` + + ` });\n` + + ` } else if (i % 2 === 0) {\n` + + ` processedData.push({\n` + + ` index: i,\n` + + ` value: i * 1.5,\n` + + ` isEven: true\n` + + ` });\n` + + ` } else {\n` + + ` processedData.push({\n` + + ` index: i,\n` + + ` value: i,\n` + + ` isOdd: true\n` + + ` });\n` + + ` }\n` + + ` }\n` + + ` \n` + + ` console.log("Processed", processedData.length, "data items");\n` + + ` \n` + + ` // Step 5: Get data\n` + + ` const result5 = await BifrostClient.get_data({key: "test"});\n` + + ` console.log("Step 5 completed:", result5);\n` + + ` results.push(result5);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 5, action: "get_data", result: result5});\n` + + ` \n` + + ` // Nested data processing\n` + + ` var nestedResults = [];\n` + + ` for (var j = 0; j < results.length; j++) {\n` + + ` var item = results[j];\n` + + ` nestedResults.push({\n` + + ` original: item,\n` + + ` processed: typeof item === "string" ? item.toUpperCase() : item * 1.1,\n` + + ` index: j,\n` + + ` metadata: {\n` + + ` type: typeof item,\n` + + ` isString: typeof item === "string",\n` + + ` isNumber: typeof item === "number"\n` + + ` }\n` + + ` });\n` + + ` }\n` + + ` \n` + + ` // Step 6: Final echo call\n` + + ` const result6 = await BifrostClient.echo({message: "final_step"});\n` + + ` console.log("Step 6 completed:", result6);\n` + + ` results.push(result6);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 6, action: "echo", result: result6});\n` + + ` \n` + + ` // Calculate statistics\n` + + ` var stats = {\n` + + ` totalResults: results.length,\n` + + ` numericSum: sum,\n` + + ` average: sum / results.length,\n` + + ` processedItems: processedData.length,\n` + + ` executionSteps: context.steps\n` + + ` };\n` + + ` \n` + + ` // Create comprehensive final data structure\n` + + ` var finalData = {\n` + + ` results: results,\n` + + ` processedData: processedData,\n` + + ` executionLog: executionLog,\n` + + ` statistics: stats,\n` + + ` context: {\n` + + ` steps: context.steps,\n` + + ` executionTime: Date.now() - context.startTime,\n` + + ` errors: context.errors,\n` + + ` warnings: context.warnings\n` + + ` },\n` + + ` metadata: {\n` + + ` executed: true,\n` + + ` completed: true,\n` + + ` totalOperations: context.steps,\n` + + ` dataProcessed: processedData.length,\n` + + ` finalSum: sum,\n` + + ` resultCount: results.length\n` + + ` }\n` + + ` };\n` + + ` \n` + + ` console.log("Final statistics:", JSON.stringify(stats));\n` + + ` console.log("Execution completed successfully with", context.steps, "steps");\n` + + ` console.log("Processed", processedData.length, "data items");\n` + + ` console.log("Final sum:", sum);\n` + + ` \n` + + ` return finalData;\n` + + ` } catch (error) {\n` + + ` console.error("Error in long execution:", error);\n` + + ` context.errors.push(error.toString());\n` + + ` return {\n` + + ` error: error.toString(),\n` + + ` context: context,\n` + + ` partialResults: results,\n` + + ` partialSum: sum\n` + + ` };\n` + + ` }\n` + + `})()`, + + // Auto-execute validation tests + CodeWithAutoExecuteTool: `const result = await BifrostClient.echo({message: "auto-execute"}); return result`, + CodeWithNonAutoExecuteTool: `const result = await BifrostClient.multiply({a: 2, b: 3}); return result`, + CodeWithMixedAutoExecute: `const r1 = await BifrostClient.echo({message: "auto"}); const r2 = await BifrostClient.multiply({a: 2, b: 3}); return r2`, + CodeWithMultipleClients: `const r1 = await BifrostClient.echo({message: "test"}); const r2 = await Server2.add({a: 1, b: 2}); return r2`, + CodeWithNoToolCalls: `return 42`, + CodeWithListToolFiles: `const files = await BifrostClient.listToolFiles({}); return files`, + CodeWithReadToolFile: `const content = await BifrostClient.readToolFile({fileName: "BifrostClient.d.ts"}); return content`, + + // Mixed client scenarios + CodeCallingCodeModeTool: `const result = await BifrostClient.echo({message: "test"}); return result`, + CodeCallingNonCodeModeTool: `const result = await NonCodeModeClient.someTool({}); return result`, + CodeCallingMultipleServers: `const r1 = await BifrostClient.echo({message: "test"}); const r2 = await Server2.add({a: 1, b: 2}); return {r1, r2}`, + CodeWithUndefinedServer: `const result = await UndefinedServer.tool({}); return result`, + CodeWithUndefinedTool: `const result = await BifrostClient.undefinedTool({}); return result`, + + // Agent mode scenarios + CodeForAgentModeAutoExecute: `const result = await BifrostClient.echo({message: "agent-auto"}); return result`, + CodeForAgentModeNonAutoExecute: `const result = await BifrostClient.multiply({a: 5, b: 6}); return result`, +} + +// ExpectedResults contains expected results for validation +var ExpectedResults = struct { + SimpleExpressionResult interface{} + EchoResult string + AddResult float64 + MultiplyResult float64 +}{ + SimpleExpressionResult: float64(2), + EchoResult: "hello", + AddResult: float64(3), + MultiplyResult: float64(6), +} diff --git a/tests/core-mcp/go.mod b/tests/core-mcp/go.mod new file mode 100644 index 0000000000..2d10ebcb66 --- /dev/null +++ b/tests/core-mcp/go.mod @@ -0,0 +1,63 @@ +module github.com/maximhq/bifrost/tests/core-mcp + +go 1.24.3 + +replace github.com/maximhq/bifrost/core => ../../core + +require ( + github.com/maximhq/bifrost/core v0.0.0-00010101000000-000000000000 + github.com/stretchr/testify v1.11.1 +) + +require ( + cloud.google.com/go/compute/metadata v0.9.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.39.5 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.13 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect + github.com/aws/smithy-go v1.23.1 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.1 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.1 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.1 // indirect + github.com/mark3labs/mcp-go v0.41.1 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.67.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.22.0 // indirect + golang.org/x/net v0.47.0 // indirect + golang.org/x/oauth2 v0.32.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/tests/core-mcp/go.sum b/tests/core-mcp/go.sum new file mode 100644 index 0000000000..73a452e719 --- /dev/null +++ b/tests/core-mcp/go.sum @@ -0,0 +1,141 @@ +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.39.5 h1:e/SXuia3rkFtapghJROrydtQpfQaaUgd1cUvyO1mp2w= +github.com/aws/aws-sdk-go-v2 v1.39.5/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 h1:t9yYsydLYNBk9cJ73rgPhPWqOh/52fcWDQB5b1JsKSY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2/go.mod h1:IusfVNTmiSN3t4rhxWFaBAqn+mcNdwKtPcV16eYdgko= +github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k= +github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.12 h1:p/9flfXdoAnwJnuW9xHEAFY22R3A6skYkW19JFF9F+8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.12/go.mod h1:ZTLHakoVCTtW8AaLGSwJ3LXqHD9uQKnOcv1TrpO6u2k= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.12 h1:2lTWFvRcnWFFLzHWmtddu5MTchc5Oj2OOey++99tPZ0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.12/go.mod h1:hI92pK+ho8HVcWMHKHrK3Uml4pfG7wvL86FzO0LVtQQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0 h1:e+8XbKB6IMn8A4OAyZccO4pYfB3s7bt6azNIPE7AnPg= +github.com/google/pprof v0.0.0-20240625030939-27f56978b8b0/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= +github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= +github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.41.1 h1:w78eWfiQam2i8ICL7AL0WFiq7KHNJQ6UB53ZVtH4KGA= +github.com/mark3labs/mcp-go v0.41.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.67.0 h1:tqKlJMUP6iuNG8hGjK/s9J4kadH7HLV4ijEcPGsezac= +github.com/valyala/fasthttp v1.67.0/go.mod h1:qYSIpqt/0XNmShgo/8Aq8E3UYWVVwNS2QYmzd8WIEPM= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= +golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= +golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tests/core-mcp/integration_test.go b/tests/core-mcp/integration_test.go new file mode 100644 index 0000000000..def838b6a5 --- /dev/null +++ b/tests/core-mcp/integration_test.go @@ -0,0 +1,229 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFullWorkflowListToolFilesReadToolFileExecuteToolCode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Step 1: List tool files + listCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient") + + // Step 2: Read tool file + readCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, readCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Contains(t, responseText, "interface", "Should contain interface definitions") + assert.Contains(t, responseText, "echo", "Should contain echo tool") + + // Step 3: Execute code using the discovered tools + executeCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingCodeModeTool, + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, executeCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestMultipleCodeModeClientsWithDifferentAutoExecuteConfigs(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure bifrostInternal with mixed auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo", "add"}, // multiply not auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config)) + assert.True(t, canAutoExecuteTool("add", bifrostClient.Config)) + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config)) +} + +func TestToolFilteringWithCodeMode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure specific tools only + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"echo", "add"}, // Only these tools available + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "add") + assert.NotContains(t, bifrostClient.Config.ToolsToExecute, "multiply") +} + +func TestCodeModeAndNonCodeModeToolsInSameRequest(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Code mode tools should be available + listCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // Verify direct tools are not exposed for code-mode clients + // Code mode clients expose tools via executeToolCode, not as direct tool calls + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test", + }) + _, bifrostErr = b.ExecuteMCPTool(ctx, echoCall) + require.NotNil(t, bifrostErr, "Direct tool call should fail for code-mode client") + assert.Contains(t, bifrostErr.Error.Message, "not available", "Error should indicate tool is not available") +} + +func TestComplexCodeExecutionWithMultipleToolCalls(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test complex code with multiple tool calls + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.ToolCallChain, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestCodeExecutionWithErrorHandling(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code with error handling + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.ToolCallErrorHandling, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "handled") +} + +func TestCodeExecutionWithAsyncAwait(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test async/await syntax + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.AsyncAwaitTest, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestLongCodeExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test long and complex code execution + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.LongCodeExecution, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} diff --git a/tests/core-mcp/mcp_connection_test.go b/tests/core-mcp/mcp_connection_test.go new file mode 100644 index 0000000000..e553549314 --- /dev/null +++ b/tests/core-mcp/mcp_connection_test.go @@ -0,0 +1,299 @@ +package mcp + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMCPManagerInitialization(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + require.NotNil(t, b) + + // Verify MCP is configured + clients, err := b.GetMCPClients() + require.NoError(t, err) + assert.NotNil(t, clients) +} + +func TestLocalToolRegistration(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Verify tools are available + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // Find the bifrostInternal client + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient, "bifrostInternal client should exist") + assert.Equal(t, schemas.MCPConnectionStateConnected, bifrostClient.State) + + // Verify tools are registered + toolNames := make(map[string]bool) + for _, tool := range bifrostClient.Tools { + toolNames[tool.Name] = true + } + + assert.True(t, toolNames["echo"], "echo tool should be registered") + assert.True(t, toolNames["add"], "add tool should be registered") + assert.True(t, toolNames["multiply"], "multiply tool should be registered") +} + +func TestToolDiscovery(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Use CodeMode since we're testing CodeMode tools (listToolFiles, readToolFile) + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test listToolFiles + listToolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listToolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "servers/", "Should list servers") + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient server") + + // Test readToolFile + readToolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, readToolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Contains(t, responseText, "interface", "Should contain TypeScript interface declarations") + assert.Contains(t, responseText, "echo", "Should contain echo tool definition") + assert.Contains(t, responseText, "EchoInput", "Should contain echo input interface") +} + +func TestToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test echo tool + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText) + + // Test add tool + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(5), + "b": float64(3), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Equal(t, "8", responseText) + + // Test multiply tool + multiplyCall := createToolCall("multiply", map[string]interface{}{ + "a": float64(4), + "b": float64(7), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, multiplyCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Equal(t, "28", responseText) +} + +func TestMultipleServers(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Use CodeMode since we're testing CodeMode tools (listToolFiles) + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Verify we have at least one server + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // Test listToolFiles with multiple servers + listToolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, listToolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient server") +} + +// TestExternalMCPConnection tests connection to external MCP server +// This test requires external MCP credentials to be provided via environment variables +// or test configuration. For now, it's a placeholder that can be enabled when credentials are available. +func TestExternalMCPConnection(t *testing.T) { + t.Skip("Skipping external MCP connection test - requires credentials") + + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + _, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Example: Connect to external MCP server + // Uncomment and configure when credentials are available + /* + connectionString := os.Getenv("EXTERNAL_MCP_CONNECTION_STRING") + if connectionString == "" { + t.Skip("EXTERNAL_MCP_CONNECTION_STRING not set") + } + + err = connectExternalMCP(b, "external-server", "external-1", "http", connectionString) + require.NoError(t, err) + + // Verify connection + clients := b.GetMCPClients() + found := false + for _, client := range clients { + if client.Config.ID == "external-1" { + found = true + assert.Equal(t, schemas.MCPConnectionStateConnected, client.State) + break + } + } + assert.True(t, found, "External client should be connected") + */ +} + +func TestToolExecutionTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test slow tool with short timeout + slowCall := createToolCall("slow_tool", map[string]interface{}{ + "delay_ms": float64(100), + }) + + start := time.Now() + result, bifrostErr := b.ExecuteMCPTool(ctx, slowCall) + duration := time.Since(start) + + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assert.GreaterOrEqual(t, duration, 100*time.Millisecond, "Should take at least 100ms") +} + +func TestToolExecutionError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test error tool - tool execution succeeds but result contains error message + errorCall := createToolCall("error_tool", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, errorCall) + + // Tool execution should succeed (no bifrostErr), but result should contain error message + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Error:", "Result should contain error message") + assert.Contains(t, responseText, "this tool always fails", "Result should contain the error text") +} + +func TestComplexArgsTool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test complex args tool + complexCall := createToolCall("complex_args_tool", map[string]interface{}{ + "data": map[string]interface{}{ + "nested": map[string]interface{}{ + "value": float64(42), + "array": []interface{}{1, 2, 3}, + }, + }, + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, complexCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Received data", "Should process complex args") + assert.Contains(t, responseText, "42", "Should contain nested value") +} diff --git a/tests/core-mcp/responses_test.go b/tests/core-mcp/responses_test.go new file mode 100644 index 0000000000..d9c0347882 --- /dev/null +++ b/tests/core-mcp/responses_test.go @@ -0,0 +1,442 @@ +package mcp + +import ( + "context" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestResponsesNonCodeModeToolExecution tests direct tool execution via Responses API +func TestResponsesNonCodeModeToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode and ensure tools are available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, // Allow all tools + }) + require.NoError(t, err) + + // Execute tool directly to verify it works + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText, "Echo tool should return the input message") +} + +// TestResponsesCodeModeToolExecution tests code mode tool execution via Responses API +func TestResponsesCodeModeToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode directly to verify code mode works + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.SimpleExpression, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "completed successfully") +} + +// TestResponsesAgentModeWithAutoExecuteTools tests agent mode configuration with auto-executable tools +func TestResponsesAgentModeWithAutoExecuteTools(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure bifrostInternal with echo as auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // Only echo is auto-execute + }) + require.NoError(t, err) + + // Verify configuration + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should not be auto-executable") + + // Verify echo tool can be executed directly + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText, "Echo tool should return the input message") +} + +// TestResponsesAgentModeWithNonAutoExecuteTools tests agent mode configuration with non-auto-executable tools +func TestResponsesAgentModeWithNonAutoExecuteTools(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure bifrostInternal with multiply NOT in auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // multiply is NOT auto-execute + }) + require.NoError(t, err) + + // Verify configuration + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should not be auto-executable") + + // Verify multiply tool can still be executed directly (just not auto-executed) + multiplyCall := createToolCall("multiply", map[string]interface{}{ + "a": float64(2), + "b": float64(3), + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, multiplyCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "6", responseText, "Multiply tool should return correct result") +} + +// TestResponsesAgentModeMaxDepth tests agent mode max depth configuration via Responses API +func TestResponsesAgentModeMaxDepth(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Create Bifrost with max depth of 2 + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 2, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx context.Context) string { + return "test-request-id" + }, + } + b, err := setupTestBifrostWithMCPConfig(ctx, mcpConfig) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure all tools as available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Verify tools still work with max depth configured + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test", responseText, "Echo tool should work with max depth configured") +} + +// TestResponsesToolExecutionTimeout tests tool execution timeout via Responses API +func TestResponsesToolExecutionTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Create Bifrost with short timeout + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 10, + ToolExecutionTimeout: 100 * time.Millisecond, // Very short timeout + }, + FetchNewRequestIDFunc: func(ctx context.Context) string { + return "test-request-id" + }, + } + b, err := setupTestBifrostWithMCPConfig(ctx, mcpConfig) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure slow_tool + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Create a Responses request that will trigger a slow tool + req := &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Call slow_tool with delay 500ms"), + }, + }, + }, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{ + { + Name: schemas.Ptr("slow_tool"), + Description: schemas.Ptr("A tool that takes time to execute"), + }, + }, + }, + } + + // Execute the request - should handle timeout gracefully + _, bifrostErr := b.ResponsesRequest(ctx, req) + // Timeout errors are acceptable in this test + if bifrostErr != nil { + assert.Contains(t, bifrost.GetErrorMessage(bifrostErr), "timeout", "Should contain timeout error") + } +} + +// TestResponsesMultipleToolCalls tests multiple tool calls via Responses API +func TestResponsesMultipleToolCalls(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure all tools as available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Test echo tool + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test", responseText, "Echo tool should return correct result") + + // Test add tool + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(5), + "b": float64(3), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "8", responseText, "Add tool should return correct result") +} + +// TestResponsesCodeModeWithCodeExecution tests code mode with code execution via Responses API +func TestResponsesCodeModeWithCodeExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code calling code mode client tools + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingCodeModeTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "test") +} + +// TestResponsesToolFiltering tests tool filtering via Responses API +func TestResponsesToolFiltering(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure specific tools only + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"echo", "add"}, // Only these tools available + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + // Verify allowed tools work + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test", responseText, "Echo tool should work") + + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(1), + "b": float64(2), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "3", responseText, "Add tool should work") + + // Verify multiply tool is NOT available (should fail) + multiplyCall := createToolCall("multiply", map[string]interface{}{ + "a": float64(2), + "b": float64(3), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, multiplyCall) + // Should fail because multiply is not in ToolsToExecute + assert.NotNil(t, bifrostErr, "Multiply tool should fail when not in ToolsToExecute") +} + +// TestResponsesComplexWorkflow tests a complex workflow via Responses API +func TestResponsesComplexWorkflow(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure all tools as available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Test echo tool + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "hello", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "hello", responseText, "Echo tool should return correct result") + + // Test add tool + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(5), + "b": float64(3), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "8", responseText, "Add tool should return correct result") + + // Test multiply tool with result from add + multiplyCall := createToolCall("multiply", map[string]interface{}{ + "a": float64(8), // Result from add + "b": float64(2), + }) + result, bifrostErr = b.ExecuteMCPTool(ctx, multiplyCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "16", responseText, "Multiply tool should return correct result") +} diff --git a/tests/core-mcp/setup.go b/tests/core-mcp/setup.go new file mode 100644 index 0000000000..b6e6cdac4c --- /dev/null +++ b/tests/core-mcp/setup.go @@ -0,0 +1,402 @@ +package mcp + +import ( + "context" + "fmt" + "os" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// TestTimeout defines the maximum duration for MCP tests +const TestTimeout = 10 * time.Minute + +// TestAccount is a minimal account implementation for testing +type TestAccount struct{} + +func (a *TestAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil +} + +func (a *TestAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, nil +} + +func (a *TestAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// setupTestBifrost initializes and returns a Bifrost instance for testing +// This creates a basic Bifrost instance without any MCP clients configured +func setupTestBifrost(ctx context.Context) (*bifrost.Bifrost, error) { + return setupTestBifrostWithMCPConfig(ctx, &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 10, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx context.Context) string { + return "test-request-id" + }, + }) +} + +// setupTestBifrostWithCodeMode initializes and returns a Bifrost instance for testing with CodeMode +// This sets up bifrostInternal client as a code mode client +// Note: Tools must be registered first to create the bifrostInternal client +func setupTestBifrostWithCodeMode(ctx context.Context) (*bifrost.Bifrost, error) { + b, err := setupTestBifrost(ctx) + if err != nil { + return nil, err + } + + // Register tools first to create the bifrostInternal client + err = registerTestTools(b) + if err != nil { + return nil, fmt.Errorf("failed to register test tools: %w", err) + } + + // Get current client config to preserve existing settings + clients, err := b.GetMCPClients() + if err != nil { + return nil, fmt.Errorf("failed to get MCP clients: %w", err) + } + + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + + if currentConfig == nil { + return nil, fmt.Errorf("bifrostInternal client not found") + } + + // Set bifrostInternal client to code mode and ensure tools are available + // Preserve existing ToolsToExecute if set, otherwise use wildcard + toolsToExecute := currentConfig.ToolsToExecute + if len(toolsToExecute) == 0 { + toolsToExecute = []string{"*"} + } + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: true, + ToolsToExecute: toolsToExecute, + ToolsToAutoExecute: currentConfig.ToolsToAutoExecute, + }) + if err != nil { + return nil, fmt.Errorf("failed to set bifrostInternal client to code mode: %w", err) + } + + return b, nil +} + +// setupTestBifrostWithMCPConfig initializes Bifrost with custom MCP config +func setupTestBifrostWithMCPConfig(ctx context.Context, mcpConfig *schemas.MCPConfig) (*bifrost.Bifrost, error) { + account := &TestAccount{} + + // Ensure FetchNewRequestIDFunc is set if not provided + // This is required for the tools handler to be fully setup + if mcpConfig.FetchNewRequestIDFunc == nil { + mcpConfig.FetchNewRequestIDFunc = func(ctx context.Context) string { + return "test-request-id" + } + } + + if mcpConfig.ToolManagerConfig == nil { + mcpConfig.ToolManagerConfig = &schemas.MCPToolManagerConfig{ + MaxAgentDepth: schemas.DefaultMaxAgentDepth, + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + } + } + + b, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: account, + Plugins: nil, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), + MCPConfig: mcpConfig, + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize Bifrost: %w", err) + } + + return b, nil +} + +// registerTestTools registers simple test tools for testing +func registerTestTools(b *bifrost.Bifrost) error { + // Echo tool + echoSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "echo", + Description: schemas.Ptr("Echoes back the input message"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to echo", + }, + }, + Required: []string{"message"}, + }, + }, + } + if err := b.RegisterMCPTool("echo", "Echoes back the input message", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + message, ok := argsMap["message"].(string) + if !ok { + return "", fmt.Errorf("message field is required") + } + return message, nil + }, echoSchema); err != nil { + return fmt.Errorf("failed to register echo tool: %w", err) + } + + // Add tool + addSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "add", + Description: schemas.Ptr("Adds two numbers"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "a": map[string]interface{}{ + "type": "number", + "description": "First number", + }, + "b": map[string]interface{}{ + "type": "number", + "description": "Second number", + }, + }, + Required: []string{"a", "b"}, + }, + }, + } + if err := b.RegisterMCPTool("add", "Adds two numbers", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + a, ok := argsMap["a"].(float64) + if !ok { + return "", fmt.Errorf("a field is required") + } + bVal, ok := argsMap["b"].(float64) + if !ok { + return "", fmt.Errorf("b field is required") + } + return fmt.Sprintf("%.0f", a+bVal), nil + }, addSchema); err != nil { + return fmt.Errorf("failed to register add tool: %w", err) + } + + // Multiply tool + multiplySchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "multiply", + Description: schemas.Ptr("Multiplies two numbers"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "a": map[string]interface{}{ + "type": "number", + "description": "First number", + }, + "b": map[string]interface{}{ + "type": "number", + "description": "Second number", + }, + }, + Required: []string{"a", "b"}, + }, + }, + } + if err := b.RegisterMCPTool("multiply", "Multiplies two numbers", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + a, ok := argsMap["a"].(float64) + if !ok { + return "", fmt.Errorf("a field is required") + } + bVal, ok := argsMap["b"].(float64) + if !ok { + return "", fmt.Errorf("b field is required") + } + return fmt.Sprintf("%.0f", a*bVal), nil + }, multiplySchema); err != nil { + return fmt.Errorf("failed to register multiply tool: %w", err) + } + + // GetData tool - returns structured data + getDataSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_data", + Description: schemas.Ptr("Returns structured data"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{}, + Required: []string{}, + }, + }, + } + if err := b.RegisterMCPTool("get_data", "Returns structured data", func(args any) (string, error) { + return `{"items": [{"id": 1, "name": "test"}, {"id": 2, "name": "example"}]}`, nil + }, getDataSchema); err != nil { + return fmt.Errorf("failed to register get_data tool: %w", err) + } + + // ErrorTool - always returns an error + errorToolSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "error_tool", + Description: schemas.Ptr("A tool that always returns an error"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{}, + Required: []string{}, + }, + }, + } + if err := b.RegisterMCPTool("error_tool", "A tool that always returns an error", func(args any) (string, error) { + return "", fmt.Errorf("this tool always fails") + }, errorToolSchema); err != nil { + return fmt.Errorf("failed to register error_tool: %w", err) + } + + // SlowTool - takes time to execute + slowToolSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "slow_tool", + Description: schemas.Ptr("A tool that takes time to execute"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "delay_ms": map[string]interface{}{ + "type": "number", + "description": "Delay in milliseconds", + }, + }, + Required: []string{"delay_ms"}, + }, + }, + } + if err := b.RegisterMCPTool("slow_tool", "A tool that takes time to execute", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + delayMs, ok := argsMap["delay_ms"].(float64) + if !ok { + return "", fmt.Errorf("delay_ms field is required") + } + time.Sleep(time.Duration(delayMs) * time.Millisecond) + return fmt.Sprintf("Completed after %v ms", delayMs), nil + }, slowToolSchema); err != nil { + return fmt.Errorf("failed to register slow_tool: %w", err) + } + + // ComplexArgsTool - accepts complex nested arguments + complexArgsSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "complex_args_tool", + Description: schemas.Ptr("A tool that accepts complex nested arguments"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &map[string]interface{}{ + "data": map[string]interface{}{ + "type": "object", + "description": "Complex nested data", + }, + }, + Required: []string{"data"}, + }, + }, + } + if err := b.RegisterMCPTool("complex_args_tool", "A tool that accepts complex nested arguments", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + data, ok := argsMap["data"] + if !ok { + return "", fmt.Errorf("data field is required") + } + return fmt.Sprintf("Received data: %v", data), nil + }, complexArgsSchema); err != nil { + return fmt.Errorf("failed to register complex_args_tool: %w", err) + } + + return nil +} + +// connectExternalMCP connects to an external MCP server +// This is a helper function that can be used when external MCP credentials are provided +func connectExternalMCP(b *bifrost.Bifrost, name, id, connectionType, connectionString string) error { + var clientConfig schemas.MCPClientConfig + + switch connectionType { + case "http": + clientConfig = schemas.MCPClientConfig{ + ID: id, + Name: name, + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: schemas.Ptr(connectionString), + } + case "sse": + clientConfig = schemas.MCPClientConfig{ + ID: id, + Name: name, + ConnectionType: schemas.MCPConnectionTypeSSE, + ConnectionString: schemas.Ptr(connectionString), + } + default: + return fmt.Errorf("unsupported connection type: %s", connectionType) + } + + clients, err := b.GetMCPClients() + if err != nil { + return fmt.Errorf("failed to get MCP clients: %w", err) + } + for _, client := range clients { + if client.Config.ID == id { + // Client already exists + return nil + } + } + + if err := b.AddMCPClient(clientConfig); err != nil { + return fmt.Errorf("failed to add external MCP client: %w", err) + } + + return nil +} diff --git a/tests/core-mcp/tool_execution_test.go b/tests/core-mcp/tool_execution_test.go new file mode 100644 index 0000000000..5f34678051 --- /dev/null +++ b/tests/core-mcp/tool_execution_test.go @@ -0,0 +1,246 @@ +package mcp + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNonCodeModeToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode and ensure tools are available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, // Allow all tools + }) + require.NoError(t, err) + + // Test direct tool execution + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText) +} + +func TestCodeModeToolExecution(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.SimpleExpression, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "completed successfully") +} + +func TestCodeModeCallingCodeModeClientTools(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code calling code mode client tools + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingCodeModeTool, + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "test") +} + +func TestCodeModeCallingMultipleCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code calling tools from multiple code mode clients + // Since we only have bifrostInternal, we'll test calling multiple tools from the same client + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.MultipleServerToolCalls, // This calls echo and add from BifrostClient + }) + + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestListToolFilesWithNoClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Don't register tools or set code mode - should have no code mode clients + toolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + // listToolFiles should still work but return empty/no servers message + if bifrostErr == nil && result != nil { + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "No servers", "Should indicate no servers") + } +} + +func TestListToolFilesWithOnlyNonCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + }) + require.NoError(t, err) + + // listToolFiles should not be available when no code mode clients exist + // But if it is called, it should return empty + toolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + if bifrostErr == nil && result != nil { + responseText := *result.Content.ContentStr + // Should indicate no servers or empty list + assert.True(t, + len(responseText) == 0 || + strings.Contains(responseText, "No servers") || strings.Contains(responseText, "servers/"), + "Should return empty or no servers message") + } +} + +func TestListToolFilesWithCodeModeClients(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "servers/", "Should list servers") + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient server") +} + +func TestReadToolFileForNonExistentClient(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "NonExistentClient.d.ts", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "No server found", "Should indicate server not found") +} + +func TestReadToolFileForCodeModeClient(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "interface", "Should contain TypeScript interface declarations") + assert.Contains(t, responseText, "echo", "Should contain echo tool definition") +} + +func TestReadToolFileWithLineRange(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + "startLine": float64(1), + "endLine": float64(10), + }) + result, bifrostErr := b.ExecuteMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.NotEmpty(t, responseText, "Should return content") +} diff --git a/tests/core-mcp/utils.go b/tests/core-mcp/utils.go new file mode 100644 index 0000000000..f48bb5f5b0 --- /dev/null +++ b/tests/core-mcp/utils.go @@ -0,0 +1,104 @@ +package mcp + +import ( + "encoding/json" + "fmt" + "slices" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createToolCall creates a tool call message for testing +func createToolCall(toolName string, arguments map[string]interface{}) schemas.ChatAssistantMessageToolCall { + argsJSON, _ := json.Marshal(arguments) + argsStr := string(argsJSON) + id := fmt.Sprintf("test-tool-call-%d", len(argsStr)) + toolType := "function" + + return schemas.ChatAssistantMessageToolCall{ + ID: &id, + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &toolName, + Arguments: argsStr, + }, + } +} + +// assertExecutionResult validates execution results +func assertExecutionResult(t *testing.T, result *schemas.ChatMessage, expectedSuccess bool, expectedLogs []string, expectedErrorKind string) { + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + + if expectedSuccess { + // Success case - should not contain error indicators (but allow console.error output) + assert.NotContains(t, responseText, "Execution runtime error", "Response should not contain execution runtime error for successful execution") + assert.NotContains(t, responseText, "Execution typescript error", "Response should not contain execution typescript error for successful execution") + assert.NotContains(t, responseText, "Error:", "Response should not contain Error: prefix for successful execution") + + // Check logs if expected + if len(expectedLogs) > 0 { + for _, expectedLog := range expectedLogs { + assert.Contains(t, responseText, expectedLog, "Response should contain expected log") + } + } + } else { + // Error case - should contain error information + assert.Contains(t, responseText, "error", "Response should contain error for failed execution") + + if expectedErrorKind != "" { + assert.Contains(t, responseText, expectedErrorKind, "Response should contain expected error kind") + } + } +} + +// assertResultContains validates that the result contains specific text +func assertResultContains(t *testing.T, result *schemas.ChatMessage, expectedText string) { + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, expectedText, "Response should contain expected text") +} + +// requireNoBifrostError asserts that bifrostErr is nil, using GetErrorMessage for better error reporting +func requireNoBifrostError(t *testing.T, bifrostErr *schemas.BifrostError, msgAndArgs ...interface{}) { + if bifrostErr != nil { + errorMsg := bifrost.GetErrorMessage(bifrostErr) + if len(msgAndArgs) > 0 { + require.Fail(t, fmt.Sprintf("Expected no error but got: %s", errorMsg), msgAndArgs...) + } else { + require.Fail(t, fmt.Sprintf("Expected no error but got: %s", errorMsg)) + } + } +} + +// canAutoExecuteTool checks if a tool can be auto-executed based on client config +func canAutoExecuteTool(toolName string, config schemas.MCPClientConfig) bool { + // First check if tool is in ToolsToExecute + if config.ToolsToExecute != nil { + if len(config.ToolsToExecute) == 0 { + return false // Empty list means no tools allowed + } + if !slices.Contains(config.ToolsToExecute, "*") && !slices.Contains(config.ToolsToExecute, toolName) { + return false // Tool not in allowed list + } + } else { + return false // nil means no tools allowed + } + + // Then check if tool is in ToolsToAutoExecute + if len(config.ToolsToAutoExecute) == 0 { + return false // No auto-execute tools configured + } + + return slices.Contains(config.ToolsToAutoExecute, "*") || slices.Contains(config.ToolsToAutoExecute, toolName) +} diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index c8de88dd93..78dc8afdbc 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -26,6 +26,7 @@ type ConfigManager interface { ReloadClientConfigFromConfigStore(ctx context.Context) error ReloadPricingManager(ctx context.Context) error UpdateDropExcessRequests(ctx context.Context, value bool) + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int) error ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error } @@ -162,7 +163,7 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { } // Validating framework config - if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != modelcatalog.DefaultPricingURL { + if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != modelcatalog.DefaultPricingURL { // Checking the accessibility of the pricing URL resp, err := http.Get(*payload.FrameworkConfig.PricingURL) if err != nil { @@ -196,6 +197,39 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.DropExcessRequests = payload.ClientConfig.DropExcessRequests } + // Validate MCP tool manager config values before updating + if payload.ClientConfig.MCPAgentDepth <= 0 { + logger.Warn("mcp_agent_depth must be greater than 0") + SendError(ctx, fasthttp.StatusBadRequest, "mcp_agent_depth must be greater than 0") + return + } + + if payload.ClientConfig.MCPToolExecutionTimeout <= 0 { + logger.Warn("mcp_tool_execution_timeout must be greater than 0") + SendError(ctx, fasthttp.StatusBadRequest, "mcp_tool_execution_timeout must be greater than 0") + return + } + + shouldReloadMCPToolManagerConfig := false + + if payload.ClientConfig.MCPAgentDepth != currentConfig.MCPAgentDepth { + updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth + shouldReloadMCPToolManagerConfig = true + } + + if payload.ClientConfig.MCPToolExecutionTimeout != currentConfig.MCPToolExecutionTimeout { + updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout + shouldReloadMCPToolManagerConfig = true + } + + if shouldReloadMCPToolManagerConfig { + if err := h.configManager.UpdateMCPToolManagerConfig(ctx, updatedConfig.MCPAgentDepth, updatedConfig.MCPToolExecutionTimeout); err != nil { + logger.Warn(fmt.Sprintf("failed to update mcp tool manager config: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp tool manager config: %v", err)) + return + } + } + if !slices.Equal(payload.ClientConfig.PrometheusLabels, currentConfig.PrometheusLabels) { updatedConfig.PrometheusLabels = payload.ClientConfig.PrometheusLabels shouldReloadTelemetryPlugin = true @@ -213,7 +247,8 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.AllowDirectKeys = payload.ClientConfig.AllowDirectKeys updatedConfig.MaxRequestBodySizeMB = payload.ClientConfig.MaxRequestBodySizeMB updatedConfig.EnableLiteLLMFallbacks = payload.ClientConfig.EnableLiteLLMFallbacks - + updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth + updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout // Validate LogRetentionDays if payload.ClientConfig.LogRetentionDays < 1 { logger.Warn("log_retention_days must be at least 1") @@ -260,7 +295,7 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { } // Updating framework config shouldReloadFrameworkConfig := false - if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != *frameworkConfig.PricingURL { + if payload.FrameworkConfig.PricingURL != nil && *payload.FrameworkConfig.PricingURL != *frameworkConfig.PricingURL { // Checking the accessibility of the pricing URL resp, err := http.Get(*payload.FrameworkConfig.PricingURL) if err != nil { diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 88b710d204..c6cdec1299 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -8,6 +8,7 @@ import ( "fmt" "slices" "sort" + "strings" "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" @@ -189,13 +190,28 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_execute: %v", err)) return } + + // Auto-clear tools_to_auto_execute if tools_to_execute is empty + // If no tools are allowed to execute, no tools can be auto-executed + if len(req.ToolsToExecute) == 0 { + req.ToolsToAutoExecute = []string{} + } + + if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_auto_execute: %v", err)) + return + } + if err := validateMCPClientName(req.Name); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) + return + } if err := h.mcpManager.AddMCPClient(ctx, req); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to add MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to connect MCP client: %v", err)) return } SendJSON(ctx, map[string]any{ "status": "success", - "message": "MCP client added successfully", + "message": "MCP client connected successfully", }) } @@ -219,6 +235,24 @@ func (h *MCPHandler) editMCPClient(ctx *fasthttp.RequestCtx) { return } + // Auto-clear tools_to_auto_execute if tools_to_execute is empty + // If no tools are allowed to execute, no tools can be auto-executed + if len(req.ToolsToExecute) == 0 { + req.ToolsToAutoExecute = []string{} + } + + // Validate tools_to_auto_execute + if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_auto_execute: %v", err)) + return + } + + // Validate client name + if err := validateMCPClientName(req.Name); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) + return + } + if err := h.mcpManager.EditMCPClient(ctx, id, req); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to edit MCP client: %v", err)) return @@ -280,3 +314,69 @@ func validateToolsToExecute(toolsToExecute []string) error { return nil } + +func validateToolsToAutoExecute(toolsToAutoExecute []string, toolsToExecute []string) error { + if len(toolsToAutoExecute) > 0 { + // Check if wildcard "*" is combined with other tool names + hasWildcard := slices.Contains(toolsToAutoExecute, "*") + if hasWildcard && len(toolsToAutoExecute) > 1 { + return fmt.Errorf("wildcard '*' cannot be combined with other tool names") + } + + // Check for duplicate entries + seen := make(map[string]bool) + for _, tool := range toolsToAutoExecute { + if seen[tool] { + return fmt.Errorf("duplicate tool name '%s'", tool) + } + seen[tool] = true + } + + // Check that all tools in ToolsToAutoExecute are also in ToolsToExecute + // Create a set of allowed tools from ToolsToExecute + allowedTools := make(map[string]bool) + hasWildcardInExecute := slices.Contains(toolsToExecute, "*") + if hasWildcardInExecute { + // If "*" is in ToolsToExecute, all tools are allowed + return nil + } + for _, tool := range toolsToExecute { + allowedTools[tool] = true + } + + // Validate each tool in ToolsToAutoExecute + for _, tool := range toolsToAutoExecute { + if tool == "*" { + // Wildcard is allowed if "*" is in ToolsToExecute + if !hasWildcardInExecute { + return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) + } + } else if !allowedTools[tool] { + return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) + } + } + } + + return nil +} + +func validateMCPClientName(name string) error { + if strings.TrimSpace(name) == "" { + return fmt.Errorf("client name is required") + } + for _, r := range name { + if r > 127 { // non-ASCII + return fmt.Errorf("name must contain only ASCII characters") + } + } + if strings.Contains(name, "-") { + return fmt.Errorf("client name cannot contain hyphens") + } + if strings.Contains(name, " ") { + return fmt.Errorf("client name cannot contain spaces") + } + if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { + return fmt.Errorf("client name cannot start with a number") + } + return nil +} diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 8ac8cbb679..bd0cd1bc66 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -180,6 +180,8 @@ var DefaultClientConfig = configstore.ClientConfig{ AllowDirectKeys: false, AllowedOrigins: []string{"*"}, MaxRequestBodySizeMB: 100, + MCPAgentDepth: 10, + MCPToolExecutionTimeout: 30, EnableLiteLLMFallbacks: false, } @@ -608,7 +610,12 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { if !config.ClientConfig.EnableLiteLLMFallbacks && configData.Client.EnableLiteLLMFallbacks { config.ClientConfig.EnableLiteLLMFallbacks = configData.Client.EnableLiteLLMFallbacks } - + if config.ClientConfig.MCPAgentDepth == 0 && configData.Client.MCPAgentDepth != 0 { + config.ClientConfig.MCPAgentDepth = configData.Client.MCPAgentDepth + } + if config.ClientConfig.MCPToolExecutionTimeout == 0 && configData.Client.MCPToolExecutionTimeout != 0 { + config.ClientConfig.MCPToolExecutionTimeout = configData.Client.MCPToolExecutionTimeout + } // Update store with merged config if config.ConfigStore != nil { logger.Debug("updating merged client config in store") @@ -764,7 +771,6 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { } if mcpConfig != nil { config.MCPConfig = mcpConfig - // Merge with config file if present if configData.MCP != nil && len(configData.MCP.ClientConfigs) > 0 { logger.Debug("merging MCP config from config file with store") @@ -1960,7 +1966,7 @@ func (c *Config) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClien if err := c.client.AddMCPClient(c.MCPConfig.ClientConfigs[len(c.MCPConfig.ClientConfigs)-1]); err != nil { c.MCPConfig.ClientConfigs = c.MCPConfig.ClientConfigs[:len(c.MCPConfig.ClientConfigs)-1] c.cleanupEnvKeys("", clientConfig.ID, newEnvKeys) - return fmt.Errorf("failed to add MCP client: %w", err) + return fmt.Errorf("failed to connect MCP client: %w", err) } if c.ConfigStore != nil { @@ -2109,8 +2115,10 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig sch // Update the in-memory config with the processed values c.MCPConfig.ClientConfigs[configIndex].Name = processedConfig.Name + c.MCPConfig.ClientConfigs[configIndex].IsCodeModeClient = processedConfig.IsCodeModeClient c.MCPConfig.ClientConfigs[configIndex].Headers = processedConfig.Headers c.MCPConfig.ClientConfigs[configIndex].ToolsToExecute = processedConfig.ToolsToExecute + c.MCPConfig.ClientConfigs[configIndex].ToolsToAutoExecute = processedConfig.ToolsToAutoExecute // Check if client is registered in Bifrost (can be not registered if client initialization failed) if clients, err := c.client.GetMCPClients(); err == nil && len(clients) > 0 { @@ -2145,12 +2153,14 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig sch func (c *Config) RedactMCPClientConfig(config schemas.MCPClientConfig) schemas.MCPClientConfig { // Create a copy with basic fields configCopy := schemas.MCPClientConfig{ - ID: config.ID, - Name: config.Name, - ConnectionType: config.ConnectionType, - ConnectionString: config.ConnectionString, - StdioConfig: config.StdioConfig, - ToolsToExecute: append([]string{}, config.ToolsToExecute...), + ID: config.ID, + Name: config.Name, + IsCodeModeClient: config.IsCodeModeClient, + ConnectionType: config.ConnectionType, + ConnectionString: config.ConnectionString, + StdioConfig: config.StdioConfig, + ToolsToExecute: append([]string{}, config.ToolsToExecute...), + ToolsToAutoExecute: append([]string{}, config.ToolsToAutoExecute...), } // Handle connection string if present diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 38b2b64b91..1e962be5ad 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -17,6 +17,7 @@ import ( "github.com/bytedance/sonic" "github.com/fasthttp/router" + "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" @@ -57,6 +58,7 @@ type ServerCallbacks interface { ReloadClientConfigFromConfigStore(ctx context.Context) error ReloadPricingManager(ctx context.Context) error UpdateDropExcessRequests(ctx context.Context, value bool) + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int) error ReloadTeam(ctx context.Context, id string) (*tables.TableTeam, error) RemoveTeam(ctx context.Context, id string) error ReloadCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) @@ -684,6 +686,14 @@ func (s *BifrostHTTPServer) UpdateDropExcessRequests(ctx context.Context, value s.Client.UpdateDropExcessRequests(value) } +// UpdateMCPToolManagerConfig updates the MCP tool manager config +func (s *BifrostHTTPServer) UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int) error { + if s.Config == nil { + return fmt.Errorf("config not found") + } + return s.Client.UpdateToolManagerConfig(maxAgentDepth, toolExecutionTimeoutInSeconds) +} + // UpdatePluginStatus updates the status of a plugin func (s *BifrostHTTPServer) UpdatePluginStatus(name string, status string, logs []string) error { s.pluginStatusMutex.Lock() @@ -1052,6 +1062,12 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to load plugins %v", err) } + mcpConfig := s.Config.MCPConfig + if mcpConfig != nil { + mcpConfig.FetchNewRequestIDFunc = func(ctx context.Context) string { + return uuid.New().String() + } + } // Initialize bifrost client // Create account backed by the high-performance store (all processing is done in LoadFromDatabase) // The account interface now benefits from ultra-fast config access times via in-memory storage @@ -1061,7 +1077,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { InitialPoolSize: s.Config.ClientConfig.InitialPoolSize, DropExcessRequests: s.Config.ClientConfig.DropExcessRequests, Plugins: s.Plugins, - MCPConfig: s.Config.MCPConfig, + MCPConfig: mcpConfig, Logger: logger, }) if err != nil { diff --git a/transports/config.schema.json b/transports/config.schema.json index 079ea75d11..41937c782e 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -343,6 +343,9 @@ "$ref": "#/$defs/mcp_client_config" }, "description": "MCP client configurations" + }, + "tool_manager_config": { + "$ref": "#/$defs/mcp_tool_manager_config" } }, "additionalProperties": false @@ -1407,6 +1410,23 @@ } ] }, + "mcp_tool_manager_config": { + "type": "object", + "properties": { + "tool_execution_timeout": { + "type": "integer", + "description": "Tool execution timeout in seconds", + "minimum": 1, + "default": 30 + }, + "max_agent_depth": { + "type": "integer", + "description": "Max agent depth", + "minimum": 1, + "default": 10 + } + } + }, "weaviate_config": { "type": "object", "description": "Weaviate configuration for vector store", diff --git a/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx b/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx index 0bc202c2b7..e9de5f2bbe 100644 --- a/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx +++ b/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx @@ -5,13 +5,12 @@ import { Button } from "@/components/ui/button"; import { useGetCoreConfigQuery } from "@/lib/store"; import { Copy, InfoIcon, KeyRound } from "lucide-react"; import Link from "next/link"; -import { useMemo, useState } from "react"; +import { useMemo } from "react"; import { toast } from "sonner"; import ContactUsView from "../views/contactUsView"; export default function APIKeysView() { const { data: bifrostConfig, isLoading } = useGetCoreConfigQuery({ fromDB: true }); - const [isTokenVisible, setIsTokenVisible] = useState(false); const isAuthConfigure = useMemo(() => { return bifrostConfig?.auth_config?.is_enabled; }, [bifrostConfig]); @@ -50,7 +49,7 @@ curl --location 'http://localhost:8080/v1/chat/completions' -

+

To generate API keys, you need to set up admin username and password first.{" "} Configure Security Settings @@ -71,10 +70,11 @@ curl --location 'http://localhost:8080/v1/chat/completions' -

+

{isInferenceAuthDisabled ? ( <> - Authentication is currently disabled for inference API calls. You can make inference requests without authentication. Dashboard and admin API calls still require Basic auth with your admin credentials encoded in the standard{" "} + Authentication is currently disabled for inference API calls. You can make inference requests without + authentication. Dashboard and admin API calls still require Basic auth with your admin credentials encoded in the standard{" "} username:password format with base64 encoding. ) : ( @@ -87,7 +87,7 @@ curl --location 'http://localhost:8080/v1/chat/completions' {!isInferenceAuthDisabled && ( <>
-

+

Example:

@@ -95,22 +95,21 @@ curl --location 'http://localhost:8080/v1/chat/completions' -
-									{curlExample}
-								
+
{curlExample}
)}
- + } title="Scope Based API Keys" description="Need granular access control with scope-based API keys? Enterprise customers can create multiple API keys with specific permissions for different services, teams, or environments." diff --git a/ui/app/workspace/config/page.tsx b/ui/app/workspace/config/page.tsx index a8e231c103..b99474e84c 100644 --- a/ui/app/workspace/config/page.tsx +++ b/ui/app/workspace/config/page.tsx @@ -15,6 +15,8 @@ import ObservabilityView from "./views/observabilityView"; import PerformanceTuningView from "./views/performanceTuningView"; import PricingConfigView from "./views/pricingConfigView"; import SecurityView from "./views/securityView"; +import { MCPIcon } from "@/components/ui/icons"; +import MCPView from "./views/mcpView"; const tabs = [ { @@ -37,6 +39,11 @@ const tabs = [ label: "Governance", icon: , }, + { + id: "mcp", + label: "MCP Server", + icon: , + }, { id: "caching", label: "Caching", @@ -79,7 +86,7 @@ export default function ConfigPage() { } return ( -
+
{tabs.map((tab) => (
{ return
Need to restart Bifrost to apply changes.
; }; - - diff --git a/ui/app/workspace/config/views/mcpView.tsx b/ui/app/workspace/config/views/mcpView.tsx new file mode 100644 index 0000000000..088a011303 --- /dev/null +++ b/ui/app/workspace/config/views/mcpView.tsx @@ -0,0 +1,153 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { getErrorMessage, useGetCoreConfigQuery, useUpdateCoreConfigMutation } from "@/lib/store"; +import { CoreConfig } from "@/lib/types/config"; +import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; + +const defaultConfig: CoreConfig = { + drop_excess_requests: false, + initial_pool_size: 1000, + prometheus_labels: [], + enable_logging: true, + enable_governance: true, + enforce_governance_header: false, + allow_direct_keys: false, + allowed_origins: [], + max_request_body_size_mb: 100, + enable_litellm_fallbacks: false, + disable_content_logging: false, + log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, +}; + +export default function MCPView() { + const hasSettingsUpdateAccess = useRbac(RbacResource.Settings, RbacOperation.Update); + const { data: bifrostConfig } = useGetCoreConfigQuery({ fromDB: true }); + const config = bifrostConfig?.client_config; + const [updateCoreConfig, { isLoading }] = useUpdateCoreConfigMutation(); + const [localConfig, setLocalConfig] = useState(defaultConfig); + + const [localValues, setLocalValues] = useState<{ + mcp_agent_depth: string; + mcp_tool_execution_timeout: string; + }>({ + mcp_agent_depth: "10", + mcp_tool_execution_timeout: "30", + }); + + useEffect(() => { + if (bifrostConfig && config) { + setLocalConfig(config); + setLocalValues({ + mcp_agent_depth: config?.mcp_agent_depth?.toString() || "10", + mcp_tool_execution_timeout: config?.mcp_tool_execution_timeout?.toString() || "30", + }); + } + }, [config, bifrostConfig]); + + const hasChanges = useMemo(() => { + if (!config) return false; + return ( + localConfig.mcp_agent_depth !== config.mcp_agent_depth || localConfig.mcp_tool_execution_timeout !== config.mcp_tool_execution_timeout + ); + }, [config, localConfig]); + + const handleAgentDepthChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_agent_depth: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + setLocalConfig((prev) => ({ ...prev, mcp_agent_depth: numValue })); + } + }, []); + + const handleToolExecutionTimeoutChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_tool_execution_timeout: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + setLocalConfig((prev) => ({ ...prev, mcp_tool_execution_timeout: numValue })); + } + }, []); + + const handleSave = useCallback(async () => { + try { + const agentDepth = Number.parseInt(localValues.mcp_agent_depth); + const toolTimeout = Number.parseInt(localValues.mcp_tool_execution_timeout); + + if (isNaN(agentDepth) || agentDepth <= 0) { + toast.error("Max agent depth must be a positive number."); + return; + } + + if (isNaN(toolTimeout) || toolTimeout <= 0) { + toast.error("Tool execution timeout must be a positive number."); + return; + } + + if (!bifrostConfig) { + toast.error("Configuration not loaded. Please refresh and try again."); + return; + } + await updateCoreConfig({ ...bifrostConfig, client_config: localConfig }).unwrap(); + toast.success("MCP settings updated successfully."); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }, [bifrostConfig, localConfig, localValues, updateCoreConfig]); + + return ( +
+
+
+

MCP Settings

+

Configure MCP (Model Context Protocol) agent and tool settings.

+
+ +
+ +
+ {/* Max Agent Depth */} +
+
+ +

Maximum depth for MCP agent execution.

+
+ handleAgentDepthChange(e.target.value)} + min="1" + /> +
+ + {/* Tool Execution Timeout */} +
+
+ +

Maximum time in seconds for tool execution.

+
+ handleToolExecutionTimeoutChange(e.target.value)} + min="1" + /> +
+
+
+ ); +} diff --git a/ui/app/workspace/config/views/observabilityView.tsx b/ui/app/workspace/config/views/observabilityView.tsx index 31071bd522..b4607f7126 100644 --- a/ui/app/workspace/config/views/observabilityView.tsx +++ b/ui/app/workspace/config/views/observabilityView.tsx @@ -24,6 +24,8 @@ const defaultConfig: CoreConfig = { enable_litellm_fallbacks: false, disable_content_logging: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, }; export default function ObservabilityView() { diff --git a/ui/app/workspace/config/views/performanceTuningView.tsx b/ui/app/workspace/config/views/performanceTuningView.tsx index de36e8cf38..cbb3eb1f05 100644 --- a/ui/app/workspace/config/views/performanceTuningView.tsx +++ b/ui/app/workspace/config/views/performanceTuningView.tsx @@ -23,6 +23,8 @@ const defaultConfig: CoreConfig = { enable_litellm_fallbacks: false, disable_content_logging: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, }; export default function PerformanceTuningView() { @@ -91,7 +93,11 @@ export default function PerformanceTuningView() { return; } - await updateCoreConfig({ ...bifrostConfig!, client_config: localConfig }).unwrap(); + if (!bifrostConfig) { + toast.error("Configuration not loaded. Please refresh and try again."); + return; + } + await updateCoreConfig({ ...bifrostConfig, client_config: localConfig }).unwrap(); toast.success("Performance settings updated successfully."); } catch (error) { toast.error(getErrorMessage(error)); diff --git a/ui/app/workspace/config/views/securityView.tsx b/ui/app/workspace/config/views/securityView.tsx index 9f0ae683e9..5d7ac1b27a 100644 --- a/ui/app/workspace/config/views/securityView.tsx +++ b/ui/app/workspace/config/views/securityView.tsx @@ -31,6 +31,8 @@ const defaultConfig: CoreConfig = { max_request_body_size_mb: 100, enable_litellm_fallbacks: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, }; export default function SecurityView() { @@ -173,7 +175,7 @@ export default function SecurityView() {

Security Settings

Configure security and access control settings.

-
diff --git a/ui/app/workspace/logs/views/columns.tsx b/ui/app/workspace/logs/views/columns.tsx index 89f02b8902..f88e5b35d5 100644 --- a/ui/app/workspace/logs/views/columns.tsx +++ b/ui/app/workspace/logs/views/columns.tsx @@ -7,7 +7,7 @@ import { ProviderName, RequestTypeColors, RequestTypeLabels, Status, StatusColor import { LogEntry, ResponsesMessageContentBlock } from "@/lib/types/logs"; import { ColumnDef } from "@tanstack/react-table"; import { ArrowUpDown, Trash2 } from "lucide-react"; -import moment from "moment" +import moment from "moment"; function getMessage(log?: LogEntry) { if (log?.input_history && log.input_history.length > 0) { @@ -23,7 +23,8 @@ function getMessage(log?: LogEntry) { } return lastTextContentBlock; } else if (log?.responses_input_history && log.responses_input_history.length > 0) { - let lastMessageContent = log.responses_input_history[log.responses_input_history.length - 1].content; + let lastMessage = log.responses_input_history[log.responses_input_history.length - 1]; + let lastMessageContent = lastMessage.content; if (typeof lastMessageContent === "string") { return lastMessageContent; } @@ -33,18 +34,23 @@ function getMessage(log?: LogEntry) { lastTextContentBlock = block.text; } } - return lastTextContentBlock; - } else if (log?.speech_input) { - return log.speech_input.input; - } else if (log?.transcription_input) { - return log.transcription_input.prompt || "Audio file"; + // If no content found in content field, check output field for Responses API + if (!lastTextContentBlock && lastMessage.output) { + // Handle output field - it could be a string, an array of content blocks, or a computer tool call output data + if (typeof lastMessage.output === "string") { + return lastMessage.output; + } else if (Array.isArray(lastMessage.output)) { + return lastMessage.output.map((block) => block.text).join("\n"); + } else if (lastMessage.output.type && lastMessage.output.type === "computer_screenshot") { + return lastMessage.output.image_url; + } + } + return lastTextContentBlock ?? ""; } return ""; } -export const createColumns = ( - onDelete: (log: LogEntry) => void, -): ColumnDef[] => [ +export const createColumns = (onDelete: (log: LogEntry) => void): ColumnDef[] => [ { accessorKey: "status", header: "Status", @@ -175,7 +181,9 @@ export const createColumns = ( cell: ({ row }) => { const log = row.original; return ( - + ); }, }, diff --git a/ui/app/workspace/logs/views/logChatMessageView.tsx b/ui/app/workspace/logs/views/logChatMessageView.tsx index 879b01c4c5..a56f0ffc51 100644 --- a/ui/app/workspace/logs/views/logChatMessageView.tsx +++ b/ui/app/workspace/logs/views/logChatMessageView.tsx @@ -86,7 +86,9 @@ export default function LogChatMessageView({ message }: LogChatMessageViewProps) options={{ scrollBeyondLastLine: false, collapsibleBlocks: true, lineNumbers: "off", alwaysConsumeMouseWheel: false }} /> ) : ( -
{message.thought}
+
+ {message.thought} +
)} )} @@ -107,14 +109,14 @@ export default function LogChatMessageView({ message }: LogChatMessageViewProps) options={{ scrollBeyondLastLine: false, collapsibleBlocks: true, lineNumbers: "off", alwaysConsumeMouseWheel: false }} /> ) : ( -
{message.refusal}
+
{message.refusal}
)} )} {/* Handle content */} {message.content && ( -
+
{typeof message.content === "string" ? ( <> {isJson(message.content) ? ( @@ -129,7 +131,7 @@ export default function LogChatMessageView({ message }: LogChatMessageViewProps) options={{ scrollBeyondLastLine: false, collapsibleBlocks: true, lineNumbers: "off", alwaysConsumeMouseWheel: false }} /> ) : ( -
{message.content}
+
{message.content}
)} ) : ( diff --git a/ui/app/workspace/logs/views/logDetailsSheet.tsx b/ui/app/workspace/logs/views/logDetailsSheet.tsx index 8ad7d3199a..5ead7a751f 100644 --- a/ui/app/workspace/logs/views/logDetailsSheet.tsx +++ b/ui/app/workspace/logs/views/logDetailsSheet.tsx @@ -4,17 +4,30 @@ import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { DottedSeparator } from "@/components/ui/separator"; import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons"; import { RequestTypeColors, RequestTypeLabels, Status, StatusColors } from "@/lib/constants/logs"; import { LogEntry } from "@/lib/types/logs"; -import { DollarSign, FileText, Timer, Trash2 } from "lucide-react"; +import { Clipboard, DollarSign, FileText, Timer, Trash2 } from "lucide-react"; import moment from "moment"; +import { toast } from "sonner"; import { CodeEditor } from "./codeEditor"; import LogChatMessageView from "./logChatMessageView"; import LogEntryDetailsView from "./logEntryDetailsView"; import LogResponsesMessageView from "./logResponsesMessageView"; import SpeechView from "./speechView"; import TranscriptionView from "./transcriptionView"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alertDialog"; interface LogDetailSheetProps { log: LogEntry | null; @@ -34,6 +47,122 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet } catch (ignored) {} } + const copyRequestBody = async () => { + try { + // Check if request is for responses, chat, speech, text completion, or embedding (exclude transcriptions) + const object = log.object?.toLowerCase() || ""; + const isChat = object === "chat_completion" || object === "chat_completion_stream"; + const isResponses = object === "responses" || object === "responses_stream"; + const isSpeech = object === "speech" || object === "speech_stream"; + const isTextCompletion = object === "text_completion" || object === "text_completion_stream"; + const isEmbedding = object === "embedding"; + const isTranscription = object === "transcription" || object === "transcription_stream"; + + // Skip if transcription + if (isTranscription) { + toast.error("Copy request body is not available for transcription requests"); + return; + } + + // Skip if not a supported request type + if (!isChat && !isResponses && !isSpeech && !isTextCompletion && !isEmbedding) { + toast.error("Copy request body is only available for chat, responses, speech, text completion, and embedding requests"); + return; + } + + // Helper function to extract text content from ChatMessage + const extractTextFromMessage = (message: any): string => { + if (!message || !message.content) { + return ""; + } + if (typeof message.content === "string") { + return message.content; + } + if (Array.isArray(message.content)) { + return message.content + .filter((block: any) => block && block.type === "text" && block.text) + .map((block: any) => block.text || "") + .join(""); + } + return ""; + }; + + // Helper function to extract texts from ChatMessage content blocks (for embeddings) + const extractTextsFromMessage = (message: any): string[] => { + if (!message || !message.content) { + return []; + } + if (typeof message.content === "string") { + return message.content ? [message.content] : []; + } + if (Array.isArray(message.content)) { + return message.content.filter((block: any) => block && block.type === "text" && block.text).map((block: any) => block.text); + } + return []; + }; + + // Build request body following OpenAI schema + const requestBody: any = { + model: log.provider && log.model ? `${log.provider}/${log.model}` : log.model || "", + }; + + // Add messages/input/prompt based on request type + if (isChat && log.input_history && log.input_history.length > 0) { + requestBody.messages = log.input_history; + } else if (isResponses && log.responses_input_history && log.responses_input_history.length > 0) { + requestBody.input = log.responses_input_history; + } else if (isSpeech && log.speech_input) { + requestBody.input = log.speech_input.input; + } else if (isTextCompletion && log.input_history && log.input_history.length > 0) { + // For text completions, extract prompt from input_history + const firstMessage = log.input_history[0]; + const prompt = extractTextFromMessage(firstMessage); + if (prompt) { + requestBody.prompt = prompt; + } + } else if (isEmbedding && log.input_history && log.input_history.length > 0) { + // For embeddings, extract all texts from input_history + const texts: string[] = []; + for (const message of log.input_history) { + const messageTexts = extractTextsFromMessage(message); + texts.push(...messageTexts); + } + if (texts.length > 0) { + // Use single string if only one text, otherwise use array + requestBody.input = texts.length === 1 ? texts[0] : texts; + } + } + + // Add params (excluding tools and instructions as they're handled separately in OpenAI schema) + if (log.params) { + const paramsCopy = { ...log.params }; + // Remove tools and instructions from params as they're typically top-level in OpenAI schema + // Keep all other params (temperature, max_tokens, voice, etc.) + delete paramsCopy.tools; + delete paramsCopy.instructions; + + // Merge remaining params into request body + Object.assign(requestBody, paramsCopy); + } + + // Add tools if they exist (for chat and responses) - OpenAI schema has tools at top level + if ((isChat || isResponses) && log.params?.tools && Array.isArray(log.params.tools) && log.params.tools.length > 0) { + requestBody.tools = log.params.tools; + } + + // Add instructions if they exist (for responses) - OpenAI schema has instructions at top level + if (isResponses && log.params?.instructions) { + requestBody.instructions = log.params.instructions; + } + + const requestBodyJson = JSON.stringify(requestBody, null, 2); + navigator.clipboard.writeText(requestBodyJson); + toast.success("Request body copied to clipboard"); + } catch (error) { + toast.error("Failed to copy request body"); + } + }; + return ( @@ -46,15 +175,44 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet
- +
+ + + + + + +

Copy request body JSON

+
+
+
+ + + + + + + Are you sure you want to delete this log? + This action cannot be undone. This will permanently delete the log entry. + + + Cancel + { + handleDelete(log); + onOpenChange(false); + }} + > + Delete + + + + +
@@ -270,7 +428,7 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet
{toolsParameter && (
-
Tools
+
Tools ({log.params?.tools?.length || 0})
= ({ open, onClose, onSaved }) => { } }, [open]); - const handleChange = (field: keyof CreateMCPClientRequest, value: string | string[] | MCPConnectionType | MCPStdioConfig | undefined) => { + const handleChange = ( + field: keyof CreateMCPClientRequest, + value: string | string[] | boolean | MCPConnectionType | MCPStdioConfig | undefined, + ) => { setForm((prev) => ({ ...prev, [field]: value })); }; @@ -95,10 +100,13 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { const validator = new Validator([ // Name validation - Validator.required(form.name?.trim(), "Client name is required"), - Validator.pattern(form.name || "", /^[a-zA-Z0-9-_]+$/, "Client name can only contain letters, numbers, hyphens and underscores"), - Validator.minLength(form.name || "", 3, "Client name must be at least 3 characters"), - Validator.maxLength(form.name || "", 50, "Client name cannot exceed 50 characters"), + Validator.required(form.name?.trim(), "Server name is required"), + Validator.pattern(form.name || "", /^[a-zA-Z0-9_]+$/, "Server name can only contain letters, numbers, and underscores"), + Validator.custom(!(form.name || "").includes("-"), "Server name cannot contain hyphens"), + Validator.custom(!(form.name || "").includes(" "), "Server name cannot contain spaces"), + Validator.custom((form.name || "").length === 0 || !/^[0-9]/.test(form.name || ""), "Server name cannot start with a number"), + Validator.minLength(form.name || "", 3, "Server name must be at least 3 characters"), + Validator.maxLength(form.name || "", 50, "Server name cannot exceed 50 characters"), // Connection type specific validation ...(form.connection_type === "http" || form.connection_type === "sse" @@ -156,7 +164,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { setIsLoading(false); toast({ title: "Success", - description: "Client created", + description: "Server created", }); onSaved(); onClose(); @@ -170,7 +178,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { - New MCP Client + New MCP Server
@@ -178,7 +186,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { ) => handleChange("name", e.target.value)} - placeholder="Client name" + placeholder="Server name" maxLength={50} />
@@ -197,6 +205,15 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => {
+
+ + handleChange("is_code_mode_client", checked)} + /> +
+ {(form.connection_type === "http" || form.connection_type === "sse") && ( <>
@@ -275,7 +292,11 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { - diff --git a/ui/app/workspace/mcp-clients/views/mcpClientSheet.tsx b/ui/app/workspace/mcp-gateway/views/mcpClientSheet.tsx similarity index 63% rename from ui/app/workspace/mcp-clients/views/mcpClientSheet.tsx rename to ui/app/workspace/mcp-gateway/views/mcpClientSheet.tsx index 227190c5d9..57537d0289 100644 --- a/ui/app/workspace/mcp-clients/views/mcpClientSheet.tsx +++ b/ui/app/workspace/mcp-gateway/views/mcpClientSheet.tsx @@ -8,6 +8,7 @@ import { HeadersTable } from "@/components/ui/headersTable"; import { Input } from "@/components/ui/input"; import { Sheet, SheetContent, SheetDescription, SheetHeader, SheetTitle } from "@/components/ui/sheet"; import { Switch } from "@/components/ui/switch"; +import { TriStateCheckbox } from "@/components/ui/tristateCheckbox"; import { useToast } from "@/hooks/use-toast"; import { MCP_STATUS_COLORS } from "@/lib/constants/config"; import { getErrorMessage, useUpdateMCPClientMutation } from "@/lib/store"; @@ -34,8 +35,10 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: mode: "onBlur", defaultValues: { name: mcpClient.config.name, + is_code_mode_client: mcpClient.config.is_code_mode_client || false, headers: mcpClient.config.headers, tools_to_execute: mcpClient.config.tools_to_execute || [], + tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], }, }); @@ -43,8 +46,10 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: useEffect(() => { form.reset({ name: mcpClient.config.name, + is_code_mode_client: mcpClient.config.is_code_mode_client || false, headers: mcpClient.config.headers, tools_to_execute: mcpClient.config.tools_to_execute || [], + tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], }); }, [form, mcpClient]); @@ -54,8 +59,10 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: id: mcpClient.config.id, data: { name: data.name, + is_code_mode_client: data.is_code_mode_client, headers: data.headers, tools_to_execute: data.tools_to_execute, + tools_to_auto_execute: data.tools_to_auto_execute, }, }).unwrap(); @@ -106,6 +113,69 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: } form.setValue("tools_to_execute", newTools, { shouldDirty: true }); + + // If tool is being removed from tools_to_execute, also remove it from tools_to_auto_execute + if (!checked) { + const currentAutoExecute = form.getValues("tools_to_auto_execute") || []; + if (currentAutoExecute.includes(toolName) || currentAutoExecute.includes("*")) { + const newAutoExecute = currentAutoExecute.filter((tool) => tool !== toolName); + // If we had "*" and removed a tool, we need to recalculate + if (currentAutoExecute.includes("*")) { + // If all tools mode, keep "*" only if tool is still in tools_to_execute + if (newTools.includes("*")) { + form.setValue("tools_to_auto_execute", ["*"], { shouldDirty: true }); + } else { + // Switch to explicit list - when in wildcard mode, all remaining tools should be auto-execute + form.setValue("tools_to_auto_execute", newTools, { shouldDirty: true }); + } + } else { + form.setValue("tools_to_auto_execute", newAutoExecute, { shouldDirty: true }); + } + } + } + }; + + const handleAutoExecuteToggle = (toolName: string, checked: boolean) => { + const currentAutoExecute = form.getValues("tools_to_auto_execute") || []; + const currentTools = form.getValues("tools_to_execute") || []; + const allToolNames = mcpClient.tools?.map((tool) => tool.name) || []; + + // Check if we're in "all tools" mode (wildcard) + const isAllToolsMode = currentTools.includes("*"); + const isAllAutoExecuteMode = currentAutoExecute.includes("*"); + + let newAutoExecute: string[]; + + if (isAllAutoExecuteMode) { + if (checked) { + // Already all selected, keep wildcard + newAutoExecute = ["*"]; + } else { + // Unchecking a tool when all are selected - switch to explicit list without this tool + if (isAllToolsMode) { + newAutoExecute = allToolNames.filter((name) => name !== toolName); + } else { + newAutoExecute = currentTools.filter((name) => name !== toolName); + } + } + } else { + // We're in explicit tool selection mode + if (checked) { + // Add tool to selection + newAutoExecute = currentAutoExecute.includes(toolName) ? currentAutoExecute : [...currentAutoExecute, toolName]; + + // If we now have all allowed tools selected, switch to wildcard mode + const allowedTools = isAllToolsMode ? allToolNames : currentTools; + if (newAutoExecute.length === allowedTools.length && allowedTools.every((tool) => newAutoExecute.includes(tool))) { + newAutoExecute = ["*"]; + } + } else { + // Remove tool from selection + newAutoExecute = currentAutoExecute.filter((tool) => tool !== toolName); + } + } + + form.setValue("tools_to_auto_execute", newAutoExecute, { shouldDirty: true }); }; return ( @@ -120,7 +190,7 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: {mcpClient.config.name} {mcpClient.state} - MCP client configuration and available tools + MCP server configuration and available tools
- Manage clients that can connect to the MCP Tools endpoint. + Manage servers that can connect to the MCP Tools endpoint.
@@ -144,7 +144,10 @@ export default function MCPClientsTable({ mcpClients }: MCPClientsTableProps) { Name Connection Type + Code Mode Connection Info + Enabled Tools + Auto-execute Tools State @@ -157,51 +160,96 @@ export default function MCPClientsTable({ mcpClients }: MCPClientsTableProps) { )} - {clients.map((c: MCPClient) => ( - handleRowClick(c)}> - {c.config.name} - {getConnectionTypeDisplay(c.config.connection_type)} - {getConnectionDisplay(c)} - - {c.state} - - e.stopPropagation()}> - - - - - - - - - Remove MCP Client - - Are you sure you want to remove MCP client {c.config.name}? You will need to reconnect the client to continue - using it. - - - - Cancel - handleDelete(c)}>Delete - - - - - - ))} + + + {c.state == "connected" ? ( + <> + {autoExecuteToolsCount}/{c.tools?.length} + + ) : ( + "-" + )} + + + {c.state} + + e.stopPropagation()}> + + + + + + + + + Remove MCP Server + + Are you sure you want to remove MCP server {c.config.name}? You will need to reconnect the server to continue + using it. + + + + Cancel + handleDelete(c)}>Delete + + + + + + ); + })}
diff --git a/ui/components/sidebar.tsx b/ui/components/sidebar.tsx index bc4e5c81af..43760e44f0 100644 --- a/ui/components/sidebar.tsx +++ b/ui/components/sidebar.tsx @@ -303,7 +303,7 @@ export default function AppSidebar() { const hasLogsAccess = useRbac(RbacResource.Logs, RbacOperation.View); const hasObservabilityAccess = useRbac(RbacResource.Observability, RbacOperation.View); const hasModelProvidersAccess = useRbac(RbacResource.ModelProvider, RbacOperation.View); - const hasMCPToolsAccess = useRbac(RbacResource.MCPGateway, RbacOperation.View); + const hasMCPGatewayAccess = useRbac(RbacResource.MCPGateway, RbacOperation.View); const hasPluginsAccess = useRbac(RbacResource.Plugins, RbacOperation.View); const hasUserProvisioningAccess = useRbac(RbacResource.UserProvisioning, RbacOperation.View); const hasAuditLogsAccess = useRbac(RbacResource.AuditLogs, RbacOperation.View); @@ -353,11 +353,11 @@ export default function AppSidebar() { hasAccess: hasModelProvidersAccess, }, { - title: "MCP Tools", - url: "/workspace/mcp-clients", + title: "MCP Gateway", + url: "/workspace/mcp-gateway", icon: MCPIcon, description: "MCP configuration", - hasAccess: hasMCPToolsAccess, + hasAccess: hasMCPGatewayAccess, }, { title: "Plugins", diff --git a/ui/components/ui/headersTable.tsx b/ui/components/ui/headersTable.tsx index 980ff595a0..c943790cdd 100644 --- a/ui/components/ui/headersTable.tsx +++ b/ui/components/ui/headersTable.tsx @@ -95,10 +95,9 @@ export function HeadersTable({ {rows.map(([key, value], index) => { - // Use key for existing entries, index for the empty row - const rowKey = key !== "" ? key : `empty-${index}`; + // Use index as key to maintain stable identity during edits return ( - + void; + + /** Optional label to render to the right of the checkbox */ + label?: React.ReactNode; + + /** Optional disabled state */ + disabled?: boolean; + + /** Extra tailwind classes for the wrapper */ + className?: string; + + /** Accessible name for icon-only checkbox (e.g. when label is rendered elsewhere) */ + ariaLabel?: string; +} + +export const TriStateCheckbox: React.FC = ({ + allIds, + selectedIds, + onChange, + label, + disabled = false, + className = "", + ariaLabel, +}) => { + const state: TriState = useMemo(() => { + if (!allIds.length) return "none"; + + const selectedSet = new Set(selectedIds); + const selectedCount = allIds.filter((id) => selectedSet.has(id)).length; + + if (selectedCount === 0) return "none"; + if (selectedCount === allIds.length) return "all"; + return "some"; + }, [allIds, selectedIds]); + + const handleClick = () => { + if (disabled) return; + + let nextSelected: string[]; + + switch (state) { + case "all": + // clear all + nextSelected = []; + break; + case "some": + case "none": + default: + // select all + nextSelected = [...allIds]; + break; + } + + onChange(nextSelected); + }; + + const ariaChecked: boolean | "mixed" = state === "all" ? true : state === "none" ? false : "mixed"; + + const isChecked = state === "all"; + const isIndeterminate = state === "some"; + + return ( + + ); +}; diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index 58acc662f2..afbd5dfa42 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -259,6 +259,8 @@ export interface CoreConfig { allowed_origins: string[]; max_request_body_size_mb: number; enable_litellm_fallbacks: boolean; + mcp_agent_depth: number; + mcp_tool_execution_timeout: number; } // Semantic cache configuration types diff --git a/ui/lib/types/logs.ts b/ui/lib/types/logs.ts index 75c03d7d97..aaeee28b37 100644 --- a/ui/lib/types/logs.ts +++ b/ui/lib/types/logs.ts @@ -431,6 +431,13 @@ export interface ResponsesMessage { encrypted_content?: string; // Additional tool-specific fields [key: string]: any; + output?: string | ResponsesMessageContentBlock[] | ResponsesComputerToolCallOutputData; +} + +export interface ResponsesComputerToolCallOutputData { + type: "computer_screenshot"; + file_id?: string; + image_url?: string; } // Stream options for responses diff --git a/ui/lib/types/mcp.ts b/ui/lib/types/mcp.ts index 7b7f4f8fb8..9faee1dbb7 100644 --- a/ui/lib/types/mcp.ts +++ b/ui/lib/types/mcp.ts @@ -13,10 +13,12 @@ export interface MCPStdioConfig { export interface MCPClientConfig { id: string; name: string; + is_code_mode_client?: boolean; connection_type: MCPConnectionType; connection_string?: string; stdio_config?: MCPStdioConfig; tools_to_execute?: string[]; + tools_to_auto_execute?: string[]; headers?: Record; } @@ -28,15 +30,19 @@ export interface MCPClient { export interface CreateMCPClientRequest { name: string; + is_code_mode_client?: boolean; connection_type: MCPConnectionType; connection_string?: string; stdio_config?: MCPStdioConfig; tools_to_execute?: string[]; + tools_to_auto_execute?: string[]; headers?: Record; } export interface UpdateMCPClientRequest { name?: string; + is_code_mode_client?: boolean; headers?: Record; tools_to_execute?: string[]; + tools_to_auto_execute?: string[]; } diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 38df112f15..7801f03b63 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -437,6 +437,8 @@ export const coreConfigSchema = z.object({ allow_direct_keys: z.boolean().default(false), allowed_origins: z.array(z.string()).default(["*"]), max_request_body_size_mb: z.number().min(1).default(100), + mcp_agent_depth: z.number().min(1).default(10), + mcp_tool_execution_timeout: z.number().min(1).default(30), }); // Bifrost config schema @@ -574,7 +576,13 @@ export const maximFormSchema = z.object({ // MCP Client update schema export const mcpClientUpdateSchema = z.object({ - name: z.string().min(1, "Name is required"), + is_code_mode_client: z.boolean().optional(), + name: z + .string() + .min(1, "Name is required") + .refine((val) => !val.includes("-"), { message: "Client name cannot contain hyphens" }) + .refine((val) => !val.includes(" "), { message: "Client name cannot contain spaces" }) + .refine((val) => !/^[0-9]/.test(val), { message: "Client name cannot start with a number" }), headers: z.record(z.string(), z.string()).optional(), tools_to_execute: z .array(z.string()) @@ -594,6 +602,24 @@ export const mcpClientUpdateSchema = z.object({ }, { message: "Duplicate tool names are not allowed" }, ), + tools_to_auto_execute: z + .array(z.string()) + .optional() + .refine( + (tools) => { + if (!tools || tools.length === 0) return true; + const hasWildcard = tools.includes("*"); + return !hasWildcard || tools.length === 1; + }, + { message: "Wildcard '*' cannot be combined with other tool names" }, + ) + .refine( + (tools) => { + if (!tools) return true; + return tools.length === new Set(tools).size; + }, + { message: "Duplicate tool names are not allowed" }, + ), }); // Export type inference helpers diff --git a/ui/lib/utils/validation.ts b/ui/lib/utils/validation.ts index 7b77cb3780..aed9c326ec 100644 --- a/ui/lib/utils/validation.ts +++ b/ui/lib/utils/validation.ts @@ -371,7 +371,11 @@ function isValidWildcardOrigin(origin: string): boolean { * @returns Object with validation result and invalid origins */ export function validateOrigins(origins: string[]): { isValid: boolean; invalidOrigins: string[] } { - const invalidOrigins = origins?.filter((origin) => !isValidOrigin(origin)) || []; + if (!origins || origins.length === 0) { + return { isValid: true, invalidOrigins: [] }; + } + + const invalidOrigins = origins.filter((origin) => !isValidOrigin(origin)); return { isValid: invalidOrigins.length === 0,