diff --git a/core/bifrost.go b/core/bifrost.go index c3665e8a49..d6b264f242 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/mcp/codemode/starlark" "github.com/maximhq/bifrost/core/providers/anthropic" "github.com/maximhq/bifrost/core/providers/azure" "github.com/maximhq/bifrost/core/providers/bedrock" @@ -72,7 +73,7 @@ type Bifrost struct { oauth2Provider schemas.OAuth2Provider // OAuth provider instance logger schemas.Logger // logger instance, default logger is used if not provided tracer atomic.Value // tracer for distributed tracing (stores schemas.Tracer, NoOpTracer if not configured) - mcpManager *mcp.MCPManager // MCP integration manager (nil if MCP not configured) + McpManager mcp.MCPManagerInterface // MCP integration manager (nil if MCP not configured) mcpInitOnce sync.Once // Ensures MCP manager is initialized only once dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. keySelector schemas.KeySelector // Custom key selector function @@ -185,6 +186,12 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { logger: config.Logger, } bifrost.tracer.Store(&tracerWrapper{tracer: tracer}) + if config.LLMPlugins == nil { + config.LLMPlugins = make([]schemas.LLMPlugin, 0) + } + if config.MCPPlugins == nil { + config.MCPPlugins = make([]schemas.MCPPlugin, 0) + } bifrost.llmPlugins.Store(&config.LLMPlugins) bifrost.mcpPlugins.Store(&config.MCPPlugins) @@ -269,7 +276,17 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { bifrost.releasePluginPipeline(pp) } } - bifrost.mcpManager = mcp.NewMCPManager(bifrostCtx, mcpConfig, bifrost.oauth2Provider, bifrost.logger) + // Create Starlark CodeMode for code execution + starlark.SetLogger(bifrost.logger) + var codeModeConfig *mcp.CodeModeConfig + if mcpConfig.ToolManagerConfig != nil { + codeModeConfig = &mcp.CodeModeConfig{ + BindingLevel: mcpConfig.ToolManagerConfig.CodeModeBindingLevel, + ToolExecutionTimeout: mcpConfig.ToolManagerConfig.ToolExecutionTimeout, + } + } + codeMode := starlark.NewStarlarkCodeMode(codeModeConfig) + bifrost.McpManager = mcp.NewMCPManager(bifrostCtx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) bifrost.logger.Info("MCP integration initialized successfully") }) } @@ -327,21 +344,6 @@ func (bifrost *Bifrost) getTracer() schemas.Tracer { // We will keep on adding other aspects as required func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error { bifrost.dropExcessRequests.Store(config.DropExcessRequests) - - // Update LLM plugins atomically - if config.LLMPlugins != nil { - llmPluginsCopy := make([]schemas.LLMPlugin, len(config.LLMPlugins)) - copy(llmPluginsCopy, config.LLMPlugins) - bifrost.llmPlugins.Store(&llmPluginsCopy) - } - - // Update MCP plugins atomically - if config.MCPPlugins != nil { - mcpPluginsCopy := make([]schemas.MCPPlugin, len(config.MCPPlugins)) - copy(mcpPluginsCopy, config.MCPPlugins) - bifrost.mcpPlugins.Store(&mcpPluginsCopy) - } - return nil } @@ -711,8 +713,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx *schemas.BifrostContext, req * } // Check if we should enter agent mode - if bifrost.mcpManager != nil { - return bifrost.mcpManager.CheckAndExecuteAgentForChatRequest( + if bifrost.McpManager != nil { + return bifrost.McpManager.CheckAndExecuteAgentForChatRequest( ctx, req, response, @@ -808,8 +810,8 @@ func (bifrost *Bifrost) ResponsesRequest(ctx *schemas.BifrostContext, req *schem } // Check if we should enter agent mode - if bifrost.mcpManager != nil { - return bifrost.mcpManager.CheckAndExecuteAgentForResponsesRequest( + if bifrost.McpManager != nil { + return bifrost.McpManager.CheckAndExecuteAgentForResponsesRequest( ctx, req, response, @@ -2105,14 +2107,22 @@ func (bifrost *Bifrost) ContainerFileDeleteRequest(ctx *schemas.BifrostContext, } // RemovePlugin removes a plugin from the server. -func (bifrost *Bifrost) RemovePlugin(name string, pluginType schemas.PluginType) error { - switch pluginType { - case schemas.PluginTypeLLM: - return bifrost.removeLLMPlugin(name) - case schemas.PluginTypeMCP: - return bifrost.removeMCPPlugin(name) - } - return fmt.Errorf("unsupported plugin type: %v", pluginType) +func (bifrost *Bifrost) RemovePlugin(name string, pluginTypes []schemas.PluginType) error { + for _, pluginType := range pluginTypes { + switch pluginType { + case schemas.PluginTypeLLM: + err := bifrost.removeLLMPlugin(name) + if err != nil { + return err + } + case schemas.PluginTypeMCP: + err := bifrost.removeMCPPlugin(name) + if err != nil { + return err + } + } + } + return nil } // removeLLMPlugin removes an LLM plugin from the server. @@ -2189,22 +2199,30 @@ func (bifrost *Bifrost) removeMCPPlugin(name string) error { // ReloadPlugin reloads a plugin with new instance // During the reload - it's stop the world phase where we take a global lock on the plugin mutex -func (bifrost *Bifrost) ReloadPlugin(plugin schemas.BasePlugin, pluginType schemas.PluginType) error { - switch pluginType { - case schemas.PluginTypeLLM: - llmPlugin, ok := plugin.(schemas.LLMPlugin) - if !ok { - return fmt.Errorf("plugin %s is not an LLMPlugin", plugin.GetName()) - } - return bifrost.reloadLLMPlugin(llmPlugin) - case schemas.PluginTypeMCP: - mcpPlugin, ok := plugin.(schemas.MCPPlugin) - if !ok { - return fmt.Errorf("plugin %s is not an MCPPlugin", plugin.GetName()) - } - return bifrost.reloadMCPPlugin(mcpPlugin) - } - return fmt.Errorf("unsupported plugin type: %v", pluginType) +func (bifrost *Bifrost) ReloadPlugin(plugin schemas.BasePlugin, pluginTypes []schemas.PluginType) error { + for _, pluginType := range pluginTypes { + switch pluginType { + case schemas.PluginTypeLLM: + llmPlugin, ok := plugin.(schemas.LLMPlugin) + if !ok { + return fmt.Errorf("plugin %s is not an LLMPlugin", plugin.GetName()) + } + err := bifrost.reloadLLMPlugin(llmPlugin) + if err != nil { + return err + } + case schemas.PluginTypeMCP: + mcpPlugin, ok := plugin.(schemas.MCPPlugin) + if !ok { + return fmt.Errorf("plugin %s is not an MCPPlugin", plugin.GetName()) + } + err := bifrost.reloadMCPPlugin(mcpPlugin) + if err != nil { + return err + } + } + } + return nil } // reloadLLMPlugin reloads an LLM plugin with new instance @@ -2643,11 +2661,11 @@ func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *syn // return args.Message, nil // }, toolSchema) func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.ChatTool) error { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.RegisterTool(name, description, handler, toolSchema) + return bifrost.McpManager.RegisterTool(name, description, handler, toolSchema) } // IMPORTANT: Running the MCP client management operations (GetMCPClients, AddMCPClient, RemoveMCPClient, EditMCPClientTools) @@ -2661,11 +2679,11 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a // - []schemas.MCPClient: List of all MCP clients // - error: Any retrieval error func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return nil, fmt.Errorf("MCP is not configured in this Bifrost instance") } - clients := bifrost.mcpManager.GetClients() + clients := bifrost.McpManager.GetClients() clientsInConfig := make([]schemas.MCPClient, 0, len(clients)) for _, client := range clients { @@ -2703,10 +2721,10 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { // Returns: // - []schemas.ChatTool: List of available tools func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return nil } - return bifrost.mcpManager.GetAvailableTools(ctx) + return bifrost.McpManager.GetAvailableTools(ctx) } // AddMCPClient adds a new MCP client to the Bifrost instance. @@ -2725,13 +2743,13 @@ func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.Chat // ConnectionType: schemas.MCPConnectionTypeHTTP, // ConnectionString: &url, // }) -func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { - if bifrost.mcpManager == nil { +func (bifrost *Bifrost) AddMCPClient(config *schemas.MCPClientConfig) error { + if bifrost.McpManager == nil { // Use sync.Once to ensure thread-safe initialization bifrost.mcpInitOnce.Do(func() { // Initialize with empty config - client will be added via AddClient below mcpConfig := schemas.MCPConfig{ - ClientConfigs: []schemas.MCPClientConfig{}, + ClientConfigs: []*schemas.MCPClientConfig{}, } // Set up plugin pipeline provider functions for executeCode tool hooks mcpConfig.PluginPipelineProvider = func() interface{} { @@ -2742,16 +2760,19 @@ func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { bifrost.releasePluginPipeline(pp) } } - bifrost.mcpManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger) + // Create Starlark CodeMode for code execution (with default config) + starlark.SetLogger(bifrost.logger) + codeMode := starlark.NewStarlarkCodeMode(nil) + bifrost.McpManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) }) } // Handle case where initialization succeeded elsewhere but manager is still nil - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP manager is not initialized") } - return bifrost.mcpManager.AddClient(config) + return bifrost.McpManager.AddClient(config) } // RemoveMCPClient removes an MCP client from the Bifrost instance. @@ -2770,23 +2791,23 @@ func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { // log.Fatalf("Failed to remove MCP client: %v", err) // } func (bifrost *Bifrost) RemoveMCPClient(id string) error { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.RemoveClient(id) + return bifrost.McpManager.RemoveClient(id) } // SetMCPManager sets the MCP manager for this Bifrost instance. -// This is primarily used for testing purposes to inject a custom MCP manager. +// This allows injecting a custom MCP manager implementation (e.g., for enterprise features). // // Parameters: -// - manager: The MCP manager to set -func (bifrost *Bifrost) SetMCPManager(manager *mcp.MCPManager) { - bifrost.mcpManager = manager +// - manager: The MCP manager to set (must implement MCPManagerInterface) +func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) { + bifrost.McpManager = manager } -// EditMCPClient edits the tools of an MCP client. +// UpdateMCPClient updates the MCP client. // This allows for dynamic MCP client tool management at runtime. // // Parameters: @@ -2798,16 +2819,16 @@ func (bifrost *Bifrost) SetMCPManager(manager *mcp.MCPManager) { // // Example: // -// err := bifrost.EditMCPClient("my-mcp-client-id", schemas.MCPClientConfig{ +// err := bifrost.UpdateMCPClient("my-mcp-client-id", schemas.MCPClientConfig{ // Name: "my-mcp-client-name", // ToolsToExecute: []string{"tool1", "tool2"}, // }) -func (bifrost *Bifrost) EditMCPClient(id string, updatedConfig schemas.MCPClientConfig) error { - if bifrost.mcpManager == nil { +func (bifrost *Bifrost) EditMCPClient(id string, updatedConfig *schemas.MCPClientConfig) error { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.EditClient(id, updatedConfig) + return bifrost.McpManager.EditClient(id, updatedConfig) } // ReconnectMCPClient attempts to reconnect an MCP client if it is disconnected. @@ -2818,21 +2839,21 @@ func (bifrost *Bifrost) EditMCPClient(id string, updatedConfig schemas.MCPClient // Returns: // - error: Any reconnection error func (bifrost *Bifrost) ReconnectMCPClient(id string) error { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.ReconnectClient(id) + return bifrost.McpManager.ReconnectClient(id) } // UpdateToolManagerConfig updates the tool manager config for the MCP manager. // This allows for hot-reloading of the tool manager config at runtime. func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - bifrost.mcpManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ + bifrost.McpManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ MaxAgentDepth: maxAgentDepth, ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), @@ -3424,8 +3445,8 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif } // Add MCP tools to request if MCP is configured and requested - if bifrost.mcpManager != nil { - req = bifrost.mcpManager.AddToolsToRequest(ctx, req) + if bifrost.McpManager != nil { + req = bifrost.McpManager.AddToolsToRequest(ctx, req) } tracer := bifrost.getTracer() @@ -3614,8 +3635,8 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } // Add MCP tools to request if MCP is configured and requested - if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.mcpManager != nil { - req = bifrost.mcpManager.AddToolsToRequest(ctx, req) + if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.McpManager != nil { + req = bifrost.McpManager.AddToolsToRequest(ctx, req) } tracer := bifrost.getTracer() @@ -4390,7 +4411,7 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r // - *schemas.BifrostMCPResponse: The MCP response after all hooks // - *schemas.BifrostError: Any execution error func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpRequest *schemas.BifrostMCPRequest, requestType schemas.RequestType) (*schemas.BifrostMCPResponse, *schemas.BifrostError) { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ @@ -4449,7 +4470,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR } // Execute tool with modified request - result, err := bifrost.mcpManager.ExecuteToolCall(ctx, preReq) + result, err := bifrost.McpManager.ExecuteToolCall(ctx, preReq) // Prepare MCP response and error for post-hooks var mcpResp *schemas.BifrostMCPResponse @@ -5257,8 +5278,8 @@ func (bifrost *Bifrost) Shutdown() { }) // Cleanup MCP manager - if bifrost.mcpManager != nil { - err := bifrost.mcpManager.Cleanup() + if bifrost.McpManager != nil { + err := bifrost.McpManager.Cleanup() if err != nil { bifrost.logger.Warn("Error cleaning up MCP manager: %s", err.Error()) } diff --git a/core/go.mod b/core/go.mod index 16bc21ea89..fde33feb4b 100644 --- a/core/go.mod +++ b/core/go.mod @@ -13,14 +13,13 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 github.com/aws/smithy-go v1.24.0 github.com/bytedance/sonic v1.14.2 - github.com/clarkmcc/go-typescript v0.7.0 - github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 github.com/google/uuid v1.6.0 github.com/hajimehoshi/go-mp3 v0.3.4 github.com/mark3labs/mcp-go v0.43.2 github.com/rs/zerolog v1.34.0 github.com/stretchr/testify v1.11.1 github.com/valyala/fasthttp v1.68.0 + go.starlark.net v0.0.0-20260102030733-3fee463870c9 golang.org/x/oauth2 v0.34.0 golang.org/x/text v0.32.0 ) @@ -29,7 +28,6 @@ require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect - github.com/Masterminds/semver/v3 v3.3.1 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect @@ -50,10 +48,7 @@ require ( github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dlclark/regexp2 v1.11.4 // indirect - github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect - github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect @@ -73,5 +68,6 @@ require ( golang.org/x/crypto v0.46.0 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sys v0.39.0 // indirect + google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/core/go.sum b/core/go.sum index 83be499d2b..c47a4c961b 100644 --- a/core/go.sum +++ b/core/go.sum @@ -14,8 +14,6 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= -github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= -github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -66,8 +64,6 @@ github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPII github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= -github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= -github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -75,21 +71,13 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= -github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= -github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= -github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= -github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -155,6 +143,8 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk= +go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8= golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= @@ -172,11 +162,11 @@ golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/core/internal/mcptests/agent_filtering_test.go b/core/internal/mcptests/agent_filtering_test.go index 01c1d0e8e2..1058f5a7f2 100644 --- a/core/internal/mcptests/agent_filtering_test.go +++ b/core/internal/mcptests/agent_filtering_test.go @@ -531,7 +531,7 @@ func TestAgent_FilteringWithMultipleClients(t *testing.T) { tempConfig := GetTemperatureMCPClientConfig("") tempConfig.ToolsToAutoExecute = []string{} // Not auto-executed - err = manager.AddClient(tempConfig) + err = manager.AddClient(&tempConfig) if err != nil { t.Skipf("Skipping test - temperature server not available: %v", err) return @@ -620,7 +620,7 @@ func TestAgent_ToolConflictInAgentMode(t *testing.T) { tempConfig := GetTemperatureMCPClientConfig("") tempConfig.ToolsToAutoExecute = []string{} // Not auto - err = manager.AddClient(tempConfig) + err = manager.AddClient(&tempConfig) if err != nil { t.Skipf("Skipping test - temperature server not available: %v", err) return diff --git a/core/internal/mcptests/agent_request_id_test.go b/core/internal/mcptests/agent_request_id_test.go index da04721c36..ded2180b6c 100644 --- a/core/internal/mcptests/agent_request_id_test.go +++ b/core/internal/mcptests/agent_request_id_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/mcp/codemode/starlark" "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,15 +21,26 @@ import ( func setupMCPManagerWithRequestIDFunc(t *testing.T, fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, clientConfigs ...schemas.MCPClientConfig) *mcp.MCPManager { t.Helper() + logger := &testLogger{t: t} + + // Convert to pointer slice for MCPConfig + clientConfigPtrs := make([]*schemas.MCPClientConfig, len(clientConfigs)) + for i := range clientConfigs { + clientConfigPtrs[i] = &clientConfigs[i] + } + // Create MCP config with request ID function mcpConfig := &schemas.MCPConfig{ - ClientConfigs: clientConfigs, + ClientConfigs: clientConfigPtrs, FetchNewRequestIDFunc: fetchNewRequestIDFunc, } - // Create MCP manager - logger := &testLogger{t: t} - manager := mcp.NewMCPManager(context.Background(), *mcpConfig, nil, logger) + // Create Starlark CodeMode + starlark.SetLogger(logger) + codeMode := starlark.NewStarlarkCodeMode(nil) + + // Create MCP manager - dependencies are injected automatically + manager := mcp.NewMCPManager(context.Background(), *mcpConfig, nil, logger, codeMode) // Cleanup t.Cleanup(func() { diff --git a/core/internal/mcptests/client_management_test.go b/core/internal/mcptests/client_management_test.go index b871d7d001..1d0e452ffa 100644 --- a/core/internal/mcptests/client_management_test.go +++ b/core/internal/mcptests/client_management_test.go @@ -25,11 +25,11 @@ func TestAddClientDuplicate(t *testing.T) { // Add client clientConfig := GetSampleHTTPClientConfig(config.HTTPServerURL) - err := manager.AddClient(clientConfig) + err := manager.AddClient(&clientConfig) require.NoError(t, err, "should add client first time") // Try to add same client again - err = manager.AddClient(clientConfig) + err = manager.AddClient(&clientConfig) // Should either return error or be idempotent if err == nil { clients := manager.GetClients() @@ -183,7 +183,7 @@ func TestEditClient(t *testing.T) { updatedConfig.Name = "UpdatedName" updatedConfig.ToolsToExecute = []string{"calculator", "echo"} - err := manager.EditClient(clientID, updatedConfig) + err := manager.EditClient(clientID, &updatedConfig) require.NoError(t, err, "should edit client") // Verify changes @@ -199,7 +199,7 @@ func TestEditClientInvalidID(t *testing.T) { // Try to edit non-existent client clientConfig := GetSampleHTTPClientConfig("http://example.com") - err := manager.EditClient("non-existent-id", clientConfig) + err := manager.EditClient("non-existent-id", &clientConfig) assert.Error(t, err, "should error when editing non-existent client") } @@ -225,7 +225,7 @@ func TestEditClientInvalidConfig(t *testing.T) { // Missing ConnectionString } - err := manager.EditClient(clientID, invalidConfig) + err := manager.EditClient(clientID, &invalidConfig) // Should return error or leave client unchanged if err == nil { clients = manager.GetClients() @@ -257,7 +257,7 @@ func TestEditClientChangeConnectionType(t *testing.T) { updatedConfig := clientConfig updatedConfig.ConnectionType = schemas.MCPConnectionTypeSSE - err := manager.EditClient(clientID, updatedConfig) + err := manager.EditClient(clientID, &updatedConfig) assert.Error(t, err, "should not allow connection type change") clients = manager.GetClients() if len(clients) > 0 { @@ -428,7 +428,7 @@ func TestConcurrentClientOperations(t *testing.T) { clientConfig := GetSampleHTTPClientConfig(config.HTTPServerURL) clientConfig.ID = string(rune('a'+id)) + "-concurrent-client" - err := manager.AddClient(clientConfig) + err := manager.AddClient(&clientConfig) if err != nil { errors <- err } diff --git a/core/internal/mcptests/concurrency_advanced_test.go b/core/internal/mcptests/concurrency_advanced_test.go index 0190ab26c9..54057ce4a5 100644 --- a/core/internal/mcptests/concurrency_advanced_test.go +++ b/core/internal/mcptests/concurrency_advanced_test.go @@ -163,7 +163,7 @@ func TestConcurrent_AddRemoveClients(t *testing.T) { ToolsToAutoExecute: []string{}, } - err := manager.AddClient(clientConfig) + err := manager.AddClient(&clientConfig) if err != nil { // InProcess connections without a server instance will fail // This is expected - we're just testing that the operations are concurrent and don't deadlock diff --git a/core/internal/mcptests/concurrency_test.go b/core/internal/mcptests/concurrency_test.go index f81cc714eb..d71e1a501f 100644 --- a/core/internal/mcptests/concurrency_test.go +++ b/core/internal/mcptests/concurrency_test.go @@ -437,7 +437,7 @@ func TestConcurrent_EditClientDuringExecution(t *testing.T) { // Edit client - update name (must not contain spaces) updatedConfig := clientConfig updatedConfig.Name = "UpdatedClientName" - err := manager.EditClient(clientConfig.ID, updatedConfig) + err := manager.EditClient(clientConfig.ID, &updatedConfig) if err != nil { errors <- fmt.Errorf("failed to edit client: %v", err) } else { diff --git a/core/internal/mcptests/error_handling_protocol_test.go b/core/internal/mcptests/error_handling_protocol_test.go index 508f04de5f..bb6ad929d9 100644 --- a/core/internal/mcptests/error_handling_protocol_test.go +++ b/core/internal/mcptests/error_handling_protocol_test.go @@ -289,7 +289,7 @@ func TestErrorHandling_STDIO_MCPErrorResponse(t *testing.T) { errorServerConfig := GetErrorTestServerConfig(bifrostRoot) manager := setupMCPManager(t) - err := manager.AddClient(errorServerConfig) + err := manager.AddClient(&errorServerConfig) if err != nil { t.Skipf("error-test-server not available: %v (build with: cd examples/mcps/error-test-server && go build -o bin/error-test-server)", err) } @@ -367,7 +367,7 @@ func TestErrorHandling_STDIO_TimeoutScenario(t *testing.T) { ToolExecutionTimeout: 2 * time.Second, // 2 second timeout }) - err := manager.AddClient(errorServerConfig) + err := manager.AddClient(&errorServerConfig) if err != nil { t.Skipf("error-test-server not available: %v", err) } @@ -420,7 +420,7 @@ func TestErrorHandling_STDIO_MalformedJSON(t *testing.T) { errorServerConfig := GetErrorTestServerConfig(bifrostRoot) manager := setupMCPManager(t) - err := manager.AddClient(errorServerConfig) + err := manager.AddClient(&errorServerConfig) if err != nil { t.Skipf("error-test-server not available: %v", err) } @@ -469,7 +469,7 @@ func TestErrorHandling_STDIO_IntermittentFailures(t *testing.T) { errorServerConfig := GetErrorTestServerConfig(bifrostRoot) manager := setupMCPManager(t) - err := manager.AddClient(errorServerConfig) + err := manager.AddClient(&errorServerConfig) if err != nil { t.Skipf("error-test-server not available: %v", err) } diff --git a/core/internal/mcptests/fixtures.go b/core/internal/mcptests/fixtures.go index f85517f3f8..38d85f6ee8 100644 --- a/core/internal/mcptests/fixtures.go +++ b/core/internal/mcptests/fixtures.go @@ -13,7 +13,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/mcp" - mcpcore "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/mcp/codemode/starlark" "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1461,17 +1461,28 @@ func setupBifrost(t *testing.T) *bifrost.Bifrost { } // setupMCPManager creates an MCP manager for testing -func setupMCPManager(t *testing.T, clientConfigs ...schemas.MCPClientConfig) *mcpcore.MCPManager { +func setupMCPManager(t *testing.T, clientConfigs ...schemas.MCPClientConfig) *mcp.MCPManager { t.Helper() + logger := &testLogger{t: t} + + // Convert to pointer slice for MCPConfig + clientConfigPtrs := make([]*schemas.MCPClientConfig, len(clientConfigs)) + for i := range clientConfigs { + clientConfigPtrs[i] = &clientConfigs[i] + } + // Create MCP config mcpConfig := &schemas.MCPConfig{ - ClientConfigs: clientConfigs, + ClientConfigs: clientConfigPtrs, } - // Create MCP manager - logger := &testLogger{t: t} - manager := mcpcore.NewMCPManager(context.Background(), *mcpConfig, nil, logger) + // Create Starlark CodeMode + starlark.SetLogger(logger) + codeMode := starlark.NewStarlarkCodeMode(nil) + + // Create MCP manager - dependencies are injected automatically + manager := mcp.NewMCPManager(context.Background(), *mcpConfig, nil, logger, codeMode) // Cleanup t.Cleanup(func() { diff --git a/core/internal/mcptests/health_monitoring_test.go b/core/internal/mcptests/health_monitoring_test.go index f0834c24d2..7e7155d04b 100644 --- a/core/internal/mcptests/health_monitoring_test.go +++ b/core/internal/mcptests/health_monitoring_test.go @@ -50,7 +50,7 @@ func TestHealthCheckSTDIOServerDropAndRecoverIn20Seconds(t *testing.T) { t.Logf("✅ Health monitor detected server drop") // 5. Restart STDIO process (re-add client) - err = manager.AddClient(clientConfig) + err = manager.AddClient(&clientConfig) require.NoError(t, err, "should re-add client to simulate server recovery") t.Logf("🔄 Simulated STDIO server recovery by re-adding client") @@ -181,7 +181,7 @@ func TestHealthCheckStateTransitions(t *testing.T) { assert.Len(t, clients, 0, "client should be removed") // Re-add client (simulates reconnection) - err = manager.AddClient(clientConfig) + err = manager.AddClient(&clientConfig) require.NoError(t, err, "should re-add client") // Verify client is connected again @@ -392,7 +392,7 @@ func TestHealthCheckReconnectAfterFailure(t *testing.T) { time.Sleep(2 * time.Second) // Re-add client (manual reconnection) - err = manager.AddClient(clientConfig) + err = manager.AddClient(&clientConfig) require.NoError(t, err, "should re-add client") // Wait for health monitoring to stabilize diff --git a/core/internal/mcptests/integration_test.go b/core/internal/mcptests/integration_test.go index 6045070c31..776a0ab80f 100644 --- a/core/internal/mcptests/integration_test.go +++ b/core/internal/mcptests/integration_test.go @@ -34,7 +34,7 @@ func TestIntegration_FullChatWorkflow(t *testing.T) { httpConfig := GetSampleHTTPClientConfig(config.HTTPServerURL) httpConfig.ID = "http-integration-test" applyTestConfigHeaders(t, &httpConfig) - err := manager.AddClient(httpConfig) + err := manager.AddClient(&httpConfig) if err != nil { t.Logf("Could not add HTTP client: %v", err) } @@ -341,7 +341,7 @@ func TestIntegration_ReconnectDuringExecution(t *testing.T) { httpConfig := GetSampleHTTPClientConfig(config.HTTPServerURL) httpConfig.ID = "reconnect-test-client" applyTestConfigHeaders(t, &httpConfig) - err := manager.AddClient(httpConfig) + err := manager.AddClient(&httpConfig) require.NoError(t, err, "should add HTTP client") // Wait for client to connect diff --git a/core/internal/mcptests/tool_filtering_test.go b/core/internal/mcptests/tool_filtering_test.go index 2fd950e980..037738679f 100644 --- a/core/internal/mcptests/tool_filtering_test.go +++ b/core/internal/mcptests/tool_filtering_test.go @@ -335,7 +335,7 @@ func TestFilteringChangesAfterClientEdit(t *testing.T) { // Edit client to only allow second tool clientConfig.ToolsToExecute = []string{"hash"} - err := manager.EditClient(clientConfig.ID, clientConfig) + err := manager.EditClient(clientConfig.ID, &clientConfig) require.NoError(t, err, "edit should succeed") // Verify configuration changed diff --git a/core/mcp/agent.go b/core/mcp/agent.go index e1d3895874..2bb8c8974d 100644 --- a/core/mcp/agent.go +++ b/core/mcp/agent.go @@ -182,10 +182,10 @@ func executeAgent( toolName := *toolCall.Function.Name client := clientManager.GetClientForTool(toolName) if client == nil { - // Allow code mode list and read tool tools - if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile { + // Allow code mode list, read, and docs tools (all read-only operations) + if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile || toolName == ToolTypeGetToolDocs { autoExecutableTools = append(autoExecutableTools, toolCall) - logger.Debug(fmt.Sprintf("Tool %s can be auto-executed", toolName)) + logger.Debug("Tool %s can be auto-executed", toolName) continue } else if toolName == ToolTypeExecuteToolCode { // Build allowed auto-execution tools map for code mode validation @@ -194,14 +194,14 @@ func executeAgent( // Parse tool arguments var arguments map[string]interface{} if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { - logger.Debug(fmt.Sprintf("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)) + logger.Debug("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) continue } code, ok := arguments["code"].(string) if !ok || code == "" { - logger.Debug(fmt.Sprintf("%s Code parameter missing or empty", CodeModeLogPrefix)) + logger.Debug("%s Code parameter missing or empty", CodeModeLogPrefix) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) continue } @@ -209,59 +209,58 @@ func executeAgent( // Step 1: Convert literal \n escape sequences to actual newlines for parsing codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") if len(codeWithNewlines) != len(code) { - logger.Debug(fmt.Sprintf("%s Converted literal \\n escape sequences to actual newlines", CodeModeLogPrefix)) + logger.Debug("%s Converted literal \\n escape sequences to actual newlines", CodeModeLogPrefix) } // Step 2: Extract tool calls from code during AST formation extractedToolCalls, err := extractToolCallsFromCode(codeWithNewlines) if err != nil { - logger.Debug(fmt.Sprintf("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err)) + logger.Debug("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) continue } - logger.Debug(fmt.Sprintf("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls))) + logger.Debug("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls)) // Step 3: Validate all tool calls against allowedAutoExecutionTools canAutoExecute := true if len(extractedToolCalls) > 0 { // If there are tool calls, we need allowedAutoExecutionTools to validate them if len(allowedAutoExecutionTools) == 0 { - logger.Debug(fmt.Sprintf("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix)) + logger.Debug("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix) canAutoExecute = false } else { - logger.Debug(fmt.Sprintf("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools))) + logger.Debug("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools)) // Validate each tool call for _, extractedToolCall := range extractedToolCalls { isAllowed := isToolCallAllowedForCodeMode(extractedToolCall.serverName, extractedToolCall.toolName, allClientNames, allowedAutoExecutionTools) if !isAllowed { - logger.Debug(fmt.Sprintf("%s Tool call %s.%s: allowed=%v", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName, isAllowed)) - logger.Debug(fmt.Sprintf("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName)) + logger.Debug("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName) canAutoExecute = false break } } if canAutoExecute { - logger.Debug(fmt.Sprintf("%s All tool calls validated successfully", CodeModeLogPrefix)) + logger.Debug("%s All tool calls validated successfully", CodeModeLogPrefix) } } } else { - logger.Debug(fmt.Sprintf("%s No tool calls found in code, skipping validation", CodeModeLogPrefix)) + logger.Debug("%s No tool calls found in code, skipping validation", CodeModeLogPrefix) } // Add to appropriate list based on validation result if canAutoExecute { autoExecutableTools = append(autoExecutableTools, toolCall) - logger.Debug(fmt.Sprintf("Tool %s can be auto-executed (validation passed)", toolName)) + logger.Debug("Tool %s can be auto-executed (validation passed)", toolName) } else { nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) - logger.Debug(fmt.Sprintf("Tool %s cannot be auto-executed (validation failed)", toolName)) + logger.Debug("Tool %s cannot be auto-executed (validation failed)", toolName) } continue } // Else, if client not found, treat as non-auto-executable (can be a manually passed tool) - logger.Debug(fmt.Sprintf("Client not found for tool %s, treating as non-auto-executable", toolName)) + logger.Debug("Client not found for tool %s, treating as non-auto-executable", toolName) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) continue } @@ -269,15 +268,15 @@ func executeAgent( // Check if tool can be auto-executed if canAutoExecuteTool(toolName, client.ExecutionConfig) { autoExecutableTools = append(autoExecutableTools, toolCall) - logger.Debug(fmt.Sprintf("Tool %s can be auto-executed", toolName)) + logger.Debug("Tool %s can be auto-executed", toolName) } else { nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) - logger.Debug(fmt.Sprintf("Tool %s cannot be auto-executed", toolName)) + logger.Debug("Tool %s cannot be auto-executed", toolName) } } - logger.Debug(fmt.Sprintf("Auto-executable tools: %d", len(autoExecutableTools))) - logger.Debug(fmt.Sprintf("Non-auto-executable tools: %d", len(nonAutoExecutableTools))) + logger.Debug("Auto-executable tools: %d", len(autoExecutableTools)) + logger.Debug("Non-auto-executable tools: %d", len(nonAutoExecutableTools)) // Execute auto-executable tools first var executedToolResults []*schemas.ChatMessage @@ -300,7 +299,7 @@ func executeAgent( mcpResponse, toolErr := executeToolFunc(ctx, mcpRequest) if toolErr != nil { - logger.Warn(fmt.Sprintf("Tool execution failed: %v", toolErr)) + logger.Warn("Tool execution failed: %v", toolErr) channelToolResults <- createToolResultMessage(toolCall, "", toolErr) } else if mcpResponse != nil && mcpResponse.ChatMessage != nil { channelToolResults <- mcpResponse.ChatMessage @@ -332,7 +331,7 @@ func executeAgent( // If there are non-auto-executable tools, return them immediately without continuing the loop if len(nonAutoExecutableTools) > 0 { - logger.Debug(fmt.Sprintf("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools))) + logger.Debug("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools)) // Return as is if its the first iteration if depth == 1 && len(allExecutedToolResults) == 0 { return currentResponse, nil @@ -361,7 +360,7 @@ func executeAgent( currentResponse = response } - logger.Debug(fmt.Sprintf("Agent mode completed after %d iterations", depth)) + logger.Debug("Agent mode completed after %d iterations", depth) return currentResponse, nil } diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go index a3f2846e71..9fd9f52805 100644 --- a/core/mcp/clientmanager.go +++ b/core/mcp/clientmanager.go @@ -72,8 +72,8 @@ func (m *MCPManager) ReconnectClient(id string) error { // // Returns: // - error: Any error that occurred during client addition or connection -func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { - if err := validateMCPClientConfig(&config); err != nil { +func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error { + if err := validateMCPClientConfig(config); err != nil { return fmt.Errorf("invalid MCP client configuration: %w", err) } @@ -89,10 +89,13 @@ func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { // Create placeholder entry m.clientMap[config.ID] = &schemas.MCPClientState{ - Name: config.Name, - ExecutionConfig: config, - ToolMap: make(map[string]schemas.ChatTool), - ToolNameMapping: make(map[string]string), + Name: config.Name, + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + ToolNameMapping: make(map[string]string), + ConnectionInfo: &schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + }, } // Temporarily unlock for the connection attempt @@ -119,8 +122,8 @@ func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { // // Returns: // - error: Any error that occurred during client addition or connection -func (m *MCPManager) AddClientInMemory(config schemas.MCPClientConfig) error { - if err := validateMCPClientConfig(&config); err != nil { +func (m *MCPManager) AddClientInMemory(config *schemas.MCPClientConfig) error { + if err := validateMCPClientConfig(config); err != nil { return fmt.Errorf("invalid MCP client configuration: %w", err) } @@ -136,10 +139,13 @@ func (m *MCPManager) AddClientInMemory(config schemas.MCPClientConfig) error { // Create placeholder entry m.clientMap[config.ID] = &schemas.MCPClientState{ - Name: config.Name, - ExecutionConfig: config, - ToolMap: make(map[string]schemas.ChatTool), - ToolNameMapping: make(map[string]string), + Name: config.Name, + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + ToolNameMapping: make(map[string]string), + ConnectionInfo: &schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + }, } // Temporarily unlock for the connection attempt @@ -186,11 +192,14 @@ func (m *MCPManager) removeClientUnsafe(id string) error { return fmt.Errorf("client %s not found", id) } - logger.Info(fmt.Sprintf("%s Disconnecting MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name)) + logger.Info("%s Disconnecting MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name) // Stop health monitoring for this client m.healthMonitorManager.StopMonitoring(id) + // Stop tool syncing for this client + m.toolSyncManager.StopSyncing(id) + // Cancel SSE context if present (required for proper SSE cleanup) if client.CancelFunc != nil { client.CancelFunc() @@ -225,7 +234,7 @@ func (m *MCPManager) removeClientUnsafe(id string) error { // // Returns: // - error: Any error that occurred during client update or tool retrieval -func (m *MCPManager) EditClient(id string, updatedConfig schemas.MCPClientConfig) error { +func (m *MCPManager) EditClient(id string, updatedConfig *schemas.MCPClientConfig) error { m.mu.Lock() defer m.mu.Unlock() @@ -387,8 +396,8 @@ func (m *MCPManager) RegisterTool(name, description string, toolFunction MCPTool return fmt.Errorf("tool '%s' is already registered", name) } - logger.Debug(fmt.Sprintf("%s Registering typed tool: %s -> prefixed as %s (client: %s)", MCPLogPrefix, name, prefixedToolName, BifrostMCPClientKey)) - logger.Info(fmt.Sprintf("%s Registering typed tool: %s", MCPLogPrefix, name)) + logger.Debug("%s Registering typed tool: %s -> prefixed as %s (client: %s)", MCPLogPrefix, name, prefixedToolName, BifrostMCPClientKey) + logger.Info("%s Registering typed tool: %s", MCPLogPrefix, name) // Create MCP handler wrapper that converts between typed and MCP interfaces mcpHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -421,7 +430,7 @@ func (m *MCPManager) RegisterTool(name, description string, toolFunction MCPTool // connectToMCPClient establishes a connection to an external MCP server and // registers its available tools with the manager. -func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { +func (m *MCPManager) connectToMCPClient(config *schemas.MCPClientConfig) error { // First lock: Initialize or validate client entry m.mu.Lock() @@ -440,11 +449,11 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { } // Create new client entry with configuration m.clientMap[config.ID] = &schemas.MCPClientState{ - Name: config.Name, - ExecutionConfig: config, - ToolMap: make(map[string]schemas.ChatTool), - ToolNameMapping: make(map[string]string), - ConnectionInfo: schemas.MCPClientConnectionInfo{ + Name: config.Name, + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + ToolNameMapping: make(map[string]string), + ConnectionInfo: &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, }, } @@ -452,11 +461,11 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { // Heavy operations performed outside lock var externalClient *client.Client - var connectionInfo schemas.MCPClientConnectionInfo + var connectionInfo *schemas.MCPClientConnectionInfo var err error // Create appropriate transport based on connection type - logger.Debug(fmt.Sprintf("%s [%s] Creating %s connection...", MCPLogPrefix, config.Name, config.ConnectionType)) + logger.Debug("%s [%s] Creating %s connection...", MCPLogPrefix, config.Name, config.ConnectionType) switch config.ConnectionType { case schemas.MCPConnectionTypeHTTP: externalClient, connectionInfo, err = m.createHTTPConnection(m.ctx, config) @@ -473,7 +482,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { if err != nil { return fmt.Errorf("failed to create connection: %w", err) } - logger.Debug(fmt.Sprintf("%s [%s] Connection created successfully", MCPLogPrefix, config.Name)) + logger.Debug("%s [%s] Connection created successfully", MCPLogPrefix, config.Name) // Initialize the external client with timeout // For SSE and STDIO connections, we need a long-lived context for the connection @@ -498,14 +507,14 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { } // Start the transport first (required for STDIO and SSE clients) - logger.Debug(fmt.Sprintf("%s [%s] Starting transport...", MCPLogPrefix, config.Name)) + logger.Debug("%s [%s] Starting transport...", MCPLogPrefix, config.Name) if err := externalClient.Start(ctx); err != nil { if config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO { cancel() // Cancel long-lived context on error } return fmt.Errorf("failed to start MCP client transport %s: %v", config.Name, err) } - logger.Debug(fmt.Sprintf("%s [%s] Transport started successfully", MCPLogPrefix, config.Name)) + logger.Debug("%s [%s] Transport started successfully", MCPLogPrefix, config.Name) // Create proper initialize request for external client extInitRequest := mcp.InitializeRequest{ @@ -528,7 +537,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { // Create timeout context for initialization phase only initCtx, initCancel = context.WithTimeout(longLivedCtx, MCPClientConnectionEstablishTimeout) defer initCancel() - logger.Debug(fmt.Sprintf("%s [%s] Initializing client with %v timeout...", MCPLogPrefix, config.Name, MCPClientConnectionEstablishTimeout)) + logger.Debug("%s [%s] Initializing client with %v timeout...", MCPLogPrefix, config.Name, MCPClientConnectionEstablishTimeout) } else { // HTTP already has timeout initCtx = ctx @@ -541,10 +550,10 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { } return fmt.Errorf("failed to initialize MCP client %s: %v", config.Name, err) } - logger.Debug(fmt.Sprintf("%s [%s] Client initialized successfully", MCPLogPrefix, config.Name)) + logger.Debug("%s [%s] Client initialized successfully", MCPLogPrefix, config.Name) // Retrieve tools from the external server (this also requires network I/O) - logger.Debug(fmt.Sprintf("%s [%s] Retrieving tools...", MCPLogPrefix, config.Name)) + logger.Debug("%s [%s] Retrieving tools...", MCPLogPrefix, config.Name) tools, toolNameMapping, err := retrieveExternalTools(ctx, externalClient, config.Name) if err != nil { logger.Warn("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err) @@ -552,11 +561,10 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { tools = make(map[string]schemas.ChatTool) toolNameMapping = make(map[string]string) } - logger.Debug(fmt.Sprintf("%s [%s] Retrieved %d tools", MCPLogPrefix, config.Name, len(tools))) + logger.Debug("%s [%s] Retrieved %d tools", MCPLogPrefix, config.Name, len(tools)) // Second lock: Update client with final connection details and tools m.mu.Lock() - defer m.mu.Unlock() // Verify client still exists (could have been cleaned up during heavy operations) if client, exists := m.clientMap[config.ID]; exists { @@ -578,9 +586,11 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { // Store tool name mapping for execution (sanitized_name -> original_mcp_name) client.ToolNameMapping = toolNameMapping - logger.Debug(fmt.Sprintf("%s [%s] Registering %d tools. Client config - ID: %s, Name: %s, IsCodeModeClient: %v", MCPLogPrefix, config.Name, len(tools), config.ID, config.Name, config.IsCodeModeClient)) - logger.Info(fmt.Sprintf("%s Connected to MCP server '%s'", MCPLogPrefix, config.Name)) + logger.Debug("%s [%s] Registering %d tools. Client config - ID: %s, Name: %s, IsCodeModeClient: %v", MCPLogPrefix, config.Name, len(tools), config.ID, config.Name, config.IsCodeModeClient) + logger.Info("%s Connected to MCP server '%s'", MCPLogPrefix, config.Name) } else { + // Release lock before cleanup and return + m.mu.Unlock() // Clean up resources before returning error: client was removed during connection setup // Cancel long-lived context if it was created if (config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO) && cancel != nil { @@ -595,6 +605,10 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { return fmt.Errorf("client %s was removed during connection setup", config.Name) } + // Release lock BEFORE starting monitors to prevent deadlock + // (StartMonitoring -> Start() tries to acquire RLock on the same mutex) + m.mu.Unlock() + // Register OnConnectionLost hook for SSE connections to detect idle timeouts if config.ConnectionType == schemas.MCPConnectionTypeSSE && externalClient != nil { externalClient.OnConnectionLost(func(err error) { @@ -612,41 +626,45 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, config.IsPingAvailable) m.healthMonitorManager.StartMonitoring(monitor) + // Start tool syncing for the client (skip for internal bifrost client) + if config.ID != BifrostMCPClientKey { + syncInterval := ResolveToolSyncInterval(config, m.toolSyncManager.GetGlobalInterval()) + if syncInterval > 0 { + syncer := NewClientToolSyncer(m, config.ID, config.Name, syncInterval) + m.toolSyncManager.StartSyncing(syncer) + } + } + return nil } // createHTTPConnection creates an HTTP-based MCP client connection without holding locks. -func (m *MCPManager) createHTTPConnection(ctx context.Context, config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { +func (m *MCPManager) createHTTPConnection(ctx context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) { if config.ConnectionString == nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("HTTP connection string is required") + return nil, nil, fmt.Errorf("HTTP connection string is required") } - // Prepare connection info - connectionInfo := schemas.MCPClientConnectionInfo{ + connectionInfo := &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, ConnectionURL: config.ConnectionString.GetValuePtr(), } - headers, err := config.HttpHeaders(ctx, m.oauth2Provider) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to get HTTP headers: %w", err) + return nil, nil, fmt.Errorf("failed to get HTTP headers: %w", err) } - // Create StreamableHTTP transport httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers)) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create HTTP transport: %w", err) + return nil, nil, fmt.Errorf("failed to create HTTP transport: %w", err) } - client := client.NewClient(httpTransport) - return client, connectionInfo, nil } // createSTDIOConnection creates a STDIO-based MCP client connection without holding locks. -func (m *MCPManager) createSTDIOConnection(ctx context.Context, config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { +func (m *MCPManager) createSTDIOConnection(_ context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) { if config.StdioConfig == nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("stdio config is required") + return nil, nil, fmt.Errorf("stdio config is required") } // Prepare STDIO command info for display @@ -655,7 +673,7 @@ func (m *MCPManager) createSTDIOConnection(ctx context.Context, config schemas.M // Check if environment variables are set for _, env := range config.StdioConfig.Envs { if os.Getenv(env) == "" { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) + return nil, nil, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) } } @@ -667,7 +685,7 @@ func (m *MCPManager) createSTDIOConnection(ctx context.Context, config schemas.M ) // Prepare connection info - connectionInfo := schemas.MCPClientConnectionInfo{ + connectionInfo := &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, StdioCommandString: &cmdString, } @@ -679,26 +697,26 @@ func (m *MCPManager) createSTDIOConnection(ctx context.Context, config schemas.M } // createSSEConnection creates a SSE-based MCP client connection without holding locks. -func (m *MCPManager) createSSEConnection(ctx context.Context, config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { +func (m *MCPManager) createSSEConnection(ctx context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) { if config.ConnectionString == nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("SSE connection string is required") + return nil, nil, fmt.Errorf("SSE connection string is required") } // Prepare connection info - connectionInfo := schemas.MCPClientConnectionInfo{ + connectionInfo := &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, ConnectionURL: config.ConnectionString.GetValuePtr(), // Reuse HTTPConnectionURL field for SSE URL display } headers, err := config.HttpHeaders(ctx, m.oauth2Provider) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to get HTTP headers: %w", err) + return nil, nil, fmt.Errorf("failed to get HTTP headers: %w", err) } // Create SSE transport sseTransport, err := transport.NewSSE(config.ConnectionString.GetValue(), transport.WithHeaders(headers)) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create SSE transport: %w", err) + return nil, nil, fmt.Errorf("failed to create SSE transport: %w", err) } client := client.NewClient(sseTransport) @@ -709,19 +727,19 @@ func (m *MCPManager) createSSEConnection(ctx context.Context, config schemas.MCP // createInProcessConnection creates an in-process MCP client connection without holding locks. // This allows direct connection to an MCP server running in the same process, providing // the lowest latency and highest performance for tool execution. -func (m *MCPManager) createInProcessConnection(ctx context.Context, config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { +func (m *MCPManager) createInProcessConnection(_ context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) { if config.InProcessServer == nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") + return nil, nil, fmt.Errorf("InProcess connection requires a server instance") } // Create in-process client directly connected to the provided server inProcessClient, err := client.NewInProcessClient(config.InProcessServer) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create in-process client: %w", err) + return nil, nil, fmt.Errorf("failed to create in-process client: %w", err) } // Prepare connection info - connectionInfo := schemas.MCPClientConnectionInfo{ + connectionInfo := &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, } @@ -804,14 +822,14 @@ func (m *MCPManager) createLocalMCPClient() (*schemas.MCPClientState, error) { // Don't create the actual client connection here - it will be created // after the server is ready using NewInProcessClient return &schemas.MCPClientState{ - ExecutionConfig: schemas.MCPClientConfig{ + ExecutionConfig: &schemas.MCPClientConfig{ ID: BifrostMCPClientKey, Name: BifrostMCPClientKey, // Use same value as ID for consistent prefixing ToolsToExecute: []string{"*"}, // Allow all tools for internal client }, ToolMap: make(map[string]schemas.ChatTool), ToolNameMapping: make(map[string]string), - ConnectionInfo: schemas.MCPClientConnectionInfo{ + ConnectionInfo: &schemas.MCPClientConnectionInfo{ Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport }, }, nil diff --git a/core/mcp/codemode.go b/core/mcp/codemode.go new file mode 100644 index 0000000000..e81c984195 --- /dev/null +++ b/core/mcp/codemode.go @@ -0,0 +1,105 @@ +//go:build !tinygo && !wasm + +package mcp + +import ( + "context" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +// CodeMode tool type constants +const ( + ToolTypeListToolFiles string = "listToolFiles" + ToolTypeReadToolFile string = "readToolFile" + ToolTypeGetToolDocs string = "getToolDocs" + ToolTypeExecuteToolCode string = "executeToolCode" +) + +// CodeModeLogPrefix is the log prefix for code mode operations +const CodeModeLogPrefix = "[CODE MODE]" + +// CodeMode defines the interface for code execution environments. +// Implementations can provide different interpreters (Starlark, Lua, JavaScript, etc.) +// while maintaining the same tool interface for the ToolsManager. +type CodeMode interface { + // GetTools returns the code mode meta-tools (listToolFiles, readToolFile, getToolDocs, executeToolCode) + // These tools are added to the available tools when a code mode client is connected. + GetTools() []schemas.ChatTool + + // ExecuteTool handles a code mode tool call by name. + // Returns the response message and any error that occurred. + ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) + + // IsCodeModeTool returns true if the given tool name is a code mode tool. + IsCodeModeTool(toolName string) bool + + // GetBindingLevel returns the current code mode binding level (server or tool). + GetBindingLevel() schemas.CodeModeBindingLevel + + // UpdateConfig updates the code mode configuration atomically. + UpdateConfig(config *CodeModeConfig) + + // SetDependencies sets the dependencies required for code execution. + // This is called by MCPManager after construction to inject the dependencies + // (ClientManager, plugin pipeline, etc.) that weren't available at CodeMode creation time. + SetDependencies(deps *CodeModeDependencies) +} + +// CodeModeConfig holds the configuration for a CodeMode implementation. +type CodeModeConfig struct { + // BindingLevel controls how tools are exposed in the VFS: "server" or "tool" + BindingLevel schemas.CodeModeBindingLevel + + // ToolExecutionTimeout is the maximum time allowed for tool execution + ToolExecutionTimeout time.Duration +} + +// CodeModeDependencies holds the dependencies required by CodeMode implementations. +type CodeModeDependencies struct { + // ClientManager provides access to MCP clients and their tools + ClientManager ClientManager + + // PluginPipelineProvider returns a plugin pipeline for running MCP hooks + PluginPipelineProvider func() PluginPipeline + + // ReleasePluginPipeline releases a plugin pipeline back to the pool + ReleasePluginPipeline func(pipeline PluginPipeline) + + // FetchNewRequestIDFunc generates unique request IDs for nested tool calls + FetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string + + // LogMutex protects concurrent access to logs during code execution + LogMutex *sync.Mutex +} + +// DefaultCodeModeConfig returns the default configuration for CodeMode. +func DefaultCodeModeConfig() *CodeModeConfig { + return &CodeModeConfig{ + BindingLevel: schemas.CodeModeBindingLevelServer, + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + } +} + +// codeModeToolNames is a set of all code mode tool names for fast lookup +var codeModeToolNames = map[string]bool{ + ToolTypeListToolFiles: true, + ToolTypeReadToolFile: true, + ToolTypeGetToolDocs: true, + ToolTypeExecuteToolCode: true, +} + +// IsCodeModeTool returns true if the given tool name is a code mode tool. +// This is a package-level helper function. +func IsCodeModeTool(toolName string) bool { + return codeModeToolNames[toolName] +} + +// toolCallInfo represents a tool call extracted from code. +// Used for validating tool calls before auto-execution in agent mode. +type toolCallInfo struct { + serverName string + toolName string +} diff --git a/core/mcp/codemode/starlark/executecode.go b/core/mcp/codemode/starlark/executecode.go new file mode 100644 index 0000000000..e99b5e888b --- /dev/null +++ b/core/mcp/codemode/starlark/executecode.go @@ -0,0 +1,647 @@ +//go:build !tinygo && !wasm + +package starlark + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/bytedance/sonic" + "github.com/mark3labs/mcp-go/mcp" + codemcp "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/schemas" + "go.starlark.net/starlark" + "go.starlark.net/starlarkstruct" +) + +// toolBinding represents a tool binding for the interpreter +type toolBinding struct { + toolName string + clientName string +} + +// ExecutionResult represents the result of code execution +type ExecutionResult struct { + Result interface{} `json:"result"` + Logs []string `json:"logs"` + Errors *ExecutionError `json:"errors,omitempty"` + Environment ExecutionEnvironment `json:"environment"` +} + +// ExecutionErrorType represents the type of execution error +type ExecutionErrorType string + +const ( + ExecutionErrorTypeCompile ExecutionErrorType = "compile" + ExecutionErrorTypeSyntax ExecutionErrorType = "syntax" + ExecutionErrorTypeRuntime ExecutionErrorType = "runtime" +) + +// ExecutionError represents an error during code execution +type ExecutionError struct { + Kind ExecutionErrorType `json:"kind"` // "compile", "syntax", or "runtime" + Message string `json:"message"` + Hints []string `json:"hints"` +} + +// ExecutionEnvironment contains information about the execution environment +type ExecutionEnvironment struct { + ServerKeys []string `json:"serverKeys"` +} + +// createExecuteToolCodeTool creates the executeToolCode tool definition for code mode. +// This tool allows executing Python (Starlark) code in a sandboxed interpreter with access to MCP server tools. +func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { + executeToolCodeProps := schemas.OrderedMap{ + "code": map[string]interface{}{ + "type": "string", + "description": "Python code to execute. The code runs in a Starlark interpreter (Python subset). Tool calls are synchronous - no async/await needed. For simple use cases, directly return results. Use print() for logging. ALWAYS retry if code fails. Example: result = server.tool_name(param=\"value\"); return result", + }, + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: codemcp.ToolTypeExecuteToolCode, + Description: schemas.Ptr( + "Executes Python code inside a sandboxed Starlark interpreter with access to all connected MCP servers' tools. " + + "All connected servers are exposed as global objects named after their configuration keys, and each server " + + "provides functions for every tool available on that server. The canonical usage pattern is: " + + "result = .(param=\"value\"). Both and should be discovered " + + "using listToolFiles and readToolFile. " + + + "IMPORTANT WORKFLOW: Always follow this order — first use listToolFiles to see available servers and tools, " + + "then use readToolFile to understand the tool definitions and their parameters, and finally use executeToolCode " + + "to execute your code. " + + + "SYNTAX NOTES: " + + "• Tool calls are synchronous - NO async/await needed, just call directly: result = server.tool(arg=\"value\") " + + "• Use keyword arguments: server.tool(param=\"value\") NOT server.tool({\"param\": \"value\"}) " + + "• Access dict values with brackets: result[\"key\"] NOT result.key " + + "• Use print() for logging (not console.log) " + + "• List comprehensions work: [x for x in items if x[\"active\"]] " + + "• To return a value, assign to 'result' variable: result = computed_value " + + + "RETRY POLICY: ALWAYS retry if a code block fails. Analyze the error, adjust your code, and retry. " + + + "The environment is intentionally minimal: " + + "• No imports needed or supported " + + "• No network APIs (use MCP tools for external interactions) " + + "• No file system access (use MCP tools) " + + "• No classes (use dicts and functions) " + + "• Deterministic execution (no random, no time) " + + + "Long-running operations are interrupted via execution timeout. " + + "This tool is designed specifically for orchestrating MCP tool calls and lightweight computation.", + ), + + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &executeToolCodeProps, + Required: []string{"code"}, + }, + }, + } +} + +// handleExecuteToolCode handles the executeToolCode tool call. +func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + toolName := "unknown" + if toolCall.Function.Name != nil { + toolName = *toolCall.Function.Name + } + logger.Debug("%s Handling executeToolCode tool call: %s", codemcp.CodeModeLogPrefix, toolName) + + // Parse tool arguments + var arguments map[string]interface{} + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + logger.Debug("%s Failed to parse tool arguments: %v", codemcp.CodeModeLogPrefix, err) + return nil, fmt.Errorf("failed to parse tool arguments: %v", err) + } + + code, ok := arguments["code"].(string) + if !ok || code == "" { + logger.Debug("%s Code parameter missing or empty", codemcp.CodeModeLogPrefix) + return nil, fmt.Errorf("code parameter is required and must be a non-empty string") + } + + logger.Debug("%s Starting code execution", codemcp.CodeModeLogPrefix) + result := s.executeCode(ctx, code) + logger.Debug("%s Code execution completed. Success: %v, Has errors: %v, Log count: %d", codemcp.CodeModeLogPrefix, result.Errors == nil, result.Errors != nil, len(result.Logs)) + + // Format response text + var responseText string + var executionSuccess bool = true + if result.Errors != nil { + logger.Debug("%s Formatting error response. Error kind: %s, Message length: %d, Hints count: %d", codemcp.CodeModeLogPrefix, result.Errors.Kind, len(result.Errors.Message), len(result.Errors.Hints)) + logsText := "" + if len(result.Logs) > 0 { + logsText = fmt.Sprintf("\n\nPrint Output:\n%s\n", strings.Join(result.Logs, "\n")) + } + + responseText = fmt.Sprintf( + "Execution %s error:\n\n%s\n\nHints:\n%s%s\n\nEnvironment:\n Available server keys: %s", + result.Errors.Kind, + result.Errors.Message, + strings.Join(result.Errors.Hints, "\n"), + logsText, + strings.Join(result.Environment.ServerKeys, ", "), + ) + logger.Debug("%s Error response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText)) + } else { + hasLogs := len(result.Logs) > 0 + hasResult := result.Result != nil + logger.Debug("%s Formatting success response. Has logs: %v, Has result: %v", codemcp.CodeModeLogPrefix, hasLogs, hasResult) + + if !hasLogs && !hasResult { + executionSuccess = false + logger.Debug("%s Execution completed with no data (no logs, no result), marking as failure", codemcp.CodeModeLogPrefix) + hints := []string{ + "Add print() statements throughout your code to debug and see what's happening at each step", + "Assign the final value to 'result' variable if you want to return it: result = computed_value", + "Check that your tool calls are actually executing and returning data", + } + responseText = fmt.Sprintf( + "Execution completed but produced no data:\n\n"+ + "The code executed without errors but returned no output (no print output and no result variable).\n\n"+ + "Hints:\n%s\n\n"+ + "Environment:\n Available server keys: %s", + strings.Join(hints, "\n"), + strings.Join(result.Environment.ServerKeys, ", "), + ) + logger.Debug("%s No-data failure response formatted. Response length: %d chars", codemcp.CodeModeLogPrefix, len(responseText)) + } else { + if hasLogs { + responseText = fmt.Sprintf("Print output:\n%s\n\nExecution completed successfully.", + strings.Join(result.Logs, "\n")) + } else { + responseText = "Execution completed successfully." + } + if hasResult { + resultJSON, err := sonic.MarshalIndent(result.Result, "", " ") + if err == nil { + responseText += fmt.Sprintf("\nReturn value: %s", string(resultJSON)) + logger.Debug("%s Added return value to response (JSON length: %d chars)", codemcp.CodeModeLogPrefix, len(resultJSON)) + } else { + logger.Debug("%s Failed to marshal result to JSON: %v", codemcp.CodeModeLogPrefix, err) + } + } + + responseText += fmt.Sprintf("\n\nEnvironment:\n Available server keys: %s", + strings.Join(result.Environment.ServerKeys, ", ")) + responseText += "\nNote: This is a Starlark (Python subset) environment. Use MCP tools for external interactions." + logger.Debug("%s Success response formatted. Response length: %d chars, Server keys: %v", codemcp.CodeModeLogPrefix, len(responseText), result.Environment.ServerKeys) + } + } + + logger.Debug("%s Returning tool response message. Execution success: %v", codemcp.CodeModeLogPrefix, executionSuccess) + return createToolResponseMessage(toolCall, responseText), nil +} + +// executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings. +func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) ExecutionResult { + logs := []string{} + + logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix) + + // Step 1: Convert literal \n escape sequences to actual newlines + codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") + + // Step 2: Handle empty code + trimmedCode := strings.TrimSpace(codeWithNewlines) + if trimmedCode == "" { + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: nil, + Environment: ExecutionEnvironment{ + ServerKeys: []string{}, + }, + } + } + + // Step 3: Build tool bindings for all connected servers + availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) + serverKeys := make([]string, 0, len(availableToolsPerClient)) + predeclared := starlark.StringDict{} + + // Thread-safe log appender + appendLog := func(msg string) { + s.logMu.Lock() + defer s.logMu.Unlock() + logs = append(logs, msg) + } + + logger.Debug("%s GetToolPerClient returned %d clients", codemcp.CodeModeLogPrefix, len(availableToolsPerClient)) + + for clientName, tools := range availableToolsPerClient { + client := s.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName) + continue + } + logger.Debug("%s [%s] Client found. IsCodeModeClient: %v, ToolCount: %d", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools)) + if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { + logger.Debug("%s [%s] Skipped: IsCodeModeClient=%v, HasTools=%v", codemcp.CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools) > 0) + continue + } + serverKeys = append(serverKeys, clientName) + + // Build struct with tool methods + structMembers := starlark.StringDict{} + + for _, tool := range tools { + if tool.Function == nil || tool.Function.Name == "" { + continue + } + + originalToolName := tool.Function.Name + unprefixedToolName := stripClientPrefix(originalToolName, clientName) + unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") + parsedToolName := parseToolName(unprefixedToolName) + + logger.Debug("%s [%s] Binding tool: %s -> %s", codemcp.CodeModeLogPrefix, clientName, originalToolName, parsedToolName) + + // Capture variables for closure + capturedToolName := originalToolName + capturedClientName := clientName + + // Create a Starlark builtin function for this tool + toolFunc := starlark.NewBuiltin(parsedToolName, func(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + // Convert kwargs to Go map + goArgs := make(map[string]interface{}) + for _, kwarg := range kwargs { + if len(kwarg) == 2 { + key := string(kwarg[0].(starlark.String)) + value := starlarkToGo(kwarg[1]) + goArgs[key] = value + } + } + + // Also handle positional args if there's exactly one dict argument + if len(args) == 1 && len(kwargs) == 0 { + if dict, ok := args[0].(*starlark.Dict); ok { + for _, item := range dict.Items() { + if keyStr, ok := item[0].(starlark.String); ok { + goArgs[string(keyStr)] = starlarkToGo(item[1]) + } + } + } + } + + // Call the MCP tool + result, err := s.callMCPTool(ctx, capturedClientName, capturedToolName, goArgs, appendLog) + if err != nil { + return starlark.None, fmt.Errorf("tool call failed: %v", err) + } + + // Convert result back to Starlark + return goToStarlark(result), nil + }) + + structMembers[parsedToolName] = toolFunc + } + + // Create a struct for this server + serverStruct := starlarkstruct.FromStringDict(starlark.String(clientName), structMembers) + predeclared[clientName] = serverStruct + logger.Debug("%s [%s] Added server struct with %d tools", codemcp.CodeModeLogPrefix, clientName, len(structMembers)) + } + + if len(serverKeys) > 0 { + logger.Debug("%s Bound %d servers with tools: %v", codemcp.CodeModeLogPrefix, len(serverKeys), serverKeys) + } else { + logger.Debug("%s No servers available for code mode execution", codemcp.CodeModeLogPrefix) + } + + // Step 4: Create Starlark thread with print function and timeout + toolExecutionTimeout := s.getToolExecutionTimeout() + timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + thread := &starlark.Thread{ + Name: "codemode", + Print: func(_ *starlark.Thread, msg string) { + appendLog(msg) + }, + } + + // Set up cancellation check + thread.SetLocal("context", timeoutCtx) + + // Step 5: Execute the code + globals, err := starlark.ExecFile(thread, "code.star", trimmedCode, predeclared) + + if err != nil { + errorMessage := err.Error() + hints := generatePythonErrorHints(errorMessage, serverKeys) + logger.Debug("%s Execution failed: %s", codemcp.CodeModeLogPrefix, errorMessage) + + errorKind := ExecutionErrorTypeRuntime + if strings.Contains(errorMessage, "syntax error") { + errorKind = ExecutionErrorTypeSyntax + } + + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: &ExecutionError{ + Kind: errorKind, + Message: errorMessage, + Hints: hints, + }, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + }, + } + } + + // Step 6: Extract result from globals + var result interface{} + if resultVal, ok := globals["result"]; ok && resultVal != starlark.None { + result = starlarkToGo(resultVal) + } + + logger.Debug("%s Execution completed successfully", codemcp.CodeModeLogPrefix) + return ExecutionResult{ + Result: result, + Logs: logs, + Errors: nil, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + }, + } +} + +// callMCPTool calls an MCP tool and returns the result. +func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { + // Get available tools per client + availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) + + // Find the client by name + tools, exists := availableToolsPerClient[clientName] + if !exists || len(tools) == 0 { + return nil, fmt.Errorf("client not found for server name: %s", clientName) + } + + // Get client using a tool from this client + var client *schemas.MCPClientState + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name != "" { + client = s.clientManager.GetClientForTool(tool.Function.Name) + if client != nil { + break + } + } + } + + if client == nil { + return nil, fmt.Errorf("client not found for server name: %s", clientName) + } + + // Strip the client name prefix from tool name before calling MCP server + originalToolName := stripClientPrefix(toolName, clientName) + + // Get BifrostContext for plugin pipeline + var bifrostCtx *schemas.BifrostContext + var ok bool + if bifrostCtx, ok = ctx.(*schemas.BifrostContext); !ok { + return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + } + + originalRequestID, _ := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string) + + // Generate new request ID for this nested tool call + var newRequestID string + if s.fetchNewRequestIDFunc != nil { + newRequestID = s.fetchNewRequestIDFunc(bifrostCtx) + } else { + newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName) + } + + // Create new child context + deadline, hasDeadline := bifrostCtx.Deadline() + if !hasDeadline { + deadline = schemas.NoDeadline + } + nestedCtx := schemas.NewBifrostContext(bifrostCtx, deadline) + nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID) + if originalRequestID != "" { + nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID) + } + + // Marshal arguments to JSON for the tool call + argsJSON, err := sonic.Marshal(args) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool arguments: %v", err) + } + + // Build tool call for MCP request + toolCallReq := schemas.ChatAssistantMessageToolCall{ + ID: schemas.Ptr(newRequestID), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr(toolName), + Arguments: string(argsJSON), + }, + } + + // Create BifrostMCPRequest + mcpRequest := &schemas.BifrostMCPRequest{ + RequestType: schemas.MCPRequestTypeChatToolCall, + ChatAssistantMessageToolCall: &toolCallReq, + } + + // Check if plugin pipeline is available + if s.pluginPipelineProvider == nil { + return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + } + + // Get plugin pipeline and run hooks + pipeline := s.pluginPipelineProvider() + if pipeline == nil { + return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + } + defer s.releasePluginPipeline(pipeline) + + // Run PreMCPHooks + preReq, shortCircuit, preCount := pipeline.RunMCPPreHooks(nestedCtx, mcpRequest) + + // Handle short-circuit cases + if shortCircuit != nil { + if shortCircuit.Response != nil { + finalResp, _ := pipeline.RunMCPPostHooks(nestedCtx, shortCircuit.Response, nil, preCount) + if finalResp != nil { + if finalResp.ChatMessage != nil { + return extractResultFromChatMessage(finalResp.ChatMessage), nil + } + if finalResp.ResponsesMessage != nil { + result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage) + if err != nil { + return nil, err + } + if result != nil { + return result, nil + } + } + } + return nil, fmt.Errorf("plugin short-circuit returned invalid response") + } + if shortCircuit.Error != nil { + pipeline.RunMCPPostHooks(nestedCtx, nil, shortCircuit.Error, preCount) + if shortCircuit.Error.Error != nil { + return nil, fmt.Errorf("%s", shortCircuit.Error.Error.Message) + } + return nil, fmt.Errorf("plugin short-circuit error") + } + } + + // If pre-hooks modified the request, extract updated args + if preReq != nil && preReq.ChatAssistantMessageToolCall != nil { + toolCallReq = *preReq.ChatAssistantMessageToolCall + if toolCallReq.Function.Arguments != "" { + if err := sonic.Unmarshal([]byte(toolCallReq.Function.Arguments), &args); err != nil { + logger.Warn("%s Failed to parse modified tool arguments, using original: %v", codemcp.CodeModeLogPrefix, err) + } + } + } + + // Execute tool + startTime := time.Now() + toolNameToCall := originalToolName + + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: toolNameToCall, + Arguments: args, + }, + } + + toolExecutionTimeout := s.getToolExecutionTimeout() + toolCtx, cancel := context.WithTimeout(nestedCtx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + latency := time.Since(startTime).Milliseconds() + + var mcpResp *schemas.BifrostMCPResponse + var bifrostErr *schemas.BifrostError + + if callErr != nil { + logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, toolName, callErr) + appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, toolName, callErr)) + bifrostErr = &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: fmt.Sprintf("tool call failed for %s.%s: %v", clientName, toolName, callErr), + }, + } + } else { + rawResult := extractTextFromMCPResponse(toolResponse, toolName) + + if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { + errorMsg := after + logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, toolName, errorMsg) + appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, toolName, errorMsg)) + bifrostErr = &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: errorMsg, + }, + } + } else { + mcpResp = &schemas.BifrostMCPResponse{ + ChatMessage: createToolResponseMessage(toolCallReq, rawResult), + ExtraFields: schemas.BifrostMCPResponseExtraFields{ + ClientName: clientName, + ToolName: originalToolName, + Latency: latency, + }, + } + + resultStr := formatResultForLog(rawResult) + logToolName := stripClientPrefix(toolName, clientName) + logToolName = strings.ReplaceAll(logToolName, "-", "_") + appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr)) + } + } + + // Run post-hooks + finalResp, finalErr := pipeline.RunMCPPostHooks(nestedCtx, mcpResp, bifrostErr, preCount) + + if finalErr != nil { + if finalErr.Error != nil { + return nil, fmt.Errorf("%s", finalErr.Error.Message) + } + return nil, fmt.Errorf("tool execution failed") + } + + if finalResp == nil { + return nil, fmt.Errorf("plugin post-hooks returned invalid response") + } + + if finalResp.ChatMessage != nil { + return extractResultFromChatMessage(finalResp.ChatMessage), nil + } + + if finalResp.ResponsesMessage != nil { + result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage) + if err != nil { + return nil, err + } + if result != nil { + return result, nil + } + } + + return nil, fmt.Errorf("plugin post-hooks returned invalid response") +} + +// callMCPToolDirect executes an MCP tool call directly without plugin hooks. +func (s *StarlarkCodeMode) callMCPToolDirect(ctx context.Context, client *schemas.MCPClientState, originalToolName, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: originalToolName, + Arguments: args, + }, + } + + toolExecutionTimeout := s.getToolExecutionTimeout() + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + logToolName := stripClientPrefix(toolName, clientName) + logToolName = strings.ReplaceAll(logToolName, "-", "_") + + toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + if callErr != nil { + logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, logToolName, callErr) + appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, logToolName, callErr)) + return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, logToolName, callErr) + } + + rawResult := extractTextFromMCPResponse(toolResponse, toolName) + + if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { + errorMsg := after + logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, logToolName, errorMsg) + appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, logToolName, errorMsg)) + return nil, fmt.Errorf("%s", errorMsg) + } + + var finalResult interface{} + if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { + finalResult = rawResult + } + + resultStr := formatResultForLog(finalResult) + appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr)) + + return finalResult, nil +} diff --git a/core/mcp/codemode/starlark/getdocs.go b/core/mcp/codemode/starlark/getdocs.go new file mode 100644 index 0000000000..b4b953bd13 --- /dev/null +++ b/core/mcp/codemode/starlark/getdocs.go @@ -0,0 +1,304 @@ +//go:build !tinygo && !wasm + +package starlark + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + codemcp "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +// createGetToolDocsTool creates the getToolDocs tool definition for code mode. +// This tool provides detailed documentation for a specific tool when the compact +// signatures from readToolFile are not sufficient to understand how to use it. +func (s *StarlarkCodeMode) createGetToolDocsTool() schemas.ChatTool { + getToolDocsProps := schemas.OrderedMap{ + "server": map[string]interface{}{ + "type": "string", + "description": "The server name (e.g., 'calculator'). Use listToolFiles to see available servers.", + }, + "tool": map[string]interface{}{ + "type": "string", + "description": "The tool name (e.g., 'add'). Use readToolFile to see available tools for a server.", + }, + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: codemcp.ToolTypeGetToolDocs, + Description: schemas.Ptr( + "Get detailed documentation for a specific tool including full parameter descriptions, " + + "types, and usage examples. Use this when the compact signature from readToolFile " + + "is not sufficient to understand how to use a tool. " + + "Requires both server name and tool name as parameters.", + ), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &getToolDocsProps, + Required: []string{"server", "tool"}, + }, + }, + } +} + +// handleGetToolDocs handles the getToolDocs tool call. +func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments: %v", err) + } + + serverName, ok := arguments["server"].(string) + if !ok || serverName == "" { + return nil, fmt.Errorf("server parameter is required and must be a string") + } + + toolName, ok := arguments["tool"].(string) + if !ok || toolName == "" { + return nil, fmt.Errorf("tool parameter is required and must be a string") + } + + // Get available tools per client + availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) + + // Find matching client + var matchedClientName string + var matchedTool *schemas.ChatTool + + serverNameLower := strings.ToLower(serverName) + toolNameLower := strings.ToLower(toolName) + + for clientName, tools := range availableToolsPerClient { + client := s.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName) + continue + } + if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { + continue + } + + clientNameLower := strings.ToLower(clientName) + if clientNameLower == serverNameLower { + matchedClientName = clientName + + // Find the specific tool + for i, tool := range tools { + if tool.Function != nil { + // Strip client prefix and replace - with _ for comparison + unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) + unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") + if strings.ToLower(unprefixedToolName) == toolNameLower { + matchedTool = &tools[i] + break + } + } + } + break + } + } + + // Handle server not found + if matchedClientName == "" { + var availableServers []string + for name := range availableToolsPerClient { + client := s.clientManager.GetClientByName(name) + if client != nil && client.ExecutionConfig.IsCodeModeClient { + availableServers = append(availableServers, name) + } + } + errorMsg := fmt.Sprintf("Server '%s' not found. Available servers are:\n", serverName) + for _, sn := range availableServers { + errorMsg += fmt.Sprintf(" - %s\n", sn) + } + return createToolResponseMessage(toolCall, errorMsg), nil + } + + // Handle tool not found + if matchedTool == nil { + tools := availableToolsPerClient[matchedClientName] + var availableTools []string + for _, tool := range tools { + if tool.Function != nil { + unprefixedToolName := stripClientPrefix(tool.Function.Name, matchedClientName) + unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") + availableTools = append(availableTools, unprefixedToolName) + } + } + errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools are:\n", toolName, matchedClientName) + for _, t := range availableTools { + errorMsg += fmt.Sprintf(" - %s\n", t) + } + return createToolResponseMessage(toolCall, errorMsg), nil + } + + // Generate detailed documentation using generateTypeDefinitions + docContent := generateTypeDefinitions(matchedClientName, []schemas.ChatTool{*matchedTool}, true) + + return createToolResponseMessage(toolCall, docContent), nil +} + +// generateTypeDefinitions generates Python documentation with docstrings from ChatTool schemas. +func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isToolLevel bool) string { + var sb strings.Builder + + // Write comprehensive header + sb.WriteString("# ============================================================================\n") + if isToolLevel && len(tools) == 1 && tools[0].Function != nil { + sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, tools[0].Function.Name)) + } else { + sb.WriteString(fmt.Sprintf("# Documentation for %s MCP server\n", clientName)) + } + sb.WriteString("# ============================================================================\n") + sb.WriteString("#\n") + if isToolLevel && len(tools) == 1 { + sb.WriteString("# This file contains Python documentation for a specific tool on this MCP server.\n") + } else { + sb.WriteString("# This file contains Python documentation for all tools available on this MCP server.\n") + } + sb.WriteString("#\n") + sb.WriteString("# USAGE INSTRUCTIONS:\n") + sb.WriteString(fmt.Sprintf("# Call tools using: result = %s.tool_name(param=value)\n", clientName)) + sb.WriteString("# No async/await needed - calls are synchronous.\n") + sb.WriteString("#\n") + sb.WriteString("# CRITICAL - HANDLING RESPONSES:\n") + sb.WriteString("# Tool responses are dicts. To avoid runtime errors:\n") + sb.WriteString("# 1. Use print(result) to inspect the response structure first\n") + sb.WriteString("# 2. Access dict values with brackets: result[\"key\"] NOT result.key\n") + sb.WriteString("# 3. Use .get() for safe access: result.get(\"key\", default)\n") + sb.WriteString("#\n") + sb.WriteString("# Common error: \"key not found\" or \"has no attribute\"\n") + sb.WriteString("# Fix: Use print() to see actual structure, then use result[\"key\"] or .get()\n") + sb.WriteString("# ============================================================================\n\n") + + // Generate function definitions for each tool + for _, tool := range tools { + if tool.Function == nil || tool.Function.Name == "" { + continue + } + + originalToolName := tool.Function.Name + unprefixedToolName := stripClientPrefix(originalToolName, clientName) + unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") + toolName := parseToolName(unprefixedToolName) + description := "" + if tool.Function.Description != nil { + description = *tool.Function.Description + } + + // Generate function signature + params := formatPythonParams(tool.Function.Parameters) + sb.WriteString(fmt.Sprintf("def %s(%s) -> dict:\n", toolName, params)) + + // Generate docstring + sb.WriteString(" \"\"\"\n") + if description != "" { + sb.WriteString(fmt.Sprintf(" %s\n", description)) + sb.WriteString("\n") + } + + // Args section + if tool.Function.Parameters != nil && tool.Function.Parameters.Properties != nil { + props := *tool.Function.Parameters.Properties + required := make(map[string]bool) + if tool.Function.Parameters.Required != nil { + for _, req := range tool.Function.Parameters.Required { + required[req] = true + } + } + + if len(props) > 0 { + sb.WriteString(" Args:\n") + + // Sort properties for consistent output + propNames := make([]string, 0, len(props)) + for name := range props { + propNames = append(propNames, name) + } + for i := 0; i < len(propNames)-1; i++ { + for j := i + 1; j < len(propNames); j++ { + if propNames[i] > propNames[j] { + propNames[i], propNames[j] = propNames[j], propNames[i] + } + } + } + + for _, propName := range propNames { + prop := props[propName] + propMap, ok := prop.(map[string]interface{}) + if !ok { + continue + } + + pyType := jsonSchemaToPython(propMap) + propDesc := "" + if desc, ok := propMap["description"].(string); ok && desc != "" { + propDesc = desc + } else { + propDesc = fmt.Sprintf("%s parameter", propName) + } + + requiredNote := "" + if required[propName] { + requiredNote = " (required)" + } else { + requiredNote = " (optional)" + } + + sb.WriteString(fmt.Sprintf(" %s (%s): %s%s\n", propName, pyType, propDesc, requiredNote)) + } + sb.WriteString("\n") + } + } + + // Returns section + sb.WriteString(" Returns:\n") + sb.WriteString(" dict: Response from the tool. Structure varies by tool.\n") + sb.WriteString(" Use print(result) to inspect the actual structure.\n") + sb.WriteString("\n") + + // Example section + sb.WriteString(" Example:\n") + sb.WriteString(fmt.Sprintf(" result = %s.%s(%s)\n", clientName, toolName, getExampleParams(tool.Function.Parameters))) + sb.WriteString(" print(result) # Always inspect response first!\n") + sb.WriteString(" value = result.get(\"key\", default) # Safe access\n") + sb.WriteString(" \"\"\"\n") + sb.WriteString(" ...\n\n") + } + + return sb.String() +} + +// getExampleParams generates example parameter usage for a function. +func getExampleParams(params *schemas.ToolFunctionParameters) string { + if params == nil || params.Properties == nil || len(*params.Properties) == 0 { + return "" + } + + props := *params.Properties + required := make(map[string]bool) + if params.Required != nil { + for _, req := range params.Required { + required[req] = true + } + } + + // Get first required param as example + for name := range props { + if required[name] { + return fmt.Sprintf("%s=\"...\"", name) + } + } + + // If no required, get first param + for name := range props { + return fmt.Sprintf("%s=\"...\"", name) + } + + return "" +} diff --git a/core/mcp/codemode/starlark/init.go b/core/mcp/codemode/starlark/init.go new file mode 100644 index 0000000000..33a1ddaac5 --- /dev/null +++ b/core/mcp/codemode/starlark/init.go @@ -0,0 +1,12 @@ +//go:build !tinygo && !wasm + +package starlark + +import "github.com/maximhq/bifrost/core/schemas" + +var logger schemas.Logger + +// SetLogger sets the logger for the starlark package. +func SetLogger(l schemas.Logger) { + logger = l +} diff --git a/core/mcp/codemodelistfiles.go b/core/mcp/codemode/starlark/listfiles.go similarity index 61% rename from core/mcp/codemodelistfiles.go rename to core/mcp/codemode/starlark/listfiles.go index 4992c3c402..c00f38c129 100644 --- a/core/mcp/codemodelistfiles.go +++ b/core/mcp/codemode/starlark/listfiles.go @@ -1,47 +1,47 @@ -package mcp +//go:build !tinygo && !wasm + +package starlark import ( "context" "fmt" "strings" + codemcp "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/schemas" ) // createListToolFilesTool creates the listToolFiles tool definition for code mode. -// This tool allows listing all available virtual .d.ts declaration files for connected MCP servers. +// This tool allows listing all available virtual .pyi stub files for connected MCP servers. // The description is dynamically generated based on the configured CodeModeBindingLevel. -// -// Returns: -// - schemas.ChatTool: The tool definition for listing tool files -func (m *ToolsManager) createListToolFilesTool() schemas.ChatTool { - bindingLevel := m.GetCodeModeBindingLevel() +func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool { + bindingLevel := s.GetBindingLevel() var description string if bindingLevel == schemas.CodeModeBindingLevelServer { - description = "Returns a tree structure listing all virtual .d.ts declaration files available for connected MCP servers. " + - "Each server has a corresponding file (e.g., servers/.d.ts) that contains definitions for all tools in that server. " + - "Use readToolFile to read a specific server file and see all available tools. " + - "In code, access tools via: await serverName.toolName({ args }). " + + description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers. " + + "Each server has a corresponding file (e.g., servers/.pyi) that contains compact Python signatures for all tools in that server. " + + "Use readToolFile to read a specific server file and see all available tools with their signatures. " + + "Use getToolDocs if you need detailed documentation for a specific tool. " + + "In code, access tools via: server_name.tool_name(param=value). " + "The server names used in code correspond to the human-readable names shown in this listing. " + "This tool is generic and works with any set of servers connected at runtime. " + - "Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + - "If you have even the SLIGHTEST DOUBT that the current tools might not be useful for the task, check listToolFiles to discover all available tools." + "Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools." } else { - description = "Returns a tree structure listing all virtual .d.ts declaration files available for connected MCP servers, organized by individual tool. " + - "Each tool has a corresponding file (e.g., servers//.d.ts) that contains definitions for that specific tool. " + - "Use readToolFile to read a specific tool file and see its parameters and usage. " + - "In code, access tools via: await serverName.toolName({ args }). " + + description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers, organized by individual tool. " + + "Each tool has a corresponding file (e.g., servers//.pyi) that contains compact Python signatures for that specific tool. " + + "Use readToolFile to read a specific tool file and see its signature. " + + "Use getToolDocs if you need detailed documentation for a specific tool. " + + "In code, access tools via: server_name.tool_name(param=value). " + "The server names used in code correspond to the human-readable names shown in this listing. " + "This tool is generic and works with any set of servers connected at runtime. " + - "Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + - "If you have even the SLIGHTEST DOUBT that the current tools might not be useful for the task, check listToolFiles to discover all available tools." + "Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools." } return schemas.ChatTool{ Type: schemas.ChatToolTypeFunction, Function: &schemas.ChatToolFunction{ - Name: ToolTypeListToolFiles, + Name: codemcp.ToolTypeListToolFiles, Description: schemas.Ptr(description), Parameters: &schemas.ToolFunctionParameters{ Type: "object", @@ -53,38 +53,27 @@ func (m *ToolsManager) createListToolFilesTool() schemas.ChatTool { } // handleListToolFiles handles the listToolFiles tool call. -// It builds a tree structure listing all virtual .d.ts files available for code mode clients. -// The structure depends on the CodeModeBindingLevel: -// - "server": servers/.d.ts (one file per server) -// - "tool": servers//.d.ts (one file per tool) -// -// Parameters: -// - ctx: Context for accessing client tools -// - toolCall: The tool call request containing no arguments -// -// Returns: -// - *schemas.ChatMessage: A tool response message containing the file tree structure -// - error: Any error that occurred during processing -func (m *ToolsManager) handleListToolFiles(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { - availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) +// It builds a tree structure listing all virtual .pyi files available for code mode clients. +func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) if len(availableToolsPerClient) == 0 { - responseText := "No servers are currently connected. There are no virtual .d.ts files available. " + + responseText := "No servers are currently connected. There are no virtual .pyi files available. " + "Please ensure servers are connected before using this tool." return createToolResponseMessage(toolCall, responseText), nil } // Get the code mode binding level - bindingLevel := m.GetCodeModeBindingLevel() + bindingLevel := s.GetBindingLevel() // Build file list based on binding level var files []string codeModeServerCount := 0 for clientName, tools := range availableToolsPerClient { - client := m.clientManager.GetClientByName(clientName) + client := s.clientManager.GetClientByName(clientName) if client == nil { - logger.Warn("%s Client %s not found, skipping", MCPLogPrefix, clientName) + logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName) continue } if !client.ExecutionConfig.IsCodeModeClient { @@ -94,22 +83,22 @@ func (m *ToolsManager) handleListToolFiles(ctx context.Context, toolCall schemas if bindingLevel == schemas.CodeModeBindingLevelServer { // Server-level: one file per server - files = append(files, fmt.Sprintf("servers/%s.d.ts", clientName)) + files = append(files, fmt.Sprintf("servers/%s.pyi", clientName)) } else { // Tool-level: one file per tool for _, tool := range tools { if tool.Function != nil && tool.Function.Name != "" { // Strip the client prefix from tool name (format: "client-toolname" -> "toolname") - // But replace - with _ for valid JavaScript identifiers + // But replace - with _ for valid Python identifiers toolName := stripClientPrefix(tool.Function.Name, clientName) - // Replace any remaining hyphens with underscores for JavaScript compatibility + // Replace any remaining hyphens with underscores for Python compatibility toolName = strings.ReplaceAll(toolName, "-", "_") // Validate normalized tool name to prevent path traversal if err := validateNormalizedToolName(toolName); err != nil { - logger.Warn(fmt.Sprintf("%s Skipping tool '%s' from client '%s': %v", MCPLogPrefix, tool.Function.Name, clientName, err)) + logger.Warn("%s Skipping tool '%s' from client '%s': %v", codemcp.CodeModeLogPrefix, tool.Function.Name, clientName, err) continue } - toolFileName := fmt.Sprintf("servers/%s/%s.d.ts", clientName, toolName) + toolFileName := fmt.Sprintf("servers/%s/%s.pyi", clientName, toolName) files = append(files, toolFileName) } } @@ -118,7 +107,7 @@ func (m *ToolsManager) handleListToolFiles(ctx context.Context, toolCall schemas if codeModeServerCount == 0 { responseText := "Servers are connected but none are configured for code mode. " + - "There are no virtual .d.ts files available." + "There are no virtual .pyi files available." return createToolResponseMessage(toolCall, responseText), nil } @@ -134,23 +123,6 @@ type treeNode struct { } // buildVFSTree creates a hierarchical tree structure from a flat list of file paths. -// It groups files by directory and formats them with proper indentation. -// -// Example input: -// - ["servers/calculator.d.ts", "servers/youtube.d.ts"] -// - ["servers/calculator/add.d.ts", "servers/youtube/GET_CHANNELS.d.ts"] -// -// Example output for server-level: -// servers/ -// calculator.d.ts -// youtube.d.ts -// -// Example output for tool-level: -// servers/ -// calculator/ -// add.d.ts -// youtube/ -// GET_CHANNELS.d.ts func buildVFSTree(files []string) string { if len(files) == 0 { return "" diff --git a/core/mcp/codemode/starlark/readfile.go b/core/mcp/codemode/starlark/readfile.go new file mode 100644 index 0000000000..6b199bff9a --- /dev/null +++ b/core/mcp/codemode/starlark/readfile.go @@ -0,0 +1,451 @@ +//go:build !tinygo && !wasm + +package starlark + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + codemcp "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +// createReadToolFileTool creates the readToolFile tool definition for code mode. +// This tool allows reading virtual .pyi stub files for specific MCP servers/tools, +// generating Python type stubs from the server's tool schemas. +func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool { + bindingLevel := s.GetBindingLevel() + + var fileNameDescription, toolDescription string + + if bindingLevel == schemas.CodeModeBindingLevelServer { + fileNameDescription = "The virtual filename from listToolFiles in format: servers/.pyi (e.g., 'calculator.pyi')" + toolDescription = "Reads a virtual .pyi stub file for a specific MCP server, returning compact Python function signatures " + + "for all tools available on that server. The fileName should be in format servers/.pyi as listed by listToolFiles. " + + "The function performs case-insensitive matching and removes the .pyi extension. " + + "Each tool can be accessed in code via: serverName.tool_name(param=value). " + + "If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " + + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + + "do NOT call this tool again with startLine/endLine - you already have the complete file." + } else { + fileNameDescription = "The virtual filename from listToolFiles in format: servers//.pyi (e.g., 'calculator/add.pyi')" + toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " + + "The fileName should be in format servers//.pyi as listed by listToolFiles. " + + "The function performs case-insensitive matching and removes the .pyi extension. " + + "The tool can be accessed in code via: serverName.tool_name(param=value). " + + "If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " + + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + + "do NOT call this tool again with startLine/endLine - you already have the complete file." + } + + readToolFileProps := schemas.OrderedMap{ + "fileName": map[string]interface{}{ + "type": "string", + "description": fileNameDescription, + }, + "startLine": map[string]interface{}{ + "type": "number", + "description": "Optional 1-based starting line number for partial file read. Usually not needed - omit to read the entire file. Files are typically small (under 50 lines).", + }, + "endLine": map[string]interface{}{ + "type": "number", + "description": "Optional 1-based ending line number for partial file read. Usually not needed - omit to read the entire file. Will be clamped to actual file size if too large.", + }, + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: codemcp.ToolTypeReadToolFile, + Description: schemas.Ptr(toolDescription), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &readToolFileProps, + Required: []string{"fileName"}, + }, + }, + } +} + +// handleReadToolFile handles the readToolFile tool call. +func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments: %v", err) + } + + fileName, ok := arguments["fileName"].(string) + if !ok || fileName == "" { + return nil, fmt.Errorf("fileName parameter is required and must be a string") + } + + // Parse the file path to extract server name and optional tool name + serverName, toolName, isToolLevel := parseVFSFilePath(fileName) + + // Get available tools per client + availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) + + // Find matching client + var matchedClientName string + var matchedTools []schemas.ChatTool + matchCount := 0 + + for clientName, tools := range availableToolsPerClient { + client := s.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn("%s Client %s not found, skipping", codemcp.CodeModeLogPrefix, clientName) + continue + } + if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { + continue + } + + clientNameLower := strings.ToLower(clientName) + serverNameLower := strings.ToLower(serverName) + + if clientNameLower == serverNameLower { + matchCount++ + if matchCount > 1 { + // Multiple matches found + errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName) + for name := range availableToolsPerClient { + if strings.ToLower(name) == serverNameLower { + errorMsg += fmt.Sprintf(" - %s\n", name) + } + } + errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity." + return createToolResponseMessage(toolCall, errorMsg), nil + } + + matchedClientName = clientName + + if isToolLevel { + // Tool-level: filter to specific tool + var foundTool *schemas.ChatTool + toolNameLower := strings.ToLower(toolName) + for i, tool := range tools { + if tool.Function != nil { + // Strip client prefix and replace - with _ for comparison + unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) + unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") + if strings.ToLower(unprefixedToolName) == toolNameLower { + foundTool = &tools[i] + break + } + } + } + + if foundTool == nil { + availableTools := make([]string, 0) + for _, tool := range tools { + if tool.Function != nil { + // Strip client prefix and replace - with _ for display + unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) + unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") + availableTools = append(availableTools, unprefixedToolName) + } + } + errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName) + for _, t := range availableTools { + errorMsg += fmt.Sprintf(" - %s/%s.pyi\n", clientName, t) + } + return createToolResponseMessage(toolCall, errorMsg), nil + } + + matchedTools = []schemas.ChatTool{*foundTool} + } else { + // Server-level: use all tools + matchedTools = tools + } + } + } + + if matchedClientName == "" { + // Build helpful error message with available files + bindingLevel := s.GetBindingLevel() + var availableFiles []string + + for name := range availableToolsPerClient { + if bindingLevel == schemas.CodeModeBindingLevelServer { + availableFiles = append(availableFiles, fmt.Sprintf("%s.pyi", name)) + } else { + client := s.clientManager.GetClientByName(name) + if client != nil && client.ExecutionConfig.IsCodeModeClient { + if tools, ok := availableToolsPerClient[name]; ok { + for _, tool := range tools { + if tool.Function != nil { + // Strip client prefix and replace - with _ for display + unprefixedToolName := stripClientPrefix(tool.Function.Name, name) + unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") + availableFiles = append(availableFiles, fmt.Sprintf("%s/%s.pyi", name, unprefixedToolName)) + } + } + } + } + } + } + + errorMsg := fmt.Sprintf("No server found matching '%s'. Available virtual files are:\n", serverName) + for _, f := range availableFiles { + errorMsg += fmt.Sprintf(" - %s\n", f) + } + return createToolResponseMessage(toolCall, errorMsg), nil + } + + // Generate compact Python signatures + fileContent := generateCompactSignatures(matchedClientName, matchedTools, isToolLevel) + lines := strings.Split(fileContent, "\n") + totalLines := len(lines) + + // Prepend total lines info so LLM knows the file size upfront + fileContent = fmt.Sprintf("# Total lines: %d (this is the complete file, no need to paginate)\n%s", totalLines+1, fileContent) + // Recalculate lines after prepending + lines = strings.Split(fileContent, "\n") + totalLines = len(lines) + + // Handle line slicing if provided + var startLine, endLine *int + if sl, ok := arguments["startLine"].(float64); ok { + slInt := int(sl) + startLine = &slInt + } + if el, ok := arguments["endLine"].(float64); ok { + elInt := int(el) + endLine = &elInt + } + + if startLine != nil || endLine != nil { + start := 1 + if startLine != nil { + start = *startLine + } + end := totalLines + if endLine != nil { + end = *endLine + } + + // Clamp values to valid range instead of erroring + // This handles cases where LLM requests more lines than exist + if start < 1 { + start = 1 + } + if start > totalLines { + start = totalLines + } + if end < 1 { + end = 1 + } + if end > totalLines { + end = totalLines + } + if start > end { + // If start > end after clamping, just return the start line + end = start + } + + // Slice lines (convert to 0-based indexing) + selectedLines := lines[start-1 : end] + fileContent = strings.Join(selectedLines, "\n") + } + + return createToolResponseMessage(toolCall, fileContent), nil +} + +// parseVFSFilePath parses a VFS file path and extracts the server name and optional tool name. +func parseVFSFilePath(fileName string) (serverName, toolName string, isToolLevel bool) { + // Remove .pyi extension + basePath := strings.TrimSuffix(fileName, ".pyi") + + // Remove "servers/" prefix if present + basePath = strings.TrimPrefix(basePath, "servers/") + + // Defensive validation: reject paths with path traversal attempts + if strings.Contains(basePath, "..") { + // Return empty to indicate invalid path + return "", "", false + } + + // Check for path separator + parts := strings.Split(basePath, "/") + if len(parts) == 2 { + // Tool-level: "serverName/toolName" + // Validate that tool name doesn't contain additional path separators or traversal + if parts[1] == "" || strings.Contains(parts[1], "/") || strings.Contains(parts[1], "..") { + // Invalid tool name, treat as server-level + return parts[0], "", false + } + return parts[0], parts[1], true + } + // Server-level: "serverName" + // Validate server name doesn't contain path separators or traversal + if strings.Contains(basePath, "/") || strings.Contains(basePath, "..") { + // Invalid path + return "", "", false + } + return basePath, "", false +} + +// generateCompactSignatures generates compact Python function signatures for tools. +func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isToolLevel bool) string { + var sb strings.Builder + + // Minimal header + if isToolLevel && len(tools) == 1 && tools[0].Function != nil { + toolName := parseToolName(stripClientPrefix(tools[0].Function.Name, clientName)) + sb.WriteString(fmt.Sprintf("# %s.%s tool\n", clientName, toolName)) + } else { + sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName)) + } + sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName)) + sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName)) + sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n") + + for _, tool := range tools { + if tool.Function == nil || tool.Function.Name == "" { + continue + } + + // Strip client prefix and replace - with _ for code mode compatibility + unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) + unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") + toolName := parseToolName(unprefixedToolName) + + // Format inline parameters in Python style + params := formatPythonParams(tool.Function.Parameters) + + // Get description (truncate if too long) + desc := "" + if tool.Function.Description != nil && *tool.Function.Description != "" { + desc = *tool.Function.Description + // Truncate long descriptions to first sentence or 80 chars + if idx := strings.Index(desc, ". "); idx > 0 && idx < 80 { + desc = desc[:idx+1] + } else if len(desc) > 80 { + desc = desc[:77] + "..." + } + } + + // Write Python signature: def tool_name(param: type, param: type = None) -> dict: # description + if desc != "" { + sb.WriteString(fmt.Sprintf("def %s(%s) -> dict: # %s\n", toolName, params, desc)) + } else { + sb.WriteString(fmt.Sprintf("def %s(%s) -> dict\n", toolName, params)) + } + } + + return sb.String() +} + +// formatPythonParams formats tool parameters as Python function parameters. +func formatPythonParams(params *schemas.ToolFunctionParameters) string { + if params == nil || params.Properties == nil || len(*params.Properties) == 0 { + return "" + } + + props := *params.Properties + required := make(map[string]bool) + if params.Required != nil { + for _, req := range params.Required { + required[req] = true + } + } + + // Sort properties: required first, then optional, alphabetically within each group + requiredNames := make([]string, 0) + optionalNames := make([]string, 0) + for name := range props { + if required[name] { + requiredNames = append(requiredNames, name) + } else { + optionalNames = append(optionalNames, name) + } + } + // Simple alphabetical sort for each group + for i := 0; i < len(requiredNames)-1; i++ { + for j := i + 1; j < len(requiredNames); j++ { + if requiredNames[i] > requiredNames[j] { + requiredNames[i], requiredNames[j] = requiredNames[j], requiredNames[i] + } + } + } + for i := 0; i < len(optionalNames)-1; i++ { + for j := i + 1; j < len(optionalNames); j++ { + if optionalNames[i] > optionalNames[j] { + optionalNames[i], optionalNames[j] = optionalNames[j], optionalNames[i] + } + } + } + + parts := make([]string, 0, len(props)) + + // Add required params first + for _, propName := range requiredNames { + prop := props[propName] + propMap, ok := prop.(map[string]interface{}) + if !ok { + continue + } + pyType := jsonSchemaToPython(propMap) + parts = append(parts, fmt.Sprintf("%s: %s", propName, pyType)) + } + + // Add optional params with default None + for _, propName := range optionalNames { + prop := props[propName] + propMap, ok := prop.(map[string]interface{}) + if !ok { + continue + } + pyType := jsonSchemaToPython(propMap) + parts = append(parts, fmt.Sprintf("%s: %s = None", propName, pyType)) + } + + return strings.Join(parts, ", ") +} + +// jsonSchemaToPython converts a JSON Schema type definition to a Python type string. +func jsonSchemaToPython(prop map[string]interface{}) string { + // Check for enum first - takes precedence over type to show allowed values + if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 { + enumStrs := make([]string, 0, len(enum)) + for _, e := range enum { + enumStrs = append(enumStrs, fmt.Sprintf("%q", e)) + } + return "Literal[" + strings.Join(enumStrs, ", ") + "]" + } + + // Check for const (single fixed value) + if constVal, ok := prop["const"]; ok { + return fmt.Sprintf("Literal[%q]", constVal) + } + + // Fall back to type-based conversion + if typeVal, ok := prop["type"].(string); ok { + switch typeVal { + case "string": + return "str" + case "number": + return "float" + case "integer": + return "int" + case "boolean": + return "bool" + case "array": + itemsType := "Any" + if items, ok := prop["items"].(map[string]interface{}); ok { + itemsType = jsonSchemaToPython(items) + } + return fmt.Sprintf("list[%s]", itemsType) + case "object": + return "dict" + case "null": + return "None" + } + } + + return "Any" +} diff --git a/core/mcp/codemode/starlark/starlark.go b/core/mcp/codemode/starlark/starlark.go new file mode 100644 index 0000000000..fc85488867 --- /dev/null +++ b/core/mcp/codemode/starlark/starlark.go @@ -0,0 +1,164 @@ +//go:build !tinygo && !wasm + +// Package starlark provides a Starlark-based implementation of the CodeMode interface. +// Starlark is a Python-like language designed for configuration and embedded scripting. +// See https://github.com/google/starlark-go for more information. +package starlark + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +// StarlarkCodeMode implements the CodeMode interface using a Starlark interpreter. +// It provides a sandboxed Python-like execution environment with access to MCP tools. +type StarlarkCodeMode struct { + // Configuration (atomic for thread-safe updates) + bindingLevel atomic.Value // schemas.CodeModeBindingLevel + toolExecutionTimeout atomic.Value // time.Duration + + // Dependencies + clientManager mcp.ClientManager + pluginPipelineProvider func() mcp.PluginPipeline + releasePluginPipeline func(pipeline mcp.PluginPipeline) + fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string + + // Mutex for protecting logs during concurrent execution + logMu sync.Mutex +} + +// NewStarlarkCodeMode creates a new Starlark-based CodeMode implementation. +// +// Parameters: +// - config: Configuration for the code mode (binding level, timeouts). Can be nil for defaults. +// +// Returns: +// - *StarlarkCodeMode: A new Starlark code mode instance +// +// Note: Dependencies must be set via SetDependencies before the CodeMode can execute tools. +// This allows the CodeMode to be created before the MCPManager, avoiding circular dependencies. +func NewStarlarkCodeMode(config *mcp.CodeModeConfig) *StarlarkCodeMode { + if config == nil { + config = mcp.DefaultCodeModeConfig() + } + + if config.BindingLevel == "" { + config.BindingLevel = schemas.CodeModeBindingLevelServer + } + + if config.ToolExecutionTimeout <= 0 { + config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout + } + + s := &StarlarkCodeMode{} + + // Initialize atomic values + s.bindingLevel.Store(config.BindingLevel) + s.toolExecutionTimeout.Store(config.ToolExecutionTimeout) + + logger.Info("%s Starlark code mode initialized with binding level: %s, timeout: %v", + mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout) + + return s +} + +// SetDependencies sets the dependencies required for code execution. +// This must be called after the MCPManager is created, as the dependencies +// include the ClientManager (which is the MCPManager itself). +func (s *StarlarkCodeMode) SetDependencies(deps *mcp.CodeModeDependencies) { + if deps != nil { + s.clientManager = deps.ClientManager + s.pluginPipelineProvider = deps.PluginPipelineProvider + s.releasePluginPipeline = deps.ReleasePluginPipeline + s.fetchNewRequestIDFunc = deps.FetchNewRequestIDFunc + } +} + +// GetTools returns the code mode meta-tools for Starlark execution. +// These tools allow LLMs to discover, read, and execute code against MCP servers. +func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool { + return []schemas.ChatTool{ + s.createListToolFilesTool(), + s.createReadToolFileTool(), + s.createGetToolDocsTool(), + s.createExecuteToolCodeTool(), + } +} + +// ExecuteTool handles a code mode tool call. +// It dispatches to the appropriate handler based on the tool name. +// +// Parameters: +// - ctx: Context for tool execution +// - toolCall: The tool call to execute +// +// Returns: +// - *schemas.ChatMessage: The tool response message +// - error: Any error that occurred during execution +func (s *StarlarkCodeMode) ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + if toolCall.Function.Name == nil { + return nil, fmt.Errorf("tool call missing function name") + } + + toolName := *toolCall.Function.Name + + switch toolName { + case mcp.ToolTypeListToolFiles: + return s.handleListToolFiles(ctx, toolCall) + case mcp.ToolTypeReadToolFile: + return s.handleReadToolFile(ctx, toolCall) + case mcp.ToolTypeGetToolDocs: + return s.handleGetToolDocs(ctx, toolCall) + case mcp.ToolTypeExecuteToolCode: + return s.handleExecuteToolCode(ctx, toolCall) + default: + return nil, fmt.Errorf("unknown code mode tool: %s", toolName) + } +} + +// IsCodeModeTool returns true if the given tool name is a code mode tool. +func (s *StarlarkCodeMode) IsCodeModeTool(toolName string) bool { + return mcp.IsCodeModeTool(toolName) +} + +// GetBindingLevel returns the current code mode binding level. +func (s *StarlarkCodeMode) GetBindingLevel() schemas.CodeModeBindingLevel { + val := s.bindingLevel.Load() + if val == nil { + return schemas.CodeModeBindingLevelServer + } + return val.(schemas.CodeModeBindingLevel) +} + +// UpdateConfig updates the code mode configuration atomically. +func (s *StarlarkCodeMode) UpdateConfig(config *mcp.CodeModeConfig) { + if config == nil { + return + } + + if config.BindingLevel != "" { + s.bindingLevel.Store(config.BindingLevel) + } + + if config.ToolExecutionTimeout > 0 { + s.toolExecutionTimeout.Store(config.ToolExecutionTimeout) + } + + logger.Info("%s Starlark code mode configuration updated: binding level=%s, timeout=%v", + mcp.CodeModeLogPrefix, config.BindingLevel, config.ToolExecutionTimeout) +} + +// getToolExecutionTimeout returns the current tool execution timeout. +func (s *StarlarkCodeMode) getToolExecutionTimeout() time.Duration { + val := s.toolExecutionTimeout.Load() + if val == nil { + return schemas.DefaultToolExecutionTimeout + } + return val.(time.Duration) +} diff --git a/core/mcp/codemodeexecutecode_test.go b/core/mcp/codemode/starlark/starlark_test.go similarity index 55% rename from core/mcp/codemodeexecutecode_test.go rename to core/mcp/codemode/starlark/starlark_test.go index 86f6481e9a..dba557f88a 100644 --- a/core/mcp/codemodeexecutecode_test.go +++ b/core/mcp/codemode/starlark/starlark_test.go @@ -1,12 +1,243 @@ -package mcp +//go:build !tinygo && !wasm + +package starlark import ( "testing" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" + "go.starlark.net/starlark" ) +func TestStarlarkToGo(t *testing.T) { + t.Run("Convert None", func(t *testing.T) { + result := starlarkToGo(starlark.None) + if result != nil { + t.Errorf("Expected nil, got %v", result) + } + }) + + t.Run("Convert Bool", func(t *testing.T) { + result := starlarkToGo(starlark.Bool(true)) + if result != true { + t.Errorf("Expected true, got %v", result) + } + }) + + t.Run("Convert Int", func(t *testing.T) { + result := starlarkToGo(starlark.MakeInt(42)) + if result != int64(42) { + t.Errorf("Expected 42, got %v", result) + } + }) + + t.Run("Convert Float", func(t *testing.T) { + result := starlarkToGo(starlark.Float(3.14)) + if result != 3.14 { + t.Errorf("Expected 3.14, got %v", result) + } + }) + + t.Run("Convert String", func(t *testing.T) { + result := starlarkToGo(starlark.String("hello")) + if result != "hello" { + t.Errorf("Expected 'hello', got %v", result) + } + }) + + t.Run("Convert List", func(t *testing.T) { + list := starlark.NewList([]starlark.Value{ + starlark.MakeInt(1), + starlark.MakeInt(2), + starlark.MakeInt(3), + }) + result := starlarkToGo(list) + arr, ok := result.([]interface{}) + if !ok { + t.Errorf("Expected []interface{}, got %T", result) + } + if len(arr) != 3 { + t.Errorf("Expected length 3, got %d", len(arr)) + } + if arr[0] != int64(1) { + t.Errorf("Expected first element 1, got %v", arr[0]) + } + }) + + t.Run("Convert Dict", func(t *testing.T) { + dict := starlark.NewDict(2) + dict.SetKey(starlark.String("key1"), starlark.String("value1")) + dict.SetKey(starlark.String("key2"), starlark.MakeInt(42)) + + result := starlarkToGo(dict) + m, ok := result.(map[string]interface{}) + if !ok { + t.Errorf("Expected map[string]interface{}, got %T", result) + } + if m["key1"] != "value1" { + t.Errorf("Expected key1='value1', got %v", m["key1"]) + } + if m["key2"] != int64(42) { + t.Errorf("Expected key2=42, got %v", m["key2"]) + } + }) +} + +func TestGoToStarlark(t *testing.T) { + t.Run("Convert nil", func(t *testing.T) { + result := goToStarlark(nil) + if result != starlark.None { + t.Errorf("Expected None, got %v", result) + } + }) + + t.Run("Convert bool", func(t *testing.T) { + result := goToStarlark(true) + if result != starlark.Bool(true) { + t.Errorf("Expected True, got %v", result) + } + }) + + t.Run("Convert int", func(t *testing.T) { + result := goToStarlark(42) + expected := starlark.MakeInt(42) + if result.String() != expected.String() { + t.Errorf("Expected %v, got %v", expected, result) + } + }) + + t.Run("Convert float64", func(t *testing.T) { + result := goToStarlark(3.14) + if result != starlark.Float(3.14) { + t.Errorf("Expected 3.14, got %v", result) + } + }) + + t.Run("Convert string", func(t *testing.T) { + result := goToStarlark("hello") + if result != starlark.String("hello") { + t.Errorf("Expected 'hello', got %v", result) + } + }) + + t.Run("Convert slice", func(t *testing.T) { + result := goToStarlark([]interface{}{1, "two", 3.0}) + list, ok := result.(*starlark.List) + if !ok { + t.Errorf("Expected *starlark.List, got %T", result) + } + if list.Len() != 3 { + t.Errorf("Expected length 3, got %d", list.Len()) + } + }) + + t.Run("Convert map", func(t *testing.T) { + result := goToStarlark(map[string]interface{}{ + "key1": "value1", + "key2": 42, + }) + dict, ok := result.(*starlark.Dict) + if !ok { + t.Errorf("Expected *starlark.Dict, got %T", result) + } + val, found, _ := dict.Get(starlark.String("key1")) + if !found { + t.Errorf("Expected key1 to exist") + } + if val != starlark.String("value1") { + t.Errorf("Expected value1, got %v", val) + } + }) +} + +func TestGeneratePythonErrorHints(t *testing.T) { + serverKeys := []string{"calculator", "weather"} + + t.Run("Undefined variable hint", func(t *testing.T) { + hints := generatePythonErrorHints("name 'foo' is not defined", serverKeys) + if len(hints) == 0 { + t.Error("Expected hints, got none") + } + found := false + for _, hint := range hints { + if containsAny(hint, "not defined", "undefined") { + found = true + break + } + } + if !found { + t.Error("Expected hint about undefined variable") + } + }) + + t.Run("Syntax error hint", func(t *testing.T) { + hints := generatePythonErrorHints("syntax error at line 5", serverKeys) + if len(hints) == 0 { + t.Error("Expected hints, got none") + } + found := false + for _, hint := range hints { + if containsAny(hint, "syntax", "indentation", "colon") { + found = true + break + } + } + if !found { + t.Error("Expected hint about syntax error") + } + }) + + t.Run("Attribute error hint", func(t *testing.T) { + hints := generatePythonErrorHints("'dict' object has no attribute 'foo'", serverKeys) + if len(hints) == 0 { + t.Error("Expected hints, got none") + } + found := false + for _, hint := range hints { + if containsAny(hint, "attribute", "brackets", "key") { + found = true + break + } + } + if !found { + t.Error("Expected hint about attribute access") + } + }) +} + +func containsAny(s string, substrs ...string) bool { + for _, sub := range substrs { + if containsIgnoreCase(s, sub) { + return true + } + } + return false +} + +func containsIgnoreCase(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && (containsIgnoreCase(s[1:], substr) || (len(s) >= len(substr) && equalFold(s[:len(substr)], substr)))) +} + +func equalFold(a, b string) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + ca, cb := a[i], b[i] + if ca >= 'A' && ca <= 'Z' { + ca += 'a' - 'A' + } + if cb >= 'A' && cb <= 'Z' { + cb += 'a' - 'A' + } + if ca != cb { + return false + } + } + return true +} + func TestExtractResultFromResponsesMessage(t *testing.T) { t.Run("Extract error from ResponsesMessage", func(t *testing.T) { errorMsg := "Tool is not allowed by security policy: dangerous_tool" diff --git a/core/mcp/codemode/starlark/utils.go b/core/mcp/codemode/starlark/utils.go new file mode 100644 index 0000000000..c3baf3cfb3 --- /dev/null +++ b/core/mcp/codemode/starlark/utils.go @@ -0,0 +1,365 @@ +//go:build !tinygo && !wasm + +package starlark + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + "unicode" + + "github.com/bytedance/sonic" + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" + "go.starlark.net/starlark" + "go.starlark.net/starlarkstruct" +) + +// starlarkToGo converts a Starlark value to a Go value +func starlarkToGo(v starlark.Value) interface{} { + switch val := v.(type) { + case starlark.NoneType: + return nil + case starlark.Bool: + return bool(val) + case starlark.Int: + if i, ok := val.Int64(); ok { + return i + } + if i, ok := val.Uint64(); ok { + return i + } + return val.String() + case starlark.Float: + return float64(val) + case starlark.String: + return string(val) + case *starlark.List: + result := make([]interface{}, val.Len()) + for i := 0; i < val.Len(); i++ { + result[i] = starlarkToGo(val.Index(i)) + } + return result + case starlark.Tuple: + result := make([]interface{}, len(val)) + for i, item := range val { + result[i] = starlarkToGo(item) + } + return result + case *starlark.Dict: + result := make(map[string]interface{}) + for _, item := range val.Items() { + if keyStr, ok := item[0].(starlark.String); ok { + result[string(keyStr)] = starlarkToGo(item[1]) + } else { + // Use string representation for non-string keys + result[item[0].String()] = starlarkToGo(item[1]) + } + } + return result + case *starlarkstruct.Struct: + result := make(map[string]interface{}) + for _, name := range val.AttrNames() { + if attrVal, err := val.Attr(name); err == nil { + result[name] = starlarkToGo(attrVal) + } + } + return result + default: + return val.String() + } +} + +// goToStarlark converts a Go value to a Starlark value +func goToStarlark(v interface{}) starlark.Value { + if v == nil { + return starlark.None + } + + switch val := v.(type) { + case bool: + return starlark.Bool(val) + case int: + return starlark.MakeInt(val) + case int64: + return starlark.MakeInt64(val) + case uint64: + return starlark.MakeUint64(val) + case float64: + return starlark.Float(val) + case string: + return starlark.String(val) + case []interface{}: + items := make([]starlark.Value, len(val)) + for i, item := range val { + items[i] = goToStarlark(item) + } + return starlark.NewList(items) + case map[string]interface{}: + dict := starlark.NewDict(len(val)) + for k, v := range val { + dict.SetKey(starlark.String(k), goToStarlark(v)) + } + return dict + default: + // Try to marshal to JSON and parse as a generic structure + if jsonBytes, err := sonic.Marshal(val); err == nil { + var generic interface{} + if sonic.Unmarshal(jsonBytes, &generic) == nil { + return goToStarlark(generic) + } + } + return starlark.String(fmt.Sprintf("%v", val)) + } +} + +// extractResultFromChatMessage extracts the result from a chat message and parses it as JSON if possible. +func extractResultFromChatMessage(msg *schemas.ChatMessage) interface{} { + if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil { + return nil + } + + rawResult := *msg.Content.ContentStr + + var finalResult interface{} + if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { + return rawResult + } + + return finalResult +} + +// extractResultFromResponsesMessage extracts the result or error from a ResponsesMessage. +func extractResultFromResponsesMessage(msg *schemas.ResponsesMessage) (interface{}, error) { + if msg == nil { + return nil, nil + } + + if msg.ResponsesToolMessage != nil { + if msg.ResponsesToolMessage.Error != nil && *msg.ResponsesToolMessage.Error != "" { + return nil, fmt.Errorf("%s", *msg.ResponsesToolMessage.Error) + } + + if msg.ResponsesToolMessage.Output != nil { + if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { + rawResult := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr + + var finalResult interface{} + if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { + return rawResult, nil + } + return finalResult, nil + } + + if len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0 { + var textParts []string + for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + } + if len(textParts) > 0 { + result := strings.Join(textParts, "\n") + var finalResult interface{} + if err := sonic.Unmarshal([]byte(result), &finalResult); err != nil { + return result, nil + } + return finalResult, nil + } + } + } + } + + return nil, nil +} + +// formatResultForLog formats a result value for logging purposes. +func formatResultForLog(result interface{}) string { + var resultStr string + if result == nil { + resultStr = "null" + } else if resultBytes, err := sonic.Marshal(result); err == nil { + resultStr = string(resultBytes) + } else { + resultStr = fmt.Sprintf("%v", result) + } + return resultStr +} + +// generatePythonErrorHints generates helpful hints for Python/Starlark errors. +func generatePythonErrorHints(errorMessage string, serverKeys []string) []string { + hints := []string{} + + if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") { + re := regexp.MustCompile(`(\w+).*(?:undefined|not defined)`) + if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar := match[1] + hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar)) + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")") + } + } + } else if strings.Contains(errorMessage, "syntax error") { + hints = append(hints, "Python syntax error detected.") + hints = append(hints, "Check for proper indentation (use spaces, not tabs).") + hints = append(hints, "Ensure colons after if/for/def statements.") + hints = append(hints, "Check for matching parentheses and brackets.") + } else if strings.Contains(errorMessage, "has no") && strings.Contains(errorMessage, "attribute") { + hints = append(hints, "You're trying to access an attribute that doesn't exist.") + hints = append(hints, "Use dict access syntax: result[\"key\"] instead of result.key") + hints = append(hints, "Use print(result) to see the actual structure.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + } else if strings.Contains(errorMessage, "not callable") { + hints = append(hints, "You're trying to call something that is not a function.") + hints = append(hints, "Ensure you're using the correct tool name.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + hints = append(hints, "Use readToolFile to see available tools for a server.") + } else if strings.Contains(errorMessage, "key") && strings.Contains(errorMessage, "not found") { + hints = append(hints, "Dictionary key not found.") + hints = append(hints, "Use print() to inspect the dict structure before accessing keys.") + hints = append(hints, "Use .get(\"key\", default) for safe access.") + } else { + hints = append(hints, "Check the error message above for details.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + hints = append(hints, "Use: result = server_name.tool_name(param=\"value\")") + hints = append(hints, "Access dict values with brackets: result[\"key\"]") + } + + return hints +} + +// extractTextFromMCPResponse extracts text content from an MCP tool response. +func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string { + if toolResponse == nil { + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) + } + + var result strings.Builder + for _, contentBlock := range toolResponse.Content { + // Handle typed content + switch content := contentBlock.(type) { + case mcp.TextContent: + result.WriteString(content.Text) + case mcp.ImageContent: + result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.AudioContent: + result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.EmbeddedResource: + result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type)) + default: + // Fallback: try to extract from map structure + if jsonBytes, err := json.Marshal(contentBlock); err == nil { + var contentMap map[string]interface{} + if json.Unmarshal(jsonBytes, &contentMap) == nil { + if text, ok := contentMap["text"].(string); ok { + result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text)) + continue + } + } + // Final fallback: serialize as JSON + result.WriteString(string(jsonBytes)) + } + } + } + + if result.Len() > 0 { + return strings.TrimSpace(result.String()) + } + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) +} + +// createToolResponseMessage creates a tool response message with the execution result. +func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage { + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: &responseText, + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: toolCall.ID, + }, + } +} + +// parseToolName parses the tool name to be JavaScript-compatible. +func parseToolName(toolName string) string { + if toolName == "" { + return "" + } + + var result strings.Builder + runes := []rune(toolName) + + // Process first character - must be letter, underscore, or dollar sign + if len(runes) > 0 { + first := runes[0] + if unicode.IsLetter(first) || first == '_' || first == '$' { + result.WriteRune(unicode.ToLower(first)) + } else { + // If first char is invalid, prefix with underscore + result.WriteRune('_') + if unicode.IsDigit(first) { + result.WriteRune(first) + } + } + } + + // Process remaining characters + for i := 1; i < len(runes); i++ { + r := runes[i] + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' { + result.WriteRune(unicode.ToLower(r)) + } else if unicode.IsSpace(r) || r == '-' { + // Replace spaces and hyphens with single underscore + // Avoid consecutive underscores + if result.Len() > 0 && result.String()[result.Len()-1] != '_' { + result.WriteRune('_') + } + } + // Skip other invalid characters + } + + parsed := result.String() + + // Remove trailing underscores + parsed = strings.TrimRight(parsed, "_") + + // Ensure we have at least one character + if parsed == "" { + return "tool" + } + + return parsed +} + +// validateNormalizedToolName validates a normalized tool name to prevent path traversal. +func validateNormalizedToolName(normalizedName string) error { + if normalizedName == "" { + return fmt.Errorf("tool name cannot be empty after normalization") + } + if strings.Contains(normalizedName, "/") { + return fmt.Errorf("tool name cannot contain '/' (path separator) after normalization: %s", normalizedName) + } + if strings.Contains(normalizedName, "..") { + return fmt.Errorf("tool name cannot contain '..' (path traversal) after normalization: %s", normalizedName) + } + return nil +} + +// stripClientPrefix removes the client name prefix from a tool name. +func stripClientPrefix(prefixedToolName, clientName string) string { + prefix := clientName + "-" + if strings.HasPrefix(prefixedToolName, prefix) { + return strings.TrimPrefix(prefixedToolName, prefix) + } + // If prefix doesn't match, return as-is (shouldn't happen, but be safe) + return prefixedToolName +} diff --git a/core/mcp/codemodeexecutecode.go b/core/mcp/codemodeexecutecode.go deleted file mode 100644 index a9d360145f..0000000000 --- a/core/mcp/codemodeexecutecode.go +++ /dev/null @@ -1,1369 +0,0 @@ -package mcp - -import ( - "context" - "fmt" - "regexp" - "strings" - "time" - - "github.com/bytedance/sonic" - "github.com/clarkmcc/go-typescript" - "github.com/dop251/goja" - "github.com/mark3labs/mcp-go/mcp" - "github.com/maximhq/bifrost/core/schemas" -) - -// toolBinding represents a tool binding for the VM -type toolBinding struct { - toolName string - clientName string -} - -// toolCallInfo represents a tool call extracted from code -type toolCallInfo struct { - serverName string - toolName string -} - -// ExecutionResult represents the result of code execution -type ExecutionResult struct { - Result interface{} `json:"result"` - Logs []string `json:"logs"` - Errors *ExecutionError `json:"errors,omitempty"` - Environment ExecutionEnvironment `json:"environment"` -} - -type ExecutionErrorType string - -const ( - ExecutionErrorTypeCompile ExecutionErrorType = "compile" - ExecutionErrorTypeTypescript ExecutionErrorType = "typescript" - ExecutionErrorTypeRuntime ExecutionErrorType = "runtime" -) - -// ExecutionError represents an error during code execution -type ExecutionError struct { - Kind ExecutionErrorType `json:"kind"` // "compile", "typescript", or "runtime" - Message string `json:"message"` - Hints []string `json:"hints"` -} - -// ExecutionEnvironment contains information about the execution environment -type ExecutionEnvironment struct { - ServerKeys []string `json:"serverKeys"` - ImportsStripped bool `json:"importsStripped"` - StrippedLines []int `json:"strippedLines"` - TypeScriptUsed bool `json:"typescriptUsed"` -} - -const ( - CodeModeLogPrefix = "[CODE MODE]" -) - -// createExecuteToolCodeTool creates the executeToolCode tool definition for code mode. -// This tool allows executing TypeScript code in a sandboxed VM with access to MCP server tools. -// -// Returns: -// - schemas.ChatTool: The tool definition for executing tool code -func (m *ToolsManager) createExecuteToolCodeTool() schemas.ChatTool { - executeToolCodeProps := schemas.OrderedMap{ - "code": map[string]interface{}{ - "type": "string", - "description": "TypeScript code to execute. The code will be transpiled to JavaScript and validated before execution. Import/export statements will be stripped. You can use async/await syntax for async operations. For simple use cases, directly return results. Check keys and value types only for debugging. Do not print entire outputs in console logs - only print structure (keys, types) when debugging. ALWAYS retry if code fails. Example (simple): const result = await serverName.toolName({arg: 'value'}); return result; Example (debugging): const result = await serverName.toolName({arg: 'value'}); const getStruct = (o, d=0) => d>2 ? '...' : o===null ? 'null' : Array.isArray(o) ? `Array[${o.length}]` : typeof o !== 'object' ? typeof o : Object.keys(o).reduce((a,k) => (a[k]=getStruct(o[k],d+1), a), {}); console.log('Structure:', getStruct(result)); return result;", - }, - } - return schemas.ChatTool{ - Type: schemas.ChatToolTypeFunction, - Function: &schemas.ChatToolFunction{ - Name: ToolTypeExecuteToolCode, - Description: schemas.Ptr( - "Executes TypeScript code inside a sandboxed goja-based VM with access to all connected MCP servers' tools. " + - "TypeScript code is automatically transpiled to JavaScript and validated before execution, providing type checking and validation. " + - "All connected servers are exposed as global objects named after their configuration keys, and each server " + - "provides async (Promise-returning) functions for every tool available on that server. The canonical usage " + - "pattern is: const result = await .({ ...args }); Both and " + - "should be discovered using listToolFiles and readToolFile. " + - - "IMPORTANT WORKFLOW: Always follow this order — first use listToolFiles to see available servers and tools, " + - "then use readToolFile to understand the tool definitions and their parameters, and finally use executeToolCode " + - "to execute your code. Check listToolFiles whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + - - "LOGGING GUIDELINES: For simple use cases, you can directly return results without logging. Check for keys and value types only " + - "for debugging purposes when you need to understand the response structure. Do not print the entire output in console logs. " + - "When debugging, use console logs to print just the output structure to understand its type. For nested objects, use a recursive helper to show types at all levels. " + - "For example: const getStruct = (o, d=0) => d>2 ? '...' : o===null ? 'null' : Array.isArray(o) ? `Array[${o.length}]` : typeof o !== 'object' ? typeof o : Object.keys(o).reduce((a,k) => (a[k]=getStruct(o[k],d+1), a), {}); " + - "console.log('Structure:', getStruct(result)); Only print the entire data if absolutely necessary for debugging. " + - "This helps understand the response structure without cluttering the output with full object contents. " + - - "RETRY POLICY: ALWAYS retry if a code block fails. If execution produces an error or unexpected result, analyze the error, " + - "adjust your code accordingly for better results or debugging, and retry the execution. Do not give up after a single failure — iterate and improve your code until it succeeds. " + - - "The environment is intentionally minimal and has several constraints: " + - "• ES modules are not supported — any leading import/export statements are automatically stripped and imported symbols will not exist. " + - "• Browser and Node APIs such as fetch, XMLHttpRequest, axios, require, setTimeout, setInterval, window, and document do not exist. " + - "• async/await syntax is supported and automatically transpiled to Promise chains compatible with goja. " + - "• Using undefined server names or tool names will result in reference or function errors. " + - "• The VM does not emulate a browser or Node.js environment — no DOM, timers, modules, or network APIs are available. " + - "• Only ES5.1+ features supported by goja are guaranteed to work. " + - "• TypeScript type checking occurs during transpilation — type errors will prevent execution. " + - - "If you want a value returned from the code, write a top-level 'return '; otherwise the return value will be null. " + - "Console output (log, error, warn, info) is captured and returned. " + - "Long-running or blocked operations are interrupted via execution timeout. " + - "This tool is designed specifically for orchestrating MCP tool calls and lightweight TypeScript computation.", - ), - - Parameters: &schemas.ToolFunctionParameters{ - Type: "object", - Properties: &executeToolCodeProps, - Required: []string{"code"}, - }, - }, - } -} - -// handleExecuteToolCode handles the executeToolCode tool call. -// It parses the code argument, executes it in a sandboxed VM, and formats the response -// with execution results, logs, errors, and environment information. -// -// Parameters: -// - ctx: Context for code execution -// - toolCall: The tool call request containing the TypeScript code to execute -// -// Returns: -// - *schemas.ChatMessage: A tool response message containing execution results -// - error: Any error that occurred during processing -func (m *ToolsManager) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { - toolName := "unknown" - if toolCall.Function.Name != nil { - toolName = *toolCall.Function.Name - } - logger.Debug(fmt.Sprintf("%s Handling executeToolCode tool call: %s", CodeModeLogPrefix, toolName)) - - // Parse tool arguments - var arguments map[string]interface{} - if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { - logger.Debug(fmt.Sprintf("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)) - return nil, fmt.Errorf("failed to parse tool arguments: %v", err) - } - - code, ok := arguments["code"].(string) - if !ok || code == "" { - logger.Debug(fmt.Sprintf("%s Code parameter missing or empty", CodeModeLogPrefix)) - return nil, fmt.Errorf("code parameter is required and must be a non-empty string") - } - - logger.Debug(fmt.Sprintf("%s Starting code execution", CodeModeLogPrefix)) - result := m.executeCode(ctx, code) - logger.Debug(fmt.Sprintf("%s Code execution completed. Success: %v, Has errors: %v, Log count: %d", CodeModeLogPrefix, result.Errors == nil, result.Errors != nil, len(result.Logs))) - - // Format response text - var responseText string - var executionSuccess bool = true // Track if execution was successful (has data) - if result.Errors != nil { - logger.Debug(fmt.Sprintf("%s Formatting error response. Error kind: %s, Message length: %d, Hints count: %d", CodeModeLogPrefix, result.Errors.Kind, len(result.Errors.Message), len(result.Errors.Hints))) - logsText := "" - if len(result.Logs) > 0 { - logsText = fmt.Sprintf("\n\nConsole/Log Output:\n%s\n", - strings.Join(result.Logs, "\n")) - } - errorKindLabel := result.Errors.Kind - - responseText = fmt.Sprintf( - "Execution %s error:\n\n%s\n\nHints:\n%s%s\n\nEnvironment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", - errorKindLabel, - result.Errors.Message, - strings.Join(result.Errors.Hints, "\n"), - logsText, - strings.Join(result.Environment.ServerKeys, ", "), - map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], - map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped], - ) - if len(result.Environment.StrippedLines) > 0 { - strippedStr := make([]string, len(result.Environment.StrippedLines)) - for i, line := range result.Environment.StrippedLines { - strippedStr[i] = fmt.Sprintf("%d", line) - } - responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) - } - logger.Debug(fmt.Sprintf("%s Error response formatted. Response length: %d chars", CodeModeLogPrefix, len(responseText))) - } else { - // Success case - check if execution produced any data - hasLogs := len(result.Logs) > 0 - hasResult := result.Result != nil - logger.Debug(fmt.Sprintf("%s Formatting success response. Has logs: %v, Has result: %v", CodeModeLogPrefix, hasLogs, hasResult)) - - // If execution completed but produced no data (no logs, no return value), treat as failure - if !hasLogs && !hasResult { - executionSuccess = false - logger.Debug(fmt.Sprintf("%s Execution completed with no data (no logs, no result), marking as failure", CodeModeLogPrefix)) - hints := []string{ - "Add console.log() statements throughout your code to debug and see what's happening at each step", - "Ensure your code has a top-level return statement if you want to return a value", - "Check that your tool calls are actually executing and returning data", - "Verify that async operations (like await) are properly handled", - } - responseText = fmt.Sprintf( - "Execution completed but produced no data:\n\n"+ - "The code executed without errors but returned no output (no console logs and no return value).\n\n"+ - "Hints:\n%s\n\n"+ - "Environment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", - strings.Join(hints, "\n"), - strings.Join(result.Environment.ServerKeys, ", "), - map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], - map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped], - ) - if len(result.Environment.StrippedLines) > 0 { - strippedStr := make([]string, len(result.Environment.StrippedLines)) - for i, line := range result.Environment.StrippedLines { - strippedStr[i] = fmt.Sprintf("%d", line) - } - responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) - } - logger.Debug(fmt.Sprintf("%s No-data failure response formatted. Response length: %d chars", CodeModeLogPrefix, len(responseText))) - } else { - // Normal success case with data - if hasLogs { - responseText = fmt.Sprintf("Console output:\n%s\n\nExecution completed successfully.", - strings.Join(result.Logs, "\n")) - } else { - responseText = "Execution completed successfully." - } - if hasResult { - resultJSON, err := sonic.MarshalIndent(result.Result, "", " ") - if err == nil { - responseText += fmt.Sprintf("\nReturn value: %s", string(resultJSON)) - logger.Debug(fmt.Sprintf("%s Added return value to response (JSON length: %d chars)", CodeModeLogPrefix, len(resultJSON))) - } else { - logger.Debug(fmt.Sprintf("%s Failed to marshal result to JSON: %v", CodeModeLogPrefix, err)) - } - } - - // Add environment information for successful executions - responseText += fmt.Sprintf("\n\nEnvironment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", - strings.Join(result.Environment.ServerKeys, ", "), - map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], - map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped]) - if len(result.Environment.StrippedLines) > 0 { - strippedStr := make([]string, len(result.Environment.StrippedLines)) - for i, line := range result.Environment.StrippedLines { - strippedStr[i] = fmt.Sprintf("%d", line) - } - responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) - } - responseText += "\nNote: Browser APIs like fetch, setTimeout are not available. Use MCP tools for external interactions." - logger.Debug(fmt.Sprintf("%s Success response formatted. Response length: %d chars, Server keys: %v", CodeModeLogPrefix, len(responseText), result.Environment.ServerKeys)) - } - } - - logger.Debug(fmt.Sprintf("%s Returning tool response message. Execution success: %v", CodeModeLogPrefix, executionSuccess)) - return createToolResponseMessage(toolCall, responseText), nil -} - -// executeCode executes TypeScript code in a sandboxed VM with MCP tool bindings. -// It handles code preprocessing (stripping imports/exports), TypeScript transpilation, -// VM setup with tool bindings, and promise-based async execution with timeout handling. -// -// Parameters: -// - ctx: Context for code execution (used for timeout and tool access) -// - code: TypeScript code string to execute -// -// Returns: -// - ExecutionResult: Result containing execution output, logs, errors, and environment info -func (m *ToolsManager) executeCode(ctx context.Context, code string) ExecutionResult { - logs := []string{} - strippedLines := []int{} - - logger.Debug(fmt.Sprintf("%s Starting TypeScript code execution", CodeModeLogPrefix)) - - // Step 1: Convert literal \n escape sequences to actual newlines first - // This ensures multiline code and import/export stripping work correctly - codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") - - // Step 2: Strip import/export statements - cleanedCode, strippedLineNumbers := stripImportsAndExports(codeWithNewlines) - strippedLines = append(strippedLines, strippedLineNumbers...) - if len(strippedLineNumbers) > 0 { - logger.Debug(fmt.Sprintf("%s Stripped %d import/export lines", CodeModeLogPrefix, len(strippedLineNumbers))) - } - - // Step 3: Handle empty code after stripping (in case stripping made it empty) - trimmedCode := strings.TrimSpace(cleanedCode) - if trimmedCode == "" { - // Empty code should return null - return early without VM execution - return ExecutionResult{ - Result: nil, - Logs: logs, - Errors: nil, - Environment: ExecutionEnvironment{ - ServerKeys: []string{}, // Will be populated below if needed, but empty code doesn't need tools - ImportsStripped: len(strippedLines) > 0, - StrippedLines: strippedLines, - TypeScriptUsed: true, - }, - } - } - - // Step 4: Wrap code in async function for proper await transpilation - // TypeScript needs an async function context to properly transpile await expressions - // Check if code is already an async IIFE - if so, await it - trimmedLower := strings.ToLower(strings.TrimSpace(trimmedCode)) - isAsyncIIFE := strings.HasPrefix(trimmedLower, "(async") && strings.Contains(trimmedCode, ")()") - - var codeToTranspile string - if isAsyncIIFE { - // Code is already an async IIFE - await it to get the result - codeToTranspile = fmt.Sprintf("async function __execute__() {\nreturn await %s\n}", trimmedCode) - } else { - // Regular code - wrap in async function - codeToTranspile = fmt.Sprintf("async function __execute__() {\n%s\n}", trimmedCode) - } - - // Step 5: Transpile TypeScript to JavaScript with validation - // Configure TypeScript compiler to transpile async/await to Promise chains (ES5 compatible) - logger.Debug(fmt.Sprintf("%s Transpiling TypeScript code", CodeModeLogPrefix)) - compileOptions := map[string]interface{}{ - "target": "ES5", // Target ES5 for goja compatibility - "module": "None", // No module system - "lib": []string{}, // No lib (minimal environment) - "downlevelIteration": true, // Support async/await transpilation - } - jsCode, transpileErr := typescript.TranspileString(codeToTranspile, typescript.WithCompileOptions(compileOptions)) - if transpileErr != nil { - logger.Debug(fmt.Sprintf("%s TypeScript transpilation failed: %v", CodeModeLogPrefix, transpileErr)) - // Build bindings to get server keys for error hints - availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) - serverKeys := make([]string, 0, len(availableToolsPerClient)) - for clientName := range availableToolsPerClient { - client := m.clientManager.GetClientByName(clientName) - if client == nil { - logger.Warn("%s Client %s not found, skipping", MCPLogPrefix, clientName) - continue - } - if !client.ExecutionConfig.IsCodeModeClient { - continue - } - serverKeys = append(serverKeys, clientName) - } - - errorMessage := transpileErr.Error() - hints := generateTypeScriptErrorHints(errorMessage, serverKeys) - - return ExecutionResult{ - Result: nil, - Logs: logs, - Errors: &ExecutionError{ - Kind: ExecutionErrorTypeTypescript, - Message: fmt.Sprintf("TypeScript compilation error: %s", errorMessage), - Hints: hints, - }, - Environment: ExecutionEnvironment{ - ServerKeys: serverKeys, - ImportsStripped: len(strippedLines) > 0, - StrippedLines: strippedLines, - TypeScriptUsed: true, - }, - } - } - - logger.Debug(fmt.Sprintf("%s TypeScript transpiled successfully", CodeModeLogPrefix)) - - // Step 5: Create timeout context early so goroutines can use it - toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) - timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) - defer cancel() - - // Step 6: Build bindings for all connected servers - availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) - bindings := make(map[string]map[string]toolBinding) - serverKeys := make([]string, 0, len(availableToolsPerClient)) - - logger.Debug(fmt.Sprintf("%s GetToolPerClient returned %d clients", CodeModeLogPrefix, len(availableToolsPerClient))) - - for clientName, tools := range availableToolsPerClient { - client := m.clientManager.GetClientByName(clientName) - if client == nil { - logger.Warn("%s Client %s not found, skipping", MCPLogPrefix, clientName) - continue - } - logger.Debug(fmt.Sprintf("%s [%s] Client found. IsCodeModeClient: %v, ToolCount: %d", CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools))) - if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { - logger.Debug(fmt.Sprintf("%s [%s] Skipped: IsCodeModeClient=%v, HasTools=%v", CodeModeLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient, len(tools) > 0)) - continue - } - serverKeys = append(serverKeys, clientName) - - toolFunctions := make(map[string]toolBinding) - - // Create a function for each tool - for _, tool := range tools { - if tool.Function == nil || tool.Function.Name == "" { - continue - } - - originalToolName := tool.Function.Name - // Strip client prefix and replace - with _ for code mode compatibility - unprefixedToolName := stripClientPrefix(originalToolName, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - // Parse tool name for property name compatibility (used as property name in the runtime) - parsedToolName := parseToolName(unprefixedToolName) - - logger.Debug(fmt.Sprintf("%s [%s] Bound tool: %s -> %s -> %s", CodeModeLogPrefix, clientName, originalToolName, unprefixedToolName, parsedToolName)) - - // Store tool binding - toolFunctions[parsedToolName] = toolBinding{ - toolName: originalToolName, - clientName: clientName, - } - } - - bindings[clientName] = toolFunctions - logger.Debug(fmt.Sprintf("%s [%s] Added to bindings with %d functions", CodeModeLogPrefix, clientName, len(toolFunctions))) - } - - if len(serverKeys) > 0 { - logger.Debug(fmt.Sprintf("%s Bound %d servers with tools: %v", CodeModeLogPrefix, len(serverKeys), serverKeys)) - } else { - logger.Debug(fmt.Sprintf("%s No servers available for code mode execution", CodeModeLogPrefix)) - } - - // Step 7: Wrap transpiled code to execute the async function and return its result - // The transpiled code contains an async function __execute__() that we need to call - // Trim trailing newlines to avoid issues when wrapping - codeToWrap := strings.TrimRight(jsCode, "\n\r") - // Wrap in IIFE that calls the transpiled async function and returns the promise - wrappedCode := fmt.Sprintf("(function() {\n%s\nreturn __execute__();\n})()", codeToWrap) - - // Step 8: Create goja runtime - vm := goja.New() - - // Step 9: Set up thread-safe logging - appendLog := func(msg string) { - m.logMu.Lock() - defer m.logMu.Unlock() - logs = append(logs, msg) - } - - // Step 10: Set up console - consoleObj := vm.NewObject() - consoleObj.Set("log", func(args ...interface{}) { - message := formatConsoleArgs(args) - appendLog(message) - }) - consoleObj.Set("error", func(args ...interface{}) { - message := formatConsoleArgs(args) - appendLog(fmt.Sprintf("[ERROR] %s", message)) - }) - consoleObj.Set("warn", func(args ...interface{}) { - message := formatConsoleArgs(args) - appendLog(fmt.Sprintf("[WARN] %s", message)) - }) - consoleObj.Set("info", func(args ...interface{}) { - message := formatConsoleArgs(args) - appendLog(fmt.Sprintf("[INFO] %s", message)) - }) - vm.Set("console", consoleObj) - - // Step 11: Set up server bindings - logger.Debug(fmt.Sprintf("%s Setting up %d server bindings in VM", CodeModeLogPrefix, len(bindings))) - for serverKey, tools := range bindings { - logger.Debug(fmt.Sprintf("%s [%s] Setting up server object with %d tools", CodeModeLogPrefix, serverKey, len(tools))) - serverObj := vm.NewObject() - for toolName, binding := range tools { - // Capture variables for closure - toolNameFinal := binding.toolName - clientNameFinal := binding.clientName - - logger.Debug(fmt.Sprintf("%s [%s] Binding tool function: %s -> %s", CodeModeLogPrefix, serverKey, toolName, toolNameFinal)) - - serverObj.Set(toolName, func(call goja.FunctionCall) goja.Value { - args := call.Argument(0).Export() - - // Convert args to map[string]interface{} - argsMap, ok := args.(map[string]interface{}) - if !ok { - logger.Debug(fmt.Sprintf("%s Invalid args type for %s.%s: expected object, got %T", - CodeModeLogPrefix, clientNameFinal, toolNameFinal, args)) - // Return rejected promise for invalid args - promise, _, reject := vm.NewPromise() - err := fmt.Errorf("expected object argument, got %T", args) - reject(vm.ToValue(err)) - return vm.ToValue(promise) - } - - // Create promise on VM goroutine (thread-safe) - promise, resolve, reject := vm.NewPromise() - - // Define result struct for channel communication - type toolResult struct { - result interface{} - err error - } - - // Create buffered channel for worker communication - resultChan := make(chan toolResult, 1) - - // Call tool asynchronously with timeout context and panic recovery - // Worker goroutine - NO VM calls allowed here - go func() { - defer func() { - if r := recover(); r != nil { - logger.Debug(fmt.Sprintf("%s Panic in tool call goroutine for %s.%s: %v", - CodeModeLogPrefix, clientNameFinal, toolNameFinal, r)) - // Send panic as error through channel (no VM calls in worker) - select { - case resultChan <- toolResult{nil, fmt.Errorf("tool call panic: %v", r)}: - case <-timeoutCtx.Done(): - // Context cancelled, ignore - } - } - }() - - // Check if context is already cancelled before starting - select { - case <-timeoutCtx.Done(): - // Send timeout error through channel (no VM calls in worker) - select { - case resultChan <- toolResult{nil, fmt.Errorf("execution timeout")}: - case <-timeoutCtx.Done(): - // Already cancelled, ignore - } - return - default: - } - - // Pass the original ctx (BifrostContext) to callMCPTool, not timeoutCtx - // callMCPTool will handle timeout internally - result, err := m.callMCPTool(ctx, clientNameFinal, toolNameFinal, argsMap, appendLog) - - // Check if context was cancelled during execution - select { - case <-timeoutCtx.Done(): - // Send timeout error through channel (no VM calls in worker) - select { - case resultChan <- toolResult{nil, fmt.Errorf("execution timeout")}: - case <-timeoutCtx.Done(): - // Already cancelled, ignore - } - return - default: - } - - // Send result through channel (no VM calls in worker) - select { - case resultChan <- toolResult{result, err}: - case <-timeoutCtx.Done(): - // Context cancelled, ignore - } - }() - - // Process result synchronously on VM goroutine to ensure thread safety - // This blocks the VM goroutine until the tool call completes, but ensures - // all VM operations (vm.ToValue, resolve, reject) happen on the correct thread - select { - case res := <-resultChan: - if res.err != nil { - logger.Debug(fmt.Sprintf("%s Tool call failed: %s.%s - %v", - CodeModeLogPrefix, clientNameFinal, toolNameFinal, res.err)) - reject(vm.ToValue(res.err)) - } else { - resolve(vm.ToValue(res.result)) - } - case <-timeoutCtx.Done(): - reject(vm.ToValue(fmt.Errorf("execution timeout"))) - } - - return vm.ToValue(promise) - }) - } - vm.Set(serverKey, serverObj) - logger.Debug(fmt.Sprintf("%s [%s] Server object set in VM", CodeModeLogPrefix, serverKey)) - } - - // Step 12: Set up environment info - envObj := vm.NewObject() - envObj.Set("serverKeys", serverKeys) - envObj.Set("version", "1.0.0") - vm.Set("__MCP_ENV__", envObj) - - // Step 13: Execute code with timeout - - // Set up interrupt handler - interruptDone := make(chan struct{}) - go func() { - select { - case <-timeoutCtx.Done(): - logger.Debug(fmt.Sprintf("%s Execution timeout reached", CodeModeLogPrefix)) - vm.Interrupt("execution timeout") - case <-interruptDone: - } - }() - - var result interface{} - var executionErr error - - func() { - defer close(interruptDone) - val, err := vm.RunString(wrappedCode) - if err != nil { - logger.Debug(fmt.Sprintf("%s VM execution error: %v", CodeModeLogPrefix, err)) - executionErr = err - return - } - - // Check if the result is a promise by checking its type - // First check if val is nil or undefined (these can't be converted to objects) - if val == nil || val == goja.Undefined() { - result = nil - return - } - - // Try to convert to object to check if it's a promise - // Use recover to safely handle null values that can't be converted to objects - var valObj *goja.Object - func() { - defer func() { - if r := recover(); r != nil { - // Value is null or can't be converted to object, just export it - valObj = nil - } - }() - valObj = val.ToObject(vm) - }() - - if valObj != nil { - // Check if it has a 'then' method (Promise-like) - if then := valObj.Get("then"); then != nil && then != goja.Undefined() { - // It's a promise, we need to await it - // Use buffered channels to prevent blocking if handlers are called after timeout - resultChan := make(chan interface{}, 1) - errChan := make(chan error, 1) - - // Set up promise handlers - thenFunc, ok := goja.AssertFunction(then) - if ok { - // Call then with resolve and reject handlers - _, err := thenFunc(val, - vm.ToValue(func(res goja.Value) { - select { - case resultChan <- res.Export(): - case <-timeoutCtx.Done(): - // Timeout already occurred, ignore result - } - }), - vm.ToValue(func(err goja.Value) { - var errMsg string - if err == nil || err == goja.Undefined() { - errMsg = "unknown error" - } else { - // Try to get error message from Error object - if errObj := err.ToObject(vm); errObj != nil { - if msg := errObj.Get("message"); msg != nil && msg != goja.Undefined() { - errMsg = msg.String() - } else if name := errObj.Get("name"); name != nil && name != goja.Undefined() { - errMsg = name.String() - } else { - errMsg = err.String() - } - } else { - // Fallback to string conversion - errMsg = err.String() - } - } - select { - case errChan <- fmt.Errorf("%s", errMsg): - case <-timeoutCtx.Done(): - // Timeout already occurred, ignore error - } - }), - ) - if err != nil { - executionErr = err - return - } - - // Wait for result or error with timeout - select { - case res := <-resultChan: - result = res - case err := <-errChan: - logger.Debug(fmt.Sprintf("%s Promise rejected: %v", CodeModeLogPrefix, err)) - executionErr = err - case <-timeoutCtx.Done(): - logger.Debug(fmt.Sprintf("%s Promise timeout while waiting for result", CodeModeLogPrefix)) - executionErr = fmt.Errorf("execution timeout") - } - } else { - result = val.Export() - } - } else { - result = val.Export() - } - } else { - // Not an object (or null/undefined), just export the value - result = val.Export() - } - }() - - if executionErr != nil { - errorMessage := executionErr.Error() - hints := generateErrorHints(errorMessage, serverKeys) - logger.Debug(fmt.Sprintf("%s Execution failed: %s", CodeModeLogPrefix, errorMessage)) - - return ExecutionResult{ - Result: nil, - Logs: logs, - Errors: &ExecutionError{ - Kind: ExecutionErrorTypeRuntime, - Message: errorMessage, - Hints: hints, - }, - Environment: ExecutionEnvironment{ - ServerKeys: serverKeys, - ImportsStripped: len(strippedLines) > 0, - StrippedLines: strippedLines, - TypeScriptUsed: true, - }, - } - } - - logger.Debug(fmt.Sprintf("%s Execution completed successfully", CodeModeLogPrefix)) - return ExecutionResult{ - Result: result, - Logs: logs, - Errors: nil, - Environment: ExecutionEnvironment{ - ServerKeys: serverKeys, - ImportsStripped: len(strippedLines) > 0, - StrippedLines: strippedLines, - TypeScriptUsed: true, - }, - } -} - -// callMCPTool calls an MCP tool and returns the result. -// It locates the client by name, constructs the MCP tool call request, executes it -// with timeout handling, and parses the response as JSON or returns it as a string. -// This function now runs MCP plugin hooks (PreMCPHook/PostMCPHook) for nested tool calls. -// -// Parameters: -// - ctx: Context for tool execution (used for timeout and plugin hooks) -// - clientName: Name of the MCP client/server to call -// - toolName: Name of the tool to execute -// - args: Tool arguments as a map -// - appendLog: Function to append log messages during execution -// -// Returns: -// - interface{}: Parsed tool result (JSON object or string) -// - error: Any error that occurred during tool execution -func (m *ToolsManager) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { - // Get available tools per client - availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) - - // Find the client by name - tools, exists := availableToolsPerClient[clientName] - if !exists || len(tools) == 0 { - return nil, fmt.Errorf("client not found for server name: %s", clientName) - } - - // Get client using a tool from this client - // Find the first tool with a valid Function to use for client lookup - var client *schemas.MCPClientState - for _, tool := range tools { - if tool.Function != nil && tool.Function.Name != "" { - client = m.clientManager.GetClientForTool(tool.Function.Name) - if client != nil { - break - } - } - } - - if client == nil { - return nil, fmt.Errorf("client not found for server name: %s", clientName) - } - - // Strip the client name prefix from tool name to get sanitized name - // Then look up the original MCP tool name from the mapping - sanitizedToolName := stripClientPrefix(toolName, clientName) - originalMCPToolName := getOriginalToolName(sanitizedToolName, client) - - // ==================== PLUGIN PIPELINE INTEGRATION ==================== - // Set up parent-child request ID tracking and run plugin hooks - - // Get original executeCode request ID from context (will become parent) - var bifrostCtx *schemas.BifrostContext - var ok bool - if bifrostCtx, ok = ctx.(*schemas.BifrostContext); !ok { - // Fallback: if not a BifrostContext, execute directly without plugins - return m.callMCPToolDirect(ctx, client, originalMCPToolName, clientName, toolName, args, appendLog) - } - - originalRequestID, _ := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string) - - // Generate new request ID for this nested tool call - var newRequestID string - if m.fetchNewRequestIDFunc != nil { - newRequestID = m.fetchNewRequestIDFunc(bifrostCtx) - } else { - // Fallback: generate a simple UUID-like ID - newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName) - } - - // Create new CHILD context with parent-child relationship - // IMPORTANT: We must use NewBifrostContext() to create a proper child context with its own - // userValues map. Using WithValue() would modify the parent context in-place, which would - // cause the parent executeToolCode's request ID to be overwritten with the last nested tool's - // request ID, leading to the parent's response overwriting the last nested tool's log entry. - deadline, hasDeadline := bifrostCtx.Deadline() - if !hasDeadline { - deadline = schemas.NoDeadline - } - nestedCtx := schemas.NewBifrostContext(bifrostCtx, deadline) - nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID) - if originalRequestID != "" { - nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID) - } - - // Marshal arguments to JSON for the tool call - argsJSON, err := sonic.Marshal(args) - if err != nil { - return nil, fmt.Errorf("failed to marshal tool arguments: %v", err) - } - - // Build tool call for MCP request - toolCall := schemas.ChatAssistantMessageToolCall{ - ID: schemas.Ptr(newRequestID), - Function: schemas.ChatAssistantMessageToolCallFunction{ - Name: schemas.Ptr(toolName), - Arguments: string(argsJSON), - }, - } - - // Create BifrostMCPRequest - mcpRequest := &schemas.BifrostMCPRequest{ - RequestType: schemas.MCPRequestTypeChatToolCall, - ChatAssistantMessageToolCall: &toolCall, - } - - // Check if plugin pipeline is available - if m.pluginPipelineProvider == nil { - // Fallback: execute directly without plugins - return m.callMCPToolDirect(ctx, client, originalMCPToolName, clientName, toolName, args, appendLog) - } - - // Get plugin pipeline and run hooks - pipeline := m.pluginPipelineProvider() - if pipeline == nil { - // Fallback: execute directly if pipeline is nil - return m.callMCPToolDirect(ctx, client, originalMCPToolName, clientName, toolName, args, appendLog) - } - defer m.releasePluginPipeline(pipeline) - - // Run PreMCPHooks - preReq, shortCircuit, preCount := pipeline.RunMCPPreHooks(nestedCtx, mcpRequest) - - // Handle short-circuit cases - if shortCircuit != nil { - if shortCircuit.Response != nil { - finalResp, _ := pipeline.RunMCPPostHooks(nestedCtx, shortCircuit.Response, nil, preCount) - if finalResp != nil { - // Try ChatMessage first - if finalResp.ChatMessage != nil { - return extractResultFromChatMessage(finalResp.ChatMessage), nil - } - // Try ResponsesMessage - if finalResp.ResponsesMessage != nil { - result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage) - if err != nil { - return nil, err - } - if result != nil { - return result, nil - } - } - } - return nil, fmt.Errorf("plugin short-circuit returned invalid response") - } - if shortCircuit.Error != nil { - pipeline.RunMCPPostHooks(nestedCtx, nil, shortCircuit.Error, preCount) - if shortCircuit.Error.Error != nil { - return nil, fmt.Errorf("%s", shortCircuit.Error.Error.Message) - } - return nil, fmt.Errorf("plugin short-circuit error") - } - } - - // If pre-hooks modified the request, extract updated tool name and args - if preReq != nil && preReq.ChatAssistantMessageToolCall != nil { - toolCall = *preReq.ChatAssistantMessageToolCall - if toolCall.Function.Arguments != "" { - // Re-parse arguments if they were modified - if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { - logger.Warn(fmt.Sprintf("%s Failed to parse modified tool arguments, using original: %v", CodeModeLogPrefix, err)) - } - } - } - - // ==================== EXECUTE TOOL ==================== - - // Capture start time for latency calculation - startTime := time.Now() - - // Derive tool name from originalMCPToolName (ignore pre-hook modifications to tool name) - // Pre-hooks should not modify which tool gets called, only arguments - toolNameToCall := originalMCPToolName - - // Call the tool via MCP client - callRequest := mcp.CallToolRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodToolsCall), - }, - Params: mcp.CallToolParams{ - Name: toolNameToCall, - Arguments: args, - }, - } - - // Create timeout context - toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) - toolCtx, cancel := context.WithTimeout(nestedCtx, toolExecutionTimeout) - defer cancel() - - toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) - - // Calculate latency - latency := time.Since(startTime).Milliseconds() - - // ==================== PREPARE RESPONSE FOR POST-HOOKS ==================== - - var mcpResp *schemas.BifrostMCPResponse - var bifrostErr *schemas.BifrostError - - if callErr != nil { - logger.Debug(fmt.Sprintf("%s Tool call failed: %s.%s - %v", CodeModeLogPrefix, clientName, toolName, callErr)) - appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, toolName, callErr)) - bifrostErr = &schemas.BifrostError{ - IsBifrostError: false, - Error: &schemas.ErrorField{ - Message: fmt.Sprintf("tool call failed for %s.%s: %v", clientName, toolName, callErr), - }, - } - } else { - // Extract result - rawResult := extractTextFromMCPResponse(toolResponse, toolName) - - // Check if this is an error result (from NewToolResultError) - // Error results start with "Error: " prefix - if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { - errorMsg := after - logger.Debug(fmt.Sprintf("%s Tool returned error result: %s.%s - %s", CodeModeLogPrefix, clientName, toolName, errorMsg)) - appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, toolName, errorMsg)) - bifrostErr = &schemas.BifrostError{ - IsBifrostError: false, - Error: &schemas.ErrorField{ - Message: errorMsg, - }, - } - } else { - // Success case - create response - mcpResp = &schemas.BifrostMCPResponse{ - ChatMessage: createToolResponseMessage(toolCall, rawResult), - ExtraFields: schemas.BifrostMCPResponseExtraFields{ - ClientName: clientName, - ToolName: originalMCPToolName, - Latency: latency, - }, - } - - // Log the result - resultStr := formatResultForLog(rawResult) - // Strip prefix and replace - with _ for code mode display - logToolName := stripClientPrefix(toolName, clientName) - logToolName = strings.ReplaceAll(logToolName, "-", "_") - appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr)) - } - } - - // ==================== RUN POST-HOOKS ==================== - - finalResp, finalErr := pipeline.RunMCPPostHooks(nestedCtx, mcpResp, bifrostErr, preCount) - - // Return result - if finalErr != nil { - if finalErr.Error != nil { - return nil, fmt.Errorf("%s", finalErr.Error.Message) - } - return nil, fmt.Errorf("tool execution failed") - } - - if finalResp == nil { - return nil, fmt.Errorf("plugin post-hooks returned invalid response") - } - - // Extract and parse the final result from the chat message or responses message - if finalResp.ChatMessage != nil { - return extractResultFromChatMessage(finalResp.ChatMessage), nil - } - - // Try ResponsesMessage if ChatMessage is not present - if finalResp.ResponsesMessage != nil { - result, err := extractResultFromResponsesMessage(finalResp.ResponsesMessage) - if err != nil { - return nil, err - } - if result != nil { - return result, nil - } - } - - return nil, fmt.Errorf("plugin post-hooks returned invalid response") -} - -// callMCPToolDirect executes an MCP tool call directly without plugin hooks. -// This is used as a fallback when the plugin pipeline is not available or context is not BifrostContext. -func (m *ToolsManager) callMCPToolDirect(ctx context.Context, client *schemas.MCPClientState, originalToolName, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { - // Call the tool via MCP client - callRequest := mcp.CallToolRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodToolsCall), - }, - Params: mcp.CallToolParams{ - Name: originalToolName, - Arguments: args, - }, - } - - // Create timeout context - toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) - toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) - defer cancel() - - // Strip prefix and replace - with _ for code mode display - logToolName := stripClientPrefix(toolName, clientName) - logToolName = strings.ReplaceAll(logToolName, "-", "_") - - toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) - if callErr != nil { - logger.Debug(fmt.Sprintf("%s Tool call failed: %s.%s - %v", CodeModeLogPrefix, clientName, logToolName, callErr)) - appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, logToolName, callErr)) - return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, logToolName, callErr) - } - - // Extract result - rawResult := extractTextFromMCPResponse(toolResponse, toolName) - - // Check if this is an error result (from NewToolResultError) - // Error results start with "Error: " prefix - if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { - errorMsg := after - logger.Debug(fmt.Sprintf("%s Tool returned error result: %s.%s - %s", CodeModeLogPrefix, clientName, logToolName, errorMsg)) - appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, logToolName, errorMsg)) - return nil, fmt.Errorf("%s", errorMsg) - } - - // Try to parse as JSON, otherwise use as string - var finalResult interface{} - if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { - // Not JSON, use as string - finalResult = rawResult - } - - // Log the result - resultStr := formatResultForLog(finalResult) - appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr)) - - return finalResult, nil -} - -// extractResultFromChatMessage extracts the result from a chat message and parses it as JSON if possible. -func extractResultFromChatMessage(msg *schemas.ChatMessage) interface{} { - if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil { - return nil - } - - rawResult := *msg.Content.ContentStr - - // Try to parse as JSON, otherwise use as string - var finalResult interface{} - if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { - // Not JSON, use as string - return rawResult - } - - return finalResult -} - -// extractResultFromResponsesMessage extracts the result or error from a ResponsesMessage. -// It checks for tool errors first, then extracts output from the ResponsesToolMessage. -// Returns the extracted result/error, and a boolean indicating if it was an error. -func extractResultFromResponsesMessage(msg *schemas.ResponsesMessage) (interface{}, error) { - if msg == nil { - return nil, nil - } - - // Check if this is a tool message - if msg.ResponsesToolMessage != nil { - // Check for tool error first - if msg.ResponsesToolMessage.Error != nil && *msg.ResponsesToolMessage.Error != "" { - return nil, fmt.Errorf("%s", *msg.ResponsesToolMessage.Error) - } - - // Extract output if present - if msg.ResponsesToolMessage.Output != nil { - // Try ResponsesToolCallOutputStr first - if msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != nil { - rawResult := *msg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr - - // Try to parse as JSON, otherwise use as string - var finalResult interface{} - if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { - // Not JSON, use as string - return rawResult, nil - } - return finalResult, nil - } - - // Try ResponsesFunctionToolCallOutputBlocks if OutputStr is not present - if len(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0 { - // For now, extract text from blocks and concatenate - var textParts []string - for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { - if block.Text != nil { - textParts = append(textParts, *block.Text) - } - } - if len(textParts) > 0 { - result := strings.Join(textParts, "\n") - // Try to parse as JSON - var finalResult interface{} - if err := sonic.Unmarshal([]byte(result), &finalResult); err != nil { - return result, nil - } - return finalResult, nil - } - } - } - } - - // If no tool message or output, return nil - return nil, nil -} - -// HELPER FUNCTIONS - -// formatResultForLog formats a result value for logging purposes. -// It attempts to marshal to JSON for structured output, falling back to string representation. -// -// Parameters: -// - result: The result value to format -// -// Returns: -// - string: Formatted string representation of the result -func formatResultForLog(result interface{}) string { - var resultStr string - if result == nil { - resultStr = "null" - } else if resultBytes, err := sonic.Marshal(result); err == nil { - resultStr = string(resultBytes) - } else { - resultStr = fmt.Sprintf("%v", result) - } - return resultStr -} - -// formatConsoleArgs formats console arguments for logging. -// It formats each argument as JSON if possible, otherwise uses string representation. -// -// Parameters: -// - args: Array of console arguments to format -// -// Returns: -// - string: Formatted string with all arguments joined by spaces -func formatConsoleArgs(args []interface{}) string { - parts := make([]string, len(args)) - for i, arg := range args { - if argBytes, err := sonic.MarshalIndent(arg, "", " "); err == nil { - parts[i] = string(argBytes) - } else { - parts[i] = fmt.Sprintf("%v", arg) - } - } - return strings.Join(parts, " ") -} - -// stripImportsAndExports strips import and export statements from code. -// It removes lines that start with import or export keywords and returns -// the cleaned code along with 1-based line numbers of stripped lines. -// -// Parameters: -// - code: Source code string to process -// -// Returns: -// - string: Code with import/export statements removed -// - []int: 1-based line numbers of stripped lines -func stripImportsAndExports(code string) (string, []int) { - lines := strings.Split(code, "\n") - keptLines := []string{} - strippedLineNumbers := []int{} - - importExportRegex := regexp.MustCompile(`^\s*(import|export)\b`) - - for i, line := range lines { - trimmed := strings.TrimSpace(line) - - // Skip empty lines - if trimmed == "" { - keptLines = append(keptLines, line) - continue - } - - // Check if this is an import or export statement - isImportOrExport := importExportRegex.MatchString(line) - - if isImportOrExport { - strippedLineNumbers = append(strippedLineNumbers, i+1) // 1-based line numbers - continue // Skip import/export lines - } - - // Keep comment lines and all other non-import/export lines - keptLines = append(keptLines, line) - } - - return strings.Join(keptLines, "\n"), strippedLineNumbers -} - -// generateTypeScriptErrorHints generates helpful hints for TypeScript compilation errors. -// It analyzes the error message and provides context-specific guidance based on error patterns. -// -// Parameters: -// - errorMessage: The TypeScript compilation error message -// - serverKeys: List of available MCP server keys for context -// -// Returns: -// - []string: Array of helpful hint messages -func generateTypeScriptErrorHints(errorMessage string, serverKeys []string) []string { - hints := []string{} - - // TypeScript-specific error patterns - if strings.Contains(errorMessage, "Cannot find name") || strings.Contains(errorMessage, "is not defined") { - hints = append(hints, "TypeScript compilation error: undefined variable or identifier.") - hints = append(hints, "Check that all variables are properly declared and typed.") - if len(serverKeys) > 0 { - hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) - hints = append(hints, "Use server keys to access MCP tools: .(args)") - } - } else if strings.Contains(errorMessage, "Type") && (strings.Contains(errorMessage, "is not assignable") || strings.Contains(errorMessage, "does not exist")) { - hints = append(hints, "TypeScript type error detected.") - hints = append(hints, "Check that variable types match their usage.") - hints = append(hints, "Ensure function arguments match the expected types.") - } else if strings.Contains(errorMessage, "Expected") { - hints = append(hints, "TypeScript syntax error detected.") - hints = append(hints, "Check for missing parentheses, brackets, or semicolons.") - hints = append(hints, "Ensure all code blocks are properly closed.") - } else if strings.Contains(errorMessage, "async") || strings.Contains(errorMessage, "await") { - hints = append(hints, "async/await syntax should be supported. If you see this error, it may be a TypeScript compilation issue.") - hints = append(hints, "Ensure async functions are properly declared: async function myFunction() { ... }") - hints = append(hints, "Example: const result = await serverName.toolName({...});") - } else { - hints = append(hints, "TypeScript compilation error detected.") - hints = append(hints, "Review the error message above for specific details.") - hints = append(hints, "Ensure your TypeScript code follows valid syntax and type rules.") - if len(serverKeys) > 0 { - hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) - } - } - - return hints -} - -// generateErrorHints generates helpful hints based on runtime error messages. -// It analyzes common runtime error patterns (undefined variables, missing functions, etc.) -// and provides context-specific guidance including available server keys and usage examples. -// -// Parameters: -// - errorMessage: The runtime error message -// - serverKeys: List of available MCP server keys for context -// -// Returns: -// - []string: Array of helpful hint messages -func generateErrorHints(errorMessage string, serverKeys []string) []string { - hints := []string{} - - if strings.Contains(errorMessage, "is not defined") { - re := regexp.MustCompile(`(\w+)\s+is not defined`) - if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { - undefinedVar := match[1] - - // Special handling for common browser/Node.js APIs - if undefinedVar == "fetch" { - hints = append(hints, "The 'fetch' API is not available in this runtime environment.") - hints = append(hints, "Instead of using fetch for HTTP requests, use the available MCP tools.") - if len(serverKeys) > 0 { - hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) - hints = append(hints, fmt.Sprintf("Example: const result = await %s.({ url: 'https://example.com' });", serverKeys[0])) - } - hints = append(hints, "MCP tools handle HTTP requests, file operations, and other external interactions.") - return hints - } else if undefinedVar == "XMLHttpRequest" || undefinedVar == "axios" { - hints = append(hints, fmt.Sprintf("The '%s' API is not available in this runtime environment.", undefinedVar)) - hints = append(hints, "Use MCP tools instead for HTTP requests and external API calls.") - if len(serverKeys) > 0 { - hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) - } - return hints - } else if undefinedVar == "setTimeout" || undefinedVar == "setInterval" { - hints = append(hints, fmt.Sprintf("The '%s' API is not available in this runtime environment.", undefinedVar)) - hints = append(hints, "This is a sandboxed environment focused on MCP tool interactions.") - hints = append(hints, "Use Promise chains with MCP tools instead of timing functions.") - return hints - } else if undefinedVar == "require" || undefinedVar == "import" { - hints = append(hints, "Module imports are not supported in this runtime environment.") - hints = append(hints, "Use the available MCP tools for external functionality.") - if len(serverKeys) > 0 { - hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) - } - return hints - } - - // Generic undefined variable handling - hints = append(hints, fmt.Sprintf("Variable or identifier '%s' is not defined.", undefinedVar)) - if len(serverKeys) > 0 { - hints = append(hints, fmt.Sprintf("Use one of the available server keys as the object name: %s", strings.Join(serverKeys, ", "))) - hints = append(hints, "Then access tools using: .(args)") - hints = append(hints, fmt.Sprintf("For example: const result = await %s.({ ... });", serverKeys[0])) - } - } - } else if strings.Contains(errorMessage, "is not a function") { - re := regexp.MustCompile(`(\w+(?:\.\w+)?)\s+is not a function`) - if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { - notFunction := match[1] - hints = append(hints, fmt.Sprintf("'%s' is not a function.", notFunction)) - hints = append(hints, "Ensure you're using the correct server key and tool name.") - if len(serverKeys) > 0 { - hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) - } - hints = append(hints, "To see available tools for a server, use listToolFiles and readToolFile.") - } - } else if strings.Contains(errorMessage, "Cannot read property") || - strings.Contains(errorMessage, "Cannot read properties") || - strings.Contains(errorMessage, "is not an object") { - hints = append(hints, "You're trying to access a property that doesn't exist or is undefined.") - hints = append(hints, "The tool response structure might be different than expected.") - hints = append(hints, "Check the console logs above to see the actual response structure from the tool.") - hints = append(hints, "Add console.log() statements to inspect the response before accessing properties.") - hints = append(hints, "Example: console.log('searchResults:', searchResults);") - if len(serverKeys) > 0 { - hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) - } - } else { - hints = append(hints, "Check the error message above for details.") - hints = append(hints, "Check the console logs above to see tool responses and debug the issue.") - if len(serverKeys) > 0 { - hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) - } - hints = append(hints, "Ensure you're using the correct syntax: const result = await .({ ...args });") - } - - return hints -} diff --git a/core/mcp/codemodereadfile.go b/core/mcp/codemodereadfile.go deleted file mode 100644 index 2c67141000..0000000000 --- a/core/mcp/codemodereadfile.go +++ /dev/null @@ -1,563 +0,0 @@ -package mcp - -import ( - "context" - "encoding/json" - "fmt" - "strings" - - "github.com/maximhq/bifrost/core/schemas" -) - -// createReadToolFileTool creates the readToolFile tool definition for code mode. -// This tool allows reading virtual .d.ts declaration files for specific MCP servers/tools, -// generating TypeScript type definitions from the server's tool schemas. -// The description is dynamically generated based on the configured CodeModeBindingLevel. -// -// Returns: -// - schemas.ChatTool: The tool definition for reading tool files -func (m *ToolsManager) createReadToolFileTool() schemas.ChatTool { - bindingLevel := m.GetCodeModeBindingLevel() - - var fileNameDescription, toolDescription string - - if bindingLevel == schemas.CodeModeBindingLevelServer { - fileNameDescription = "The virtual filename from listToolFiles in format: servers/.d.ts (e.g., 'calculator.d.ts')" - toolDescription = "Reads a virtual .d.ts declaration file for a specific MCP server, generating TypeScript type definitions " + - "for all tools available on that server. The fileName should be in format servers/.d.ts as listed by listToolFiles. " + - "The function performs case-insensitive matching and removes the .d.ts extension. " + - "Optionally, you can specify startLine and endLine (1-based, inclusive) to read only a portion of the file. " + - "IMPORTANT: Line numbers are 1-based, not 0-based. The first line is line 1, not line 0. " + - "This generates TypeScript type definitions describing all tools in the server and their argument types, " + - "enabling code-mode execution. Each tool can be accessed in code via: await serverName.toolName({ args }). " + - "Always follow this workflow: first use listToolFiles to see available servers, then use readToolFile to understand " + - "all available tool definitions for a server, and finally use executeToolCode to execute your code." - } else { - fileNameDescription = "The virtual filename from listToolFiles in format: servers//.d.ts (e.g., 'calculator/add.d.ts')" - toolDescription = "Reads a virtual .d.ts declaration file for a specific tool, generating TypeScript type definitions " + - "for that individual tool. The fileName should be in format servers//.d.ts as listed by listToolFiles. " + - "The function performs case-insensitive matching and removes the .d.ts extension. " + - "Optionally, you can specify startLine and endLine (1-based, inclusive) to read only a portion of the file. " + - "IMPORTANT: Line numbers are 1-based, not 0-based. The first line is line 1, not line 0. " + - "This generates TypeScript type definitions for a single tool, describing its parameters and usage, " + - "enabling focused code-mode execution. The tool can be accessed in code via: await serverName.toolName({ args }). " + - "Always follow this workflow: first use listToolFiles to see available tools, then use readToolFile to understand " + - "a specific tool's definition, and finally use executeToolCode to execute your code." - } - - readToolFileProps := schemas.OrderedMap{ - "fileName": map[string]interface{}{ - "type": "string", - "description": fileNameDescription, - }, - "startLine": map[string]interface{}{ - "type": "number", - "description": "Optional 1-based starting line number for partial file read (inclusive). Note: Line numbers start at 1, not 0. The first line is line 1.", - }, - "endLine": map[string]interface{}{ - "type": "number", - "description": "Optional 1-based ending line number for partial file read (inclusive)", - }, - } - return schemas.ChatTool{ - Type: schemas.ChatToolTypeFunction, - Function: &schemas.ChatToolFunction{ - Name: ToolTypeReadToolFile, - Description: schemas.Ptr(toolDescription), - Parameters: &schemas.ToolFunctionParameters{ - Type: "object", - Properties: &readToolFileProps, - Required: []string{"fileName"}, - }, - }, - } -} - -// handleReadToolFile handles the readToolFile tool call. -// It reads a virtual .d.ts file for a specific MCP server/tool, generates TypeScript type definitions, -// and optionally returns a portion of the file based on line range parameters. -// Supports both server-level files (e.g., "calculator.d.ts") and tool-level files (e.g., "calculator/add.d.ts"). -// -// Parameters: -// - ctx: Context for accessing client tools -// - toolCall: The tool call request containing fileName and optional startLine/endLine -// -// Returns: -// - *schemas.ChatMessage: A tool response message containing the TypeScript definitions -// - error: Any error that occurred during processing -func (m *ToolsManager) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { - // Parse tool arguments - var arguments map[string]interface{} - if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { - return nil, fmt.Errorf("failed to parse tool arguments: %v", err) - } - - fileName, ok := arguments["fileName"].(string) - if !ok || fileName == "" { - return nil, fmt.Errorf("fileName parameter is required and must be a string") - } - - // Parse the file path to extract server name and optional tool name - serverName, toolName, isToolLevel := parseVFSFilePath(fileName) - - // Get available tools per client - availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) - - // Find matching client - var matchedClientName string - var matchedTools []schemas.ChatTool - matchCount := 0 - - for clientName, tools := range availableToolsPerClient { - client := m.clientManager.GetClientByName(clientName) - if client == nil { - logger.Warn("%s Client %s not found, skipping", MCPLogPrefix, clientName) - continue - } - if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { - continue - } - - clientNameLower := strings.ToLower(clientName) - serverNameLower := strings.ToLower(serverName) - - if clientNameLower == serverNameLower { - matchCount++ - if matchCount > 1 { - // Multiple matches found - errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName) - for name := range availableToolsPerClient { - if strings.ToLower(name) == serverNameLower { - errorMsg += fmt.Sprintf(" - %s\n", name) - } - } - errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity." - return createToolResponseMessage(toolCall, errorMsg), nil - } - - matchedClientName = clientName - - if isToolLevel { - // Tool-level: filter to specific tool - var foundTool *schemas.ChatTool - toolNameLower := strings.ToLower(toolName) - for i, tool := range tools { - if tool.Function != nil { - // Strip client prefix and replace - with _ for comparison - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - if strings.ToLower(unprefixedToolName) == toolNameLower { - foundTool = &tools[i] - break - } - } - } - - if foundTool == nil { - availableTools := make([]string, 0) - for _, tool := range tools { - if tool.Function != nil { - // Strip client prefix and replace - with _ for display - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableTools = append(availableTools, unprefixedToolName) - } - } - errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName) - for _, t := range availableTools { - errorMsg += fmt.Sprintf(" - %s/%s.d.ts\n", clientName, t) - } - return createToolResponseMessage(toolCall, errorMsg), nil - } - - matchedTools = []schemas.ChatTool{*foundTool} - } else { - // Server-level: use all tools - matchedTools = tools - } - } - } - - if matchedClientName == "" { - // Build helpful error message with available files - bindingLevel := m.GetCodeModeBindingLevel() - var availableFiles []string - - for name := range availableToolsPerClient { - if bindingLevel == schemas.CodeModeBindingLevelServer { - availableFiles = append(availableFiles, fmt.Sprintf("%s.d.ts", name)) - } else { - client := m.clientManager.GetClientByName(name) - if client != nil && client.ExecutionConfig.IsCodeModeClient { - if tools, ok := availableToolsPerClient[name]; ok { - for _, tool := range tools { - if tool.Function != nil { - // Strip client prefix and replace - with _ for display - unprefixedToolName := stripClientPrefix(tool.Function.Name, name) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableFiles = append(availableFiles, fmt.Sprintf("%s/%s.d.ts", name, unprefixedToolName)) - } - } - } - } - } - } - - errorMsg := fmt.Sprintf("No server found matching '%s'. Available virtual files are:\n", serverName) - for _, f := range availableFiles { - errorMsg += fmt.Sprintf(" - %s\n", f) - } - return createToolResponseMessage(toolCall, errorMsg), nil - } - - // Generate TypeScript definitions - fileContent := generateTypeDefinitions(matchedClientName, matchedTools, isToolLevel) - lines := strings.Split(fileContent, "\n") - totalLines := len(lines) - - // Handle line slicing if provided - var startLine, endLine *int - if sl, ok := arguments["startLine"].(float64); ok { - slInt := int(sl) - startLine = &slInt - } - if el, ok := arguments["endLine"].(float64); ok { - elInt := int(el) - endLine = &elInt - } - - if startLine != nil || endLine != nil { - start := 1 - if startLine != nil { - start = *startLine - } - end := totalLines - if endLine != nil { - end = *endLine - } - - // Validate line numbers - if start < 1 || start > totalLines { - errorMsg := fmt.Sprintf("Invalid startLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%v, totalLines=%d", - start, totalLines, start, endLine, totalLines) - return createToolResponseMessage(toolCall, errorMsg), nil - } - if end < 1 || end > totalLines { - errorMsg := fmt.Sprintf("Invalid endLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%d, totalLines=%d", - end, totalLines, start, end, totalLines) - return createToolResponseMessage(toolCall, errorMsg), nil - } - if start > end { - errorMsg := fmt.Sprintf("Invalid line range: startLine (%d) must be less than or equal to endLine (%d). Total lines in file: %d", - start, end, totalLines) - return createToolResponseMessage(toolCall, errorMsg), nil - } - - // Slice lines (convert to 0-based indexing) - selectedLines := lines[start-1 : end] - fileContent = strings.Join(selectedLines, "\n") - } - - return createToolResponseMessage(toolCall, fileContent), nil -} - -// HELPER FUNCTIONS - -// parseVFSFilePath parses a VFS file path and extracts the server name and optional tool name. -// For server-level paths (e.g., "calculator.d.ts"), returns (serverName="calculator", toolName="", isToolLevel=false) -// For tool-level paths (e.g., "calculator/add.d.ts"), returns (serverName="calculator", toolName="add", isToolLevel=true) -// -// Parameters: -// - fileName: The virtual file path from listToolFiles -// -// Returns: -// - serverName: The name of the MCP server -// - toolName: The name of the tool (empty for server-level) -// - isToolLevel: Whether this is a tool-level path -func parseVFSFilePath(fileName string) (serverName, toolName string, isToolLevel bool) { - // Remove .d.ts extension - basePath := strings.TrimSuffix(fileName, ".d.ts") - - // Remove "servers/" prefix if present - basePath = strings.TrimPrefix(basePath, "servers/") - - // Defensive validation: reject paths with path traversal attempts - if strings.Contains(basePath, "..") { - // Return empty to indicate invalid path - return "", "", false - } - - // Check for path separator - parts := strings.Split(basePath, "/") - if len(parts) == 2 { - // Tool-level: "serverName/toolName" - // Validate that tool name doesn't contain additional path separators or traversal - if parts[1] == "" || strings.Contains(parts[1], "/") || strings.Contains(parts[1], "..") { - // Invalid tool name, treat as server-level - return parts[0], "", false - } - return parts[0], parts[1], true - } - // Server-level: "serverName" - // Validate server name doesn't contain path separators or traversal - if strings.Contains(basePath, "/") || strings.Contains(basePath, "..") { - // Invalid path - return "", "", false - } - return basePath, "", false -} - -// generateTypeDefinitions generates TypeScript type definitions from ChatTool schemas -// with comprehensive comments to help LLMs understand how to use the tools. -// It creates interfaces for tool inputs and responses, along with function declarations. -// -// Parameters: -// - clientName: Name of the MCP client/server -// - tools: List of chat tools to generate definitions for -// - isToolLevel: Whether this is a tool-level definition (single tool) or server-level (all tools) -// -// Returns: -// - string: Complete TypeScript declaration file content -func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isToolLevel bool) string { - var sb strings.Builder - - // Write comprehensive header comment - sb.WriteString("// ============================================================================\n") - if isToolLevel && len(tools) == 1 && tools[0].Function != nil { - // Tool-level: show individual tool name - sb.WriteString(fmt.Sprintf("// Type definitions for %s.%s tool\n", clientName, tools[0].Function.Name)) - } else { - // Server-level: show all tools in server - sb.WriteString(fmt.Sprintf("// Type definitions for %s MCP server\n", clientName)) - } - sb.WriteString("// ============================================================================\n") - sb.WriteString("//\n") - if isToolLevel && len(tools) == 1 { - sb.WriteString("// This file contains TypeScript type definitions for a specific tool on this MCP server.\n") - } else { - sb.WriteString("// This file contains TypeScript type definitions for all tools available on this MCP server.\n") - } - sb.WriteString("// These definitions enable code-mode execution as described in the MCP code execution pattern.\n") - sb.WriteString("//\n") - sb.WriteString("// USAGE INSTRUCTIONS:\n") - sb.WriteString("// 1. Each tool has an input interface (e.g., ToolNameInput) that defines the required parameters\n") - sb.WriteString("// 2. Each tool has a function declaration showing how to call it\n") - sb.WriteString("// 3. To use these tools in executeToolCode, you would call them like:\n") - sb.WriteString("// const result = await .({ ...args });\n") - sb.WriteString("//\n") - sb.WriteString("// NOTE: The server name used in executeToolCode is the same as the display name shown here.\n") - sb.WriteString("//\n") - sb.WriteString("// ⚠️ CRITICAL - HANDLING RESPONSES:\n") - sb.WriteString("// Tool responses have dynamic structures that vary by tool. To avoid runtime errors:\n") - sb.WriteString("// 1. ALWAYS use console.log() to inspect the response structure before accessing properties\n") - sb.WriteString("// 2. NEVER assume a property exists - use optional chaining (result?.property) or explicit checks\n") - sb.WriteString("// 3. Provide fallback values for arrays/objects (result?.items || [])\n") - sb.WriteString("//\n") - sb.WriteString("// Common error: \"Cannot read property 'map' of undefined or null\"\n") - sb.WriteString("// Fix: Add console.log() to see actual structure, then use safe access patterns\n") - sb.WriteString("// ============================================================================\n\n") - - // Generate interfaces and function declarations for each tool - for _, tool := range tools { - if tool.Function == nil || tool.Function.Name == "" { - continue - } - - originalToolName := tool.Function.Name - // Strip client prefix and replace - with _ for code mode compatibility - unprefixedToolName := stripClientPrefix(originalToolName, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - // Parse tool name for property name compatibility (used in virtual TypeScript files) - toolName := parseToolName(unprefixedToolName) - description := "" - if tool.Function.Description != nil { - description = *tool.Function.Description - } - - // Generate input interface with detailed comments - inputInterfaceName := toPascalCase(toolName) + "Input" - sb.WriteString("// ----------------------------------------------------------------------------\n") - sb.WriteString(fmt.Sprintf("// Tool: %s\n", toolName)) - sb.WriteString("// ----------------------------------------------------------------------------\n") - if description != "" { - sb.WriteString(fmt.Sprintf("// Description: %s\n", description)) - } - sb.WriteString(fmt.Sprintf("// Input interface for %s\n", toolName)) - sb.WriteString(fmt.Sprintf("// This interface defines all parameters that can be passed to the %s tool.\n", toolName)) - sb.WriteString(fmt.Sprintf("interface %s {\n", inputInterfaceName)) - - if tool.Function.Parameters != nil && tool.Function.Parameters.Properties != nil { - props := *tool.Function.Parameters.Properties - required := make(map[string]bool) - if tool.Function.Parameters.Required != nil { - for _, req := range tool.Function.Parameters.Required { - required[req] = true - } - } - - // Sort properties for consistent output - propNames := make([]string, 0, len(props)) - for name := range props { - propNames = append(propNames, name) - } - // Simple alphabetical sort - for i := 0; i < len(propNames)-1; i++ { - for j := i + 1; j < len(propNames); j++ { - if propNames[i] > propNames[j] { - propNames[i], propNames[j] = propNames[j], propNames[i] - } - } - } - - for _, propName := range propNames { - prop := props[propName] - propMap, ok := prop.(map[string]interface{}) - if !ok { - continue - } - - tsType := jsonSchemaToTypeScript(propMap) - optional := "" - if !required[propName] { - optional = "?" - } - - propDesc := "" - if desc, ok := propMap["description"].(string); ok && desc != "" { - propDesc = fmt.Sprintf(" // %s", desc) - } else { - propDesc = fmt.Sprintf(" // %s parameter", propName) - } - - requiredNote := "" - if required[propName] { - requiredNote = " (required)" - } else { - requiredNote = " (optional)" - } - - sb.WriteString(fmt.Sprintf(" %s%s: %s;%s%s\n", propName, optional, tsType, propDesc, requiredNote)) - } - } - - sb.WriteString("}\n\n") - - // Generate response interface with helpful comments - responseInterfaceName := toPascalCase(toolName) + "Response" - sb.WriteString(fmt.Sprintf("// Response interface for %s\n", toolName)) - sb.WriteString("// The actual response structure depends on the tool implementation.\n") - sb.WriteString("// This is a placeholder interface - the actual response may contain different fields.\n") - sb.WriteString("//\n") - sb.WriteString("// IMPORTANT - HANDLING TOOL RESPONSES:\n") - sb.WriteString("// 1. ALWAYS check if the response or its properties exist before accessing them\n") - sb.WriteString("// 2. ALWAYS use console.log() to inspect the actual response structure first\n") - sb.WriteString("// 3. NEVER assume a property exists - use optional chaining or explicit checks\n") - sb.WriteString("//\n") - sb.WriteString("// Common error: \"Cannot read property 'X' of undefined or null\"\n") - sb.WriteString("// This means you're trying to access a property that doesn't exist.\n") - sb.WriteString("//\n") - sb.WriteString("// BEST PRACTICES:\n") - sb.WriteString("// ❌ BAD: const items = result.data.map(...) // Fails if result.data is undefined\n") - sb.WriteString("// ✅ GOOD: console.log('result:', result); // Inspect the structure first\n") - sb.WriteString("// const items = result?.data?.map(...) || []; // Safe access with fallback\n") - sb.WriteString("//\n") - sb.WriteString("// ❌ BAD: return response.items.filter(...) // Fails if response.items is undefined\n") - sb.WriteString("// ✅ GOOD: if (!response || !response.items) { return []; }\n") - sb.WriteString("// return response.items.filter(...); // Explicit check before use\n") - sb.WriteString("//\n") - sb.WriteString(fmt.Sprintf("interface %s {\n", responseInterfaceName)) - sb.WriteString(" // Response structure depends on the tool implementation\n") - sb.WriteString(" // Common fields may include: result, error, data, items, etc.\n") - sb.WriteString(" // ALWAYS inspect the actual response with console.log() before accessing properties\n") - sb.WriteString(" [key: string]: any;\n") - sb.WriteString("}\n\n") - - // Generate function declaration with usage example - sb.WriteString(fmt.Sprintf("// Function declaration for %s\n", toolName)) - if description != "" { - sb.WriteString(fmt.Sprintf("// %s\n", description)) - } - sb.WriteString("//\n") - sb.WriteString("// Usage example in executeToolCode:\n") - sb.WriteString(fmt.Sprintf("// const result = await .%s({ ... });\n", toolName)) - sb.WriteString("// console.log('result:', result); // ALWAYS inspect the response first!\n") - sb.WriteString(fmt.Sprintf("// const data = result?.someProperty || defaultValue; // Safe access\n")) - sb.WriteString("// // Replace with the actual server name/ID\n") - sb.WriteString(fmt.Sprintf("// // Replace { ... } with the appropriate %sInput object\n", inputInterfaceName)) - sb.WriteString(fmt.Sprintf("export async function %s(input: %s): Promise<%s>;\n\n", toolName, inputInterfaceName, responseInterfaceName)) - } - - return sb.String() -} - -// jsonSchemaToTypeScript converts a JSON Schema type definition to a TypeScript type string. -// It handles basic types, arrays, enums, and defaults to "any" for unknown types. -// -// Parameters: -// - prop: JSON Schema property definition map -// -// Returns: -// - string: TypeScript type string representation -func jsonSchemaToTypeScript(prop map[string]interface{}) string { - // Check for explicit type - if typeVal, ok := prop["type"].(string); ok { - switch typeVal { - case "string": - return "string" - case "number", "integer": - return "number" - case "boolean": - return "boolean" - case "array": - itemsType := "any" - if items, ok := prop["items"].(map[string]interface{}); ok { - itemsType = jsonSchemaToTypeScript(items) - } - return fmt.Sprintf("%s[]", itemsType) - case "object": - return "object" - case "null": - return "null" - } - } - - // Check for enum - if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 { - enumStrs := make([]string, 0, len(enum)) - for _, e := range enum { - enumStrs = append(enumStrs, fmt.Sprintf("%q", e)) - } - return strings.Join(enumStrs, " | ") - } - - // Default to any - return "any" -} - -// toPascalCase converts a string to PascalCase format. -// It splits on underscores, hyphens, and spaces, then capitalizes the first letter -// of each word and lowercases the rest. -// -// Parameters: -// - s: Input string to convert -// -// Returns: -// - string: PascalCase formatted string -func toPascalCase(s string) string { - if s == "" { - return s - } - parts := strings.FieldsFunc(s, func(r rune) bool { - return r == '_' || r == '-' || r == ' ' - }) - result := "" - for _, part := range parts { - if len(part) > 0 { - result += strings.ToUpper(part[:1]) + strings.ToLower(part[1:]) - } - } - if result == "" { - return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) - } - return result -} diff --git a/core/mcp/health_monitor.go b/core/mcp/healthmonitor.go similarity index 99% rename from core/mcp/health_monitor.go rename to core/mcp/healthmonitor.go index c8215f7520..6d4eda73d5 100644 --- a/core/mcp/health_monitor.go +++ b/core/mcp/healthmonitor.go @@ -277,4 +277,4 @@ func (hmm *HealthMonitorManager) StopAll() { monitor.Stop() } hmm.monitors = make(map[string]*ClientHealthMonitor) -} +} \ No newline at end of file diff --git a/core/mcp/interface.go b/core/mcp/interface.go new file mode 100644 index 0000000000..b74e559b9c --- /dev/null +++ b/core/mcp/interface.go @@ -0,0 +1,73 @@ +//go:build !tinygo && !wasm + +package mcp + +import ( + "context" + + "github.com/maximhq/bifrost/core/schemas" +) + +// MCPManagerInterface defines the interface for MCP management functionality. +// This interface allows different implementations (OSS and Enterprise) to be used +// interchangeably in the Bifrost core. +type MCPManagerInterface interface { + // Tool Operations + // AddToolsToRequest parses available MCP tools and adds them to the request + AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest + + // GetAvailableTools returns all available MCP tools for the given context + GetAvailableTools(ctx context.Context) []schemas.ChatTool + + // ExecuteToolCall executes a single tool call and returns the result + ExecuteToolCall(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) + + // UpdateToolManagerConfig updates the configuration for the tool manager + UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) + + // Agent Mode Operations + // CheckAndExecuteAgentForChatRequest handles agent mode for Chat Completions API + CheckAndExecuteAgentForChatRequest( + ctx *schemas.BifrostContext, + req *schemas.BifrostChatRequest, + response *schemas.BifrostChatResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), + executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), + ) (*schemas.BifrostChatResponse, *schemas.BifrostError) + + // CheckAndExecuteAgentForResponsesRequest handles agent mode for Responses API + CheckAndExecuteAgentForResponsesRequest( + ctx *schemas.BifrostContext, + req *schemas.BifrostResponsesRequest, + response *schemas.BifrostResponsesResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), + executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), + ) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) + + // Client Management + // GetClients returns all MCP clients + GetClients() []schemas.MCPClientState + + // AddClient adds a new MCP client with the given configuration + AddClient(config *schemas.MCPClientConfig) error + + // RemoveClient removes an MCP client by ID + RemoveClient(id string) error + + // EditClient updates an existing MCP client configuration + EditClient(id string, updatedConfig *schemas.MCPClientConfig) error + + // ReconnectClient reconnects an MCP client by ID + ReconnectClient(id string) error + + // Tool Registration + // RegisterTool registers a local tool with the MCP server + RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error + + // Lifecycle + // Cleanup performs cleanup of all MCP resources + Cleanup() error +} + +// Ensure MCPManager implements MCPManagerInterface +var _ MCPManagerInterface = (*MCPManager)(nil) diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go index 657f272b1e..45d833efbd 100644 --- a/core/mcp/mcp.go +++ b/core/mcp/mcp.go @@ -46,6 +46,7 @@ type MCPManager struct { mu sync.RWMutex // Read-write mutex for thread-safe operations serverRunning bool // Track whether local MCP server is running healthMonitorManager *HealthMonitorManager // Manager for client health monitors + toolSyncManager *ToolSyncManager // Manager for periodic tool synchronization } // MCPToolFunction is a generic function type for handling tool calls with typed arguments. @@ -59,13 +60,17 @@ type MCPToolFunction[T any] func(args T) (string, error) // NewMCPManager creates and initializes a new MCP manager instance. // // Parameters: +// - ctx: Context for the MCP manager // - config: MCP configuration including server port and client configs +// - oauth2Provider: OAuth2 provider for authentication // - logger: Logger instance for structured logging (uses default if nil) +// - codeMode: Optional CodeMode implementation for code execution (e.g., Starlark). +// Pass nil if code mode is not needed. The CodeMode's dependencies will be +// injected automatically via SetDependencies after the manager is created. // // Returns: // - *MCPManager: Initialized manager instance -// - error: Any initialization error -func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider schemas.OAuth2Provider, logger schemas.Logger) *MCPManager { +func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider schemas.OAuth2Provider, logger schemas.Logger, codeMode CodeMode) *MCPManager { SetLogger(logger) // Set default values if config.ToolManagerConfig == nil { @@ -79,6 +84,7 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider ctx: ctx, clientMap: make(map[string]*schemas.MCPClientState), healthMonitorManager: NewHealthMonitorManager(), + toolSyncManager: NewToolSyncManager(config.ToolSyncInterval), oauth2Provider: oauth2Provider, } // Convert plugin pipeline provider functions to the interface expected by ToolsManager @@ -101,13 +107,20 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc, pluginPipelineProvider, releasePluginPipeline) + // Set up CodeMode if provided - inject dependencies after manager is created + if codeMode != nil { + deps := manager.toolsManager.GetCodeModeDependencies() + codeMode.SetDependencies(deps) + manager.toolsManager.SetCodeMode(codeMode) + } + // Process client configs: create client map entries and establish connections if len(config.ClientConfigs) > 0 { // Add clients in parallel wg := sync.WaitGroup{} wg.Add(len(config.ClientConfigs)) for _, clientConfig := range config.ClientConfigs { - go func(clientConfig schemas.MCPClientConfig) { + go func(clientConfig *schemas.MCPClientConfig) { defer wg.Done() if err := manager.AddClient(clientConfig); err != nil { logger.Warn("%s Failed to add MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err) @@ -273,6 +286,9 @@ func (m *MCPManager) Cleanup() error { // Stop all health monitors first m.healthMonitorManager.StopAll() + // Stop all tool syncers + m.toolSyncManager.StopAll() + m.mu.Lock() defer m.mu.Unlock() diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go index c2e625d0bd..e47bb1d20a 100644 --- a/core/mcp/toolmanager.go +++ b/core/mcp/toolmanager.go @@ -1,3 +1,5 @@ +//go:build !tinygo && !wasm + package mcp import ( @@ -5,7 +7,6 @@ import ( "encoding/json" "fmt" "strings" - "sync" "sync/atomic" "time" @@ -13,18 +14,28 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +// ClientManager interface for accessing MCP clients and tools type ClientManager interface { GetClientByName(clientName string) *schemas.MCPClientState GetClientForTool(toolName string) *schemas.MCPClientState GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool } +// PluginPipeline represents the plugin execution pipeline interface +// This allows ToolsManager to run plugin hooks without direct dependency on Bifrost +type PluginPipeline interface { + RunMCPPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, int) + RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostMCPResponse, *schemas.BifrostError) +} + +// ToolsManager manages MCP tool execution and agent mode. type ToolsManager struct { toolExecutionTimeout atomic.Value maxAgentDepth atomic.Int32 - codeModeBindingLevel atomic.Value // Stores CodeModeBindingLevel clientManager ClientManager - logMu sync.Mutex // Protects concurrent access to logs slice in codemode execution + + // CodeMode implementation for code execution (Starlark by default) + codeMode CodeMode // Function to fetch a new request ID for each tool call result message in agent mode, // this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user. @@ -40,19 +51,6 @@ type ToolsManager struct { releasePluginPipeline func(pipeline PluginPipeline) } -// PluginPipeline represents the plugin execution pipeline interface -// This allows ToolsManager to run plugin hooks without direct dependency on Bifrost -type PluginPipeline interface { - RunMCPPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, int) - RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostMCPResponse, *schemas.BifrostError) -} - -const ( - ToolTypeListToolFiles string = "listToolFiles" - ToolTypeReadToolFile string = "readToolFile" - ToolTypeExecuteToolCode string = "executeToolCode" -) - // NewToolsManager creates and initializes a new tools manager instance. // It validates the configuration, sets defaults if needed, and initializes atomic values // for thread-safe configuration updates. @@ -72,6 +70,37 @@ func NewToolsManager( fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, pluginPipelineProvider func() PluginPipeline, releasePluginPipeline func(pipeline PluginPipeline), +) *ToolsManager { + return NewToolsManagerWithCodeMode( + config, + clientManager, + fetchNewRequestIDFunc, + pluginPipelineProvider, + releasePluginPipeline, + nil, // Use default code mode (will be set later via SetCodeMode) + ) +} + +// NewToolsManagerWithCodeMode creates a new tools manager with a custom CodeMode implementation. +// This allows using alternative code execution environments (e.g., Lua, JavaScript, WASM). +// +// Parameters: +// - config: Tool manager configuration with execution timeout and max agent depth +// - clientManager: Client manager interface for accessing MCP clients and tools +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for agent mode +// - pluginPipelineProvider: Optional function to get a plugin pipeline for running MCP hooks +// - releasePluginPipeline: Optional function to release a plugin pipeline back to the pool +// - codeMode: Optional CodeMode implementation (if nil, must be set later via SetCodeMode) +// +// Returns: +// - *ToolsManager: Initialized tools manager instance +func NewToolsManagerWithCodeMode( + config *schemas.MCPToolManagerConfig, + clientManager ClientManager, + fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, + pluginPipelineProvider func() PluginPipeline, + releasePluginPipeline func(pipeline PluginPipeline), + codeMode CodeMode, ) *ToolsManager { if config == nil { config = &schemas.MCPToolManagerConfig{ @@ -90,21 +119,45 @@ func NewToolsManager( if config.CodeModeBindingLevel == "" { config.CodeModeBindingLevel = schemas.CodeModeBindingLevelServer } + manager := &ToolsManager{ clientManager: clientManager, fetchNewRequestIDFunc: fetchNewRequestIDFunc, pluginPipelineProvider: pluginPipelineProvider, releasePluginPipeline: releasePluginPipeline, + codeMode: codeMode, } + // Initialize atomic values manager.toolExecutionTimeout.Store(config.ToolExecutionTimeout) manager.maxAgentDepth.Store(int32(config.MaxAgentDepth)) - manager.codeModeBindingLevel.Store(config.CodeModeBindingLevel) - logger.Info(fmt.Sprintf("%s tool manager initialized with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)) + logger.Info("%s tool manager initialized with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel) return manager } +// SetCodeMode sets the CodeMode implementation for code execution. +// This should be called after construction if no CodeMode was provided to the constructor. +func (m *ToolsManager) SetCodeMode(codeMode CodeMode) { + m.codeMode = codeMode +} + +// GetCodeMode returns the current CodeMode implementation. +func (m *ToolsManager) GetCodeMode() CodeMode { + return m.codeMode +} + +// GetCodeModeDependencies returns the dependencies needed by CodeMode implementations. +// This is useful when constructing a CodeMode implementation externally. +func (m *ToolsManager) GetCodeModeDependencies() *CodeModeDependencies { + return &CodeModeDependencies{ + ClientManager: m.clientManager, + PluginPipelineProvider: m.pluginPipelineProvider, + ReleasePluginPipeline: m.releasePluginPipeline, + FetchNewRequestIDFunc: m.fetchNewRequestIDFunc, + } +} + // GetAvailableTools returns the available tools for the given context. func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) @@ -135,12 +188,9 @@ func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool } } - if includeCodeModeTools { - codeModeTools := []schemas.ChatTool{ - m.createListToolFilesTool(), - m.createReadToolFileTool(), - m.createExecuteToolCodeTool(), - } + // Add code mode tools if any client is configured for code mode and we have a CodeMode implementation + if includeCodeModeTools && m.codeMode != nil { + codeModeTools := m.codeMode.GetTools() // Add code mode tools, checking for duplicates for _, tool := range codeModeTools { if tool.Function != nil && tool.Function.Name != "" { @@ -420,89 +470,83 @@ func (m *ToolsManager) ExecuteTool(ctx *schemas.BifrostContext, request *schemas func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, string, string, error) { toolName := *toolCall.Function.Name - // Handle code mode tools (they don't have a client, and tool name is the full name) - switch toolName { - case ToolTypeListToolFiles: - msg, err := m.handleListToolFiles(ctx, *toolCall) - return msg, "", toolName, err - case ToolTypeReadToolFile: - msg, err := m.handleReadToolFile(ctx, *toolCall) - return msg, "", toolName, err - case ToolTypeExecuteToolCode: - msg, err := m.handleExecuteToolCode(ctx, *toolCall) + // Check if this is a code mode tool and delegate to CodeMode implementation + if m.codeMode != nil && m.codeMode.IsCodeModeTool(toolName) { + msg, err := m.codeMode.ExecuteTool(ctx, *toolCall) return msg, "", toolName, err - default: - // Check if the user has permission to execute the tool call - availableTools := m.clientManager.GetToolPerClient(ctx) - toolFound := false - for _, tools := range availableTools { - for _, mcpTool := range tools { - if mcpTool.Function != nil && mcpTool.Function.Name == toolName { - toolFound = true - break - } - } - if toolFound { + } + + // Handle regular MCP tools + // Check if the user has permission to execute the tool call + availableTools := m.clientManager.GetToolPerClient(ctx) + toolFound := false + for _, tools := range availableTools { + for _, mcpTool := range tools { + if mcpTool.Function != nil && mcpTool.Function.Name == toolName { + toolFound = true break } } - - if !toolFound { - return nil, "", "", fmt.Errorf("tool '%s' is not available or not permitted", toolName) + if toolFound { + break } + } - client := m.clientManager.GetClientForTool(toolName) - if client == nil { - return nil, "", "", fmt.Errorf("client not found for tool %s", toolName) - } + if !toolFound { + return nil, "", "", fmt.Errorf("tool '%s' is not available or not permitted", toolName) + } - // Parse tool arguments - var arguments map[string]interface{} - if strings.TrimSpace(toolCall.Function.Arguments) == "" { - arguments = map[string]interface{}{} - } else { - if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { - return nil, "", "", fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err) - } - } + client := m.clientManager.GetClientForTool(toolName) + if client == nil { + return nil, "", "", fmt.Errorf("client not found for tool %s", toolName) + } - // Strip the client name prefix from tool name before calling MCP server - // The MCP server expects the original tool name (with hyphens), not the sanitized version - sanitizedToolName := stripClientPrefix(toolName, client.ExecutionConfig.Name) - originalMCPToolName := getOriginalToolName(sanitizedToolName, client) - - // Call the tool via MCP client -> MCP server - callRequest := mcp.CallToolRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodToolsCall), - }, - Params: mcp.CallToolParams{ - Name: originalMCPToolName, - Arguments: arguments, - }, + // Parse tool arguments + var arguments map[string]interface{} + if strings.TrimSpace(toolCall.Function.Arguments) == "" { + arguments = map[string]interface{}{} + } else { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, "", "", fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err) } + } - // Create timeout context for tool execution - toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) - toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) - defer cancel() + // Strip the client name prefix from tool name before calling MCP server + // The MCP server expects the original tool name (with hyphens), not the sanitized version + sanitizedToolName := stripClientPrefix(toolName, client.ExecutionConfig.Name) + originalMCPToolName := getOriginalToolName(sanitizedToolName, client) + + // Call the tool via MCP client -> MCP server + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: originalMCPToolName, + Arguments: arguments, + }, + } - toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) - if callErr != nil { - // Check if it was a timeout error - if toolCtx.Err() == context.DeadlineExceeded { - return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) - } - logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr) - return nil, "", "", fmt.Errorf("MCP tool call failed: %v", callErr) + // Create timeout context for tool execution + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + if callErr != nil { + // Check if it was a timeout error + if toolCtx.Err() == context.DeadlineExceeded { + return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) } + logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr) + return nil, "", "", fmt.Errorf("MCP tool call failed: %v", callErr) + } - // Extract text from MCP response - responseText := extractTextFromMCPResponse(toolResponse, toolName) + // Extract text from MCP response + responseText := extractTextFromMCPResponse(toolResponse, toolName) - // Create tool response message - return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil - } + // Create tool response message + return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil } // ExecuteAgentForChatRequest executes agent mode for a chat request, handling @@ -591,19 +635,23 @@ func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) { if config.MaxAgentDepth > 0 { m.maxAgentDepth.Store(int32(config.MaxAgentDepth)) } - if config.CodeModeBindingLevel != "" { - m.codeModeBindingLevel.Store(config.CodeModeBindingLevel) + + // Update CodeMode configuration if present + if m.codeMode != nil && config.CodeModeBindingLevel != "" { + m.codeMode.UpdateConfig(&CodeModeConfig{ + BindingLevel: config.CodeModeBindingLevel, + ToolExecutionTimeout: config.ToolExecutionTimeout, + }) } - logger.Info(fmt.Sprintf("%s tool manager configuration updated with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)) + logger.Info("%s tool manager configuration updated with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel) } // GetCodeModeBindingLevel returns the current code mode binding level. // This method is safe to call concurrently from multiple goroutines. func (m *ToolsManager) GetCodeModeBindingLevel() schemas.CodeModeBindingLevel { - val := m.codeModeBindingLevel.Load() - if val == nil { - return schemas.CodeModeBindingLevelServer + if m.codeMode != nil { + return m.codeMode.GetBindingLevel() } - return val.(schemas.CodeModeBindingLevel) + return schemas.CodeModeBindingLevelServer } diff --git a/core/mcp/toolsync.go b/core/mcp/toolsync.go new file mode 100644 index 0000000000..2f1c0815e5 --- /dev/null +++ b/core/mcp/toolsync.go @@ -0,0 +1,242 @@ +package mcp + +import ( + "context" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +const ( + // Tool sync configuration + DefaultToolSyncInterval = 10 * time.Minute // Default interval for syncing tools from MCP servers + ToolSyncTimeout = 10 * time.Second // Timeout for each sync operation +) + +// ClientToolSyncer periodically syncs tools from an MCP server +type ClientToolSyncer struct { + manager *MCPManager + clientID string + clientName string + interval time.Duration + timeout time.Duration + mu sync.Mutex + ticker *time.Ticker + ctx context.Context + cancel context.CancelFunc + isSyncing bool +} + +// NewClientToolSyncer creates a new tool syncer for an MCP client +func NewClientToolSyncer( + manager *MCPManager, + clientID string, + clientName string, + interval time.Duration, +) *ClientToolSyncer { + if interval <= 0 { + interval = DefaultToolSyncInterval + } + + return &ClientToolSyncer{ + manager: manager, + clientID: clientID, + clientName: clientName, + interval: interval, + timeout: ToolSyncTimeout, + isSyncing: false, + } +} + +// Start begins syncing tools in a background goroutine +func (cts *ClientToolSyncer) Start() { + cts.mu.Lock() + defer cts.mu.Unlock() + + if cts.isSyncing { + return // Already syncing + } + + cts.isSyncing = true + cts.ctx, cts.cancel = context.WithCancel(context.Background()) + cts.ticker = time.NewTicker(cts.interval) + + go cts.syncLoop() + logger.Debug("%s Tool syncer started for client %s (interval: %v)", MCPLogPrefix, cts.clientID, cts.interval) +} + +// Stop stops syncing tools +func (cts *ClientToolSyncer) Stop() { + cts.mu.Lock() + defer cts.mu.Unlock() + + if !cts.isSyncing { + return // Not syncing + } + + cts.isSyncing = false + if cts.ticker != nil { + cts.ticker.Stop() + } + if cts.cancel != nil { + cts.cancel() + } + logger.Debug("%s Tool syncer stopped for client %s", MCPLogPrefix, cts.clientID) +} + +// syncLoop runs the tool sync loop +func (cts *ClientToolSyncer) syncLoop() { + for { + select { + case <-cts.ctx.Done(): + return + case <-cts.ticker.C: + cts.performSync() + } + } +} + +// performSync performs a tool sync for the client +func (cts *ClientToolSyncer) performSync() { + // Get the client connection (read lock) + cts.manager.mu.RLock() + clientState, exists := cts.manager.clientMap[cts.clientID] + if !exists { + cts.manager.mu.RUnlock() + cts.Stop() + return + } + + if clientState.Conn == nil { + cts.manager.mu.RUnlock() + logger.Debug("%s Skipping tool sync for %s: client not connected", MCPLogPrefix, cts.clientID) + return + } + + // Get the connection reference while holding the lock + conn := clientState.Conn + clientName := clientState.ExecutionConfig.Name + cts.manager.mu.RUnlock() + + // Perform tool sync with timeout (outside of lock) + ctx, cancel := context.WithTimeout(context.Background(), cts.timeout) + defer cancel() + + newTools, newMapping, err := retrieveExternalTools(ctx, conn, clientName) + if err != nil { + // On failure, keep existing tools intact + logger.Warn("%s Tool sync failed for %s, keeping existing tools: %v", MCPLogPrefix, cts.clientID, err) + return + } + + // Update tools atomically (write lock) + cts.manager.mu.Lock() + clientState, exists = cts.manager.clientMap[cts.clientID] + if !exists { + cts.manager.mu.Unlock() + return + } + + // Check if tools have changed + oldToolCount := len(clientState.ToolMap) + newToolCount := len(newTools) + + clientState.ToolMap = newTools + clientState.ToolNameMapping = newMapping + cts.manager.mu.Unlock() + + if oldToolCount != newToolCount { + logger.Info("%s Tool sync completed for %s: %d -> %d tools", MCPLogPrefix, cts.clientID, oldToolCount, newToolCount) + } else { + logger.Debug("%s Tool sync completed for %s: %d tools (no change)", MCPLogPrefix, cts.clientID, newToolCount) + } +} + +// ToolSyncManager manages all client tool syncers +type ToolSyncManager struct { + syncers map[string]*ClientToolSyncer + globalInterval time.Duration + mu sync.RWMutex +} + +// NewToolSyncManager creates a new tool sync manager +func NewToolSyncManager(globalInterval time.Duration) *ToolSyncManager { + if globalInterval <= 0 { + globalInterval = DefaultToolSyncInterval + } + + return &ToolSyncManager{ + syncers: make(map[string]*ClientToolSyncer), + globalInterval: globalInterval, + } +} + +// GetGlobalInterval returns the global tool sync interval +func (tsm *ToolSyncManager) GetGlobalInterval() time.Duration { + return tsm.globalInterval +} + +// StartSyncing starts syncing for a specific client +func (tsm *ToolSyncManager) StartSyncing(syncer *ClientToolSyncer) { + tsm.mu.Lock() + defer tsm.mu.Unlock() + + // Stop any existing syncer for this client + if existing, ok := tsm.syncers[syncer.clientID]; ok { + existing.Stop() + } + + tsm.syncers[syncer.clientID] = syncer + syncer.Start() +} + +// StopSyncing stops syncing for a specific client +func (tsm *ToolSyncManager) StopSyncing(clientID string) { + tsm.mu.Lock() + defer tsm.mu.Unlock() + + if syncer, ok := tsm.syncers[clientID]; ok { + syncer.Stop() + delete(tsm.syncers, clientID) + } +} + +// StopAll stops all syncing +func (tsm *ToolSyncManager) StopAll() { + tsm.mu.Lock() + defer tsm.mu.Unlock() + + for _, syncer := range tsm.syncers { + syncer.Stop() + } + tsm.syncers = make(map[string]*ClientToolSyncer) +} + +// ResolveToolSyncInterval determines the effective tool sync interval for a client. +// Priority: per-client override > global setting > default +// +// Per-client semantics: +// - Negative value: disabled for this client +// - Zero: use global setting +// - Positive value: use this interval +// +// Returns 0 if sync is disabled for this client. +func ResolveToolSyncInterval(clientConfig *schemas.MCPClientConfig, globalInterval time.Duration) time.Duration { + // Per-client explicitly disabled (negative value) + if clientConfig.ToolSyncInterval < 0 { + return 0 // Disabled for this client + } + + // Per-client override (positive value) + if clientConfig.ToolSyncInterval > 0 { + return clientConfig.ToolSyncInterval + } + + // Use global interval (or default if global is 0) + if globalInterval > 0 { + return globalInterval + } + + return DefaultToolSyncInterval +} diff --git a/core/mcp/utils.go b/core/mcp/utils.go index 2482ea0c46..a46990f8e2 100644 --- a/core/mcp/utils.go +++ b/core/mcp/utils.go @@ -54,7 +54,7 @@ func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas. includeClients = existingIncludeClients } - logger.Debug(fmt.Sprintf("%s GetToolPerClient: Total clients in manager: %d, Filter: %v", MCPLogPrefix, len(m.clientMap), includeClients)) + logger.Debug("%s GetToolPerClient: Total clients in manager: %d, Filter: %v", MCPLogPrefix, len(m.clientMap), includeClients) tools := make(map[string][]schemas.ChatTool) for _, client := range m.clientMap { @@ -62,11 +62,11 @@ func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas. clientName := client.ExecutionConfig.Name clientID := client.ExecutionConfig.ID - logger.Debug(fmt.Sprintf("%s Evaluating client %s (ID: %s) for tools", MCPLogPrefix, clientName, clientID)) + logger.Debug("%s Evaluating client %s (ID: %s) for tools", MCPLogPrefix, clientName, clientID) // Apply client filtering logic - check both ID and Name for compatibility if !shouldIncludeClient(clientName, includeClients) { - logger.Debug(fmt.Sprintf("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, clientName)) + logger.Debug("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, clientName) continue } @@ -91,7 +91,7 @@ func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas. tools[clientName] = append(tools[clientName], tool) } if len(tools[clientName]) > 0 { - logger.Debug(fmt.Sprintf("%s Added %d tools for MCP client %s", MCPLogPrefix, len(tools[clientName]), clientName)) + logger.Debug("%s Added %d tools for MCP client %s", MCPLogPrefix, len(tools[clientName]), clientName) } } return tools @@ -107,18 +107,18 @@ func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas. func (m *MCPManager) GetClientByName(clientName string) *schemas.MCPClientState { m.mu.RLock() defer m.mu.RUnlock() - logger.Debug(fmt.Sprintf("%s GetClientByName: Looking for client '%s' among %d clients", MCPLogPrefix, clientName, len(m.clientMap))) + logger.Debug("%s GetClientByName: Looking for client '%s' among %d clients", MCPLogPrefix, clientName, len(m.clientMap)) for _, client := range m.clientMap { - logger.Debug(fmt.Sprintf("%s Checking client with Name: %s, ID: %s", MCPLogPrefix, client.ExecutionConfig.Name, client.ExecutionConfig.ID)) + logger.Debug("%s Checking client with Name: %s, ID: %s", MCPLogPrefix, client.ExecutionConfig.Name, client.ExecutionConfig.ID) if client.ExecutionConfig.Name == clientName { // Return a copy to prevent TOCTOU race conditions // The caller receives a snapshot of the client state at this point in time - logger.Debug(fmt.Sprintf("%s Found client '%s' with IsCodeModeClient=%v", MCPLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient)) + logger.Debug("%s Found client '%s' with IsCodeModeClient=%v", MCPLogPrefix, clientName, client.ExecutionConfig.IsCodeModeClient) clientCopy := *client return &clientCopy } } - logger.Debug(fmt.Sprintf("%s Client '%s' not found", MCPLogPrefix, clientName)) + logger.Debug("%s Client '%s' not found", MCPLogPrefix, clientName) return nil } @@ -148,18 +148,18 @@ func retrieveExternalTools(ctx context.Context, client *client.Client, clientNam // toolsResponse is already a ListToolsResult for _, mcpTool := range toolsResponse.Tools { - originalMCPName := mcpTool.Name // Original name from MCP server (e.g., "notion-search") - sanitizedToolName := strings.ReplaceAll(mcpTool.Name, "-", "_") // For code mode and internal use (e.g., "notion_search") - - if err := validateNormalizedToolName(sanitizedToolName); err != nil { - logger.Warn(fmt.Sprintf("%s Skipping MCP tool %q: %v", MCPLogPrefix, mcpTool.Name, err)) + // Validate the original tool name (with hyphens replaced by underscores for validation only) + validationName := strings.ReplaceAll(mcpTool.Name, "-", "_") + if err := validateNormalizedToolName(validationName); err != nil { + logger.Warn("%s Skipping MCP tool %q: %v", MCPLogPrefix, mcpTool.Name, err) continue } // Convert MCP tool schema to Bifrost format bifrostTool := convertMCPToolToBifrostSchema(&mcpTool) // Prefix tool name with client name to make it permanent (using '-' as separator) - prefixedToolName := fmt.Sprintf("%s-%s", clientName, sanitizedToolName) + // Keep the original tool name (don't sanitize) so we can call the MCP server correctly + prefixedToolName := fmt.Sprintf("%s-%s", clientName, mcpTool.Name) // Update the tool's function name to match the prefixed name if bifrostTool.Function != nil { bifrostTool.Function.Name = prefixedToolName @@ -167,7 +167,8 @@ func retrieveExternalTools(ctx context.Context, client *client.Client, clientNam // Store the tool with the prefixed name tools[prefixedToolName] = bifrostTool // Store the mapping from sanitized name to original MCP name for later lookup during execution - toolNameMapping[sanitizedToolName] = originalMCPName + sanitizedToolName := strings.ReplaceAll(mcpTool.Name, "-", "_") + toolNameMapping[sanitizedToolName] = mcpTool.Name } return tools, toolNameMapping, nil @@ -179,29 +180,32 @@ func shouldIncludeClient(clientName string, includeClients []string) bool { if includeClients != nil { // Handle empty array [] - means no clients are included if len(includeClients) == 0 { - logger.Debug(fmt.Sprintf("%s shouldIncludeClient: %s - BLOCKED (empty include list)", MCPLogPrefix, clientName)) + logger.Debug("%s shouldIncludeClient: %s - BLOCKED (empty include list)", MCPLogPrefix, clientName) return false // No clients allowed } // Handle wildcard "*" - if present, all clients are included if slices.Contains(includeClients, "*") { - logger.Debug(fmt.Sprintf("%s shouldIncludeClient: %s - ALLOWED (wildcard filter)", MCPLogPrefix, clientName)) + logger.Debug("%s shouldIncludeClient: %s - ALLOWED (wildcard filter)", MCPLogPrefix, clientName) return true // All clients allowed } // Check if specific client is in the list included := slices.Contains(includeClients, clientName) - logger.Debug(fmt.Sprintf("%s shouldIncludeClient: %s - %s (filter: %v)", MCPLogPrefix, clientName, map[bool]string{true: "ALLOWED", false: "BLOCKED"}[included], includeClients)) + logger.Debug("%s shouldIncludeClient: %s - %s (filter: %v)", MCPLogPrefix, clientName, map[bool]string{true: "ALLOWED", false: "BLOCKED"}[included], includeClients) return included } // Default: include all clients when no filtering specified (nil case) - logger.Debug(fmt.Sprintf("%s shouldIncludeClient: %s - ALLOWED (no filter)", MCPLogPrefix, clientName)) + logger.Debug("%s shouldIncludeClient: %s - ALLOWED (no filter)", MCPLogPrefix, clientName) return true } // shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap). -func shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bool { +func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) bool { + if config == nil { + return true // No tools allowed + } // If ToolsToExecute is specified (not nil), apply filtering if config.ToolsToExecute != nil { // Handle empty array [] - means no tools are allowed @@ -228,7 +232,7 @@ func shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bo // canAutoExecuteTool checks if a tool can be auto-executed based on client configuration. // Returns true if the tool can be auto-executed, false otherwise. -func canAutoExecuteTool(toolName string, config schemas.MCPClientConfig) bool { +func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // First check if tool is in ToolsToExecute (must be executable first) if shouldSkipToolForConfig(toolName, config) { return false // Tool is not in ToolsToExecute, so it cannot be auto-executed @@ -506,8 +510,8 @@ func validateNormalizedToolName(normalizedName string) error { return nil } -// extractToolCallsFromCode extracts tool calls from TypeScript code -// Tool calls are in the format: serverName.toolName(...) or await serverName.toolName(...) +// extractToolCallsFromCode extracts tool calls from Python/Starlark code +// Tool calls are in the format: server_name.tool_name(...) func extractToolCallsFromCode(code string) ([]toolCallInfo, error) { toolCalls := []toolCallInfo{} @@ -540,7 +544,7 @@ func extractToolCallsFromCode(code string) ([]toolCallInfo, error) { func isToolCallAllowedForCodeMode(serverName, toolName string, allClientNames []string, allowedAutoExecutionTools map[string][]string) bool { // Check if the server name is in the list of all client names if !slices.Contains(allClientNames, serverName) { - // It can be a built-in JavaScript/TypeScript object, if not then downstream execution will fail with a runtime error. + // It can be a built-in Python/Starlark object, if not then downstream execution will fail with a runtime error. return true } @@ -676,15 +680,16 @@ func FixArraySchemas(properties map[string]interface{}) { if _, hasItems := schemaMap["items"]; !hasItems { // Add a default 'items' schema (unconstrained) schemaMap["items"] = map[string]interface{}{} - logger.Debug(fmt.Sprintf("%s Fixed array schema for property '%s': added missing 'items' field", MCPLogPrefix, key)) + logger.Debug("%s Fixed array schema for property '%s': added missing 'items' field", MCPLogPrefix, key) } // Recurse into items regardless of type (object or array) if itemsMap, ok := schemaMap["items"].(map[string]interface{}); ok { itemsType, _ := itemsMap["type"].(string) - if itemsType == "array" { + switch itemsType { + case "array": // Handle nested arrays (array-of-array) FixArraySchemas(map[string]interface{}{"": itemsMap}) - } else if itemsType == "object" { + case "object": // Recurse into object properties if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok { FixArraySchemas(itemsProps) @@ -708,15 +713,16 @@ func FixArraySchemas(properties map[string]interface{}) { if unionType, ok := unionMap["type"].(string); ok && unionType == "array" { if _, hasItems := unionMap["items"]; !hasItems { unionMap["items"] = map[string]interface{}{} - logger.Debug(fmt.Sprintf("%s Fixed array schema in %s for property '%s': added missing 'items' field", MCPLogPrefix, unionKey, key)) + logger.Debug("%s Fixed array schema in %s for property '%s': added missing 'items' field", MCPLogPrefix, unionKey, key) } // Recurse into items regardless of type if itemsMap, ok := unionMap["items"].(map[string]interface{}); ok { itemsType, _ := itemsMap["type"].(string) - if itemsType == "array" { + switch itemsType { + case "array": // Handle nested arrays FixArraySchemas(map[string]interface{}{"": itemsMap}) - } else if itemsType == "object" { + case "object": if itemsProps, ok := itemsMap["properties"].(map[string]interface{}); ok { FixArraySchemas(itemsProps) } diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index dd1eed72bc..a8f6ffe60d 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -6,6 +6,7 @@ package schemas import ( "context" "errors" + "strings" "time" "github.com/bytedance/sonic" @@ -25,8 +26,9 @@ var ( // MCPConfig represents the configuration for MCP integration in Bifrost. // It enables tool auto-discovery and execution from local and external MCP servers. type MCPConfig struct { - ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations + ClientConfigs []*MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations ToolManagerConfig *MCPToolManagerConfig `json:"tool_manager_config,omitempty"` // MCP tool manager configuration + ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Global default interval for syncing tools from MCP servers (0 = use default 10 min) // Function to fetch a new request ID for each tool call result message in agent mode, // this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user. @@ -74,7 +76,7 @@ const ( // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { - ID string `json:"id"` // Client ID + ID string `json:"client_id"` // Client ID Name string `json:"name"` // Client name IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) @@ -98,8 +100,10 @@ type MCPClientConfig struct { // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => auto-execute only the specified tools // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. - IsPingAvailable bool `json:"is_ping_available"` // Whether the MCP server supports ping for health checks (default: true). If false, uses listTools for health checks. - ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + IsPingAvailable bool `json:"is_ping_available"` // Whether the MCP server supports ping for health checks (default: true). If false, uses listTools for health checks. + ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) + ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) + ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } // NewMCPClientConfigFromMap creates a new MCP client config from a map[string]any. @@ -131,6 +135,14 @@ func (c *MCPClientConfig) HttpHeaders(ctx context.Context, oauth2Provider OAuth2 if err != nil { return nil, err } + // Validate token format - trim whitespace and check for invalid characters + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + return nil, errors.New("access token is empty") + } + if strings.ContainsAny(accessToken, "\n\r\t") { + return nil, errors.New("access token contains invalid characters") + } headers["Authorization"] = "Bearer " + accessToken case MCPAuthTypeHeaders: for key, value := range c.Headers { @@ -176,14 +188,14 @@ const ( // MCPClientState represents a connected MCP client with its configuration and tools. // It is used internally by the MCP manager to track the state of a connected MCP client. type MCPClientState struct { - Name string // Unique name for this client - Conn *client.Client // Active MCP client connection - ExecutionConfig MCPClientConfig // Tool filtering settings - ToolMap map[string]ChatTool // Available tools mapped by name - ToolNameMapping map[string]string // Maps sanitized_name -> original_mcp_name (e.g., "notion_search" -> "notion-search") - ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management - CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) - State MCPConnectionState // Connection state (connected, disconnected, error) + Name string // Unique name for this client + Conn *client.Client // Active MCP client connection + ExecutionConfig *MCPClientConfig // Tool filtering settings + ToolMap map[string]ChatTool // Available tools mapped by name + ToolNameMapping map[string]string // Maps sanitized_name -> original_mcp_name (e.g., "notion_search" -> "notion-search") + ConnectionInfo *MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management + CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) + State MCPConnectionState // Connection state (connected, disconnected, error) } // MCPClientConnectionInfo stores metadata about how a client is connected. @@ -197,7 +209,7 @@ type MCPClientConnectionInfo struct { // and connection information, after it has been initialized. // It is returned by GetMCPClients() method in bifrost. type MCPClient struct { - Config MCPClientConfig `json:"config"` // Tool filtering settings + Config *MCPClientConfig `json:"config"` // Tool filtering settings Tools []ChatToolFunction `json:"tools"` // Available tools State MCPConnectionState `json:"state"` // Connection state } diff --git a/core/utils.go b/core/utils.go index 641e2f653d..f54ffa555e 100644 --- a/core/utils.go +++ b/core/utils.go @@ -447,5 +447,5 @@ func sanitizeSpanName(name string) string { // IsCodemodeTool returns true if the given tool name is a codemode tool. func IsCodemodeTool(toolName string) bool { - return toolName == mcp.ToolTypeListToolFiles || toolName == mcp.ToolTypeReadToolFile || toolName == mcp.ToolTypeExecuteToolCode + return mcp.IsCodeModeTool(toolName) } diff --git a/docs/architecture/core/mcp.mdx b/docs/architecture/core/mcp.mdx index b78ca29edf..24c44761e8 100644 --- a/docs/architecture/core/mcp.mdx +++ b/docs/architecture/core/mcp.mdx @@ -612,23 +612,23 @@ When max depth is reached, the response may contain pending tool calls that were ## Code Mode Architecture -Code Mode enables AI models to write and execute TypeScript code that orchestrates multiple MCP tools in a single request. This provides a powerful meta-layer for complex multi-tool workflows. +Code Mode enables AI models to write and execute Python code (Starlark) that orchestrates multiple MCP tools in a single request. This provides a powerful meta-layer for complex multi-tool workflows. ### **Code Mode System Overview** ```mermaid graph TB subgraph "Code Mode Components" - VM["🖥️ Goja VM
TypeScript/JavaScript Runtime"] - VFS["📁 Virtual File System
Tool Definitions as .d.ts"] - TS["📝 TypeScript Transpiler
TS → JS Conversion"] + VM["🖥️ Starlark Interpreter
Python-like Runtime"] + VFS["📁 Virtual File System
Tool Definitions as .pyi"] EXEC["⚙️ Code Executor
Sandboxed Execution"] end subgraph "Meta Tools" LIST["listToolFiles()
Discover available servers"] - READ["readToolFile(fileName)
Get tool definitions"] - CODE["executeToolCode(code)
Run TypeScript code"] + READ["readToolFile(fileName)
Get tool signatures"] + DOCS["getToolDocs(server, tool)
Get detailed docs"] + CODE["executeToolCode(code)
Run Python code"] end subgraph "MCP Integration" @@ -642,9 +642,11 @@ graph TB LLM --> READ READ --> VFS VFS --> LLM + LLM --> DOCS + DOCS --> VFS + VFS --> LLM LLM --> CODE - CODE --> TS - TS --> VM + CODE --> VM VM --> EXEC EXEC --> TOOLS TOOLS --> RESULTS @@ -657,7 +659,7 @@ graph TB ### **Virtual File System (VFS)** -Code Mode generates TypeScript declaration files (`.d.ts`) for all connected MCP tools, enabling type-safe tool invocation: +Code Mode generates Python stub files (`.pyi`) for all connected MCP tools, providing compact function signatures: @@ -666,26 +668,27 @@ When `code_mode_binding_level: "server"` (default), tools are grouped by MCP cli ``` servers/ -├── filesystem.d.ts → All filesystem tools -├── web_search.d.ts → All web search tools -└── database.d.ts → All database tools +├── filesystem.pyi → All filesystem tools +├── web_search.pyi → All web search tools +└── database.pyi → All database tools ``` -**Generated Declaration Example:** -```typescript -// servers/filesystem.d.ts -declare const filesystem: { - read_file(args: { path: string }): Promise; - write_file(args: { path: string; content: string }): Promise; - list_directory(args: { path: string }): Promise; -}; +**Generated Stub Example:** +```python +# servers/filesystem.pyi +# Usage: filesystem.tool_name(param=value) +# For detailed docs: use getToolDocs(server="filesystem", tool="tool_name") + +def read_file(path: str) -> dict: # Read contents of a file +def write_file(path: str, content: str) -> dict: # Write content to a file +def list_directory(path: str) -> dict: # List directory contents ``` **Usage in Code:** -```typescript -const files = await filesystem.list_directory({ path: "." }); -const content = await filesystem.read_file({ path: files[0] }); -return content; +```python +files = filesystem.list_directory(path=".") +content = filesystem.read_file(path=files["entries"][0]) +result = content ``` @@ -694,24 +697,29 @@ return content; When `code_mode_binding_level: "tool"`, each tool gets its own file: ``` -tools/ -├── filesystem_read_file.d.ts -├── filesystem_write_file.d.ts -├── filesystem_list_directory.d.ts -├── web_search_search.d.ts -└── database_query.d.ts +servers/ +├── filesystem/ +│ ├── read_file.pyi +│ ├── write_file.pyi +│ └── list_directory.pyi +├── web_search/ +│ └── search.pyi +└── database/ + └── query.pyi ``` -**Generated Declaration Example:** -```typescript -// tools/filesystem_read_file.d.ts -declare function filesystem_read_file(args: { path: string }): Promise; +**Generated Stub Example:** +```python +# servers/filesystem/read_file.pyi +# Usage: filesystem.read_file(param=value) + +def read_file(path: str) -> dict: # Read contents of a file ``` **Usage in Code:** -```typescript -const content = await filesystem_read_file({ path: "config.json" }); -return content; +```python +content = filesystem.read_file(path="config.json") +result = content ``` @@ -723,52 +731,48 @@ return content; sequenceDiagram participant LLM as 🤖 LLM participant CM as 📝 Code Mode Handler - participant TS as 🔄 TypeScript Transpiler - participant VM as 🖥️ Goja VM + participant VM as 🖥️ Starlark Interpreter participant TM as 🔧 Tools Manager participant MCP as 🌐 MCP Servers LLM->>CM: executeToolCode({ code: "..." }) - CM->>TS: Transpile TypeScript - TS-->>CM: JavaScript code - CM->>VM: Initialize sandbox CM->>VM: Inject tool bindings - CM->>VM: Execute code + CM->>VM: Execute Python code loop For each tool call in code - VM->>TM: await server.tool(args) + VM->>TM: server.tool(param=value) TM->>MCP: Execute tool MCP-->>TM: Tool result TM-->>VM: Return result end VM-->>CM: Execution result - CM-->>LLM: { result, console_output } + CM-->>LLM: { result, logs } ``` -### **Goja VM Sandbox** +### **Starlark Sandbox** -The code execution environment is carefully sandboxed: +The code execution environment is carefully sandboxed using Starlark, a Python-like language designed for configuration and embedded scripting: - - ✅ **ES5.1+ JavaScript** - Core language features - - ✅ **async/await** - Transpiled to Promise chains - - ✅ **TypeScript** - Full type checking during transpilation - - ✅ **console.log/error/warn** - Output captured and returned - - ✅ **JSON.parse/stringify** - Data serialization + - ✅ **Python-like syntax** - Familiar Python syntax and semantics + - ✅ **Synchronous calls** - No async/await needed, direct function calls + - ✅ **List comprehensions** - `[x for x in items if condition]` + - ✅ **print()** - Output captured and returned in logs + - ✅ **Dict/List operations** - Standard Python data structures - ✅ **Tool bindings** - All connected MCP tools as globals - - ❌ **ES Modules** - `import`/`export` statements stripped - - ❌ **Node.js APIs** - No `require`, `fs`, `path`, etc. - - ❌ **Browser APIs** - No `fetch`, `XMLHttpRequest`, `DOM` - - ❌ **Timers** - No `setTimeout`, `setInterval` + - ❌ **Imports** - No `import` statements (tools are pre-bound) + - ❌ **Classes** - Use dicts and functions instead + - ❌ **File I/O** - No direct filesystem access (use MCP tools) - ❌ **Network** - No direct network access (use MCP tools) + - ❌ **Randomness/Time** - Deterministic execution only @@ -778,8 +782,8 @@ The code execution environment is carefully sandboxed: ```mermaid graph TB subgraph "Security Layers" - L1["🔒 TypeScript Validation
Type checking before execution"] - L2["🛡️ Import Stripping
No external module access"] + L1["🔒 Code Validation
Syntax checking before execution"] + L2["🛡️ Sandboxed Runtime
No external module access"] L3["⏱️ Execution Timeout
Bounded runtime"] L4["🔐 Tool ACL
Only allowed tools accessible"] end @@ -788,7 +792,7 @@ graph TB B1["No filesystem access
(except via MCP tools)"] B2["No network access
(except via MCP tools)"] B3["No process spawning"] - B4["Memory limits enforced"] + B4["Memory isolation enforced"] end L1 --> L2 --> L3 --> L4 diff --git a/docs/mcp/code-mode.mdx b/docs/mcp/code-mode.mdx index a8ca3371e9..a151c2ae30 100644 --- a/docs/mcp/code-mode.mdx +++ b/docs/mcp/code-mode.mdx @@ -1,7 +1,7 @@ --- title: "Code Mode" sidebarTitle: "Code Mode" -description: "AI writes TypeScript to orchestrate tools. Reduces token usage by 50%+ when using multiple MCP servers." +description: "AI writes Python to orchestrate tools. Reduces token usage by 50%+ when using multiple MCP servers." icon: "code" --- @@ -15,7 +15,7 @@ This feature is only available on `v1.4.0-prerelease1` and above. > **The Problem:** When you connect 8-10 MCP servers (150+ tools), every single request includes all tool definitions in the context. The LLM spends most of its budget reading tool catalogs instead of doing actual work. -**The Solution:** Instead of exposing 150 tools directly, Code Mode exposes just **three generic tools**. The LLM uses those three tools to write TypeScript code that orchestrates everything else in a sandbox. +**The Solution:** Instead of exposing 150 tools directly, Code Mode exposes just **four generic tools**. The LLM uses those tools to write Python code (Starlark) that orchestrates everything else in a sandbox. ### The Impact @@ -28,15 +28,16 @@ Compare a workflow across 5 MCP servers with ~100 tools: **Code Mode Flow:** - 3-4 LLM turns -- Only 3 tools + definitions on-demand +- Only 4 tools + definitions on-demand - Intermediate results processed in sandbox **Result: ~50% cost reduction + 30-40% faster execution** -Code Mode provides three meta-tools to the AI: +Code Mode provides four meta-tools to the AI: 1. **`listToolFiles`** - Discover available MCP servers -2. **`readToolFile`** - Load TypeScript definitions on-demand -3. **`executeToolCode`** - Execute TypeScript code with full tool bindings +2. **`readToolFile`** - Load Python stub signatures on-demand +3. **`getToolDocs`** - Get detailed documentation for a specific tool +4. **`executeToolCode`** - Execute Python code with full tool bindings ## When to Use Code Mode @@ -57,31 +58,35 @@ Code Mode provides three meta-tools to the AI: ## How Code Mode Works -### The Three Tools +### The Four Tools -Instead of seeing 150+ tool definitions, the model sees three generic tools: +Instead of seeing 150+ tool definitions, the model sees four generic tools: ```mermaid graph LR LLM["LLM Context
Compact & Efficient"] List["listToolFiles
Discover servers"] - Read["readToolFile
Load definitions"] + Read["readToolFile
Load signatures"] + Docs["getToolDocs
Get detailed docs"] Execute["executeToolCode
Run code with bindings"] - Hidden["All other MCP servers
hidden behind these 3 tools
"] + Hidden["All other MCP servers
hidden behind these 4 tools
"] LLM --> List LLM --> Read + LLM --> Docs LLM --> Execute List -.-> Hidden Read -.-> Hidden + Docs -.-> Hidden Execute -.-> Hidden style LLM fill:#E3F2FD,stroke:#0D47A1,stroke-width:2.5px,color:#1A1A1A style List fill:#E8F5E9,stroke:#1B5E20,stroke-width:2.5px,color:#1A1A1A style Read fill:#FFF3E0,stroke:#BF360C,stroke-width:2.5px,color:#1A1A1A + style Docs fill:#E1F5FE,stroke:#0288D1,stroke-width:2.5px,color:#1A1A1A style Execute fill:#F3E5F5,stroke:#4A148C,stroke-width:2.5px,color:#1A1A1A style Hidden fill:#EEEEEE,stroke:#424242,stroke-width:1.5px,stroke-dasharray: 5 5,color:#1A1A1A ``` @@ -96,7 +101,7 @@ graph LR GetDefs["3. Load Definitions
readToolFile()"] - Write["4. Write Code
TypeScript
in sandbox"] + Write["4. Write Code
Python
in sandbox"] Execute["5. Execute
Real MCP calls
contained in VM"] @@ -142,11 +147,11 @@ Total: 6 LLM calls, ~600+ tokens in tool definitions alone ### Code Mode with same 5 servers: ``` -Turn 1: Prompt + 3 tools (listToolFiles, readToolFile, executeToolCode) -Turn 2: Prompt + server list + 3 tools -Turn 3: Prompt + selected definitions + 3 tools + [EXECUTES CODE] +Turn 1: Prompt + 4 tools (listToolFiles, readToolFile, getToolDocs, executeToolCode) +Turn 2: Prompt + server list + 4 tools +Turn 3: Prompt + selected definitions + 4 tools + [EXECUTES CODE] [YouTube search, channel list, videos, summaries, doc creation all happen in sandbox] -Turn 4: Prompt + final result + 3 tools +Turn 4: Prompt + final result + 4 tools Total: 3-4 LLM calls, ~50 tokens in tool definitions Result: 50% cost reduction, 3-4x fewer LLM round trips @@ -156,7 +161,7 @@ Result: 50% cost reduction, 3-4x fewer LLM round trips ## Enabling Code Mode -Code Mode must be enabled **per MCP client**. Once enabled, that client's tools are accessed through the three meta-tools rather than exposed directly. +Code Mode must be enabled **per MCP client**. Once enabled, that client's tools are accessed through the four meta-tools rather than exposed directly. **Best practice:** Enable Code Mode for 3+ servers or any "heavy" server (web search, documents, databases). @@ -267,80 +272,123 @@ mcpConfig := &schemas.MCPConfig{ --- -## The Three Code Mode Tools +## The Four Code Mode Tools -When Code Mode clients are connected, Bifrost automatically adds three meta-tools to every request: +When Code Mode clients are connected, Bifrost automatically adds four meta-tools to every request: ### 1. listToolFiles -Lists all available virtual `.d.ts` declaration files for connected code mode servers. +Lists all available virtual `.pyi` stub files for connected code mode servers. **Example output (Server-level binding):** ``` servers/ - youtube.d.ts - filesystem.d.ts + youtube.pyi + filesystem.pyi ``` **Example output (Tool-level binding):** ``` servers/ youtube/ - search.d.ts - get_video.d.ts + search.pyi + get_video.pyi filesystem/ - read_file.d.ts - write_file.d.ts + read_file.pyi + write_file.pyi ``` ### 2. readToolFile -Reads a virtual `.d.ts` file to get TypeScript type definitions for tools. +Reads a virtual `.pyi` file to get compact Python function signatures for tools. **Parameters:** -- `fileName` (required): Path like `servers/youtube.d.ts` or `servers/youtube/search.d.ts` +- `fileName` (required): Path like `servers/youtube.pyi` or `servers/youtube/search.pyi` - `startLine` (optional): 1-based starting line for partial reads - `endLine` (optional): 1-based ending line for partial reads **Example output:** -```typescript -// Type definitions for youtube MCP server -// Usage: const result = await youtube.search({ query: "..." }); +```python +# youtube server tools +# Usage: youtube.tool_name(param=value) +# For detailed docs: use getToolDocs(server="youtube", tool="tool_name") -interface SearchInput { - query: string; // Search query (required) - maxResults?: number; // Max results to return (optional) -} +def search(query: str, maxResults: int = None) -> dict: # Search for videos +def get_video(id: str) -> dict: # Get video details +``` -interface SearchResponse { - [key: string]: any; -} +### 3. getToolDocs + +Get detailed documentation for a specific tool when the compact signature from `readToolFile` is not sufficient. -export async function search(input: SearchInput): Promise; +**Parameters:** +- `server` (required): The server name (e.g., `"youtube"`) +- `tool` (required): The tool name (e.g., `"search"`) + +**Example output:** +```python +# ============================================================================ +# Documentation for youtube.search tool +# ============================================================================ +# +# USAGE INSTRUCTIONS: +# Call tools using: result = youtube.tool_name(param=value) +# No async/await needed - calls are synchronous. +# +# CRITICAL - HANDLING RESPONSES: +# Tool responses are dicts. To avoid runtime errors: +# 1. Use print(result) to inspect the response structure first +# 2. Access dict values with brackets: result["key"] NOT result.key +# 3. Use .get() for safe access: result.get("key", default) +# ============================================================================ + +def search(query: str, maxResults: int = None) -> dict: + """ + Search for videos on YouTube. + + Args: + query (str): Search query (required) + maxResults (int): Max results to return (optional) + + Returns: + dict: Response from the tool. Structure varies by tool. + Use print(result) to inspect the actual structure. + + Example: + result = youtube.search(query="...") + print(result) # Always inspect response first! + value = result.get("key", default) # Safe access + """ + ... ``` -### 3. executeToolCode +### 4. executeToolCode -Executes TypeScript code in a sandboxed VM with access to all code mode server tools. +Executes Python code in a sandboxed Starlark interpreter with access to all code mode server tools. **Parameters:** -- `code` (required): TypeScript code to execute +- `code` (required): Python code to execute **Execution Environment:** -- TypeScript is transpiled to ES5-compatible JavaScript +- Python code runs in a Starlark interpreter (Python subset) - All code mode servers are exposed as global objects (e.g., `youtube`, `filesystem`) -- Each server has async functions for its tools (e.g., `youtube.search()`) -- Console output (`log`, `error`, `warn`, `info`) is captured -- Use `return` to return a value from the code +- Tool calls are **synchronous** - no async/await needed +- Use `print()` for logging (output captured in logs) +- Assign to `result` variable to return a value - Tool execution timeout applies (default 30s) +**Syntax notes:** +- Use keyword arguments: `server.tool(param="value")` NOT `server.tool({"param": "value"})` +- Access dict values with brackets: `result["key"]` NOT `result.key` +- List comprehensions work: `[x for x in items if x["active"]]` + **Example code:** -```typescript -// Search YouTube and return formatted results -const results = await youtube.search({ query: "AI news", maxResults: 5 }); -const titles = results.items.map(item => item.snippet.title); -console.log("Found", titles.length, "videos"); -return { titles, count: titles.length }; +```python +# Search YouTube and return formatted results +results = youtube.search(query="AI news", maxResults=5) +titles = [item["snippet"]["title"] for item in results["items"]] +print("Found", len(titles), "videos") +result = {"titles": titles, "count": len(titles)} ``` --- @@ -351,12 +399,12 @@ Code Mode supports two binding levels that control how tools are organized in th ### Server-Level Binding (Default) -All tools from a server are grouped into a single `.d.ts` file. +All tools from a server are grouped into a single `.pyi` file. ``` servers/ - youtube.d.ts ← Contains all youtube tools - filesystem.d.ts ← Contains all filesystem tools + youtube.pyi ← Contains all youtube tools + filesystem.pyi ← Contains all filesystem tools ``` **Best for:** @@ -366,18 +414,18 @@ servers/ ### Tool-Level Binding -Each tool gets its own `.d.ts` file. +Each tool gets its own `.pyi` file. ``` servers/ youtube/ - search.d.ts - get_video.d.ts - get_channel.d.ts + search.pyi + get_video.pyi + get_channel.pyi filesystem/ - read_file.d.ts - write_file.d.ts - list_directory.d.ts + read_file.pyi + write_file.pyi + list_directory.pyi ``` **Best for:** @@ -398,15 +446,15 @@ Binding level can be viewed in the MCP configuration overview: MCP Gateway Configuration -- **Server-level (default)**: One `.d.ts` file per MCP server +- **Server-level (default)**: One `.pyi` file per MCP server - Use when: 5-20 tools per server, want simple discovery - - Example: `servers/youtube.d.ts` contains all YouTube tools + - Example: `servers/youtube.pyi` contains all YouTube tools -- **Tool-level**: One `.d.ts` file per individual tool +- **Tool-level**: One `.pyi` file per individual tool - Use when: 30+ tools per server, want minimal context bloat - - Example: `servers/youtube/search.d.ts`, `servers/youtube/list_channels.d.ts` + - Example: `servers/youtube/search.pyi`, `servers/youtube/list_channels.pyi` -Both modes use the same three-tool interface (`listToolFiles`, `readToolFile`, `executeToolCode`). The choice is purely about **context efficiency per read operation**. +Both modes use the same four-tool interface (`listToolFiles`, `readToolFile`, `getToolDocs`, `executeToolCode`). The choice is purely about **context efficiency per read operation**. @@ -453,7 +501,7 @@ Code Mode tools can be auto-executed in [Agent Mode](./agent-mode), but with **a When `executeToolCode` is called in agent mode: -1. Bifrost parses the TypeScript code +1. Bifrost parses the Python code 2. Extracts all `serverName.toolName()` calls 3. Checks each call against `tools_to_auto_execute` for that server 4. If ALL calls are allowed → auto-execute @@ -469,13 +517,13 @@ When `executeToolCode` is called in agent mode: } ``` -```typescript -// This code WILL auto-execute (only uses search) -const results = await youtube.search({ query: "AI" }); -return results; +```python +# This code WILL auto-execute (only uses search) +results = youtube.search(query="AI") +result = results -// This code will NOT auto-execute (uses delete_video which is not in auto-execute list) -await youtube.delete_video({ id: "abc123" }); +# This code will NOT auto-execute (uses delete_video which is not in auto-execute list) +youtube.delete_video(id="abc123") ``` --- @@ -486,45 +534,44 @@ await youtube.delete_video({ id: "abc123" }); | Available | Not Available | |-----------|---------------| -| `async/await` | `fetch`, `XMLHttpRequest` | -| `Promise` | `setTimeout`, `setInterval` | -| `console.log/error/warn/info` | `require`, `import` | -| JSON operations | DOM APIs (`document`, `window`) | -| String/Array/Object methods | Node.js APIs | +| Python-like syntax | `import` statements | +| Synchronous tool calls | Classes (use dicts) | +| `print()` for logging | File I/O | +| Dict/List operations | Network access | +| List comprehensions | `random`, `time` modules | ### Runtime Environment Details -**Engine:** Goja VM with ES5+ JavaScript compatibility +**Engine:** Starlark interpreter (Python subset) **Tool Exposure:** Tools from code mode clients are exposed as global objects: -```typescript -// If you have a 'youtube' code mode client with a 'search' tool -const results = await youtube.search({ query: "AI news" }); +```python +# If you have a 'youtube' code mode client with a 'search' tool +results = youtube.search(query="AI news") ``` **Code Processing:** -1. Import/export statements are stripped -2. TypeScript is transpiled to JavaScript (ES5 compatible) -3. Tool calls are extracted and validated -4. Code executes in isolated VM context -5. Return value is automatically serialized to JSON +1. Code is validated for syntax errors +2. Tool calls are extracted and validated +3. Code executes in isolated Starlark context +4. Result variable is automatically serialized to JSON **Execution Limits:** - Default timeout: 30 seconds per tool execution - Memory isolation: Each execution gets its own context - No access to host file system or network -- Logs captured from console methods +- Logs captured from print() calls ### Error Handling Bifrost provides detailed error messages with hints: -```typescript -// Error: youtube is not defined -// Hints: -// - Variable or identifier 'youtube' is not defined -// - Available server keys: youtubeAPI, filesystem -// - Use one of the available server keys as the object name +```python +# Error: youtube is not defined +# Hints: +# - Variable or identifier 'youtube' is not defined +# - Available server keys: youtubeAPI, filesystem +# - Use one of the available server keys as the object name ``` ### Timeouts @@ -566,7 +613,7 @@ Bifrost provides detailed error messages with hints: | Avg Total Cost | $1.20-1.80 | | Latency | 8-12 seconds | -**Benefit:** Model writes one TypeScript script. All orchestration happens in sandbox. Only compact result returned to LLM. +**Benefit:** Model writes one Python script. All orchestration happens in sandbox. Only compact result returned to LLM. --- diff --git a/docs/mcp/overview.mdx b/docs/mcp/overview.mdx index f221b11ea3..69afc880ec 100644 --- a/docs/mcp/overview.mdx +++ b/docs/mcp/overview.mdx @@ -14,7 +14,7 @@ Bifrost provides a comprehensive MCP integration that goes beyond simple tool ex - **MCP Client**: Connect to any MCP-compatible server (filesystem tools, web search, databases, etc.) - **MCP Server**: Expose your connected tools to external MCP clients (like Claude Desktop) - **Agent Mode**: Autonomous tool execution with configurable auto-approval -- **Code Mode**: Let AI write and execute TypeScript to orchestrate multiple tools +- **Code Mode**: Let AI write and execute Python to orchestrate multiple tools ## Security-First Design @@ -47,7 +47,7 @@ By default, Bifrost does NOT automatically execute tool calls. All tool executio Enable autonomous tool execution with configurable auto-approval - Let AI write TypeScript to orchestrate multiple tools in one request + Let AI write Python to orchestrate multiple tools in one request Expose Bifrost as an MCP server for Claude Desktop and other clients @@ -111,7 +111,7 @@ This pattern ensures: If you're planning to use **3+ MCP servers**, read the [Code Mode](./code-mode) documentation carefully. -Code Mode reduces token usage by **50%+ and execution latency by 40-50%** compared to classic MCP by having the AI write TypeScript code to orchestrate tools in a sandbox, rather than exposing 100+ tool definitions directly to the LLM. +Code Mode reduces token usage by **50%+ and execution latency by 40-50%** compared to classic MCP by having the AI write Python code to orchestrate tools in a sandbox, rather than exposing 100+ tool definitions directly to the LLM. --- diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index ad3015b04b..93b6f052d4 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -53,6 +53,7 @@ type ClientConfig struct { MCPAgentDepth int `json:"mcp_agent_depth"` // The maximum depth for MCP agent mode tool execution MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds MCPCodeModeBindingLevel string `json:"mcp_code_mode_binding_level"` // Code mode binding level: "server" or "tool" + MCPToolSyncInterval int `json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled) HeaderFilterConfig *tables.GlobalHeaderFilterConfig `json:"header_filter_config,omitempty"` // Global header filtering configuration for x-bf-eh-* headers ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } @@ -129,6 +130,12 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write([]byte("mcpCodeModeBindingLevel:server")) } + if c.MCPToolSyncInterval > 0 { + hash.Write([]byte("mcpToolSyncInterval:" + strconv.Itoa(c.MCPToolSyncInterval))) + } else { + hash.Write([]byte("mcpToolSyncInterval:0")) + } + // Hash integer fields data, err := sonic.Marshal(c.InitialPoolSize) if err != nil { diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 9c16d0ff81..d9ce9ba521 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -167,6 +167,12 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddOAuthTables(ctx, db); err != nil { return err } + if err := migrationAddToolSyncIntervalColumns(ctx, db); err != nil { + return err + } + if err := migrationAddMCPClientConfigToOAuthConfig(ctx, db); err != nil { + return err + } return nil } @@ -3060,3 +3066,79 @@ func migrationAddOAuthTables(ctx context.Context, db *gorm.DB) error { } return nil } + +// migrationAddToolSyncIntervalColumns adds the tool_sync_interval columns to config_client and config_mcp_clients tables +func migrationAddToolSyncIntervalColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_tool_sync_interval_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + // Add mcp_tool_sync_interval column to config_client table (global setting) + if !migrator.HasColumn(&tables.TableClientConfig{}, "mcp_tool_sync_interval") { + if err := migrator.AddColumn(&tables.TableClientConfig{}, "mcp_tool_sync_interval"); err != nil { + return err + } + } + // Add tool_sync_interval column to config_mcp_clients table (per-client setting) + if !migrator.HasColumn(&tables.TableMCPClient{}, "tool_sync_interval") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "tool_sync_interval"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + + if err := migrator.DropColumn(&tables.TableClientConfig{}, "mcp_tool_sync_interval"); err != nil { + return err + } + if err := migrator.DropColumn(&tables.TableMCPClient{}, "tool_sync_interval"); err != nil { + return err + } + + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running tool sync interval migration: %s", err.Error()) + } + return nil +} + +// migrationAddMCPClientConfigToOAuthConfig adds the mcp_client_config_json column to oauth_configs table +// This enables multi-instance support by storing pending MCP client config in the database +// instead of in-memory, so OAuth callbacks can be handled by any server instance +func migrationAddMCPClientConfigToOAuthConfig(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_client_config_to_oauth_config", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableOauthConfig{}, "mcp_client_config_json") { + if err := migrator.AddColumn(&tables.TableOauthConfig{}, "mcp_client_config_json"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableOauthConfig{}, "mcp_client_config_json") { + if err := migrator.DropColumn(&tables.TableOauthConfig{}, "mcp_client_config_json"); err != nil { + return err + } + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running mcp client config oauth migration: %s", err.Error()) + } + return nil +} diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index c4f7ea4c2e..413f81abec 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -55,6 +55,7 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC MCPAgentDepth: config.MCPAgentDepth, MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel, + MCPToolSyncInterval: config.MCPToolSyncInterval, HeaderFilterConfig: config.HeaderFilterConfig, ConfigHash: config.ConfigHash, } @@ -215,6 +216,7 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er MCPAgentDepth: dbConfig.MCPAgentDepth, MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel, + MCPToolSyncInterval: dbConfig.MCPToolSyncInterval, HeaderFilterConfig: dbConfig.HeaderFilterConfig, ConfigHash: dbConfig.ConfigHash, }, nil @@ -747,7 +749,7 @@ func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*t } // GetMCPConfig retrieves the MCP configuration from the database. -func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*tables.MCPConfig, error) { +func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { var dbMCPClients []tables.TableMCPClient // Get all MCP clients if err := s.db.WithContext(ctx).Find(&dbMCPClients).Error; err != nil { @@ -761,8 +763,27 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*tables.MCPConfig, e if errors.Is(err, gorm.ErrRecordNotFound) { // Return MCP config with default ToolManagerConfig if no client config exists // This will never happen, but just in case. - return &tables.MCPConfig{ - ClientConfigs: dbMCPClients, + clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) + for i, dbClient := range dbMCPClients { + clientConfigs[i] = &schemas.MCPClientConfig{ + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: dbClient.ConnectionString, + StdioConfig: dbClient.StdioConfig, + AuthType: schemas.MCPAuthType(dbClient.AuthType), + OauthConfigID: dbClient.OauthConfigID, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: dbClient.Headers, + IsPingAvailable: dbClient.IsPingAvailable, + ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + ToolPricing: dbClient.ToolPricing, + } + } + return &schemas.MCPConfig{ + ClientConfigs: clientConfigs, ToolManagerConfig: &schemas.MCPToolManagerConfig{ ToolExecutionTimeout: 30 * time.Second, // default from TableClientConfig MaxAgentDepth: 10, // default from TableClientConfig @@ -776,20 +797,9 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*tables.MCPConfig, e MaxAgentDepth: clientConfig.MCPAgentDepth, CodeModeBindingLevel: schemas.CodeModeBindingLevel(clientConfig.MCPCodeModeBindingLevel), } - return &tables.MCPConfig{ - ClientConfigs: dbMCPClients, - ToolManagerConfig: &toolManagerConfig, - }, nil -} - -// ConvertTableMCPConfigToSchemas converts tables.MCPConfig to schemas.MCPConfig -func ConvertTableMCPConfigToSchemas(tableConfig *tables.MCPConfig) *schemas.MCPConfig { - if tableConfig == nil { - return nil - } - clientConfigs := make([]schemas.MCPClientConfig, len(tableConfig.ClientConfigs)) - for i, dbClient := range tableConfig.ClientConfigs { - clientConfigs[i] = schemas.MCPClientConfig{ + clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) + for i, dbClient := range dbMCPClients { + clientConfigs[i] = &schemas.MCPClientConfig{ ID: dbClient.ClientID, Name: dbClient.Name, IsCodeModeClient: dbClient.IsCodeModeClient, @@ -802,14 +812,17 @@ func ConvertTableMCPConfigToSchemas(tableConfig *tables.MCPConfig) *schemas.MCPC ToolsToAutoExecute: dbClient.ToolsToAutoExecute, Headers: dbClient.Headers, IsPingAvailable: dbClient.IsPingAvailable, + ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + ToolPricing: dbClient.ToolPricing, } } return &schemas.MCPConfig{ ClientConfigs: clientConfigs, - ToolManagerConfig: tableConfig.ToolManagerConfig, - } + ToolManagerConfig: &toolManagerConfig, + }, nil } + // GetMCPClientByID retrieves an MCP client by ID from the database. func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) { var mcpClient tables.TableMCPClient @@ -835,10 +848,10 @@ func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (* } // CreateMCPClientConfig creates a new MCP client configuration in the database. -func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig schemas.MCPClientConfig) error { +func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { return s.db.Transaction(func(tx *gorm.DB) error { // Create a deep copy to avoid modifying the original - clientConfigCopy, err := deepCopy(clientConfig) + clientConfigCopy, err := deepCopy(*clientConfig) if err != nil { return err } @@ -856,6 +869,7 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, Headers: clientConfigCopy.Headers, IsPingAvailable: clientConfigCopy.IsPingAvailable, + ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()), } if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil { return s.parseGormError(err) @@ -865,7 +879,7 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig } // UpdateMCPClientConfig updates an existing MCP client configuration in the database. -func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, clientConfig tables.TableMCPClient) error { +func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error { return s.db.Transaction(func(tx *gorm.DB) error { // Find existing client var existingClient tables.TableMCPClient @@ -928,6 +942,7 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c "headers_json": string(headersJSON), "tool_pricing_json": string(toolPricingJSON), "is_ping_available": clientConfigCopy.IsPingAvailable, + "tool_sync_interval": clientConfigCopy.ToolSyncInterval, "updated_at": time.Now(), } diff --git a/framework/configstore/store.go b/framework/configstore/store.go index b39d18ca5b..52d1a68974 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -38,11 +38,11 @@ type ConfigStore interface { GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) // MCP config CRUD - GetMCPConfig(ctx context.Context) (*tables.MCPConfig, error) + GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) - CreateMCPClientConfig(ctx context.Context, clientConfig schemas.MCPClientConfig) error - UpdateMCPClientConfig(ctx context.Context, id string, clientConfig tables.TableMCPClient) error + CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error + UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error DeleteMCPClientConfig(ctx context.Context, id string) error // Vector store config CRUD diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index fcf1db2d42..66d9d83e65 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -27,6 +27,7 @@ type TableClientConfig struct { MCPAgentDepth int `gorm:"default:10" json:"mcp_agent_depth"` MCPToolExecutionTimeout int `gorm:"default:30" json:"mcp_tool_execution_timeout"` // Timeout for individual tool execution in seconds (default: 30) MCPCodeModeBindingLevel string `gorm:"default:server" json:"mcp_code_mode_binding_level"` // How tools are exposed in VFS: "server" or "tool" + MCPToolSyncInterval int `gorm:"default:10" json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled) // LiteLLM fallback flag EnableLiteLLMFallbacks bool `gorm:"column:enable_litellm_fallbacks;default:false" json:"enable_litellm_fallbacks"` diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index b2e105d58c..e85c21e4cd 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -9,14 +9,6 @@ import ( "gorm.io/gorm" ) -type MCPConfig struct { - ClientConfigs []TableMCPClient `json:"client_configs,omitempty"` // Per-client execution configurations - ToolManagerConfig *schemas.MCPToolManagerConfig `json:"tool_manager_config,omitempty"` // MCP tool manager configuration - FetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string `json:"-"` - PluginPipelineProvider func() interface{} `json:"-"` - ReleasePluginPipeline func(pipeline interface{}) `json:"-"` -} - // TableMCPClient represents an MCP client configuration in the database type TableMCPClient struct { ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. @@ -29,8 +21,9 @@ type TableMCPClient struct { ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - IsPingAvailable bool `gorm:"default:true" json:"is_ping_available"` // Whether the MCP server supports ping for health checks - ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64 + IsPingAvailable bool `gorm:"default:true" json:"is_ping_available"` // Whether the MCP server supports ping for health checks + ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64 + ToolSyncInterval int `gorm:"default:0" json:"tool_sync_interval"` // Per-client tool sync interval in minutes (0 = use global, -1 = disabled) // OAuth authentication fields AuthType string `gorm:"type:varchar(20);default:'headers'" json:"auth_type"` // "none", "headers", "oauth" diff --git a/framework/configstore/tables/oauth.go b/framework/configstore/tables/oauth.go index ada014a206..b3b897bb50 100644 --- a/framework/configstore/tables/oauth.go +++ b/framework/configstore/tables/oauth.go @@ -22,9 +22,10 @@ type TableOauthConfig struct { CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider) Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked" TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback) - ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery - UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery - CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery + UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery + MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion) + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min) } diff --git a/framework/oauth2/main.go b/framework/oauth2/main.go index 2b41c729e9..598dff04da 100644 --- a/framework/oauth2/main.go +++ b/framework/oauth2/main.go @@ -20,19 +20,11 @@ import ( "github.com/maximhq/bifrost/framework/configstore/tables" ) -// PendingMCPClient represents an MCP client waiting for OAuth completion -type PendingMCPClient struct { - MCPClientConfig schemas.MCPClientConfig - OauthConfigID string - CreatedAt time.Time -} - // OAuth2Provider implements the schemas.OAuth2Provider interface // It provides OAuth 2.0 authentication functionality with database persistence type OAuth2Provider struct { - configStore configstore.ConfigStore - mu sync.RWMutex - pendingMCPClients map[string]*PendingMCPClient // Key: mcp_client_id + configStore configstore.ConfigStore + mu sync.RWMutex } // NewOAuth2Provider creates a new OAuth provider instance @@ -41,15 +33,9 @@ func NewOAuth2Provider(configStore configstore.ConfigStore, logger schemas.Logge logger = bifrost.NewDefaultLogger(schemas.LogLevelInfo) } SetLogger(logger) - p := &OAuth2Provider{ - configStore: configStore, - pendingMCPClients: make(map[string]*PendingMCPClient), + return &OAuth2Provider{ + configStore: configStore, } - - // Start background cleanup goroutine for expired pending clients - go p.cleanupExpiredPendingClients() - - return p } // GetAccessToken retrieves the access token for a given oauth_config_id @@ -95,8 +81,12 @@ func (p *OAuth2Provider) GetAccessToken(ctx context.Context, oauthConfigID strin } } - // Return access token directly (no encryption needed for internal use) - return token.AccessToken, nil + // Sanitize and return access token (trim whitespace/newlines that may cause header formatting issues) + accessToken := strings.TrimSpace(token.AccessToken) + if accessToken == "" { + return "", fmt.Errorf("access token is empty after sanitization") + } + return accessToken, nil } // RefreshAccessToken refreshes the access token for a given oauth_config_id @@ -131,11 +121,11 @@ func (p *OAuth2Provider) RefreshAccessToken(ctx context.Context, oauthConfigID s return fmt.Errorf("token refresh failed: %w", err) } - // Update token in database + // Update token in database (sanitize tokens to prevent header formatting issues) now := time.Now() - token.AccessToken = newTokenResponse.AccessToken + token.AccessToken = strings.TrimSpace(newTokenResponse.AccessToken) if newTokenResponse.RefreshToken != "" { - token.RefreshToken = newTokenResponse.RefreshToken + token.RefreshToken = strings.TrimSpace(newTokenResponse.RefreshToken) } token.ExpiresAt = now.Add(time.Duration(newTokenResponse.ExpiresIn) * time.Second) token.LastRefreshedAt = &now @@ -209,53 +199,115 @@ func (p *OAuth2Provider) RevokeToken(ctx context.Context, oauthConfigID string) } // StorePendingMCPClient stores an MCP client config that's waiting for OAuth completion -func (p *OAuth2Provider) StorePendingMCPClient(mcpClientID string, mcpClientConfig schemas.MCPClientConfig) { - p.mu.Lock() - defer p.mu.Unlock() - oauthConfigID := "" - if mcpClientConfig.OauthConfigID != nil { - oauthConfigID = *mcpClientConfig.OauthConfigID +// The config is persisted in the database (oauth_configs.mcp_client_config_json) to support +// multi-instance deployments where OAuth callback may hit a different server instance. +func (p *OAuth2Provider) StorePendingMCPClient(oauthConfigID string, mcpClientConfig schemas.MCPClientConfig) error { + ctx := context.Background() + + oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID) + if err != nil { + return fmt.Errorf("failed to get oauth config: %w", err) } - p.pendingMCPClients[mcpClientID] = &PendingMCPClient{ - MCPClientConfig: mcpClientConfig, - OauthConfigID: oauthConfigID, - CreatedAt: time.Now(), + if oauthConfig == nil { + return fmt.Errorf("oauth config not found: %s", oauthConfigID) } -} -// GetPendingMCPClient retrieves an MCP client config by mcp_client_id -func (p *OAuth2Provider) GetPendingMCPClient(mcpClientID string) *schemas.MCPClientConfig { - p.mu.RLock() - defer p.mu.RUnlock() - if pending, exists := p.pendingMCPClients[mcpClientID]; exists { - return &pending.MCPClientConfig + configJSON, err := json.Marshal(mcpClientConfig) + if err != nil { + return fmt.Errorf("failed to marshal MCP client config: %w", err) } + configStr := string(configJSON) + oauthConfig.MCPClientConfigJSON = &configStr + + if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil { + return fmt.Errorf("failed to update oauth config with MCP client config: %w", err) + } + + logger.Debug("Stored pending MCP client config", "oauth_config_id", oauthConfigID) return nil } -// RemovePendingMCPClient removes a pending MCP client after OAuth completion -func (p *OAuth2Provider) RemovePendingMCPClient(mcpClientID string) { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.pendingMCPClients, mcpClientID) +// GetPendingMCPClient retrieves an MCP client config by oauth_config_id +// Returns nil if no pending config is found or if the oauth config has expired +func (p *OAuth2Provider) GetPendingMCPClient(oauthConfigID string) (*schemas.MCPClientConfig, error) { + ctx := context.Background() + + oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID) + if err != nil { + return nil, fmt.Errorf("failed to get oauth config: %w", err) + } + if oauthConfig == nil { + return nil, nil + } + + // Check if expired + if time.Now().After(oauthConfig.ExpiresAt) { + return nil, nil + } + + if oauthConfig.MCPClientConfigJSON == nil || *oauthConfig.MCPClientConfigJSON == "" { + return nil, nil + } + + var config schemas.MCPClientConfig + if err := json.Unmarshal([]byte(*oauthConfig.MCPClientConfigJSON), &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal MCP client config: %w", err) + } + + return &config, nil } -// cleanupExpiredPendingClients removes pending MCP clients older than 5 minutes -func (p *OAuth2Provider) cleanupExpiredPendingClients() { - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - for range ticker.C { - p.mu.Lock() - now := time.Now() - for mcpClientID, pending := range p.pendingMCPClients { - if now.Sub(pending.CreatedAt) > 5*time.Minute { - delete(p.pendingMCPClients, mcpClientID) - logger.Debug("Cleaned up expired pending MCP client", "mcp_client_id", mcpClientID) - } - } - p.mu.Unlock() +// GetPendingMCPClientByState retrieves an MCP client config by OAuth state token +// This is useful when the callback only has the state parameter +func (p *OAuth2Provider) GetPendingMCPClientByState(state string) (*schemas.MCPClientConfig, string, error) { + ctx := context.Background() + + oauthConfig, err := p.configStore.GetOauthConfigByState(ctx, state) + if err != nil { + return nil, "", fmt.Errorf("failed to get oauth config by state: %w", err) + } + if oauthConfig == nil { + return nil, "", nil + } + + // Check if expired + if time.Now().After(oauthConfig.ExpiresAt) { + return nil, "", nil + } + + if oauthConfig.MCPClientConfigJSON == nil || *oauthConfig.MCPClientConfigJSON == "" { + return nil, oauthConfig.ID, nil + } + + var config schemas.MCPClientConfig + if err := json.Unmarshal([]byte(*oauthConfig.MCPClientConfigJSON), &config); err != nil { + return nil, oauthConfig.ID, fmt.Errorf("failed to unmarshal MCP client config: %w", err) + } + + return &config, oauthConfig.ID, nil +} + +// RemovePendingMCPClient clears the pending MCP client config from the oauth config +// This is called after OAuth completion to clean up +func (p *OAuth2Provider) RemovePendingMCPClient(oauthConfigID string) error { + ctx := context.Background() + + oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID) + if err != nil { + return fmt.Errorf("failed to get oauth config: %w", err) } + if oauthConfig == nil { + return nil // Already removed or doesn't exist + } + + oauthConfig.MCPClientConfigJSON = nil + + if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil { + return fmt.Errorf("failed to clear pending MCP client config: %w", err) + } + + logger.Debug("Removed pending MCP client config", "oauth_config_id", oauthConfigID) + return nil } // InitiateOAuthFlow creates an OAuth config and returns the authorization URL @@ -479,12 +531,12 @@ func (p *OAuth2Provider) CompleteOAuthFlow(ctx context.Context, state, code stri } scopesJSON, _ := json.Marshal(scopes) - // Create oauth_token record + // Create oauth_token record (sanitize tokens to prevent header formatting issues) tokenID := uuid.New().String() tokenRecord := &tables.TableOauthToken{ ID: tokenID, - AccessToken: tokenResponse.AccessToken, - RefreshToken: tokenResponse.RefreshToken, + AccessToken: strings.TrimSpace(tokenResponse.AccessToken), + RefreshToken: strings.TrimSpace(tokenResponse.RefreshToken), TokenType: tokenResponse.TokenType, ExpiresAt: time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second), Scopes: string(scopesJSON), @@ -591,8 +643,30 @@ func (p *OAuth2Provider) callTokenEndpoint(tokenURL string, data url.Values) (*s } var tokenResponse schemas.OAuth2TokenExchangeResponse + + // Try to parse as JSON first if err := json.Unmarshal(body, &tokenResponse); err != nil { - return nil, fmt.Errorf("failed to parse token response: %w", err) + // If JSON parsing fails, try to parse as URL-encoded form data + // (GitHub's OAuth endpoint may return application/x-www-form-urlencoded) + formValues, parseErr := url.ParseQuery(string(body)) + if parseErr != nil { + return nil, fmt.Errorf("failed to parse token response as JSON or form data: JSON error: %w, form error: %v", err, parseErr) + } + + tokenResponse.AccessToken = formValues.Get("access_token") + tokenResponse.RefreshToken = formValues.Get("refresh_token") + tokenResponse.TokenType = formValues.Get("token_type") + tokenResponse.Scope = formValues.Get("scope") + + // Parse expires_in if present + if expiresIn := formValues.Get("expires_in"); expiresIn != "" { + fmt.Sscanf(expiresIn, "%d", &tokenResponse.ExpiresIn) + } + } + + // Validate that we got an access token + if tokenResponse.AccessToken == "" { + return nil, fmt.Errorf("token response missing access_token, body: %s", string(body)) } return &tokenResponse, nil diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 8fa68881ef..1042efdb52 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -9,6 +9,7 @@ import ( "slices" "sort" "strings" + "time" "github.com/fasthttp/router" "github.com/google/uuid" @@ -20,10 +21,10 @@ import ( ) type MCPManager interface { - ReconnectMCPClient(ctx context.Context, id string) error - AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error + AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error RemoveMCPClient(ctx context.Context, id string) error - EditMCPClient(ctx context.Context, id string, updatedConfig configstoreTables.TableMCPClient) error + UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error + ReconnectMCPClient(ctx context.Context, id string) error } // MCPHandler manages HTTP requests for MCP tool operations @@ -48,21 +49,25 @@ func NewMCPHandler(mcpManager MCPManager, client *bifrost.Bifrost, store *lib.Co func (h *MCPHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/api/mcp/clients", lib.ChainMiddlewares(h.getMCPClients, middlewares...)) r.POST("/api/mcp/client", lib.ChainMiddlewares(h.addMCPClient, middlewares...)) - r.PUT("/api/mcp/client/{id}", lib.ChainMiddlewares(h.editMCPClient, middlewares...)) - r.DELETE("/api/mcp/client/{id}", lib.ChainMiddlewares(h.removeMCPClient, middlewares...)) + r.PUT("/api/mcp/client/{id}", lib.ChainMiddlewares(h.updateMCPClient, middlewares...)) + r.DELETE("/api/mcp/client/{id}", lib.ChainMiddlewares(h.deleteMCPClient, middlewares...)) r.POST("/api/mcp/client/{id}/reconnect", lib.ChainMiddlewares(h.reconnectMCPClient, middlewares...)) r.POST("/api/mcp/client/{id}/complete-oauth", lib.ChainMiddlewares(h.completeMCPClientOAuth, middlewares...)) } // MCPClientResponse represents the response structure for MCP clients type MCPClientResponse struct { - Config configstoreTables.TableMCPClient `json:"config"` - Tools []schemas.ChatToolFunction `json:"tools"` - State schemas.MCPConnectionState `json:"state"` + Config *schemas.MCPClientConfig `json:"config"` + Tools []schemas.ChatToolFunction `json:"tools"` + State schemas.MCPConnectionState `json:"state"` } // getMCPClients handles GET /api/mcp/clients - Get all MCP clients func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendJSON(ctx, []MCPClientResponse{}) + return + } // Get clients from store config configsInStore := h.store.MCPConfig if configsInStore == nil { @@ -85,9 +90,8 @@ func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { for _, configClient := range configsInStore.ClientConfigs { // Redact sensitive fields before sending to UI - redactedConfig := h.store.RedactTableMCPClient(configClient) - - if connectedClient, exists := connectedClientsMap[configClient.ClientID]; exists { + redactedConfig := h.store.RedactMCPClientConfig(configClient) + if connectedClient, exists := connectedClientsMap[configClient.ID]; exists { // Sort tools alphabetically by name sortedTools := make([]schemas.ChatToolFunction, len(connectedClient.Tools)) copy(sortedTools, connectedClient.Tools) @@ -114,6 +118,10 @@ func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { // reconnectMCPClient handles POST /api/mcp/client/{id}/reconnect - Reconnect an MCP client func (h *MCPHandler) reconnectMCPClient(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") + return + } id, err := getIDFromCtx(ctx) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid id: %v", err)) @@ -147,6 +155,10 @@ type MCPClientRequest struct { // addMCPClient handles POST /api/mcp/client - Add a new MCP client func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") + return + } var req MCPClientRequest if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) @@ -238,8 +250,12 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { Headers: req.Headers, } - // Store in OAuth provider memory (will auto-cleanup after 5 minutes if not completed) - h.oauthHandler.StorePendingMCPClient(req.ClientID, pendingConfig) + // Store pending config in database (associated with oauth_config_id for multi-instance support) + if err := h.oauthHandler.StorePendingMCPClient(flowInitiation.OauthConfigID, pendingConfig); err != nil { + logger.Error(fmt.Sprintf("[Add MCP Client] Failed to store pending MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to store pending MCP client: %v", err)) + return + } // Return OAuth flow initiation response SendJSON(ctx, map[string]any{ @@ -253,19 +269,36 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { return } - // For "none" or "headers" auth, proceed with immediate connection - schemasConfig := schemas.MCPClientConfig{ + toolSyncInterval := 10 * time.Minute + if req.ToolSyncInterval != 0 { + toolSyncInterval = time.Duration(req.ToolSyncInterval) * time.Minute + } else { + config, err := h.store.ConfigStore.GetClientConfig(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get client config: %v", err)) + return + } + if config != nil { + toolSyncInterval = time.Duration(config.MCPToolSyncInterval) * time.Minute + } + + } + // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) + schemasConfig := &schemas.MCPClientConfig{ ID: req.ClientID, Name: req.Name, IsCodeModeClient: req.IsCodeModeClient, ConnectionType: schemas.MCPConnectionType(req.ConnectionType), ConnectionString: req.ConnectionString, StdioConfig: req.StdioConfig, - AuthType: schemas.MCPAuthType(req.AuthType), - OauthConfigID: nil, ToolsToExecute: req.ToolsToExecute, ToolsToAutoExecute: req.ToolsToAutoExecute, Headers: req.Headers, + AuthType: schemas.MCPAuthType(req.AuthType), + OauthConfigID: req.OauthConfigID, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ToolPricing: req.ToolPricing, } if err := h.mcpManager.AddMCPClient(ctx, schemasConfig); err != nil { @@ -274,34 +307,36 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { } // Creating MCP client config in config store if h.store.ConfigStore != nil { - if err := h.store.ConfigStore.CreateMCPClientConfig(ctx, req); err != nil { + if err := h.store.ConfigStore.CreateMCPClientConfig(ctx, schemasConfig); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Added to core but failed to create MCP config in database: %v, please restart bifrost to keep core and database in sync", err)) return } } + SendJSON(ctx, map[string]any{ "status": "success", "message": "MCP client connected successfully", }) } -// editMCPClient handles PUT /api/mcp/client/{id} - Edit MCP client -func (h *MCPHandler) editMCPClient(ctx *fasthttp.RequestCtx) { +// updateMCPClient handles PUT /api/mcp/client/{id} - Edit MCP client +func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") + return + } id, err := getIDFromCtx(ctx) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid id: %v", err)) return } - // Accept the full table client config to support tool_pricing - var req configstoreTables.TableMCPClient + var req *configstoreTables.TableMCPClient if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) return } - req.ClientID = id - // Validate tools_to_execute if err := validateToolsToExecute(req.ToolsToExecute); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_execute: %v", err)) @@ -322,53 +357,85 @@ func (h *MCPHandler) editMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) return } - // Get existing config to handle redacted values - var existingConfig *configstoreTables.TableMCPClient + var existingConfig *schemas.MCPClientConfig if h.store.MCPConfig != nil { for i, client := range h.store.MCPConfig.ClientConfigs { - if client.ClientID == id { - existingConfig = &h.store.MCPConfig.ClientConfigs[i] + if client.ID == id { + existingConfig = h.store.MCPConfig.ClientConfigs[i] break } } } - if existingConfig == nil { SendError(ctx, fasthttp.StatusNotFound, "MCP client not found") return } // Merge redacted values - preserve old values if incoming values are redacted and unchanged - req = mergeMCPRedactedValues(req, *existingConfig, h.store.RedactTableMCPClient(*existingConfig)) - - // EditMCPClient internally handles database update, env vars, and runtime update - if err := h.mcpManager.EditMCPClient(ctx, id, req); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to edit MCP client: %v", err)) - return - } - // Update MCP client config in config store with merged values + req = mergeMCPRedactedValues(req, existingConfig, h.store.RedactMCPClientConfig(existingConfig)) + // Persist changes to config store if h.store.ConfigStore != nil { - if err := h.store.ConfigStore.UpdateMCPClientConfig(ctx, id, mergedConfig); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Updated in core but failed to save MCP config in database: %v, please restart bifrost to keep core and database in sync", err)) + if err := h.store.ConfigStore.UpdateMCPClientConfig(ctx, id, req); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp client config in store: %v", err)) return } } + toolSyncInterval := 10 * time.Minute + if req.ToolSyncInterval != 0 { + toolSyncInterval = time.Duration(req.ToolSyncInterval) * time.Minute + } else { + config, err := h.store.ConfigStore.GetClientConfig(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get client config: %v", err)) + return + } + if config != nil { + toolSyncInterval = time.Duration(config.MCPToolSyncInterval) * time.Minute + } + + } + // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) + schemasConfig := &schemas.MCPClientConfig{ + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(req.ConnectionType), + ConnectionString: req.ConnectionString, + StdioConfig: req.StdioConfig, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, + Headers: req.Headers, + AuthType: schemas.MCPAuthType(req.AuthType), + OauthConfigID: req.OauthConfigID, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ToolPricing: req.ToolPricing, + } + // Update MCP client in memory + if err := h.mcpManager.UpdateMCPClient(ctx, id, schemasConfig); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp client: %v", err)) + return + } SendJSON(ctx, map[string]any{ "status": "success", "message": "MCP client edited successfully", }) } -// removeMCPClient handles DELETE /api/mcp/client/{id} - Remove an MCP client -func (h *MCPHandler) removeMCPClient(ctx *fasthttp.RequestCtx) { +// deleteMCPClient handles DELETE /api/mcp/client/{id} - Remove an MCP client +func (h *MCPHandler) deleteMCPClient(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") + return + } id, err := getIDFromCtx(ctx) if err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid id: %v", err)) + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid id: %v", err)) return } if err := h.mcpManager.RemoveMCPClient(ctx, id); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to remove MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to remove MCP client: %v", err)) return } // Deleting MCP client config from config store @@ -572,7 +639,7 @@ func validateMCPClientName(name string) error { // mergeMCPRedactedValues merges incoming MCP client config with existing config, // preserving old values when incoming values are redacted and unchanged. // This follows the same pattern as provider config updates. -func mergeMCPRedactedValues(incoming, oldRaw, oldRedacted configstoreTables.TableMCPClient) configstoreTables.TableMCPClient { +func mergeMCPRedactedValues(incoming *configstoreTables.TableMCPClient, oldRaw, oldRedacted *schemas.MCPClientConfig) *configstoreTables.TableMCPClient { merged := incoming // Handle ConnectionString - if incoming is redacted and equals old redacted, keep old raw value @@ -606,30 +673,23 @@ func mergeMCPRedactedValues(incoming, oldRaw, oldRedacted configstoreTables.Tabl } // completeMCPClientOAuth handles POST /api/mcp/client/{id}/complete-oauth - Complete MCP client creation after OAuth authorization +// The {id} parameter is the oauth_config_id returned from the initial addMCPClient call func (h *MCPHandler) completeMCPClientOAuth(ctx *fasthttp.RequestCtx) { - id, err := getIDFromCtx(ctx) - if err != nil { - logger.Error(fmt.Sprintf("[OAuth Complete] Invalid id: %v", err)) - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid id: %v", err)) + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") return } - - logger.Debug(fmt.Sprintf("[OAuth Complete] Completing OAuth for MCP client: %s", id)) - - // Get MCP client config from OAuth provider memory - mcpClientConfig := h.oauthHandler.GetPendingMCPClient(id) - if mcpClientConfig == nil { - SendError(ctx, fasthttp.StatusNotFound, "MCP client not found in pending OAuth clients") + oauthConfigID, err := getIDFromCtx(ctx) + if err != nil { + logger.Error(fmt.Sprintf("[OAuth Complete] Invalid oauth_config_id: %v", err)) + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid oauth_config_id: %v", err)) return } - if mcpClientConfig.OauthConfigID == nil { - SendError(ctx, fasthttp.StatusBadRequest, "No OAuth config linked to this MCP client") - return - } + logger.Debug(fmt.Sprintf("[OAuth Complete] Completing OAuth for oauth_config_id: %s", oauthConfigID)) // Check if OAuth flow is authorized - oauthConfig, err := h.store.ConfigStore.GetOauthConfigByID(ctx, *mcpClientConfig.OauthConfigID) + oauthConfig, err := h.store.ConfigStore.GetOauthConfigByID(ctx, oauthConfigID) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get OAuth config: %v", err)) return @@ -645,17 +705,32 @@ func (h *MCPHandler) completeMCPClientOAuth(ctx *fasthttp.RequestCtx) { return } + // Get MCP client config from database (stored with oauth_config for multi-instance support) + mcpClientConfig, err := h.oauthHandler.GetPendingMCPClient(oauthConfigID) + if err != nil { + logger.Error(fmt.Sprintf("[OAuth Complete] Failed to get pending MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get pending MCP client: %v", err)) + return + } + if mcpClientConfig == nil { + SendError(ctx, fasthttp.StatusNotFound, "MCP client not found in pending OAuth clients. The OAuth flow may have expired or already been completed.") + return + } + // Add MCP client to Bifrost (this will save to database and connect) - if err := h.mcpManager.AddMCPClient(ctx, *mcpClientConfig); err != nil { + if err := h.mcpManager.AddMCPClient(ctx, mcpClientConfig); err != nil { logger.Error(fmt.Sprintf("[OAuth Complete] Failed to connect MCP client: %v", err)) SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to connect MCP client: %v", err)) return } - // Remove from pending OAuth clients memory - h.oauthHandler.RemovePendingMCPClient(id) + // Clear pending MCP client config from oauth_config (cleanup) + if err := h.oauthHandler.RemovePendingMCPClient(oauthConfigID); err != nil { + logger.Warn(fmt.Sprintf("[OAuth Complete] Failed to clear pending MCP client config: %v", err)) + // Don't fail the request - the MCP client was successfully created + } - logger.Debug(fmt.Sprintf("[OAuth Complete] MCP client connected successfully: %s", id)) + logger.Debug(fmt.Sprintf("[OAuth Complete] MCP client connected successfully: %s", mcpClientConfig.ID)) SendJSON(ctx, map[string]any{ "status": "success", "message": "MCP client connected successfully with OAuth", diff --git a/transports/bifrost-http/handlers/oauth2.go b/transports/bifrost-http/handlers/oauth2.go index faf5c86654..7e47330c0a 100644 --- a/transports/bifrost-http/handlers/oauth2.go +++ b/transports/bifrost-http/handlers/oauth2.go @@ -223,17 +223,23 @@ func (h *OAuthHandler) InitiateOAuthFlow(ctx context.Context, req OAuthInitiatio return h.oauthProvider.InitiateOAuthFlow(ctx, config) } -// StorePendingMCPClient stores an MCP client config in OAuth provider memory while waiting for OAuth completion -func (h *OAuthHandler) StorePendingMCPClient(mcpClientID string, mcpClientConfig schemas.MCPClientConfig) { - h.oauthProvider.StorePendingMCPClient(mcpClientID, mcpClientConfig) +// StorePendingMCPClient stores an MCP client config in the database while waiting for OAuth completion +// This supports multi-instance deployments where OAuth callback may hit a different server instance +func (h *OAuthHandler) StorePendingMCPClient(oauthConfigID string, mcpClientConfig schemas.MCPClientConfig) error { + return h.oauthProvider.StorePendingMCPClient(oauthConfigID, mcpClientConfig) } -// GetPendingMCPClient retrieves a pending MCP client config by mcp_client_id -func (h *OAuthHandler) GetPendingMCPClient(mcpClientID string) *schemas.MCPClientConfig { - return h.oauthProvider.GetPendingMCPClient(mcpClientID) +// GetPendingMCPClient retrieves a pending MCP client config by oauth_config_id +func (h *OAuthHandler) GetPendingMCPClient(oauthConfigID string) (*schemas.MCPClientConfig, error) { + return h.oauthProvider.GetPendingMCPClient(oauthConfigID) +} + +// GetPendingMCPClientByState retrieves a pending MCP client config by OAuth state token +func (h *OAuthHandler) GetPendingMCPClientByState(state string) (*schemas.MCPClientConfig, string, error) { + return h.oauthProvider.GetPendingMCPClientByState(state) } // RemovePendingMCPClient removes a pending MCP client after OAuth completion -func (h *OAuthHandler) RemovePendingMCPClient(mcpClientID string) { - h.oauthProvider.RemovePendingMCPClient(mcpClientID) +func (h *OAuthHandler) RemovePendingMCPClient(oauthConfigID string) error { + return h.oauthProvider.RemovePendingMCPClient(oauthConfigID) } diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index ee585a755b..e774c52a5d 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -90,7 +90,7 @@ type ConfigData struct { AuthConfig *configstore.AuthConfig `json:"auth_config,omitempty"` Providers map[string]configstore.ProviderConfig `json:"providers"` FrameworkConfig *framework.FrameworkConfig `json:"framework,omitempty"` - MCP *configstoreTables.MCPConfig `json:"mcp,omitempty"` + MCP *schemas.MCPConfig `json:"mcp,omitempty"` Governance *configstore.GovernanceConfig `json:"governance,omitempty"` VectorStoreConfig *vectorstore.Config `json:"vector_store,omitempty"` ConfigStoreConfig *configstore.Config `json:"config_store,omitempty"` @@ -109,7 +109,7 @@ func (cd *ConfigData) UnmarshalJSON(data []byte) error { EncryptionKey string `json:"encryption_key"` AuthConfig *configstore.AuthConfig `json:"auth_config,omitempty"` Providers map[string]configstore.ProviderConfig `json:"providers"` - MCP *configstoreTables.MCPConfig `json:"mcp,omitempty"` + MCP *schemas.MCPConfig `json:"mcp,omitempty"` Governance *configstore.GovernanceConfig `json:"governance,omitempty"` VectorStoreConfig json.RawMessage `json:"vector_store,omitempty"` ConfigStoreConfig json.RawMessage `json:"config_store,omitempty"` @@ -245,7 +245,7 @@ type Config struct { // In-memory storage ClientConfig configstore.ClientConfig Providers map[schemas.ModelProvider]configstore.ProviderConfig - MCPConfig *configstoreTables.MCPConfig + MCPConfig *schemas.MCPConfig GovernanceConfig *configstore.GovernanceConfig FrameworkConfig *framework.FrameworkConfig ProxyConfig *configstoreTables.GlobalProxyConfig @@ -268,8 +268,8 @@ type Config struct { pluginStatusMu sync.RWMutex pluginStatus map[string]schemas.PluginStatus // name -> status - OAuthProvider *oauth2.OAuth2Provider - TokenRefreshWorker *oauth2.TokenRefreshWorker + OAuthProvider *oauth2.OAuth2Provider + TokenRefreshWorker *oauth2.TokenRefreshWorker // Catalog managers ModelCatalog *modelcatalog.ModelCatalog @@ -806,6 +806,12 @@ func reconcileProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []sc // loadMCPConfigFromFile loads and merges MCP config from file func loadMCPConfigFromFile(ctx context.Context, config *Config, configData *ConfigData) { + if config.ConfigStore == nil { + if configData.MCP != nil && len(configData.MCP.ClientConfigs) > 0 { + logger.Warn("config store is disabled - MCP manager will not be initialized. MCP clients require config store for persistence.") + } + return + } if config.ConfigStore != nil { logger.Debug("getting MCP config from store") tableMCPConfig, err := config.ConfigStore.GetMCPConfig(ctx) @@ -828,11 +834,8 @@ func loadMCPConfigFromFile(ctx context.Context, config *Config, configData *Conf if config.ConfigStore != nil && config.MCPConfig != nil { logger.Debug("updating MCP config in store") for _, clientConfig := range config.MCPConfig.ClientConfigs { - schemasClientConfig := configstore.ConvertTableMCPConfigToSchemas(&configstoreTables.MCPConfig{ - ClientConfigs: []configstoreTables.TableMCPClient{clientConfig}, - }) - if schemasClientConfig != nil && len(schemasClientConfig.ClientConfigs) > 0 { - if err := config.ConfigStore.CreateMCPClientConfig(ctx, schemasClientConfig.ClientConfigs[0]); err != nil { + if clientConfig != nil { + if err := config.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { logger.Warn("failed to create MCP client config: %v", err) } } @@ -842,7 +845,7 @@ func loadMCPConfigFromFile(ctx context.Context, config *Config, configData *Conf } // mergeMCPConfig merges MCP config from file with store -func mergeMCPConfig(ctx context.Context, config *Config, configData *ConfigData, mcpConfig *configstoreTables.MCPConfig) { +func mergeMCPConfig(ctx context.Context, config *Config, configData *ConfigData, mcpConfig *schemas.MCPConfig) { logger.Debug("merging MCP config from config file with store") if configData.MCP == nil { @@ -851,11 +854,11 @@ func mergeMCPConfig(ctx context.Context, config *Config, configData *ConfigData, tempMCPConfig := configData.MCP config.MCPConfig = tempMCPConfig // Merge ClientConfigs arrays by ClientID or Name - clientConfigsToAdd := make([]configstoreTables.TableMCPClient, 0) + clientConfigsToAdd := make([]*schemas.MCPClientConfig, 0) for _, newClientConfig := range tempMCPConfig.ClientConfigs { found := false for _, existingClientConfig := range mcpConfig.ClientConfigs { - if (newClientConfig.ClientID != "" && existingClientConfig.ClientID == newClientConfig.ClientID) || + if (newClientConfig.ID != "" && existingClientConfig.ID == newClientConfig.ID) || (newClientConfig.Name != "" && existingClientConfig.Name == newClientConfig.Name) { found = true break @@ -871,11 +874,8 @@ func mergeMCPConfig(ctx context.Context, config *Config, configData *ConfigData, if config.ConfigStore != nil && len(clientConfigsToAdd) > 0 { logger.Debug("updating MCP config in store with %d new client configs", len(clientConfigsToAdd)) for _, clientConfig := range clientConfigsToAdd { - schemasClientConfig := configstore.ConvertTableMCPConfigToSchemas(&configstoreTables.MCPConfig{ - ClientConfigs: []configstoreTables.TableMCPClient{clientConfig}, - }) - if schemasClientConfig != nil && len(schemasClientConfig.ClientConfigs) > 0 { - if err := config.ConfigStore.CreateMCPClientConfig(ctx, schemasClientConfig.ClientConfigs[0]); err != nil { + if clientConfig != nil { + if err := config.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { logger.Warn("failed to create MCP client config: %v", err) } } @@ -1486,8 +1486,8 @@ func mergePluginsFromFile(ctx context.Context, config *Config, configData *Confi } // convertSchemasMCPClientConfigToTable converts schemas.MCPClientConfig to tables.TableMCPClient -func convertSchemasMCPClientConfigToTable(clientConfig schemas.MCPClientConfig) configstoreTables.TableMCPClient { - return configstoreTables.TableMCPClient{ +func convertSchemasMCPClientConfigToTable(clientConfig *schemas.MCPClientConfig) *configstoreTables.TableMCPClient { + return &configstoreTables.TableMCPClient{ ClientID: clientConfig.ID, Name: clientConfig.Name, IsCodeModeClient: clientConfig.IsCodeModeClient, @@ -1860,11 +1860,8 @@ func loadDefaultMCPConfig(ctx context.Context, config *Config) error { if tableMCPConfig == nil { if config.MCPConfig != nil { for _, clientConfig := range config.MCPConfig.ClientConfigs { - schemasClientConfig := configstore.ConvertTableMCPConfigToSchemas(&configstoreTables.MCPConfig{ - ClientConfigs: []configstoreTables.TableMCPClient{clientConfig}, - }) - if schemasClientConfig != nil && len(schemasClientConfig.ClientConfigs) > 0 { - if err := config.ConfigStore.CreateMCPClientConfig(ctx, schemasClientConfig.ClientConfigs[0]); err != nil { + if clientConfig != nil { + if err := config.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { logger.Warn("failed to create MCP client config: %v", err) continue } @@ -2414,11 +2411,11 @@ func (c *Config) GetPluginStatusByName(name string) (schemas.PluginStatus, bool) return status, ok } -// RegisterPlugin adds or updates a plugin in the registry +// ReloadPlugin adds or updates a plugin in the registry // This is the single entry point for all plugin additions/updates // If a plugin with the same name exists, it will be replaced (atomic find-and-replace) // If no plugin exists with that name, it will be added -func (c *Config) RegisterPlugin(plugin schemas.BasePlugin) error { +func (c *Config) ReloadPlugin(plugin schemas.BasePlugin) error { c.pluginsMu.Lock() defer c.pluginsMu.Unlock() @@ -2775,13 +2772,8 @@ func (c *Config) GetMCPClient(id string) (*schemas.MCPClientConfig, error) { } for _, clientConfig := range c.MCPConfig.ClientConfigs { - if clientConfig.ClientID == id { - schemasConfig := configstore.ConvertTableMCPConfigToSchemas(&configstoreTables.MCPConfig{ - ClientConfigs: []configstoreTables.TableMCPClient{clientConfig}, - }) - if schemasConfig != nil && len(schemasConfig.ClientConfigs) > 0 { - return &schemasConfig.ClientConfigs[0], nil - } + if clientConfig.ID == id { + return clientConfig, nil } } @@ -2795,17 +2787,17 @@ func (c *Config) GetMCPClient(id string) (*schemas.MCPClientConfig, error) { // - Validates that the MCP client doesn't already exist // - Processes environment variables in the MCP client configuration // - Stores the processed configuration in memory -func (c *Config) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error { +func (c *Config) AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { if c.client == nil { return fmt.Errorf("bifrost client not set") } c.muMCP.Lock() defer c.muMCP.Unlock() if c.MCPConfig == nil { - c.MCPConfig = &configstoreTables.MCPConfig{} + c.MCPConfig = &schemas.MCPConfig{} } // Track new environment variables - c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs, convertSchemasMCPClientConfigToTable(clientConfig)) + c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs, clientConfig) // Config with processed env vars if err := c.client.AddMCPClient(clientConfig); err != nil { c.MCPConfig.ClientConfigs = c.MCPConfig.ClientConfigs[:len(c.MCPConfig.ClientConfigs)-1] @@ -2813,10 +2805,17 @@ func (c *Config) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClien } // Updating in config store if c.ConfigStore != nil { - if err := c.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { - return fmt.Errorf("failed to create MCP client config in store: %w", err) + skipDBUpdate := false + if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil { + if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok { + skipDBUpdate = skip + } + } + if !skipDBUpdate { + if err := c.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { + return fmt.Errorf("failed to create MCP client config in store: %w", err) + } } - // Update MCP catalog pricing data for the new client if c.MCPCatalog != nil { // Get the created client config from store to get tool_pricing @@ -2862,10 +2861,10 @@ func (c *Config) RemoveMCPClient(ctx context.Context, id string) error { } } // Find and remove client from in-memory config - var removedClientConfig *configstoreTables.TableMCPClient + var removedClientConfig *schemas.MCPClientConfig for i, clientConfig := range c.MCPConfig.ClientConfigs { - if clientConfig.ClientID == id { - removedClientConfig = &clientConfig + if clientConfig.ID == id { + removedClientConfig = clientConfig c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs[:i], c.MCPConfig.ClientConfigs[i+1:]...) break } @@ -2886,13 +2885,13 @@ func (c *Config) RemoveMCPClient(ctx context.Context, id string) error { return nil } -// EditMCPClient edits an MCP client configuration. +// UpdateMCPClient edits an MCP client configuration. // This allows for dynamic MCP client management at runtime with proper env var handling. // // Parameters: // - id: ID of the client to edit // - updatedConfig: Updated MCP client configuration -func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig configstoreTables.TableMCPClient) error { +func (c *Config) UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error { if c.client == nil { return fmt.Errorf("bifrost client not set") } @@ -2903,11 +2902,11 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig con return fmt.Errorf("no MCP config found") } // Find the existing client config - var oldConfig configstoreTables.TableMCPClient + var oldConfig *schemas.MCPClientConfig var found bool var configIndex int for i, clientConfig := range c.MCPConfig.ClientConfigs { - if clientConfig.ClientID == id { + if clientConfig.ID == id { oldConfig = clientConfig configIndex = i found = true @@ -2917,27 +2916,11 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig con if !found { return fmt.Errorf("MCP client '%s' not found", id) } - - // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) - // Use oldConfig for connection info since those fields are read-only and not sent in update request - schemasConfig := schemas.MCPClientConfig{ - ID: updatedConfig.ClientID, - Name: updatedConfig.Name, - IsCodeModeClient: updatedConfig.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(oldConfig.ConnectionType), - ConnectionString: oldConfig.ConnectionString, - StdioConfig: oldConfig.StdioConfig, - ToolsToExecute: updatedConfig.ToolsToExecute, - ToolsToAutoExecute: updatedConfig.ToolsToAutoExecute, - Headers: updatedConfig.Headers, - IsPingAvailable: updatedConfig.IsPingAvailable, - } - // Check if client is registered in Bifrost (can be not registered if client initialization failed) if clients, err := c.client.GetMCPClients(); err == nil && len(clients) > 0 { for _, client := range clients { if client.Config.ID == id { - if err := c.client.EditMCPClient(id, schemasConfig); err != nil { + if err := c.client.EditMCPClient(id, updatedConfig); err != nil { // Rollback in-memory changes c.MCPConfig.ClientConfigs[configIndex] = oldConfig return fmt.Errorf("failed to edit MCP client: %w", err) @@ -2946,66 +2929,28 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig con } } } - return nil -} - -// RemoveMCPClient removes an MCP client from the configuration. -// This method is called when an MCP client is removed via the HTTP API. -// -// The method: -// - Validates that the MCP client exists -// - Removes the MCP client from the configuration -// - Removes the MCP client from the Bifrost client -func (c *Config) RemoveMCPClient(ctx context.Context, id string) error { - if c.client == nil { - return fmt.Errorf("bifrost client not set") - } - c.muMCP.Lock() - defer c.muMCP.Unlock() - if c.MCPConfig == nil { - return fmt.Errorf("no MCP config found") - } - // Check if client is registered in Bifrost (can be not registered if client initialization failed) - if clients, err := c.client.GetMCPClients(); err == nil && len(clients) > 0 { - for _, client := range clients { - if client.Config.ID == id { - if err := c.client.RemoveMCPClient(id); err != nil { - return fmt.Errorf("failed to remove MCP client: %w", err) - } - break + // Update MCP catalog pricing data for the edited client + if c.MCPCatalog != nil { + // If the client name has changed, delete all old pricing entries under the old name + if updatedConfig.Name != oldConfig.Name { + for toolName := range oldConfig.ToolPricing { + c.MCPCatalog.DeletePricingData(oldConfig.Name, toolName) } - } - } - for i, clientConfig := range c.MCPConfig.ClientConfigs { - if clientConfig.ID == id { - c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs[:i], c.MCPConfig.ClientConfigs[i+1:]...) - break - } - - // Update MCP catalog pricing data for the edited client - if c.MCPCatalog != nil { - // If the client name has changed, delete all old pricing entries under the old name - if updatedConfig.Name != oldConfig.Name { - for toolName := range oldConfig.ToolPricing { - c.MCPCatalog.DeletePricingData(oldConfig.Name, toolName) - } - logger.Debug("deleted old MCP catalog pricing for renamed client: %s -> %s (%d tools)", oldConfig.Name, updatedConfig.Name, len(oldConfig.ToolPricing)) - } else { - // If name hasn't changed, remove pricing entries that were deleted - for toolName := range oldConfig.ToolPricing { - if _, exists := updatedConfig.ToolPricing[toolName]; !exists { - c.MCPCatalog.DeletePricingData(updatedConfig.Name, toolName) - } + logger.Debug("deleted old MCP catalog pricing for renamed client: %s -> %s (%d tools)", oldConfig.Name, updatedConfig.Name, len(oldConfig.ToolPricing)) + } else { + // If name hasn't changed, remove pricing entries that were deleted + for toolName := range oldConfig.ToolPricing { + if _, exists := updatedConfig.ToolPricing[toolName]; !exists { + c.MCPCatalog.DeletePricingData(updatedConfig.Name, toolName) } } - // Then, add or update pricing entries from the new config (with new name if changed) - for toolName, costPerExecution := range updatedConfig.ToolPricing { - c.MCPCatalog.UpdatePricingData(updatedConfig.Name, toolName, costPerExecution) - } - logger.Debug("updated MCP catalog pricing for client: %s (%d tools)", updatedConfig.Name, len(updatedConfig.ToolPricing)) } + // Then, add or update pricing entries from the new config (with new name if changed) + for toolName, costPerExecution := range updatedConfig.ToolPricing { + c.MCPCatalog.UpdatePricingData(updatedConfig.Name, toolName, costPerExecution) + } + logger.Debug("updated MCP catalog pricing for client: %s (%d tools)", updatedConfig.Name, len(updatedConfig.ToolPricing)) } - // Update the in-memory configuration with only the fields that were changed // Preserve connection info (connection_type, connection_string, stdio_config) from oldConfig // as these are read-only and not sent in the update request @@ -3019,11 +2964,12 @@ func (c *Config) RemoveMCPClient(ctx context.Context, id string) error { return nil } -// RedactTableMCPClient creates a redacted copy of a TableMCPClient configuration. +// RedactMCPClientConfig creates a redacted copy of a MCPClientConfig configuration. // Connection strings and headers are redacted for safe external exposure. -func (c *Config) RedactTableMCPClient(config configstoreTables.TableMCPClient) configstoreTables.TableMCPClient { - // Create a shallow copy - configCopy := config +func (c *Config) RedactMCPClientConfig(config *schemas.MCPClientConfig) *schemas.MCPClientConfig { + // Create an actual copy of the struct (not just a pointer copy) + // This prevents modifying the original config when redacting + configCopy := *config // Redact connection string if present if config.ConnectionString != nil { @@ -3038,7 +2984,7 @@ func (c *Config) RedactTableMCPClient(config configstoreTables.TableMCPClient) c } } - return configCopy + return &configCopy } // autoDetectProviders automatically detects common environment variables and sets up providers diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index d26cf3df94..0f93abb401 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -350,7 +350,7 @@ import ( type MockConfigStore struct { clientConfig *configstore.ClientConfig providers map[schemas.ModelProvider]configstore.ProviderConfig - mcpConfig *tables.MCPConfig + mcpConfig *schemas.MCPConfig governanceConfig *configstore.GovernanceConfig authConfig *configstore.AuthConfig frameworkConfig *tables.TableFrameworkConfig @@ -361,7 +361,7 @@ type MockConfigStore struct { // Track update calls for verification clientConfigUpdated bool providersConfigUpdated bool - mcpConfigsCreated []schemas.MCPClientConfig + mcpConfigsCreated []*schemas.MCPClientConfig mcpClientConfigUpdates []struct { ID string Config tables.TableMCPClient @@ -438,7 +438,7 @@ func (m *MockConfigStore) DeleteProvider(ctx context.Context, provider schemas.M } // MCP config -func (m *MockConfigStore) GetMCPConfig(ctx context.Context) (*tables.MCPConfig, error) { +func (m *MockConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { return m.mcpConfig, nil } @@ -446,25 +446,8 @@ func (m *MockConfigStore) GetMCPClientByName(ctx context.Context, name string) ( return nil, nil } -func (m *MockConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig schemas.MCPClientConfig) error { - if m.mcpConfig == nil { - m.mcpConfig = &tables.MCPConfig{ - ClientConfigs: []tables.TableMCPClient{}, - } - } - // Convert schemas.MCPClientConfig to tables.TableMCPClient - tableClient := tables.TableMCPClient{ - ClientID: clientConfig.ID, - Name: clientConfig.Name, - IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: string(clientConfig.ConnectionType), - ConnectionString: clientConfig.ConnectionString, - StdioConfig: clientConfig.StdioConfig, - ToolsToExecute: clientConfig.ToolsToExecute, - ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, - Headers: clientConfig.Headers, - } - m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, tableClient) +func (m *MockConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { + m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, clientConfig) m.mcpConfigsCreated = append(m.mcpConfigsCreated, clientConfig) return nil } @@ -480,20 +463,20 @@ func (m *MockConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, // Initialize m.mcpConfig if nil (same pattern as CreateMCPClientConfig) if m.mcpConfig == nil { - m.mcpConfig = &tables.MCPConfig{ - ClientConfigs: []tables.TableMCPClient{}, + m.mcpConfig = &schemas.MCPConfig{ + ClientConfigs: []*schemas.MCPClientConfig{}, } } // Update the in-memory state to ensure GetMCPConfig returns updated data for i := range m.mcpConfig.ClientConfigs { - if m.mcpConfig.ClientConfigs[i].ClientID == id { + if m.mcpConfig.ClientConfigs[i].ID == id { // Found the entry, update it with the new config - m.mcpConfig.ClientConfigs[i] = tables.TableMCPClient{ - ClientID: clientConfig.ClientID, + m.mcpConfig.ClientConfigs[i] = &schemas.MCPClientConfig{ + ID: clientConfig.ClientID, Name: clientConfig.Name, IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: clientConfig.ConnectionType, + ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), ConnectionString: clientConfig.ConnectionString, StdioConfig: clientConfig.StdioConfig, Headers: clientConfig.Headers, @@ -504,11 +487,11 @@ func (m *MockConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, } } // If not found, create a new entry (similar to CreateMCPClientConfig behavior) - m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, tables.TableMCPClient{ - ClientID: clientConfig.ClientID, + m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, &schemas.MCPClientConfig{ + ID: clientConfig.ClientID, Name: clientConfig.Name, IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: clientConfig.ConnectionType, + ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), ConnectionString: clientConfig.ConnectionString, StdioConfig: clientConfig.StdioConfig, Headers: clientConfig.Headers, @@ -1431,7 +1414,7 @@ func TestLoadConfig_Providers_Merge(t *testing.T) { func TestLoadConfig_MCP_Merge(t *testing.T) { // Setup DB MCP config dbMCPConfig := &schemas.MCPConfig{ - ClientConfigs: []schemas.MCPClientConfig{ + ClientConfigs: []*schemas.MCPClientConfig{ { ID: "mcp-1", Name: "db-client-1", @@ -1447,7 +1430,7 @@ func TestLoadConfig_MCP_Merge(t *testing.T) { // Setup file MCP config with some overlapping and some new fileMCPConfig := &schemas.MCPConfig{ - ClientConfigs: []schemas.MCPClientConfig{ + ClientConfigs: []*schemas.MCPClientConfig{ { ID: "mcp-1", // Same ID - should be skipped Name: "different-name", @@ -1467,7 +1450,7 @@ func TestLoadConfig_MCP_Merge(t *testing.T) { } // Simulate merge logic - clientConfigsToAdd := make([]schemas.MCPClientConfig, 0) + clientConfigsToAdd := make([]*schemas.MCPClientConfig, 0) for _, newClientConfig := range fileMCPConfig.ClientConfigs { found := false for _, existingClientConfig := range dbMCPConfig.ClientConfigs { @@ -9001,7 +8984,7 @@ func TestSQLite_VirtualKey_WithMCPConfigs(t *testing.T) { Name: "test-mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClientConfig) if err != nil { t.Fatalf("Failed to create MCP client: %v", err) } @@ -9090,7 +9073,7 @@ func TestSQLite_VKMCPConfig_Reconciliation(t *testing.T) { Name: "mcp-client-1", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig1) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClientConfig1) if err != nil { t.Fatalf("Failed to create MCP client 1: %v", err) } @@ -9100,7 +9083,7 @@ func TestSQLite_VKMCPConfig_Reconciliation(t *testing.T) { Name: "mcp-client-2", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig2) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClientConfig2) if err != nil { t.Fatalf("Failed to create MCP client 2: %v", err) } @@ -9412,7 +9395,7 @@ func TestSQLite_VirtualKey_DashboardMCPConfig_DeletedOnFileChange(t *testing.T) Name: "mcp-client-1", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClient1Config) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClient1Config) if err != nil { t.Fatalf("Failed to create MCP client 1: %v", err) } @@ -9422,7 +9405,7 @@ func TestSQLite_VirtualKey_DashboardMCPConfig_DeletedOnFileChange(t *testing.T) Name: "mcp-client-2", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClient2Config) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClient2Config) if err != nil { t.Fatalf("Failed to create MCP client 2: %v", err) } @@ -9577,8 +9560,8 @@ func TestSQLite_VKMCPConfig_AddRemove(t *testing.T) { } // Create MCP clients - config1.ConfigStore.CreateMCPClientConfig(ctx, schemas.MCPClientConfig{ID: "mcp-1", Name: "mcp-1", ConnectionType: schemas.MCPConnectionTypeHTTP}) - config1.ConfigStore.CreateMCPClientConfig(ctx, schemas.MCPClientConfig{ID: "mcp-2", Name: "mcp-2", ConnectionType: schemas.MCPConnectionTypeHTTP}) + config1.ConfigStore.CreateMCPClientConfig(ctx, &schemas.MCPClientConfig{ID: "mcp-1", Name: "mcp-1", ConnectionType: schemas.MCPConnectionTypeHTTP}) + config1.ConfigStore.CreateMCPClientConfig(ctx, &schemas.MCPClientConfig{ID: "mcp-2", Name: "mcp-2", ConnectionType: schemas.MCPConnectionTypeHTTP}) mcpClient1, _ := config1.ConfigStore.GetMCPClientByName(ctx, "mcp-1") mcpClient2, _ := config1.ConfigStore.GetMCPClientByName(ctx, "mcp-2") @@ -9699,7 +9682,7 @@ func TestSQLite_VKMCPConfig_UpdateTools(t *testing.T) { } // Create MCP client - config1.ConfigStore.CreateMCPClientConfig(ctx, schemas.MCPClientConfig{ID: "mcp-client", Name: "mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP}) + config1.ConfigStore.CreateMCPClientConfig(ctx, &schemas.MCPClientConfig{ID: "mcp-client", Name: "mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP}) mcpClient, _ := config1.ConfigStore.GetMCPClientByName(ctx, "mcp-client") // Create VK with MCP config @@ -9793,7 +9776,7 @@ func TestSQLite_VK_ProviderAndMCPConfigs_Combined(t *testing.T) { } // Create MCP client - config1.ConfigStore.CreateMCPClientConfig(ctx, schemas.MCPClientConfig{ID: "mcp-client", Name: "mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP}) + config1.ConfigStore.CreateMCPClientConfig(ctx, &schemas.MCPClientConfig{ID: "mcp-client", Name: "mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP}) mcpClient, _ := config1.ConfigStore.GetMCPClientByName(ctx, "mcp-client") config1.ConfigStore.Close(ctx) diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index 0019b8ebfd..15ee838836 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -17,10 +17,9 @@ import ( "github.com/maximhq/bifrost/transports/bifrost-http/lib" ) -// getPluginTypes determines which interface types a plugin implements -func getPluginTypes(plugin schemas.BasePlugin) []schemas.PluginType { +// InferPluginTypes determines which interface types a plugin implements +func InferPluginTypes(plugin schemas.BasePlugin) []schemas.PluginType { var types []schemas.PluginType - if _, ok := plugin.(schemas.LLMPlugin); ok { types = append(types, schemas.PluginTypeLLM) } @@ -30,7 +29,6 @@ func getPluginTypes(plugin schemas.BasePlugin) []schemas.PluginType { if _, ok := plugin.(schemas.HTTPTransportPlugin); ok { types = append(types, schemas.PluginTypeHTTP) } - return types } @@ -122,17 +120,16 @@ func loadCustomPlugin(ctx context.Context, path *string, pluginConfig any, bifro // InstantiatePlugins loads all plugins from configuration // This is called once during Bootstrap -func (s *BifrostHTTPServer) InstantiatePlugins(ctx context.Context) error { + +func (s *BifrostHTTPServer) LoadPlugins(ctx context.Context) error { // Load built-in plugins first (order matters) if err := s.loadBuiltinPlugins(ctx); err != nil { return err } - // Load custom plugins from config if err := s.loadCustomPlugins(ctx); err != nil { return err } - return nil } @@ -173,7 +170,6 @@ func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error { if isBuiltinPlugin(cfg.Name) { continue } - // Handle disabled plugins if !cfg.Enabled { // For custom plugins with a path, verify to get the real plugin name @@ -218,9 +214,9 @@ func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error { } // Register enabled plugin and mark as active - s.Config.RegisterPlugin(plugin) + s.Config.ReloadPlugin(plugin) s.Config.UpdatePluginOverallStatus(plugin.GetName(), cfg.Name, schemas.PluginStatusActive, - []string{fmt.Sprintf("plugin %s initialized successfully", cfg.Name)}, getPluginTypes(plugin)) + []string{fmt.Sprintf("plugin %s initialized successfully", cfg.Name)}, InferPluginTypes(plugin)) } return nil } diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 27a9edb130..d51aed898b 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -47,33 +47,43 @@ var enterprisePlugins = []string{ // ServerCallbacks is a interface that defines the callbacks for the server. type ServerCallbacks interface { + // Plugins callbacks ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error RemovePlugin(ctx context.Context, name string) error GetPluginStatus(ctx context.Context) map[string]schemas.PluginStatus - GetModelsForProvider(provider schemas.ModelProvider) []string + // Auth related callbacks UpdateAuthConfig(ctx context.Context, authConfig *configstore.AuthConfig) error ReloadClientConfigFromConfigStore(ctx context.Context) error + // Pricing related callbacks ReloadPricingManager(ctx context.Context) error ForceReloadPricing(ctx context.Context) error + // Proxy related callbacks ReloadProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error + // Client config related callbacks ReloadHeaderFilterConfig(ctx context.Context, config *tables.GlobalHeaderFilterConfig) error UpdateDropExcessRequests(ctx context.Context, value bool) - UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error + // Governance related callbacks + GetGovernanceData() *governance.GovernanceData ReloadTeam(ctx context.Context, id string) (*tables.TableTeam, error) RemoveTeam(ctx context.Context, id string) error ReloadCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) RemoveCustomer(ctx context.Context, id string) error + // Virtual key related callbacks ReloadVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) RemoveVirtualKey(ctx context.Context, id string) error + // Provider related callbacks + GetModelsForProvider(provider schemas.ModelProvider) []string ReloadModelConfig(ctx context.Context, id string) (*tables.TableModelConfig, error) RemoveModelConfig(ctx context.Context, id string) error ReloadProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) RemoveProvider(ctx context.Context, provider schemas.ModelProvider) error - GetGovernanceData() *governance.GovernanceData - ReconnectMCPClient(ctx context.Context, id string) error - AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error + // MCP related callbacks + AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error RemoveMCPClient(ctx context.Context, id string) error - EditMCPClient(ctx context.Context, id string, updatedConfig tables.TableMCPClient) error + UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error + ReconnectMCPClient(ctx context.Context, id string) error + // Logging related callbacks NewLogEntryAdded(ctx context.Context, logEntry *logstore.Log) error } @@ -139,7 +149,7 @@ func (s *GovernanceInMemoryStore) GetConfiguredProviders() map[schemas.ModelProv } // AddMCPClient adds a new MCP client to the in-memory store -func (s *BifrostHTTPServer) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error { +func (s *BifrostHTTPServer) AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { if err := s.Config.AddMCPClient(ctx, clientConfig); err != nil { return err } @@ -149,9 +159,36 @@ func (s *BifrostHTTPServer) AddMCPClient(ctx context.Context, clientConfig schem return nil } -// EditMCPClient edits an MCP client in the in-memory store -func (s *BifrostHTTPServer) EditMCPClient(ctx context.Context, id string, updatedConfig tables.TableMCPClient) error { - if err := s.Config.EditMCPClient(ctx, id, updatedConfig); err != nil { +// ReconnectMCPClient reconnects an MCP client to the in-memory store +func (s *BifrostHTTPServer) ReconnectMCPClient(ctx context.Context, id string) error { + // Check if client is registered in Bifrost (can be not registered if client initialization failed) + if clients, err := s.Client.GetMCPClients(); err == nil && len(clients) > 0 { + for _, client := range clients { + if client.Config.ID == id { + if err := s.Client.ReconnectMCPClient(id); err != nil { + return err + } + return nil + } + } + } + // Config exists in store, but not in Bifrost (can happen if client initialization failed) + clientConfig, err := s.Config.GetMCPClient(id) + if err != nil { + return err + } + if err := s.Client.AddMCPClient(clientConfig); err != nil { + return err + } + if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { + logger.Warn("failed to sync MCP servers after adding client: %v", err) + } + return nil +} + +// UpdateMCPClient updates an MCP client in the in-memory store +func (s *BifrostHTTPServer) UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error { + if err := s.Config.UpdateMCPClient(ctx, id, updatedConfig); err != nil { return err } if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { @@ -502,7 +539,7 @@ func (s *BifrostHTTPServer) ReloadClientConfigFromConfigStore(ctx context.Contex account := lib.NewBaseAccount(s.Config) var mcpConfig *schemas.MCPConfig if s.Config.MCPConfig != nil { - mcpConfig = configstore.ConvertTableMCPConfigToSchemas(s.Config.MCPConfig) + mcpConfig = s.Config.MCPConfig } s.Client.ReloadConfig(schemas.BifrostConfig{ Account: account, @@ -564,28 +601,9 @@ func (s *BifrostHTTPServer) UpdateMCPToolManagerConfig(ctx context.Context, maxA return s.Client.UpdateToolManagerConfig(maxAgentDepth, toolExecutionTimeoutInSeconds, codeModeBindingLevel) } -// reloadBifrostPlugins syncs Config plugins to Bifrost client -func (s *BifrostHTTPServer) reloadBifrostPlugins() error { - account := lib.NewBaseAccount(s.Config) - var mcpConfig *schemas.MCPConfig - if s.Config.MCPConfig != nil { - mcpConfig = configstore.ConvertTableMCPConfigToSchemas(s.Config.MCPConfig) - } - - return s.Client.ReloadConfig(schemas.BifrostConfig{ - Account: account, - InitialPoolSize: s.Config.ClientConfig.InitialPoolSize, - DropExcessRequests: s.Config.ClientConfig.DropExcessRequests, - LLMPlugins: s.Config.GetLoadedLLMPlugins(), - MCPPlugins: s.Config.GetLoadedMCPPlugins(), - MCPConfig: mcpConfig, - Logger: logger, - }) -} - // reloadObservabilityPlugins reloads all observability plugins in the tracing middleware func (s *BifrostHTTPServer) reloadObservabilityPlugins() { - observabilityPlugins := s.collectObservabilityPlugins() + observabilityPlugins := s.CollectObservabilityPlugins() // Always update the tracing middleware, even with empty slice, to clear stale plugins s.TracingMiddleware.SetObservabilityPlugins(observabilityPlugins) } @@ -651,50 +669,50 @@ func (s *BifrostHTTPServer) GetPluginStatus(ctx context.Context) map[string]sche return s.Config.GetPluginStatus() } -// ReloadPlugin reloads a plugin with new instance and updates Bifrost core. -// The plugin is checked for LLM and MCP interfaces independently and registered -// to the appropriate arrays based on which interfaces it implements. -func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error { - logger.Debug("reloading plugin %s", name) - - // Helper to update error status - updateError := func(step string, err error) error { - if err := s.Config.UpdatePluginStatus(name, schemas.PluginStatusError); err != nil { - return err - } - if err := s.Config.AppendPluginStateLogs(name, []string{fmt.Sprintf("error %s plugin %s: %v", step, name, err)}); err != nil { - return err - } +// Helper to update error status +func (s *BifrostHTTPServer) updatePluginErrorStatus(name, step string, err error) error { + if err := s.Config.UpdatePluginStatus(name, schemas.PluginStatusError); err != nil { return err } - - // 1. Instantiate new version - plugin, err := InstantiatePlugin(ctx, name, path, pluginConfig, s.Config) - if err != nil { - return updateError("loading", err) + if err := s.Config.AppendPluginStateLogs(name, []string{fmt.Sprintf("error %s plugin %s: %v", step, name, err)}); err != nil { + return err } + return err +} +// SyncLoadedPlugin syncs a loaded plugin to the Bifrost client and updates the plugin status +func (s *BifrostHTTPServer) SyncLoadedPlugin(ctx context.Context, name string, plugin schemas.BasePlugin) error { // 2. Register (replaces old version atomically) - if err := s.Config.RegisterPlugin(plugin); err != nil { - return updateError("registering", err) + if err := s.Config.ReloadPlugin(plugin); err != nil { + return s.updatePluginErrorStatus(plugin.GetName(), "registering", err) } - // 3. Update Bifrost client - if err := s.reloadBifrostPlugins(); err != nil { - return updateError("reloading bifrost config for", err) + if err := s.Client.ReloadPlugin(plugin, InferPluginTypes(plugin)); err != nil { + return s.updatePluginErrorStatus(plugin.GetName(), "reloading bifrost config for", err) } - // 4. Special handling for observability plugins if _, ok := plugin.(schemas.ObservabilityPlugin); ok { s.reloadObservabilityPlugins() } - // 5. Update plugin status s.Config.UpdatePluginOverallStatus(plugin.GetName(), name, schemas.PluginStatusActive, - []string{fmt.Sprintf("plugin %s reloaded successfully", name)}, getPluginTypes(plugin)) + []string{fmt.Sprintf("plugin %s reloaded successfully", name)}, InferPluginTypes(plugin)) return nil } +// ReloadPlugin reloads a plugin with new instance and updates Bifrost core. +// The plugin is checked for LLM and MCP interfaces independently and registered +// to the appropriate arrays based on which interfaces it implements. +func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error { + logger.Debug("reloading plugin %s", name) + // 1. Instantiate new version + plugin, err := InstantiatePlugin(ctx, name, path, pluginConfig, s.Config) + if err != nil { + return s.updatePluginErrorStatus(name, "loading", err) + } + return s.SyncLoadedPlugin(ctx, name, plugin) +} + // RemovePlugin removes a plugin from the server. // The plugin is removed from both LLM and MCP arrays independently if it exists in them. func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, displayName string) error { @@ -706,7 +724,9 @@ func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, displayName string // Check if plugin implements ObservabilityPlugin before removal var isObservability bool - if plugin, err := s.Config.FindPluginByName(name); err == nil { + var err error + var plugin schemas.BasePlugin + if plugin, err = s.Config.FindPluginByName(name); err == nil { _, isObservability = plugin.(schemas.ObservabilityPlugin) } @@ -716,7 +736,7 @@ func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, displayName string } // 2. Update Bifrost client - if err := s.reloadBifrostPlugins(); err != nil { + if err := s.Client.RemovePlugin(name, InferPluginTypes(plugin)); err != nil { logger.Warn("failed to reload bifrost config after plugin removal: %v", err) } @@ -958,14 +978,14 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { } } // Load all plugins - if err := s.InstantiatePlugins(ctx); err != nil { + if err := s.LoadPlugins(ctx); err != nil { return fmt.Errorf("failed to instantiate plugins: %v", err) } tableMCPConfig := s.Config.MCPConfig var mcpConfig *schemas.MCPConfig if tableMCPConfig != nil { - mcpConfig = configstore.ConvertTableMCPConfigToSchemas(tableMCPConfig) + mcpConfig = s.Config.MCPConfig if mcpConfig != nil { mcpConfig.FetchNewRequestIDFunc = func(ctx *schemas.BifrostContext) string { return uuid.New().String() @@ -1033,7 +1053,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { // Registering inference middlewares inferenceMiddlewares = append([]schemas.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config)}, inferenceMiddlewares...) // Curating observability plugins - observabilityPlugins := s.collectObservabilityPlugins() + observabilityPlugins := s.CollectObservabilityPlugins() // This enables the central streaming accumulator for both use cases // Initializing tracer with embedded streaming accumulator traceStore := tracing.NewTraceStore(60*time.Minute, logger) diff --git a/transports/bifrost-http/server/utils.go b/transports/bifrost-http/server/utils.go index 4f2c939784..4abd274e85 100644 --- a/transports/bifrost-http/server/utils.go +++ b/transports/bifrost-http/server/utils.go @@ -93,14 +93,14 @@ func (s *BifrostHTTPServer) registerPluginWithStatus(ctx context.Context, name s return nil } - s.Config.RegisterPlugin(plugin) + s.Config.ReloadPlugin(plugin) s.Config.UpdatePluginOverallStatus(name, name, schemas.PluginStatusActive, - []string{fmt.Sprintf("%s plugin initialized successfully", name)}, getPluginTypes(plugin)) + []string{fmt.Sprintf("%s plugin initialized successfully", name)}, InferPluginTypes(plugin)) return nil } -// collectObservabilityPlugins gathers all loaded plugins that implement ObservabilityPlugin interface -func (s *BifrostHTTPServer) collectObservabilityPlugins() []schemas.ObservabilityPlugin { +// CollectObservabilityPlugins gathers all loaded plugins that implement ObservabilityPlugin interface +func (s *BifrostHTTPServer) CollectObservabilityPlugins() []schemas.ObservabilityPlugin { var observabilityPlugins []schemas.ObservabilityPlugin // Check LLM plugins diff --git a/ui/app/_fallbacks/enterprise/components/mcp-tool-groups/mcpToolGroups.tsx b/ui/app/_fallbacks/enterprise/components/mcp-tool-groups/mcpToolGroups.tsx new file mode 100644 index 0000000000..512a31c0ab --- /dev/null +++ b/ui/app/_fallbacks/enterprise/components/mcp-tool-groups/mcpToolGroups.tsx @@ -0,0 +1,16 @@ +import { ToolCase } from "lucide-react"; +import ContactUsView from "../views/contactUsView"; + +export default function MCPToolGroups() { + return ( +
+ } + title="Unlock MCP Tool Groups" + description="This feature is a part of the Bifrost enterprise license. Configure tool groups for MCP servers to organize your MCP tools and govern them across your organization." + readmeLink="https://docs.getbifrost.ai/mcp/overview" + /> +
+ ); +} diff --git a/ui/app/workspace/config/views/mcpView.tsx b/ui/app/workspace/config/views/mcpView.tsx index 206f26be28..04307c2f37 100644 --- a/ui/app/workspace/config/views/mcpView.tsx +++ b/ui/app/workspace/config/views/mcpView.tsx @@ -20,10 +20,12 @@ export default function MCPView() { mcp_agent_depth: string; mcp_tool_execution_timeout: string; mcp_code_mode_binding_level: string; + mcp_tool_sync_interval: string; }>({ mcp_agent_depth: "10", mcp_tool_execution_timeout: "30", mcp_code_mode_binding_level: "server", + mcp_tool_sync_interval: "10", }); useEffect(() => { @@ -33,6 +35,7 @@ export default function MCPView() { mcp_agent_depth: config?.mcp_agent_depth?.toString() || "10", mcp_tool_execution_timeout: config?.mcp_tool_execution_timeout?.toString() || "30", mcp_code_mode_binding_level: config?.mcp_code_mode_binding_level || "server", + mcp_tool_sync_interval: config?.mcp_tool_sync_interval?.toString() || "10", }); } }, [config, bifrostConfig]); @@ -42,7 +45,8 @@ export default function MCPView() { return ( localConfig.mcp_agent_depth !== config.mcp_agent_depth || localConfig.mcp_tool_execution_timeout !== config.mcp_tool_execution_timeout || - localConfig.mcp_code_mode_binding_level !== (config.mcp_code_mode_binding_level || "server") + localConfig.mcp_code_mode_binding_level !== (config.mcp_code_mode_binding_level || "server") || + localConfig.mcp_tool_sync_interval !== (config.mcp_tool_sync_interval ?? 10) ); }, [config, localConfig]); @@ -69,6 +73,14 @@ export default function MCPView() { } }, []); + const handleToolSyncIntervalChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_tool_sync_interval: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue >= 0) { + setLocalConfig((prev) => ({ ...prev, mcp_tool_sync_interval: numValue })); + } + }, []); + const handleSave = useCallback(async () => { try { const agentDepth = Number.parseInt(localValues.mcp_agent_depth); @@ -143,6 +155,26 @@ export default function MCPView() { /> + {/* Tool Sync Interval */} +
+
+ +

+ How often to refresh tool lists from MCP servers. Set to 0 to disable. +

+
+ handleToolSyncIntervalChange(e.target.value)} + min="0" + /> +
+ {/* Code Mode Binding Level */}
diff --git a/ui/app/workspace/logs/views/logDetailsSheet.tsx b/ui/app/workspace/logs/views/logDetailsSheet.tsx index dfce7c5a0c..80b98b488b 100644 --- a/ui/app/workspace/logs/views/logDetailsSheet.tsx +++ b/ui/app/workspace/logs/views/logDetailsSheet.tsx @@ -46,24 +46,30 @@ interface LogDetailSheetProps { // Helper to detect container operations (for hiding irrelevant fields like Model/Tokens) const isContainerOperation = (object: string) => { const containerTypes = [ - 'container_create', 'container_list', 'container_retrieve', 'container_delete', - 'container_file_create', 'container_file_list', 'container_file_retrieve', - 'container_file_content', 'container_file_delete' - ] - return containerTypes.includes(object?.toLowerCase()) -} + "container_create", + "container_list", + "container_retrieve", + "container_delete", + "container_file_create", + "container_file_list", + "container_file_retrieve", + "container_file_content", + "container_file_delete", + ]; + return containerTypes.includes(object?.toLowerCase()); +}; export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDetailSheetProps) { if (!log) return null; - const isContainer = isContainerOperation(log.object) + const isContainer = isContainerOperation(log.object); // Taking out tool call let toolsParameter = null; if (log.params?.tools) { try { toolsParameter = JSON.stringify(log.params.tools, null, 2); - } catch (ignored) { } + } catch (ignored) {} } const copyRequestBody = async () => { @@ -193,7 +199,7 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet return ( - +
@@ -286,8 +292,9 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet label="Type" value={
{RequestTypeLabels[log.object as keyof typeof RequestTypeLabels] ?? log.object ?? "unknown"}
diff --git a/ui/app/workspace/mcp-logs/views/mcpLogDetailsSheet.tsx b/ui/app/workspace/mcp-logs/views/mcpLogDetailsSheet.tsx index 3e2bf1f428..873a773bb2 100644 --- a/ui/app/workspace/mcp-logs/views/mcpLogDetailsSheet.tsx +++ b/ui/app/workspace/mcp-logs/views/mcpLogDetailsSheet.tsx @@ -1,6 +1,6 @@ "use client"; -import type { ReactNode } from "react"; +import { CodeEditor } from "@/app/workspace/logs/views/codeEditor"; import { AlertDialog, AlertDialogAction, @@ -19,10 +19,10 @@ import { DottedSeparator } from "@/components/ui/separator"; import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; import { Status, StatusColors, Statuses } from "@/lib/constants/logs"; import type { MCPToolLogEntry } from "@/lib/types/logs"; -import { FileText, MoreVertical, Timer, Trash2 } from "lucide-react"; +import { MoreVertical, Trash2 } from "lucide-react"; import moment from "moment"; +import type { ReactNode } from "react"; import { toast } from "sonner"; -import { CodeEditor } from "@/app/workspace/logs/views/codeEditor"; interface MCPLogDetailSheetProps { log: MCPToolLogEntry | null; @@ -62,7 +62,7 @@ export function MCPLogDetailSheet({ log, open, onOpenChange, handleDelete }: MCP return ( - +
diff --git a/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx b/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx index 26c83209eb..74e9c802b1 100644 --- a/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx +++ b/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx @@ -14,7 +14,7 @@ import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/comp import { TriStateCheckbox } from "@/components/ui/tristateCheckbox"; import { useToast } from "@/hooks/use-toast"; import { MCP_STATUS_COLORS } from "@/lib/constants/config"; -import { getErrorMessage, useUpdateMCPClientMutation } from "@/lib/store"; +import { getErrorMessage, useGetCoreConfigQuery, useUpdateMCPClientMutation } from "@/lib/store"; import { MCPClient } from "@/lib/types/mcp"; import { mcpClientUpdateSchema, type MCPClientUpdateSchema } from "@/lib/types/schemas"; import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; @@ -32,6 +32,8 @@ interface MCPClientSheetProps { export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: MCPClientSheetProps) { const hasUpdateMCPClientAccess = useRbac(RbacResource.MCPGateway, RbacOperation.Update); const [updateMCPClient, { isLoading: isUpdating }] = useUpdateMCPClientMutation(); + const { data: bifrostConfig } = useGetCoreConfigQuery({ fromDB: true }); + const globalToolSyncInterval = bifrostConfig?.client_config?.mcp_tool_sync_interval ?? 10; const { toast } = useToast(); const [expandedTools, setExpandedTools] = useState>(new Set()); @@ -58,6 +60,7 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: tools_to_execute: mcpClient.config.tools_to_execute || [], tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], tool_pricing: mcpClient.config.tool_pricing || {}, + tool_sync_interval: mcpClient.config.tool_sync_interval, }, }); @@ -71,6 +74,7 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: tools_to_execute: mcpClient.config.tools_to_execute || [], tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], tool_pricing: mcpClient.config.tool_pricing || {}, + tool_sync_interval: mcpClient.config.tool_sync_interval, }); }, [form, mcpClient]); @@ -86,6 +90,7 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: tools_to_execute: data.tools_to_execute, tools_to_auto_execute: data.tools_to_auto_execute, tool_pricing: data.tool_pricing, + tool_sync_interval: data.tool_sync_interval, }, }).unwrap(); @@ -235,7 +240,23 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: name="name" render={({ field }) => ( - Name +
+ Name + + + + + + +

+ Use a descriptive, meaningful name that clearly identifies the server. For example, use "google_drive" + instead of "gdrive", or "hacker_news" instead of "hn". This name is used as the Python module name in code + mode. +

+
+
+
+
@@ -284,6 +305,51 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: )} /> + { + const isUsingGlobal = field.value === undefined || field.value === null || field.value === 0; + return ( + +
+
+
+ Tool Sync Interval (minutes) +
+ + + + + + +

+ Override the global tool sync interval for this server. Leave empty to use global setting. Set to -1 to + disable sync for this server. +

+
+
+
+
+
{isUsingGlobal &&

Using global setting

}
+
+ + { + const val = e.target.value === "" ? undefined : parseInt(e.target.value); + field.onChange(val); + }} + min="-1" + /> + +
+ ); + }} + /> -
Registered MCP Servers
+

MCP server catalog

@@ -213,19 +213,14 @@ export default function MCPClientsTable({ mcpClients, refetch }: MCPClientsTable )} - - {c.state} - + {c.state} e.stopPropagation()}> + +
+
+ ))} +
+ )} + + {/* Empty state */} + {selectedServersWithInfo.length === 0 && ( +
+ No servers selected. Use the search above to add MCP servers. +
+ )} +
+ ); +} diff --git a/ui/components/ui/mcpToolSelector.tsx b/ui/components/ui/mcpToolSelector.tsx new file mode 100644 index 0000000000..e54b85517a --- /dev/null +++ b/ui/components/ui/mcpToolSelector.tsx @@ -0,0 +1,323 @@ +"use client" + +import { CodeEditor } from "@/app/workspace/logs/views/codeEditor"; +import { ChevronDown, ChevronRight, X } from "lucide-react" +import { useCallback, useMemo, useState } from "react" +import { components, OptionProps } from "react-select"; +import { AsyncMultiSelect } from "./asyncMultiselect" +import { Badge } from "./badge" +import { Button } from "./button" +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "./collapsible" +import { Option } from "./multiselectUtils"; +import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "./table"; +import { cn } from "./utils" + +// Types +export interface SelectedTool { + mcpClientId: string + toolName: string +} + +export interface ToolFunction { + name: string + description?: string + // Parameters can be any object type to support various schema formats + parameters?: Record | object + strict?: boolean +} + +export interface MCPClientInfo { + config: { + client_id: string + name: string + connection_type?: string + } + tools: ToolFunction[] + state?: string +} + +interface ToolOptionMeta { + mcpClientId: string + mcpClientName: string + toolName: string + description?: string + parameters?: Record | object +} + +interface MCPToolSelectorProps { + value: SelectedTool[] + onChange: (tools: SelectedTool[]) => void + mcpClients: MCPClientInfo[] + placeholder?: string + disabled?: boolean + className?: string +} + +export function MCPToolSelector({ + value, + onChange, + mcpClients, + placeholder = "Search and select tools...", + disabled = false, + className, +}: MCPToolSelectorProps) { + const [expandedTools, setExpandedTools] = useState>(new Set()) + + // Flatten all tools from all MCP clients into searchable options + // Using meta field for complex data as per Option type definition + const allToolOptions = useMemo(() => { + const options: Option[] = [] + + for (const client of mcpClients) { + if (!client.tools) continue + + for (const tool of client.tools) { + const key = `${client.config.client_id}:${tool.name}` + + options.push({ + label: `${client.config.name} / ${tool.name}`, + value: key, + meta: { + mcpClientId: client.config.client_id, + mcpClientName: client.config.name, + toolName: tool.name, + description: tool.description, + parameters: tool.parameters, + }, + }) + } + } + + return options + }, [mcpClients]) + + // Get full tool info for selected tools + const selectedToolsWithInfo = useMemo(() => { + return value.map((selected) => { + const client = mcpClients.find((c) => c.config.client_id === selected.mcpClientId) + const tool = client?.tools?.find((t) => t.name === selected.toolName) + return { + ...selected, + mcpClientName: client?.config.name || selected.mcpClientId, + description: tool?.description, + parameters: tool?.parameters, + } + }) + }, [value, mcpClients]) + + // Filter out already selected tools from options + const availableOptions = useMemo(() => { + const selectedKeys = new Set( + value.map((t) => `${t.mcpClientId}:${t.toolName}`) + ) + return allToolOptions.filter((opt) => !selectedKeys.has(opt.value)) + }, [allToolOptions, value]) + + const toggleExpanded = useCallback((key: string) => { + setExpandedTools((prev) => { + const next = new Set(prev) + if (next.has(key)) { + next.delete(key) + } else { + next.add(key) + } + return next + }) + }, []) + + const handleSelectTool = useCallback( + (selected: Option[]) => { + if (selected.length === 0) return + + const newTool = selected[selected.length - 1] + if (!newTool?.meta) return + + const newSelectedTool: SelectedTool = { + mcpClientId: newTool.meta.mcpClientId, + toolName: newTool.meta.toolName, + } + + // Check if already selected + const exists = value.some( + (t) => t.mcpClientId === newSelectedTool.mcpClientId && t.toolName === newSelectedTool.toolName + ) + + if (!exists) { + onChange([...value, newSelectedTool]) + } + }, + [value, onChange] + ) + + const handleRemoveTool = useCallback( + (mcpClientId: string, toolName: string) => { + onChange(value.filter((t) => !(t.mcpClientId === mcpClientId && t.toolName === toolName))) + }, + [value, onChange] + ) + + const reload = useCallback( + (query: string, callback: (options: Option[]) => void) => { + const lowerQuery = query.toLowerCase() + const filtered = availableOptions.filter( + (opt) => + opt.label.toLowerCase().includes(lowerQuery) || + opt.meta?.description?.toLowerCase().includes(lowerQuery) + ) + callback(filtered) + }, + [availableOptions] + ) + + return ( +
+ {/* Search dropdown */} + + placeholder={placeholder} + disabled={disabled} + defaultOptions={availableOptions} + reload={reload} + debounce={150} + onChange={handleSelectTool} + value={[]} + isClearable={false} + closeMenuOnSelect={true} + hideSelectedOptions={true} + controlShouldRenderValue={false} + noOptionsMessage={() => "No results found"} + views={{ + option: (optionProps: OptionProps) => { + const { Option } = components; + // Access data as Option since that's the actual runtime type + const data = optionProps.data as unknown as Option; + return ( + + ); + }, + }} + /> + + {/* Selected tools table */} + {selectedToolsWithInfo.length > 0 && ( +
+ + + + + Tool + Server + + + + + {selectedToolsWithInfo.map((tool) => { + const key = `${tool.mcpClientId}:${tool.toolName}`; + const isExpanded = expandedTools.has(key); + + return ( + toggleExpanded(key)} asChild> + <> + + + + + + + +
+
{tool.toolName}
+ {tool.description &&

{tool.description}

} +
+
+ + {tool.mcpClientName} + + + + +
+ +
+ + + + + + ); + })} + +
+
+
Parameters Schema
+ {tool.parameters ? ( + + ) : ( +
No parameters defined
+ )} +
+
+
+ )} + + {/* Empty state */} + {selectedToolsWithInfo.length === 0 && ( +
+ No tools selected. Use the search above to add tools. +
+ )} +
+ ); +} diff --git a/ui/lib/store/apis/baseApi.ts b/ui/lib/store/apis/baseApi.ts index d7cbf0748e..7491481c17 100644 --- a/ui/lib/store/apis/baseApi.ts +++ b/ui/lib/store/apis/baseApi.ts @@ -171,6 +171,7 @@ export const baseApi = createApi({ "Permissions", "APIKeys", "OAuth2Config", + "MCPToolGroups", ], endpoints: () => ({}), }); diff --git a/ui/lib/store/apis/mcpApi.ts b/ui/lib/store/apis/mcpApi.ts index e342f2cd64..291e2dcece 100644 --- a/ui/lib/store/apis/mcpApi.ts +++ b/ui/lib/store/apis/mcpApi.ts @@ -57,8 +57,8 @@ export const mcpApi = baseApi.injectEndpoints({ // Complete OAuth flow for MCP client completeOAuthFlow: builder.mutation<{ status: string; message: string }, string>({ - query: (mcpClientId) => ({ - url: `/mcp/client/${mcpClientId}/complete-oauth`, + query: (oauthConfigId) => ({ + url: `/mcp/client/${oauthConfigId}/complete-oauth`, method: "POST", }), invalidatesTags: ["MCPClients"], diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index 18fd5de88f..d378306d35 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -377,6 +377,7 @@ export interface CoreConfig { mcp_agent_depth: number; mcp_tool_execution_timeout: number; mcp_code_mode_binding_level?: string; + mcp_tool_sync_interval: number; header_filter_config?: GlobalHeaderFilterConfig; } @@ -397,6 +398,7 @@ export const DefaultCoreConfig: CoreConfig = { mcp_agent_depth: 10, mcp_tool_execution_timeout: 30, mcp_code_mode_binding_level: "server", + mcp_tool_sync_interval: 10, allowed_headers: [], }; diff --git a/ui/lib/types/mcp.ts b/ui/lib/types/mcp.ts index a241754810..43b3c8a45e 100644 --- a/ui/lib/types/mcp.ts +++ b/ui/lib/types/mcp.ts @@ -39,6 +39,7 @@ export interface MCPClientConfig { headers?: Record; is_ping_available?: boolean; tool_pricing?: Record; + tool_sync_interval?: number; // Per-client override in minutes (0 = use global, -1 = disabled) } export interface MCPClient { @@ -88,4 +89,17 @@ export interface UpdateMCPClientRequest { tools_to_auto_execute?: string[]; is_ping_available?: boolean; tool_pricing?: Record; + tool_sync_interval?: number; // Per-client override in minutes (0 = use global, -1 = disabled) +} + +// Types for MCP Tool Selector component +export interface SelectedTool { + mcpClientId: string; + toolName: string; +} + +// MCP Tool Spec for tool groups (matches backend schema) +export interface MCPToolSpec { + mcp_client_id: string; + tool_names: string[]; } diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 2c41955a61..08f45068b0 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -660,6 +660,7 @@ export const mcpClientUpdateSchema = z.object({ { message: "Duplicate tool names are not allowed" }, ), tool_pricing: z.record(z.string(), z.number().min(0, "Cost must be non-negative")).optional(), + tool_sync_interval: z.number().optional(), // -1 = disabled, 0 = use global, >0 = custom interval in minutes }); // Global proxy type schema