diff --git a/core/bifrost.go b/core/bifrost.go index 568061d696..d6b264f242 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -73,7 +73,7 @@ type Bifrost struct { oauth2Provider schemas.OAuth2Provider // OAuth provider instance logger schemas.Logger // logger instance, default logger is used if not provided tracer atomic.Value // tracer for distributed tracing (stores schemas.Tracer, NoOpTracer if not configured) - mcpManager *mcp.MCPManager // MCP integration manager (nil if MCP not configured) + McpManager mcp.MCPManagerInterface // MCP integration manager (nil if MCP not configured) mcpInitOnce sync.Once // Ensures MCP manager is initialized only once dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. keySelector schemas.KeySelector // Custom key selector function @@ -186,6 +186,12 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { logger: config.Logger, } bifrost.tracer.Store(&tracerWrapper{tracer: tracer}) + if config.LLMPlugins == nil { + config.LLMPlugins = make([]schemas.LLMPlugin, 0) + } + if config.MCPPlugins == nil { + config.MCPPlugins = make([]schemas.MCPPlugin, 0) + } bifrost.llmPlugins.Store(&config.LLMPlugins) bifrost.mcpPlugins.Store(&config.MCPPlugins) @@ -280,7 +286,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { } } codeMode := starlark.NewStarlarkCodeMode(codeModeConfig) - bifrost.mcpManager = mcp.NewMCPManager(bifrostCtx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) + bifrost.McpManager = mcp.NewMCPManager(bifrostCtx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) bifrost.logger.Info("MCP integration initialized successfully") }) } @@ -338,21 +344,6 @@ func (bifrost *Bifrost) getTracer() schemas.Tracer { // We will keep on adding other aspects as required func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error { bifrost.dropExcessRequests.Store(config.DropExcessRequests) - - // Update LLM plugins atomically - if config.LLMPlugins != nil { - llmPluginsCopy := make([]schemas.LLMPlugin, len(config.LLMPlugins)) - copy(llmPluginsCopy, config.LLMPlugins) - bifrost.llmPlugins.Store(&llmPluginsCopy) - } - - // Update MCP plugins atomically - if config.MCPPlugins != nil { - mcpPluginsCopy := make([]schemas.MCPPlugin, len(config.MCPPlugins)) - copy(mcpPluginsCopy, config.MCPPlugins) - bifrost.mcpPlugins.Store(&mcpPluginsCopy) - } - return nil } @@ -722,8 +713,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx *schemas.BifrostContext, req * } // Check if we should enter agent mode - if bifrost.mcpManager != nil { - return bifrost.mcpManager.CheckAndExecuteAgentForChatRequest( + if bifrost.McpManager != nil { + return bifrost.McpManager.CheckAndExecuteAgentForChatRequest( ctx, req, response, @@ -819,8 +810,8 @@ func (bifrost *Bifrost) ResponsesRequest(ctx *schemas.BifrostContext, req *schem } // Check if we should enter agent mode - if bifrost.mcpManager != nil { - return bifrost.mcpManager.CheckAndExecuteAgentForResponsesRequest( + if bifrost.McpManager != nil { + return bifrost.McpManager.CheckAndExecuteAgentForResponsesRequest( ctx, req, response, @@ -2116,14 +2107,22 @@ func (bifrost *Bifrost) ContainerFileDeleteRequest(ctx *schemas.BifrostContext, } // RemovePlugin removes a plugin from the server. -func (bifrost *Bifrost) RemovePlugin(name string, pluginType schemas.PluginType) error { - switch pluginType { - case schemas.PluginTypeLLM: - return bifrost.removeLLMPlugin(name) - case schemas.PluginTypeMCP: - return bifrost.removeMCPPlugin(name) - } - return fmt.Errorf("unsupported plugin type: %v", pluginType) +func (bifrost *Bifrost) RemovePlugin(name string, pluginTypes []schemas.PluginType) error { + for _, pluginType := range pluginTypes { + switch pluginType { + case schemas.PluginTypeLLM: + err := bifrost.removeLLMPlugin(name) + if err != nil { + return err + } + case schemas.PluginTypeMCP: + err := bifrost.removeMCPPlugin(name) + if err != nil { + return err + } + } + } + return nil } // removeLLMPlugin removes an LLM plugin from the server. @@ -2200,22 +2199,30 @@ func (bifrost *Bifrost) removeMCPPlugin(name string) error { // ReloadPlugin reloads a plugin with new instance // During the reload - it's stop the world phase where we take a global lock on the plugin mutex -func (bifrost *Bifrost) ReloadPlugin(plugin schemas.BasePlugin, pluginType schemas.PluginType) error { - switch pluginType { - case schemas.PluginTypeLLM: - llmPlugin, ok := plugin.(schemas.LLMPlugin) - if !ok { - return fmt.Errorf("plugin %s is not an LLMPlugin", plugin.GetName()) - } - return bifrost.reloadLLMPlugin(llmPlugin) - case schemas.PluginTypeMCP: - mcpPlugin, ok := plugin.(schemas.MCPPlugin) - if !ok { - return fmt.Errorf("plugin %s is not an MCPPlugin", plugin.GetName()) - } - return bifrost.reloadMCPPlugin(mcpPlugin) - } - return fmt.Errorf("unsupported plugin type: %v", pluginType) +func (bifrost *Bifrost) ReloadPlugin(plugin schemas.BasePlugin, pluginTypes []schemas.PluginType) error { + for _, pluginType := range pluginTypes { + switch pluginType { + case schemas.PluginTypeLLM: + llmPlugin, ok := plugin.(schemas.LLMPlugin) + if !ok { + return fmt.Errorf("plugin %s is not an LLMPlugin", plugin.GetName()) + } + err := bifrost.reloadLLMPlugin(llmPlugin) + if err != nil { + return err + } + case schemas.PluginTypeMCP: + mcpPlugin, ok := plugin.(schemas.MCPPlugin) + if !ok { + return fmt.Errorf("plugin %s is not an MCPPlugin", plugin.GetName()) + } + err := bifrost.reloadMCPPlugin(mcpPlugin) + if err != nil { + return err + } + } + } + return nil } // reloadLLMPlugin reloads an LLM plugin with new instance @@ -2654,11 +2661,11 @@ func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *syn // return args.Message, nil // }, toolSchema) func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.ChatTool) error { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.RegisterTool(name, description, handler, toolSchema) + return bifrost.McpManager.RegisterTool(name, description, handler, toolSchema) } // IMPORTANT: Running the MCP client management operations (GetMCPClients, AddMCPClient, RemoveMCPClient, EditMCPClientTools) @@ -2672,11 +2679,11 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a // - []schemas.MCPClient: List of all MCP clients // - error: Any retrieval error func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return nil, fmt.Errorf("MCP is not configured in this Bifrost instance") } - clients := bifrost.mcpManager.GetClients() + clients := bifrost.McpManager.GetClients() clientsInConfig := make([]schemas.MCPClient, 0, len(clients)) for _, client := range clients { @@ -2714,10 +2721,10 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { // Returns: // - []schemas.ChatTool: List of available tools func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return nil } - return bifrost.mcpManager.GetAvailableTools(ctx) + return bifrost.McpManager.GetAvailableTools(ctx) } // AddMCPClient adds a new MCP client to the Bifrost instance. @@ -2736,13 +2743,13 @@ func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.Chat // ConnectionType: schemas.MCPConnectionTypeHTTP, // ConnectionString: &url, // }) -func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { - if bifrost.mcpManager == nil { +func (bifrost *Bifrost) AddMCPClient(config *schemas.MCPClientConfig) error { + if bifrost.McpManager == nil { // Use sync.Once to ensure thread-safe initialization bifrost.mcpInitOnce.Do(func() { // Initialize with empty config - client will be added via AddClient below mcpConfig := schemas.MCPConfig{ - ClientConfigs: []schemas.MCPClientConfig{}, + ClientConfigs: []*schemas.MCPClientConfig{}, } // Set up plugin pipeline provider functions for executeCode tool hooks mcpConfig.PluginPipelineProvider = func() interface{} { @@ -2756,16 +2763,16 @@ func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { // Create Starlark CodeMode for code execution (with default config) starlark.SetLogger(bifrost.logger) codeMode := starlark.NewStarlarkCodeMode(nil) - bifrost.mcpManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) + bifrost.McpManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) }) } // Handle case where initialization succeeded elsewhere but manager is still nil - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP manager is not initialized") } - return bifrost.mcpManager.AddClient(config) + return bifrost.McpManager.AddClient(config) } // RemoveMCPClient removes an MCP client from the Bifrost instance. @@ -2784,23 +2791,23 @@ func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { // log.Fatalf("Failed to remove MCP client: %v", err) // } func (bifrost *Bifrost) RemoveMCPClient(id string) error { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.RemoveClient(id) + return bifrost.McpManager.RemoveClient(id) } // SetMCPManager sets the MCP manager for this Bifrost instance. -// This is primarily used for testing purposes to inject a custom MCP manager. +// This allows injecting a custom MCP manager implementation (e.g., for enterprise features). // // Parameters: -// - manager: The MCP manager to set -func (bifrost *Bifrost) SetMCPManager(manager *mcp.MCPManager) { - bifrost.mcpManager = manager +// - manager: The MCP manager to set (must implement MCPManagerInterface) +func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) { + bifrost.McpManager = manager } -// EditMCPClient edits the tools of an MCP client. +// UpdateMCPClient updates the MCP client. // This allows for dynamic MCP client tool management at runtime. // // Parameters: @@ -2812,16 +2819,16 @@ func (bifrost *Bifrost) SetMCPManager(manager *mcp.MCPManager) { // // Example: // -// err := bifrost.EditMCPClient("my-mcp-client-id", schemas.MCPClientConfig{ +// err := bifrost.UpdateMCPClient("my-mcp-client-id", schemas.MCPClientConfig{ // Name: "my-mcp-client-name", // ToolsToExecute: []string{"tool1", "tool2"}, // }) -func (bifrost *Bifrost) EditMCPClient(id string, updatedConfig schemas.MCPClientConfig) error { - if bifrost.mcpManager == nil { +func (bifrost *Bifrost) EditMCPClient(id string, updatedConfig *schemas.MCPClientConfig) error { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.EditClient(id, updatedConfig) + return bifrost.McpManager.EditClient(id, updatedConfig) } // ReconnectMCPClient attempts to reconnect an MCP client if it is disconnected. @@ -2832,21 +2839,21 @@ func (bifrost *Bifrost) EditMCPClient(id string, updatedConfig schemas.MCPClient // Returns: // - error: Any reconnection error func (bifrost *Bifrost) ReconnectMCPClient(id string) error { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.ReconnectClient(id) + return bifrost.McpManager.ReconnectClient(id) } // UpdateToolManagerConfig updates the tool manager config for the MCP manager. // This allows for hot-reloading of the tool manager config at runtime. func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } - bifrost.mcpManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ + bifrost.McpManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ MaxAgentDepth: maxAgentDepth, ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), @@ -3438,8 +3445,8 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif } // Add MCP tools to request if MCP is configured and requested - if bifrost.mcpManager != nil { - req = bifrost.mcpManager.AddToolsToRequest(ctx, req) + if bifrost.McpManager != nil { + req = bifrost.McpManager.AddToolsToRequest(ctx, req) } tracer := bifrost.getTracer() @@ -3628,8 +3635,8 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } // Add MCP tools to request if MCP is configured and requested - if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.mcpManager != nil { - req = bifrost.mcpManager.AddToolsToRequest(ctx, req) + if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.McpManager != nil { + req = bifrost.McpManager.AddToolsToRequest(ctx, req) } tracer := bifrost.getTracer() @@ -4404,7 +4411,7 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r // - *schemas.BifrostMCPResponse: The MCP response after all hooks // - *schemas.BifrostError: Any execution error func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpRequest *schemas.BifrostMCPRequest, requestType schemas.RequestType) (*schemas.BifrostMCPResponse, *schemas.BifrostError) { - if bifrost.mcpManager == nil { + if bifrost.McpManager == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ @@ -4463,7 +4470,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR } // Execute tool with modified request - result, err := bifrost.mcpManager.ExecuteToolCall(ctx, preReq) + result, err := bifrost.McpManager.ExecuteToolCall(ctx, preReq) // Prepare MCP response and error for post-hooks var mcpResp *schemas.BifrostMCPResponse @@ -5271,8 +5278,8 @@ func (bifrost *Bifrost) Shutdown() { }) // Cleanup MCP manager - if bifrost.mcpManager != nil { - err := bifrost.mcpManager.Cleanup() + if bifrost.McpManager != nil { + err := bifrost.McpManager.Cleanup() if err != nil { bifrost.logger.Warn("Error cleaning up MCP manager: %s", err.Error()) } diff --git a/core/internal/mcptests/agent_filtering_test.go b/core/internal/mcptests/agent_filtering_test.go index 01c1d0e8e2..1058f5a7f2 100644 --- a/core/internal/mcptests/agent_filtering_test.go +++ b/core/internal/mcptests/agent_filtering_test.go @@ -531,7 +531,7 @@ func TestAgent_FilteringWithMultipleClients(t *testing.T) { tempConfig := GetTemperatureMCPClientConfig("") tempConfig.ToolsToAutoExecute = []string{} // Not auto-executed - err = manager.AddClient(tempConfig) + err = manager.AddClient(&tempConfig) if err != nil { t.Skipf("Skipping test - temperature server not available: %v", err) return @@ -620,7 +620,7 @@ func TestAgent_ToolConflictInAgentMode(t *testing.T) { tempConfig := GetTemperatureMCPClientConfig("") tempConfig.ToolsToAutoExecute = []string{} // Not auto - err = manager.AddClient(tempConfig) + err = manager.AddClient(&tempConfig) if err != nil { t.Skipf("Skipping test - temperature server not available: %v", err) return diff --git a/core/internal/mcptests/agent_request_id_test.go b/core/internal/mcptests/agent_request_id_test.go index eb882c12f4..ded2180b6c 100644 --- a/core/internal/mcptests/agent_request_id_test.go +++ b/core/internal/mcptests/agent_request_id_test.go @@ -23,9 +23,15 @@ func setupMCPManagerWithRequestIDFunc(t *testing.T, fetchNewRequestIDFunc func(c logger := &testLogger{t: t} + // Convert to pointer slice for MCPConfig + clientConfigPtrs := make([]*schemas.MCPClientConfig, len(clientConfigs)) + for i := range clientConfigs { + clientConfigPtrs[i] = &clientConfigs[i] + } + // Create MCP config with request ID function mcpConfig := &schemas.MCPConfig{ - ClientConfigs: clientConfigs, + ClientConfigs: clientConfigPtrs, FetchNewRequestIDFunc: fetchNewRequestIDFunc, } diff --git a/core/internal/mcptests/client_management_test.go b/core/internal/mcptests/client_management_test.go index b871d7d001..1d0e452ffa 100644 --- a/core/internal/mcptests/client_management_test.go +++ b/core/internal/mcptests/client_management_test.go @@ -25,11 +25,11 @@ func TestAddClientDuplicate(t *testing.T) { // Add client clientConfig := GetSampleHTTPClientConfig(config.HTTPServerURL) - err := manager.AddClient(clientConfig) + err := manager.AddClient(&clientConfig) require.NoError(t, err, "should add client first time") // Try to add same client again - err = manager.AddClient(clientConfig) + err = manager.AddClient(&clientConfig) // Should either return error or be idempotent if err == nil { clients := manager.GetClients() @@ -183,7 +183,7 @@ func TestEditClient(t *testing.T) { updatedConfig.Name = "UpdatedName" updatedConfig.ToolsToExecute = []string{"calculator", "echo"} - err := manager.EditClient(clientID, updatedConfig) + err := manager.EditClient(clientID, &updatedConfig) require.NoError(t, err, "should edit client") // Verify changes @@ -199,7 +199,7 @@ func TestEditClientInvalidID(t *testing.T) { // Try to edit non-existent client clientConfig := GetSampleHTTPClientConfig("http://example.com") - err := manager.EditClient("non-existent-id", clientConfig) + err := manager.EditClient("non-existent-id", &clientConfig) assert.Error(t, err, "should error when editing non-existent client") } @@ -225,7 +225,7 @@ func TestEditClientInvalidConfig(t *testing.T) { // Missing ConnectionString } - err := manager.EditClient(clientID, invalidConfig) + err := manager.EditClient(clientID, &invalidConfig) // Should return error or leave client unchanged if err == nil { clients = manager.GetClients() @@ -257,7 +257,7 @@ func TestEditClientChangeConnectionType(t *testing.T) { updatedConfig := clientConfig updatedConfig.ConnectionType = schemas.MCPConnectionTypeSSE - err := manager.EditClient(clientID, updatedConfig) + err := manager.EditClient(clientID, &updatedConfig) assert.Error(t, err, "should not allow connection type change") clients = manager.GetClients() if len(clients) > 0 { @@ -428,7 +428,7 @@ func TestConcurrentClientOperations(t *testing.T) { clientConfig := GetSampleHTTPClientConfig(config.HTTPServerURL) clientConfig.ID = string(rune('a'+id)) + "-concurrent-client" - err := manager.AddClient(clientConfig) + err := manager.AddClient(&clientConfig) if err != nil { errors <- err } diff --git a/core/internal/mcptests/concurrency_advanced_test.go b/core/internal/mcptests/concurrency_advanced_test.go index 0190ab26c9..54057ce4a5 100644 --- a/core/internal/mcptests/concurrency_advanced_test.go +++ b/core/internal/mcptests/concurrency_advanced_test.go @@ -163,7 +163,7 @@ func TestConcurrent_AddRemoveClients(t *testing.T) { ToolsToAutoExecute: []string{}, } - err := manager.AddClient(clientConfig) + err := manager.AddClient(&clientConfig) if err != nil { // InProcess connections without a server instance will fail // This is expected - we're just testing that the operations are concurrent and don't deadlock diff --git a/core/internal/mcptests/concurrency_test.go b/core/internal/mcptests/concurrency_test.go index f81cc714eb..d71e1a501f 100644 --- a/core/internal/mcptests/concurrency_test.go +++ b/core/internal/mcptests/concurrency_test.go @@ -437,7 +437,7 @@ func TestConcurrent_EditClientDuringExecution(t *testing.T) { // Edit client - update name (must not contain spaces) updatedConfig := clientConfig updatedConfig.Name = "UpdatedClientName" - err := manager.EditClient(clientConfig.ID, updatedConfig) + err := manager.EditClient(clientConfig.ID, &updatedConfig) if err != nil { errors <- fmt.Errorf("failed to edit client: %v", err) } else { diff --git a/core/internal/mcptests/error_handling_protocol_test.go b/core/internal/mcptests/error_handling_protocol_test.go index 508f04de5f..bb6ad929d9 100644 --- a/core/internal/mcptests/error_handling_protocol_test.go +++ b/core/internal/mcptests/error_handling_protocol_test.go @@ -289,7 +289,7 @@ func TestErrorHandling_STDIO_MCPErrorResponse(t *testing.T) { errorServerConfig := GetErrorTestServerConfig(bifrostRoot) manager := setupMCPManager(t) - err := manager.AddClient(errorServerConfig) + err := manager.AddClient(&errorServerConfig) if err != nil { t.Skipf("error-test-server not available: %v (build with: cd examples/mcps/error-test-server && go build -o bin/error-test-server)", err) } @@ -367,7 +367,7 @@ func TestErrorHandling_STDIO_TimeoutScenario(t *testing.T) { ToolExecutionTimeout: 2 * time.Second, // 2 second timeout }) - err := manager.AddClient(errorServerConfig) + err := manager.AddClient(&errorServerConfig) if err != nil { t.Skipf("error-test-server not available: %v", err) } @@ -420,7 +420,7 @@ func TestErrorHandling_STDIO_MalformedJSON(t *testing.T) { errorServerConfig := GetErrorTestServerConfig(bifrostRoot) manager := setupMCPManager(t) - err := manager.AddClient(errorServerConfig) + err := manager.AddClient(&errorServerConfig) if err != nil { t.Skipf("error-test-server not available: %v", err) } @@ -469,7 +469,7 @@ func TestErrorHandling_STDIO_IntermittentFailures(t *testing.T) { errorServerConfig := GetErrorTestServerConfig(bifrostRoot) manager := setupMCPManager(t) - err := manager.AddClient(errorServerConfig) + err := manager.AddClient(&errorServerConfig) if err != nil { t.Skipf("error-test-server not available: %v", err) } diff --git a/core/internal/mcptests/fixtures.go b/core/internal/mcptests/fixtures.go index fcf98c1739..38d85f6ee8 100644 --- a/core/internal/mcptests/fixtures.go +++ b/core/internal/mcptests/fixtures.go @@ -1466,9 +1466,15 @@ func setupMCPManager(t *testing.T, clientConfigs ...schemas.MCPClientConfig) *mc logger := &testLogger{t: t} + // Convert to pointer slice for MCPConfig + clientConfigPtrs := make([]*schemas.MCPClientConfig, len(clientConfigs)) + for i := range clientConfigs { + clientConfigPtrs[i] = &clientConfigs[i] + } + // Create MCP config mcpConfig := &schemas.MCPConfig{ - ClientConfigs: clientConfigs, + ClientConfigs: clientConfigPtrs, } // Create Starlark CodeMode diff --git a/core/internal/mcptests/health_monitoring_test.go b/core/internal/mcptests/health_monitoring_test.go index f0834c24d2..7e7155d04b 100644 --- a/core/internal/mcptests/health_monitoring_test.go +++ b/core/internal/mcptests/health_monitoring_test.go @@ -50,7 +50,7 @@ func TestHealthCheckSTDIOServerDropAndRecoverIn20Seconds(t *testing.T) { t.Logf("✅ Health monitor detected server drop") // 5. Restart STDIO process (re-add client) - err = manager.AddClient(clientConfig) + err = manager.AddClient(&clientConfig) require.NoError(t, err, "should re-add client to simulate server recovery") t.Logf("🔄 Simulated STDIO server recovery by re-adding client") @@ -181,7 +181,7 @@ func TestHealthCheckStateTransitions(t *testing.T) { assert.Len(t, clients, 0, "client should be removed") // Re-add client (simulates reconnection) - err = manager.AddClient(clientConfig) + err = manager.AddClient(&clientConfig) require.NoError(t, err, "should re-add client") // Verify client is connected again @@ -392,7 +392,7 @@ func TestHealthCheckReconnectAfterFailure(t *testing.T) { time.Sleep(2 * time.Second) // Re-add client (manual reconnection) - err = manager.AddClient(clientConfig) + err = manager.AddClient(&clientConfig) require.NoError(t, err, "should re-add client") // Wait for health monitoring to stabilize diff --git a/core/internal/mcptests/integration_test.go b/core/internal/mcptests/integration_test.go index 6045070c31..776a0ab80f 100644 --- a/core/internal/mcptests/integration_test.go +++ b/core/internal/mcptests/integration_test.go @@ -34,7 +34,7 @@ func TestIntegration_FullChatWorkflow(t *testing.T) { httpConfig := GetSampleHTTPClientConfig(config.HTTPServerURL) httpConfig.ID = "http-integration-test" applyTestConfigHeaders(t, &httpConfig) - err := manager.AddClient(httpConfig) + err := manager.AddClient(&httpConfig) if err != nil { t.Logf("Could not add HTTP client: %v", err) } @@ -341,7 +341,7 @@ func TestIntegration_ReconnectDuringExecution(t *testing.T) { httpConfig := GetSampleHTTPClientConfig(config.HTTPServerURL) httpConfig.ID = "reconnect-test-client" applyTestConfigHeaders(t, &httpConfig) - err := manager.AddClient(httpConfig) + err := manager.AddClient(&httpConfig) require.NoError(t, err, "should add HTTP client") // Wait for client to connect diff --git a/core/internal/mcptests/tool_filtering_test.go b/core/internal/mcptests/tool_filtering_test.go index 2fd950e980..037738679f 100644 --- a/core/internal/mcptests/tool_filtering_test.go +++ b/core/internal/mcptests/tool_filtering_test.go @@ -335,7 +335,7 @@ func TestFilteringChangesAfterClientEdit(t *testing.T) { // Edit client to only allow second tool clientConfig.ToolsToExecute = []string{"hash"} - err := manager.EditClient(clientConfig.ID, clientConfig) + err := manager.EditClient(clientConfig.ID, &clientConfig) require.NoError(t, err, "edit should succeed") // Verify configuration changed diff --git a/core/mcp/agent.go b/core/mcp/agent.go index da7185c597..2bb8c8974d 100644 --- a/core/mcp/agent.go +++ b/core/mcp/agent.go @@ -182,8 +182,8 @@ func executeAgent( toolName := *toolCall.Function.Name client := clientManager.GetClientForTool(toolName) if client == nil { - // Allow code mode list and read tool tools - if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile { + // Allow code mode list, read, and docs tools (all read-only operations) + if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile || toolName == ToolTypeGetToolDocs { autoExecutableTools = append(autoExecutableTools, toolCall) logger.Debug("Tool %s can be auto-executed", toolName) continue diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go index 4fa39cc780..9fd9f52805 100644 --- a/core/mcp/clientmanager.go +++ b/core/mcp/clientmanager.go @@ -72,8 +72,8 @@ func (m *MCPManager) ReconnectClient(id string) error { // // Returns: // - error: Any error that occurred during client addition or connection -func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { - if err := validateMCPClientConfig(&config); err != nil { +func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error { + if err := validateMCPClientConfig(config); err != nil { return fmt.Errorf("invalid MCP client configuration: %w", err) } @@ -89,10 +89,13 @@ func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { // Create placeholder entry m.clientMap[config.ID] = &schemas.MCPClientState{ - Name: config.Name, - ExecutionConfig: config, - ToolMap: make(map[string]schemas.ChatTool), - ToolNameMapping: make(map[string]string), + Name: config.Name, + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + ToolNameMapping: make(map[string]string), + ConnectionInfo: &schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + }, } // Temporarily unlock for the connection attempt @@ -119,8 +122,8 @@ func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { // // Returns: // - error: Any error that occurred during client addition or connection -func (m *MCPManager) AddClientInMemory(config schemas.MCPClientConfig) error { - if err := validateMCPClientConfig(&config); err != nil { +func (m *MCPManager) AddClientInMemory(config *schemas.MCPClientConfig) error { + if err := validateMCPClientConfig(config); err != nil { return fmt.Errorf("invalid MCP client configuration: %w", err) } @@ -136,10 +139,13 @@ func (m *MCPManager) AddClientInMemory(config schemas.MCPClientConfig) error { // Create placeholder entry m.clientMap[config.ID] = &schemas.MCPClientState{ - Name: config.Name, - ExecutionConfig: config, - ToolMap: make(map[string]schemas.ChatTool), - ToolNameMapping: make(map[string]string), + Name: config.Name, + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + ToolNameMapping: make(map[string]string), + ConnectionInfo: &schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + }, } // Temporarily unlock for the connection attempt @@ -228,7 +234,7 @@ func (m *MCPManager) removeClientUnsafe(id string) error { // // Returns: // - error: Any error that occurred during client update or tool retrieval -func (m *MCPManager) EditClient(id string, updatedConfig schemas.MCPClientConfig) error { +func (m *MCPManager) EditClient(id string, updatedConfig *schemas.MCPClientConfig) error { m.mu.Lock() defer m.mu.Unlock() @@ -424,7 +430,7 @@ func (m *MCPManager) RegisterTool(name, description string, toolFunction MCPTool // connectToMCPClient establishes a connection to an external MCP server and // registers its available tools with the manager. -func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { +func (m *MCPManager) connectToMCPClient(config *schemas.MCPClientConfig) error { // First lock: Initialize or validate client entry m.mu.Lock() @@ -443,11 +449,11 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { } // Create new client entry with configuration m.clientMap[config.ID] = &schemas.MCPClientState{ - Name: config.Name, - ExecutionConfig: config, - ToolMap: make(map[string]schemas.ChatTool), - ToolNameMapping: make(map[string]string), - ConnectionInfo: schemas.MCPClientConnectionInfo{ + Name: config.Name, + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + ToolNameMapping: make(map[string]string), + ConnectionInfo: &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, }, } @@ -455,7 +461,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { // Heavy operations performed outside lock var externalClient *client.Client - var connectionInfo schemas.MCPClientConnectionInfo + var connectionInfo *schemas.MCPClientConnectionInfo var err error // Create appropriate transport based on connection type @@ -559,7 +565,6 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { // Second lock: Update client with final connection details and tools m.mu.Lock() - defer m.mu.Unlock() // Verify client still exists (could have been cleaned up during heavy operations) if client, exists := m.clientMap[config.ID]; exists { @@ -584,6 +589,8 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { logger.Debug("%s [%s] Registering %d tools. Client config - ID: %s, Name: %s, IsCodeModeClient: %v", MCPLogPrefix, config.Name, len(tools), config.ID, config.Name, config.IsCodeModeClient) logger.Info("%s Connected to MCP server '%s'", MCPLogPrefix, config.Name) } else { + // Release lock before cleanup and return + m.mu.Unlock() // Clean up resources before returning error: client was removed during connection setup // Cancel long-lived context if it was created if (config.ConnectionType == schemas.MCPConnectionTypeSSE || config.ConnectionType == schemas.MCPConnectionTypeSTDIO) && cancel != nil { @@ -598,6 +605,10 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { return fmt.Errorf("client %s was removed during connection setup", config.Name) } + // Release lock BEFORE starting monitors to prevent deadlock + // (StartMonitoring -> Start() tries to acquire RLock on the same mutex) + m.mu.Unlock() + // Register OnConnectionLost hook for SSE connections to detect idle timeouts if config.ConnectionType == schemas.MCPConnectionTypeSSE && externalClient != nil { externalClient.OnConnectionLost(func(err error) { @@ -628,37 +639,32 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { } // createHTTPConnection creates an HTTP-based MCP client connection without holding locks. -func (m *MCPManager) createHTTPConnection(ctx context.Context, config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { +func (m *MCPManager) createHTTPConnection(ctx context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) { if config.ConnectionString == nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("HTTP connection string is required") + return nil, nil, fmt.Errorf("HTTP connection string is required") } - // Prepare connection info - connectionInfo := schemas.MCPClientConnectionInfo{ + connectionInfo := &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, ConnectionURL: config.ConnectionString.GetValuePtr(), } - headers, err := config.HttpHeaders(ctx, m.oauth2Provider) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to get HTTP headers: %w", err) + return nil, nil, fmt.Errorf("failed to get HTTP headers: %w", err) } - // Create StreamableHTTP transport httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers)) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create HTTP transport: %w", err) + return nil, nil, fmt.Errorf("failed to create HTTP transport: %w", err) } - client := client.NewClient(httpTransport) - return client, connectionInfo, nil } // createSTDIOConnection creates a STDIO-based MCP client connection without holding locks. -func (m *MCPManager) createSTDIOConnection(_ context.Context, config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { +func (m *MCPManager) createSTDIOConnection(_ context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) { if config.StdioConfig == nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("stdio config is required") + return nil, nil, fmt.Errorf("stdio config is required") } // Prepare STDIO command info for display @@ -667,7 +673,7 @@ func (m *MCPManager) createSTDIOConnection(_ context.Context, config schemas.MCP // Check if environment variables are set for _, env := range config.StdioConfig.Envs { if os.Getenv(env) == "" { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) + return nil, nil, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) } } @@ -679,7 +685,7 @@ func (m *MCPManager) createSTDIOConnection(_ context.Context, config schemas.MCP ) // Prepare connection info - connectionInfo := schemas.MCPClientConnectionInfo{ + connectionInfo := &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, StdioCommandString: &cmdString, } @@ -691,26 +697,26 @@ func (m *MCPManager) createSTDIOConnection(_ context.Context, config schemas.MCP } // createSSEConnection creates a SSE-based MCP client connection without holding locks. -func (m *MCPManager) createSSEConnection(ctx context.Context, config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { +func (m *MCPManager) createSSEConnection(ctx context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) { if config.ConnectionString == nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("SSE connection string is required") + return nil, nil, fmt.Errorf("SSE connection string is required") } // Prepare connection info - connectionInfo := schemas.MCPClientConnectionInfo{ + connectionInfo := &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, ConnectionURL: config.ConnectionString.GetValuePtr(), // Reuse HTTPConnectionURL field for SSE URL display } headers, err := config.HttpHeaders(ctx, m.oauth2Provider) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to get HTTP headers: %w", err) + return nil, nil, fmt.Errorf("failed to get HTTP headers: %w", err) } // Create SSE transport sseTransport, err := transport.NewSSE(config.ConnectionString.GetValue(), transport.WithHeaders(headers)) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create SSE transport: %w", err) + return nil, nil, fmt.Errorf("failed to create SSE transport: %w", err) } client := client.NewClient(sseTransport) @@ -721,19 +727,19 @@ func (m *MCPManager) createSSEConnection(ctx context.Context, config schemas.MCP // createInProcessConnection creates an in-process MCP client connection without holding locks. // This allows direct connection to an MCP server running in the same process, providing // the lowest latency and highest performance for tool execution. -func (m *MCPManager) createInProcessConnection(_ context.Context, config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { +func (m *MCPManager) createInProcessConnection(_ context.Context, config *schemas.MCPClientConfig) (*client.Client, *schemas.MCPClientConnectionInfo, error) { if config.InProcessServer == nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") + return nil, nil, fmt.Errorf("InProcess connection requires a server instance") } // Create in-process client directly connected to the provided server inProcessClient, err := client.NewInProcessClient(config.InProcessServer) if err != nil { - return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create in-process client: %w", err) + return nil, nil, fmt.Errorf("failed to create in-process client: %w", err) } // Prepare connection info - connectionInfo := schemas.MCPClientConnectionInfo{ + connectionInfo := &schemas.MCPClientConnectionInfo{ Type: config.ConnectionType, } @@ -816,14 +822,14 @@ func (m *MCPManager) createLocalMCPClient() (*schemas.MCPClientState, error) { // Don't create the actual client connection here - it will be created // after the server is ready using NewInProcessClient return &schemas.MCPClientState{ - ExecutionConfig: schemas.MCPClientConfig{ + ExecutionConfig: &schemas.MCPClientConfig{ ID: BifrostMCPClientKey, Name: BifrostMCPClientKey, // Use same value as ID for consistent prefixing ToolsToExecute: []string{"*"}, // Allow all tools for internal client }, ToolMap: make(map[string]schemas.ChatTool), ToolNameMapping: make(map[string]string), - ConnectionInfo: schemas.MCPClientConnectionInfo{ + ConnectionInfo: &schemas.MCPClientConnectionInfo{ Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport }, }, nil diff --git a/core/mcp/codemode/starlark/readfile.go b/core/mcp/codemode/starlark/readfile.go index 4a96b4e53e..6b199bff9a 100644 --- a/core/mcp/codemode/starlark/readfile.go +++ b/core/mcp/codemode/starlark/readfile.go @@ -27,7 +27,9 @@ func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool { "The function performs case-insensitive matching and removes the .pyi extension. " + "Each tool can be accessed in code via: serverName.tool_name(param=value). " + "If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " + - "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode." + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + + "do NOT call this tool again with startLine/endLine - you already have the complete file." } else { fileNameDescription = "The virtual filename from listToolFiles in format: servers//.pyi (e.g., 'calculator/add.pyi')" toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " + @@ -35,7 +37,9 @@ func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool { "The function performs case-insensitive matching and removes the .pyi extension. " + "The tool can be accessed in code via: serverName.tool_name(param=value). " + "If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " + - "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode." + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + + "do NOT call this tool again with startLine/endLine - you already have the complete file." } readToolFileProps := schemas.OrderedMap{ @@ -45,11 +49,11 @@ func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool { }, "startLine": map[string]interface{}{ "type": "number", - "description": "Optional 1-based starting line number for partial file read (inclusive). Note: Line numbers start at 1, not 0. The first line is line 1.", + "description": "Optional 1-based starting line number for partial file read. Usually not needed - omit to read the entire file. Files are typically small (under 50 lines).", }, "endLine": map[string]interface{}{ "type": "number", - "description": "Optional 1-based ending line number for partial file read (inclusive)", + "description": "Optional 1-based ending line number for partial file read. Usually not needed - omit to read the entire file. Will be clamped to actual file size if too large.", }, } return schemas.ChatTool{ @@ -197,6 +201,12 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche lines := strings.Split(fileContent, "\n") totalLines := len(lines) + // Prepend total lines info so LLM knows the file size upfront + fileContent = fmt.Sprintf("# Total lines: %d (this is the complete file, no need to paginate)\n%s", totalLines+1, fileContent) + // Recalculate lines after prepending + lines = strings.Split(fileContent, "\n") + totalLines = len(lines) + // Handle line slicing if provided var startLine, endLine *int if sl, ok := arguments["startLine"].(float64); ok { @@ -218,21 +228,23 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche end = *endLine } - // Validate line numbers - if start < 1 || start > totalLines { - errorMsg := fmt.Sprintf("Invalid startLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%v, totalLines=%d", - start, totalLines, start, endLine, totalLines) - return createToolResponseMessage(toolCall, errorMsg), nil + // Clamp values to valid range instead of erroring + // This handles cases where LLM requests more lines than exist + if start < 1 { + start = 1 + } + if start > totalLines { + start = totalLines + } + if end < 1 { + end = 1 } - if end < 1 || end > totalLines { - errorMsg := fmt.Sprintf("Invalid endLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%d, totalLines=%d", - end, totalLines, start, end, totalLines) - return createToolResponseMessage(toolCall, errorMsg), nil + if end > totalLines { + end = totalLines } if start > end { - errorMsg := fmt.Sprintf("Invalid line range: startLine (%d) must be less than or equal to endLine (%d). Total lines in file: %d", - start, end, totalLines) - return createToolResponseMessage(toolCall, errorMsg), nil + // If start > end after clamping, just return the start line + end = start } // Slice lines (convert to 0-based indexing) @@ -289,7 +301,8 @@ func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isTo sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName)) } sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName)) - sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n\n", clientName)) + sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName)) + sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n") for _, tool := range tools { if tool.Function == nil || tool.Function.Name == "" { diff --git a/core/mcp/healthmonitor.go b/core/mcp/healthmonitor.go index fb1e91a060..6d4eda73d5 100644 --- a/core/mcp/healthmonitor.go +++ b/core/mcp/healthmonitor.go @@ -2,6 +2,7 @@ package mcp import ( "context" + "fmt" "sync" "time" @@ -81,7 +82,7 @@ func (chm *ClientHealthMonitor) Start() { chm.ticker = time.NewTicker(chm.interval) go chm.monitorLoop() - logger.Debug("%s Health monitor started for client %s (interval: %v)", MCPLogPrefix, chm.clientID, chm.interval) + logger.Debug("%s Health monitor started for client %s", MCPLogPrefix, clientState.ExecutionConfig.Name) } // Stop stops monitoring the client's health @@ -112,7 +113,13 @@ func (chm *ClientHealthMonitor) Stop() { if chm.cancel != nil { chm.cancel() } - logger.Debug("%s Health monitor stopped for client %s", MCPLogPrefix, chm.clientID) + + if !exists { + logger.Error("%s Health monitor failed to stop for client %s, client not found in manager", MCPLogPrefix, displayName) + return + } + + logger.Debug("%s Health monitor stopped for client %s", MCPLogPrefix, displayName) } // monitorLoop runs the health check loop @@ -198,7 +205,7 @@ func (chm *ClientHealthMonitor) updateClientState(state schemas.MCPConnectionSta // Log after releasing the lock if stateChanged { - logger.Info("%s Client %s connection state changed to: %s", MCPLogPrefix, chm.clientID, state) + logger.Info(fmt.Sprintf("%s Client %s connection state changed to: %s", MCPLogPrefix, clientState.ExecutionConfig.Name, state)) } } @@ -270,4 +277,4 @@ func (hmm *HealthMonitorManager) StopAll() { monitor.Stop() } hmm.monitors = make(map[string]*ClientHealthMonitor) -} +} \ No newline at end of file diff --git a/core/mcp/interface.go b/core/mcp/interface.go new file mode 100644 index 0000000000..b74e559b9c --- /dev/null +++ b/core/mcp/interface.go @@ -0,0 +1,73 @@ +//go:build !tinygo && !wasm + +package mcp + +import ( + "context" + + "github.com/maximhq/bifrost/core/schemas" +) + +// MCPManagerInterface defines the interface for MCP management functionality. +// This interface allows different implementations (OSS and Enterprise) to be used +// interchangeably in the Bifrost core. +type MCPManagerInterface interface { + // Tool Operations + // AddToolsToRequest parses available MCP tools and adds them to the request + AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest + + // GetAvailableTools returns all available MCP tools for the given context + GetAvailableTools(ctx context.Context) []schemas.ChatTool + + // ExecuteToolCall executes a single tool call and returns the result + ExecuteToolCall(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) + + // UpdateToolManagerConfig updates the configuration for the tool manager + UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) + + // Agent Mode Operations + // CheckAndExecuteAgentForChatRequest handles agent mode for Chat Completions API + CheckAndExecuteAgentForChatRequest( + ctx *schemas.BifrostContext, + req *schemas.BifrostChatRequest, + response *schemas.BifrostChatResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), + executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), + ) (*schemas.BifrostChatResponse, *schemas.BifrostError) + + // CheckAndExecuteAgentForResponsesRequest handles agent mode for Responses API + CheckAndExecuteAgentForResponsesRequest( + ctx *schemas.BifrostContext, + req *schemas.BifrostResponsesRequest, + response *schemas.BifrostResponsesResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), + executeTool func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), + ) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) + + // Client Management + // GetClients returns all MCP clients + GetClients() []schemas.MCPClientState + + // AddClient adds a new MCP client with the given configuration + AddClient(config *schemas.MCPClientConfig) error + + // RemoveClient removes an MCP client by ID + RemoveClient(id string) error + + // EditClient updates an existing MCP client configuration + EditClient(id string, updatedConfig *schemas.MCPClientConfig) error + + // ReconnectClient reconnects an MCP client by ID + ReconnectClient(id string) error + + // Tool Registration + // RegisterTool registers a local tool with the MCP server + RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error + + // Lifecycle + // Cleanup performs cleanup of all MCP resources + Cleanup() error +} + +// Ensure MCPManager implements MCPManagerInterface +var _ MCPManagerInterface = (*MCPManager)(nil) diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go index 15f91eec69..45d833efbd 100644 --- a/core/mcp/mcp.go +++ b/core/mcp/mcp.go @@ -120,7 +120,7 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider wg := sync.WaitGroup{} wg.Add(len(config.ClientConfigs)) for _, clientConfig := range config.ClientConfigs { - go func(clientConfig schemas.MCPClientConfig) { + go func(clientConfig *schemas.MCPClientConfig) { defer wg.Done() if err := manager.AddClient(clientConfig); err != nil { logger.Warn("%s Failed to add MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err) diff --git a/core/mcp/toolsync.go b/core/mcp/toolsync.go index 82a22a68b3..2f1c0815e5 100644 --- a/core/mcp/toolsync.go +++ b/core/mcp/toolsync.go @@ -222,7 +222,7 @@ func (tsm *ToolSyncManager) StopAll() { // - Positive value: use this interval // // Returns 0 if sync is disabled for this client. -func ResolveToolSyncInterval(clientConfig schemas.MCPClientConfig, globalInterval time.Duration) time.Duration { +func ResolveToolSyncInterval(clientConfig *schemas.MCPClientConfig, globalInterval time.Duration) time.Duration { // Per-client explicitly disabled (negative value) if clientConfig.ToolSyncInterval < 0 { return 0 // Disabled for this client diff --git a/core/mcp/utils.go b/core/mcp/utils.go index 531ded3d7b..a46990f8e2 100644 --- a/core/mcp/utils.go +++ b/core/mcp/utils.go @@ -202,7 +202,10 @@ func shouldIncludeClient(clientName string, includeClients []string) bool { } // shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap). -func shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bool { +func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) bool { + if config == nil { + return true // No tools allowed + } // If ToolsToExecute is specified (not nil), apply filtering if config.ToolsToExecute != nil { // Handle empty array [] - means no tools are allowed @@ -229,7 +232,7 @@ func shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bo // canAutoExecuteTool checks if a tool can be auto-executed based on client configuration. // Returns true if the tool can be auto-executed, false otherwise. -func canAutoExecuteTool(toolName string, config schemas.MCPClientConfig) bool { +func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // First check if tool is in ToolsToExecute (must be executable first) if shouldSkipToolForConfig(toolName, config) { return false // Tool is not in ToolsToExecute, so it cannot be auto-executed diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index 755c3dbb08..a8f6ffe60d 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -26,7 +26,7 @@ var ( // MCPConfig represents the configuration for MCP integration in Bifrost. // It enables tool auto-discovery and execution from local and external MCP servers. type MCPConfig struct { - ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations + ClientConfigs []*MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations ToolManagerConfig *MCPToolManagerConfig `json:"tool_manager_config,omitempty"` // MCP tool manager configuration ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Global default interval for syncing tools from MCP servers (0 = use default 10 min) @@ -76,7 +76,7 @@ const ( // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { - ID string `json:"id"` // Client ID + ID string `json:"client_id"` // Client ID Name string `json:"name"` // Client name IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) @@ -100,9 +100,10 @@ type MCPClientConfig struct { // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => auto-execute only the specified tools // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. - IsPingAvailable bool `json:"is_ping_available"` // Whether the MCP server supports ping for health checks (default: true). If false, uses listTools for health checks. - ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) - ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + IsPingAvailable bool `json:"is_ping_available"` // Whether the MCP server supports ping for health checks (default: true). If false, uses listTools for health checks. + ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) + ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) + ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } // NewMCPClientConfigFromMap creates a new MCP client config from a map[string]any. @@ -187,14 +188,14 @@ const ( // MCPClientState represents a connected MCP client with its configuration and tools. // It is used internally by the MCP manager to track the state of a connected MCP client. type MCPClientState struct { - Name string // Unique name for this client - Conn *client.Client // Active MCP client connection - ExecutionConfig MCPClientConfig // Tool filtering settings - ToolMap map[string]ChatTool // Available tools mapped by name - ToolNameMapping map[string]string // Maps sanitized_name -> original_mcp_name (e.g., "notion_search" -> "notion-search") - ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management - CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) - State MCPConnectionState // Connection state (connected, disconnected, error) + Name string // Unique name for this client + Conn *client.Client // Active MCP client connection + ExecutionConfig *MCPClientConfig // Tool filtering settings + ToolMap map[string]ChatTool // Available tools mapped by name + ToolNameMapping map[string]string // Maps sanitized_name -> original_mcp_name (e.g., "notion_search" -> "notion-search") + ConnectionInfo *MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management + CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) + State MCPConnectionState // Connection state (connected, disconnected, error) } // MCPClientConnectionInfo stores metadata about how a client is connected. @@ -208,7 +209,7 @@ type MCPClientConnectionInfo struct { // and connection information, after it has been initialized. // It is returned by GetMCPClients() method in bifrost. type MCPClient struct { - Config MCPClientConfig `json:"config"` // Tool filtering settings + Config *MCPClientConfig `json:"config"` // Tool filtering settings Tools []ChatToolFunction `json:"tools"` // Available tools State MCPConnectionState `json:"state"` // Connection state } diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index e617594a6e..d9ce9ba521 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -170,6 +170,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddToolSyncIntervalColumns(ctx, db); err != nil { return err } + if err := migrationAddMCPClientConfigToOAuthConfig(ctx, db); err != nil { + return err + } return nil } @@ -3105,3 +3108,37 @@ func migrationAddToolSyncIntervalColumns(ctx context.Context, db *gorm.DB) error } return nil } + +// migrationAddMCPClientConfigToOAuthConfig adds the mcp_client_config_json column to oauth_configs table +// This enables multi-instance support by storing pending MCP client config in the database +// instead of in-memory, so OAuth callbacks can be handled by any server instance +func migrationAddMCPClientConfigToOAuthConfig(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_client_config_to_oauth_config", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableOauthConfig{}, "mcp_client_config_json") { + if err := migrator.AddColumn(&tables.TableOauthConfig{}, "mcp_client_config_json"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableOauthConfig{}, "mcp_client_config_json") { + if err := migrator.DropColumn(&tables.TableOauthConfig{}, "mcp_client_config_json"); err != nil { + return err + } + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running mcp client config oauth migration: %s", err.Error()) + } + return nil +} diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index a75e1370b0..413f81abec 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -749,7 +749,7 @@ func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*t } // GetMCPConfig retrieves the MCP configuration from the database. -func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*tables.MCPConfig, error) { +func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { var dbMCPClients []tables.TableMCPClient // Get all MCP clients if err := s.db.WithContext(ctx).Find(&dbMCPClients).Error; err != nil { @@ -763,8 +763,27 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*tables.MCPConfig, e if errors.Is(err, gorm.ErrRecordNotFound) { // Return MCP config with default ToolManagerConfig if no client config exists // This will never happen, but just in case. - return &tables.MCPConfig{ - ClientConfigs: dbMCPClients, + clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) + for i, dbClient := range dbMCPClients { + clientConfigs[i] = &schemas.MCPClientConfig{ + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: dbClient.ConnectionString, + StdioConfig: dbClient.StdioConfig, + AuthType: schemas.MCPAuthType(dbClient.AuthType), + OauthConfigID: dbClient.OauthConfigID, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: dbClient.Headers, + IsPingAvailable: dbClient.IsPingAvailable, + ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + ToolPricing: dbClient.ToolPricing, + } + } + return &schemas.MCPConfig{ + ClientConfigs: clientConfigs, ToolManagerConfig: &schemas.MCPToolManagerConfig{ ToolExecutionTimeout: 30 * time.Second, // default from TableClientConfig MaxAgentDepth: 10, // default from TableClientConfig @@ -778,20 +797,9 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*tables.MCPConfig, e MaxAgentDepth: clientConfig.MCPAgentDepth, CodeModeBindingLevel: schemas.CodeModeBindingLevel(clientConfig.MCPCodeModeBindingLevel), } - return &tables.MCPConfig{ - ClientConfigs: dbMCPClients, - ToolManagerConfig: &toolManagerConfig, - }, nil -} - -// ConvertTableMCPConfigToSchemas converts tables.MCPConfig to schemas.MCPConfig -func ConvertTableMCPConfigToSchemas(tableConfig *tables.MCPConfig, globalToolSyncInterval int) *schemas.MCPConfig { - if tableConfig == nil { - return nil - } - clientConfigs := make([]schemas.MCPClientConfig, len(tableConfig.ClientConfigs)) - for i, dbClient := range tableConfig.ClientConfigs { - clientConfigs[i] = schemas.MCPClientConfig{ + clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) + for i, dbClient := range dbMCPClients { + clientConfigs[i] = &schemas.MCPClientConfig{ ID: dbClient.ClientID, Name: dbClient.Name, IsCodeModeClient: dbClient.IsCodeModeClient, @@ -805,15 +813,16 @@ func ConvertTableMCPConfigToSchemas(tableConfig *tables.MCPConfig, globalToolSyn Headers: dbClient.Headers, IsPingAvailable: dbClient.IsPingAvailable, ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + ToolPricing: dbClient.ToolPricing, } } return &schemas.MCPConfig{ ClientConfigs: clientConfigs, - ToolManagerConfig: tableConfig.ToolManagerConfig, - ToolSyncInterval: time.Duration(globalToolSyncInterval) * time.Minute, - } + ToolManagerConfig: &toolManagerConfig, + }, nil } + // GetMCPClientByID retrieves an MCP client by ID from the database. func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) { var mcpClient tables.TableMCPClient @@ -839,10 +848,10 @@ func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (* } // CreateMCPClientConfig creates a new MCP client configuration in the database. -func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig schemas.MCPClientConfig) error { +func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { return s.db.Transaction(func(tx *gorm.DB) error { // Create a deep copy to avoid modifying the original - clientConfigCopy, err := deepCopy(clientConfig) + clientConfigCopy, err := deepCopy(*clientConfig) if err != nil { return err } @@ -870,7 +879,7 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig } // UpdateMCPClientConfig updates an existing MCP client configuration in the database. -func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, clientConfig tables.TableMCPClient) error { +func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error { return s.db.Transaction(func(tx *gorm.DB) error { // Find existing client var existingClient tables.TableMCPClient diff --git a/framework/configstore/store.go b/framework/configstore/store.go index b39d18ca5b..52d1a68974 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -38,11 +38,11 @@ type ConfigStore interface { GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) // MCP config CRUD - GetMCPConfig(ctx context.Context) (*tables.MCPConfig, error) + GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) - CreateMCPClientConfig(ctx context.Context, clientConfig schemas.MCPClientConfig) error - UpdateMCPClientConfig(ctx context.Context, id string, clientConfig tables.TableMCPClient) error + CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error + UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error DeleteMCPClientConfig(ctx context.Context, id string) error // Vector store config CRUD diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index 2b0da6dda2..e85c21e4cd 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -9,14 +9,6 @@ import ( "gorm.io/gorm" ) -type MCPConfig struct { - ClientConfigs []TableMCPClient `json:"client_configs,omitempty"` // Per-client execution configurations - ToolManagerConfig *schemas.MCPToolManagerConfig `json:"tool_manager_config,omitempty"` // MCP tool manager configuration - FetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string `json:"-"` - PluginPipelineProvider func() interface{} `json:"-"` - ReleasePluginPipeline func(pipeline interface{}) `json:"-"` -} - // TableMCPClient represents an MCP client configuration in the database type TableMCPClient struct { ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. diff --git a/framework/configstore/tables/oauth.go b/framework/configstore/tables/oauth.go index ada014a206..b3b897bb50 100644 --- a/framework/configstore/tables/oauth.go +++ b/framework/configstore/tables/oauth.go @@ -22,9 +22,10 @@ type TableOauthConfig struct { CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider) Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked" TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback) - ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery - UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery - CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery + UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery + MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion) + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min) } diff --git a/framework/oauth2/main.go b/framework/oauth2/main.go index db3e210f0e..598dff04da 100644 --- a/framework/oauth2/main.go +++ b/framework/oauth2/main.go @@ -20,19 +20,11 @@ import ( "github.com/maximhq/bifrost/framework/configstore/tables" ) -// PendingMCPClient represents an MCP client waiting for OAuth completion -type PendingMCPClient struct { - MCPClientConfig schemas.MCPClientConfig - OauthConfigID string - CreatedAt time.Time -} - // OAuth2Provider implements the schemas.OAuth2Provider interface // It provides OAuth 2.0 authentication functionality with database persistence type OAuth2Provider struct { - configStore configstore.ConfigStore - mu sync.RWMutex - pendingMCPClients map[string]*PendingMCPClient // Key: mcp_client_id + configStore configstore.ConfigStore + mu sync.RWMutex } // NewOAuth2Provider creates a new OAuth provider instance @@ -41,15 +33,9 @@ func NewOAuth2Provider(configStore configstore.ConfigStore, logger schemas.Logge logger = bifrost.NewDefaultLogger(schemas.LogLevelInfo) } SetLogger(logger) - p := &OAuth2Provider{ - configStore: configStore, - pendingMCPClients: make(map[string]*PendingMCPClient), + return &OAuth2Provider{ + configStore: configStore, } - - // Start background cleanup goroutine for expired pending clients - go p.cleanupExpiredPendingClients() - - return p } // GetAccessToken retrieves the access token for a given oauth_config_id @@ -213,53 +199,115 @@ func (p *OAuth2Provider) RevokeToken(ctx context.Context, oauthConfigID string) } // StorePendingMCPClient stores an MCP client config that's waiting for OAuth completion -func (p *OAuth2Provider) StorePendingMCPClient(mcpClientID string, mcpClientConfig schemas.MCPClientConfig) { - p.mu.Lock() - defer p.mu.Unlock() - oauthConfigID := "" - if mcpClientConfig.OauthConfigID != nil { - oauthConfigID = *mcpClientConfig.OauthConfigID +// The config is persisted in the database (oauth_configs.mcp_client_config_json) to support +// multi-instance deployments where OAuth callback may hit a different server instance. +func (p *OAuth2Provider) StorePendingMCPClient(oauthConfigID string, mcpClientConfig schemas.MCPClientConfig) error { + ctx := context.Background() + + oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID) + if err != nil { + return fmt.Errorf("failed to get oauth config: %w", err) } - p.pendingMCPClients[mcpClientID] = &PendingMCPClient{ - MCPClientConfig: mcpClientConfig, - OauthConfigID: oauthConfigID, - CreatedAt: time.Now(), + if oauthConfig == nil { + return fmt.Errorf("oauth config not found: %s", oauthConfigID) } -} -// GetPendingMCPClient retrieves an MCP client config by mcp_client_id -func (p *OAuth2Provider) GetPendingMCPClient(mcpClientID string) *schemas.MCPClientConfig { - p.mu.RLock() - defer p.mu.RUnlock() - if pending, exists := p.pendingMCPClients[mcpClientID]; exists { - return &pending.MCPClientConfig + configJSON, err := json.Marshal(mcpClientConfig) + if err != nil { + return fmt.Errorf("failed to marshal MCP client config: %w", err) + } + configStr := string(configJSON) + oauthConfig.MCPClientConfigJSON = &configStr + + if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil { + return fmt.Errorf("failed to update oauth config with MCP client config: %w", err) } + + logger.Debug("Stored pending MCP client config", "oauth_config_id", oauthConfigID) return nil } -// RemovePendingMCPClient removes a pending MCP client after OAuth completion -func (p *OAuth2Provider) RemovePendingMCPClient(mcpClientID string) { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.pendingMCPClients, mcpClientID) +// GetPendingMCPClient retrieves an MCP client config by oauth_config_id +// Returns nil if no pending config is found or if the oauth config has expired +func (p *OAuth2Provider) GetPendingMCPClient(oauthConfigID string) (*schemas.MCPClientConfig, error) { + ctx := context.Background() + + oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID) + if err != nil { + return nil, fmt.Errorf("failed to get oauth config: %w", err) + } + if oauthConfig == nil { + return nil, nil + } + + // Check if expired + if time.Now().After(oauthConfig.ExpiresAt) { + return nil, nil + } + + if oauthConfig.MCPClientConfigJSON == nil || *oauthConfig.MCPClientConfigJSON == "" { + return nil, nil + } + + var config schemas.MCPClientConfig + if err := json.Unmarshal([]byte(*oauthConfig.MCPClientConfigJSON), &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal MCP client config: %w", err) + } + + return &config, nil } -// cleanupExpiredPendingClients removes pending MCP clients older than 5 minutes -func (p *OAuth2Provider) cleanupExpiredPendingClients() { - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - - for range ticker.C { - p.mu.Lock() - now := time.Now() - for mcpClientID, pending := range p.pendingMCPClients { - if now.Sub(pending.CreatedAt) > 5*time.Minute { - delete(p.pendingMCPClients, mcpClientID) - logger.Debug("Cleaned up expired pending MCP client", "mcp_client_id", mcpClientID) - } - } - p.mu.Unlock() +// GetPendingMCPClientByState retrieves an MCP client config by OAuth state token +// This is useful when the callback only has the state parameter +func (p *OAuth2Provider) GetPendingMCPClientByState(state string) (*schemas.MCPClientConfig, string, error) { + ctx := context.Background() + + oauthConfig, err := p.configStore.GetOauthConfigByState(ctx, state) + if err != nil { + return nil, "", fmt.Errorf("failed to get oauth config by state: %w", err) + } + if oauthConfig == nil { + return nil, "", nil + } + + // Check if expired + if time.Now().After(oauthConfig.ExpiresAt) { + return nil, "", nil + } + + if oauthConfig.MCPClientConfigJSON == nil || *oauthConfig.MCPClientConfigJSON == "" { + return nil, oauthConfig.ID, nil + } + + var config schemas.MCPClientConfig + if err := json.Unmarshal([]byte(*oauthConfig.MCPClientConfigJSON), &config); err != nil { + return nil, oauthConfig.ID, fmt.Errorf("failed to unmarshal MCP client config: %w", err) } + + return &config, oauthConfig.ID, nil +} + +// RemovePendingMCPClient clears the pending MCP client config from the oauth config +// This is called after OAuth completion to clean up +func (p *OAuth2Provider) RemovePendingMCPClient(oauthConfigID string) error { + ctx := context.Background() + + oauthConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID) + if err != nil { + return fmt.Errorf("failed to get oauth config: %w", err) + } + if oauthConfig == nil { + return nil // Already removed or doesn't exist + } + + oauthConfig.MCPClientConfigJSON = nil + + if err := p.configStore.UpdateOauthConfig(ctx, oauthConfig); err != nil { + return fmt.Errorf("failed to clear pending MCP client config: %w", err) + } + + logger.Debug("Removed pending MCP client config", "oauth_config_id", oauthConfigID) + return nil } // InitiateOAuthFlow creates an OAuth config and returns the authorization URL diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 8fa68881ef..1042efdb52 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -9,6 +9,7 @@ import ( "slices" "sort" "strings" + "time" "github.com/fasthttp/router" "github.com/google/uuid" @@ -20,10 +21,10 @@ import ( ) type MCPManager interface { - ReconnectMCPClient(ctx context.Context, id string) error - AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error + AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error RemoveMCPClient(ctx context.Context, id string) error - EditMCPClient(ctx context.Context, id string, updatedConfig configstoreTables.TableMCPClient) error + UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error + ReconnectMCPClient(ctx context.Context, id string) error } // MCPHandler manages HTTP requests for MCP tool operations @@ -48,21 +49,25 @@ func NewMCPHandler(mcpManager MCPManager, client *bifrost.Bifrost, store *lib.Co func (h *MCPHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/api/mcp/clients", lib.ChainMiddlewares(h.getMCPClients, middlewares...)) r.POST("/api/mcp/client", lib.ChainMiddlewares(h.addMCPClient, middlewares...)) - r.PUT("/api/mcp/client/{id}", lib.ChainMiddlewares(h.editMCPClient, middlewares...)) - r.DELETE("/api/mcp/client/{id}", lib.ChainMiddlewares(h.removeMCPClient, middlewares...)) + r.PUT("/api/mcp/client/{id}", lib.ChainMiddlewares(h.updateMCPClient, middlewares...)) + r.DELETE("/api/mcp/client/{id}", lib.ChainMiddlewares(h.deleteMCPClient, middlewares...)) r.POST("/api/mcp/client/{id}/reconnect", lib.ChainMiddlewares(h.reconnectMCPClient, middlewares...)) r.POST("/api/mcp/client/{id}/complete-oauth", lib.ChainMiddlewares(h.completeMCPClientOAuth, middlewares...)) } // MCPClientResponse represents the response structure for MCP clients type MCPClientResponse struct { - Config configstoreTables.TableMCPClient `json:"config"` - Tools []schemas.ChatToolFunction `json:"tools"` - State schemas.MCPConnectionState `json:"state"` + Config *schemas.MCPClientConfig `json:"config"` + Tools []schemas.ChatToolFunction `json:"tools"` + State schemas.MCPConnectionState `json:"state"` } // getMCPClients handles GET /api/mcp/clients - Get all MCP clients func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendJSON(ctx, []MCPClientResponse{}) + return + } // Get clients from store config configsInStore := h.store.MCPConfig if configsInStore == nil { @@ -85,9 +90,8 @@ func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { for _, configClient := range configsInStore.ClientConfigs { // Redact sensitive fields before sending to UI - redactedConfig := h.store.RedactTableMCPClient(configClient) - - if connectedClient, exists := connectedClientsMap[configClient.ClientID]; exists { + redactedConfig := h.store.RedactMCPClientConfig(configClient) + if connectedClient, exists := connectedClientsMap[configClient.ID]; exists { // Sort tools alphabetically by name sortedTools := make([]schemas.ChatToolFunction, len(connectedClient.Tools)) copy(sortedTools, connectedClient.Tools) @@ -114,6 +118,10 @@ func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { // reconnectMCPClient handles POST /api/mcp/client/{id}/reconnect - Reconnect an MCP client func (h *MCPHandler) reconnectMCPClient(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") + return + } id, err := getIDFromCtx(ctx) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid id: %v", err)) @@ -147,6 +155,10 @@ type MCPClientRequest struct { // addMCPClient handles POST /api/mcp/client - Add a new MCP client func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") + return + } var req MCPClientRequest if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) @@ -238,8 +250,12 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { Headers: req.Headers, } - // Store in OAuth provider memory (will auto-cleanup after 5 minutes if not completed) - h.oauthHandler.StorePendingMCPClient(req.ClientID, pendingConfig) + // Store pending config in database (associated with oauth_config_id for multi-instance support) + if err := h.oauthHandler.StorePendingMCPClient(flowInitiation.OauthConfigID, pendingConfig); err != nil { + logger.Error(fmt.Sprintf("[Add MCP Client] Failed to store pending MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to store pending MCP client: %v", err)) + return + } // Return OAuth flow initiation response SendJSON(ctx, map[string]any{ @@ -253,19 +269,36 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { return } - // For "none" or "headers" auth, proceed with immediate connection - schemasConfig := schemas.MCPClientConfig{ + toolSyncInterval := 10 * time.Minute + if req.ToolSyncInterval != 0 { + toolSyncInterval = time.Duration(req.ToolSyncInterval) * time.Minute + } else { + config, err := h.store.ConfigStore.GetClientConfig(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get client config: %v", err)) + return + } + if config != nil { + toolSyncInterval = time.Duration(config.MCPToolSyncInterval) * time.Minute + } + + } + // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) + schemasConfig := &schemas.MCPClientConfig{ ID: req.ClientID, Name: req.Name, IsCodeModeClient: req.IsCodeModeClient, ConnectionType: schemas.MCPConnectionType(req.ConnectionType), ConnectionString: req.ConnectionString, StdioConfig: req.StdioConfig, - AuthType: schemas.MCPAuthType(req.AuthType), - OauthConfigID: nil, ToolsToExecute: req.ToolsToExecute, ToolsToAutoExecute: req.ToolsToAutoExecute, Headers: req.Headers, + AuthType: schemas.MCPAuthType(req.AuthType), + OauthConfigID: req.OauthConfigID, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ToolPricing: req.ToolPricing, } if err := h.mcpManager.AddMCPClient(ctx, schemasConfig); err != nil { @@ -274,34 +307,36 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { } // Creating MCP client config in config store if h.store.ConfigStore != nil { - if err := h.store.ConfigStore.CreateMCPClientConfig(ctx, req); err != nil { + if err := h.store.ConfigStore.CreateMCPClientConfig(ctx, schemasConfig); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Added to core but failed to create MCP config in database: %v, please restart bifrost to keep core and database in sync", err)) return } } + SendJSON(ctx, map[string]any{ "status": "success", "message": "MCP client connected successfully", }) } -// editMCPClient handles PUT /api/mcp/client/{id} - Edit MCP client -func (h *MCPHandler) editMCPClient(ctx *fasthttp.RequestCtx) { +// updateMCPClient handles PUT /api/mcp/client/{id} - Edit MCP client +func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") + return + } id, err := getIDFromCtx(ctx) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid id: %v", err)) return } - // Accept the full table client config to support tool_pricing - var req configstoreTables.TableMCPClient + var req *configstoreTables.TableMCPClient if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) return } - req.ClientID = id - // Validate tools_to_execute if err := validateToolsToExecute(req.ToolsToExecute); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_execute: %v", err)) @@ -322,53 +357,85 @@ func (h *MCPHandler) editMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) return } - // Get existing config to handle redacted values - var existingConfig *configstoreTables.TableMCPClient + var existingConfig *schemas.MCPClientConfig if h.store.MCPConfig != nil { for i, client := range h.store.MCPConfig.ClientConfigs { - if client.ClientID == id { - existingConfig = &h.store.MCPConfig.ClientConfigs[i] + if client.ID == id { + existingConfig = h.store.MCPConfig.ClientConfigs[i] break } } } - if existingConfig == nil { SendError(ctx, fasthttp.StatusNotFound, "MCP client not found") return } // Merge redacted values - preserve old values if incoming values are redacted and unchanged - req = mergeMCPRedactedValues(req, *existingConfig, h.store.RedactTableMCPClient(*existingConfig)) - - // EditMCPClient internally handles database update, env vars, and runtime update - if err := h.mcpManager.EditMCPClient(ctx, id, req); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to edit MCP client: %v", err)) - return - } - // Update MCP client config in config store with merged values + req = mergeMCPRedactedValues(req, existingConfig, h.store.RedactMCPClientConfig(existingConfig)) + // Persist changes to config store if h.store.ConfigStore != nil { - if err := h.store.ConfigStore.UpdateMCPClientConfig(ctx, id, mergedConfig); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Updated in core but failed to save MCP config in database: %v, please restart bifrost to keep core and database in sync", err)) + if err := h.store.ConfigStore.UpdateMCPClientConfig(ctx, id, req); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp client config in store: %v", err)) return } } + toolSyncInterval := 10 * time.Minute + if req.ToolSyncInterval != 0 { + toolSyncInterval = time.Duration(req.ToolSyncInterval) * time.Minute + } else { + config, err := h.store.ConfigStore.GetClientConfig(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get client config: %v", err)) + return + } + if config != nil { + toolSyncInterval = time.Duration(config.MCPToolSyncInterval) * time.Minute + } + + } + // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) + schemasConfig := &schemas.MCPClientConfig{ + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(req.ConnectionType), + ConnectionString: req.ConnectionString, + StdioConfig: req.StdioConfig, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, + Headers: req.Headers, + AuthType: schemas.MCPAuthType(req.AuthType), + OauthConfigID: req.OauthConfigID, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ToolPricing: req.ToolPricing, + } + // Update MCP client in memory + if err := h.mcpManager.UpdateMCPClient(ctx, id, schemasConfig); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp client: %v", err)) + return + } SendJSON(ctx, map[string]any{ "status": "success", "message": "MCP client edited successfully", }) } -// removeMCPClient handles DELETE /api/mcp/client/{id} - Remove an MCP client -func (h *MCPHandler) removeMCPClient(ctx *fasthttp.RequestCtx) { +// deleteMCPClient handles DELETE /api/mcp/client/{id} - Remove an MCP client +func (h *MCPHandler) deleteMCPClient(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") + return + } id, err := getIDFromCtx(ctx) if err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid id: %v", err)) + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid id: %v", err)) return } if err := h.mcpManager.RemoveMCPClient(ctx, id); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to remove MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to remove MCP client: %v", err)) return } // Deleting MCP client config from config store @@ -572,7 +639,7 @@ func validateMCPClientName(name string) error { // mergeMCPRedactedValues merges incoming MCP client config with existing config, // preserving old values when incoming values are redacted and unchanged. // This follows the same pattern as provider config updates. -func mergeMCPRedactedValues(incoming, oldRaw, oldRedacted configstoreTables.TableMCPClient) configstoreTables.TableMCPClient { +func mergeMCPRedactedValues(incoming *configstoreTables.TableMCPClient, oldRaw, oldRedacted *schemas.MCPClientConfig) *configstoreTables.TableMCPClient { merged := incoming // Handle ConnectionString - if incoming is redacted and equals old redacted, keep old raw value @@ -606,30 +673,23 @@ func mergeMCPRedactedValues(incoming, oldRaw, oldRedacted configstoreTables.Tabl } // completeMCPClientOAuth handles POST /api/mcp/client/{id}/complete-oauth - Complete MCP client creation after OAuth authorization +// The {id} parameter is the oauth_config_id returned from the initial addMCPClient call func (h *MCPHandler) completeMCPClientOAuth(ctx *fasthttp.RequestCtx) { - id, err := getIDFromCtx(ctx) - if err != nil { - logger.Error(fmt.Sprintf("[OAuth Complete] Invalid id: %v", err)) - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid id: %v", err)) + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "MCP operations unavailable: config store is disabled") return } - - logger.Debug(fmt.Sprintf("[OAuth Complete] Completing OAuth for MCP client: %s", id)) - - // Get MCP client config from OAuth provider memory - mcpClientConfig := h.oauthHandler.GetPendingMCPClient(id) - if mcpClientConfig == nil { - SendError(ctx, fasthttp.StatusNotFound, "MCP client not found in pending OAuth clients") + oauthConfigID, err := getIDFromCtx(ctx) + if err != nil { + logger.Error(fmt.Sprintf("[OAuth Complete] Invalid oauth_config_id: %v", err)) + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid oauth_config_id: %v", err)) return } - if mcpClientConfig.OauthConfigID == nil { - SendError(ctx, fasthttp.StatusBadRequest, "No OAuth config linked to this MCP client") - return - } + logger.Debug(fmt.Sprintf("[OAuth Complete] Completing OAuth for oauth_config_id: %s", oauthConfigID)) // Check if OAuth flow is authorized - oauthConfig, err := h.store.ConfigStore.GetOauthConfigByID(ctx, *mcpClientConfig.OauthConfigID) + oauthConfig, err := h.store.ConfigStore.GetOauthConfigByID(ctx, oauthConfigID) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get OAuth config: %v", err)) return @@ -645,17 +705,32 @@ func (h *MCPHandler) completeMCPClientOAuth(ctx *fasthttp.RequestCtx) { return } + // Get MCP client config from database (stored with oauth_config for multi-instance support) + mcpClientConfig, err := h.oauthHandler.GetPendingMCPClient(oauthConfigID) + if err != nil { + logger.Error(fmt.Sprintf("[OAuth Complete] Failed to get pending MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get pending MCP client: %v", err)) + return + } + if mcpClientConfig == nil { + SendError(ctx, fasthttp.StatusNotFound, "MCP client not found in pending OAuth clients. The OAuth flow may have expired or already been completed.") + return + } + // Add MCP client to Bifrost (this will save to database and connect) - if err := h.mcpManager.AddMCPClient(ctx, *mcpClientConfig); err != nil { + if err := h.mcpManager.AddMCPClient(ctx, mcpClientConfig); err != nil { logger.Error(fmt.Sprintf("[OAuth Complete] Failed to connect MCP client: %v", err)) SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to connect MCP client: %v", err)) return } - // Remove from pending OAuth clients memory - h.oauthHandler.RemovePendingMCPClient(id) + // Clear pending MCP client config from oauth_config (cleanup) + if err := h.oauthHandler.RemovePendingMCPClient(oauthConfigID); err != nil { + logger.Warn(fmt.Sprintf("[OAuth Complete] Failed to clear pending MCP client config: %v", err)) + // Don't fail the request - the MCP client was successfully created + } - logger.Debug(fmt.Sprintf("[OAuth Complete] MCP client connected successfully: %s", id)) + logger.Debug(fmt.Sprintf("[OAuth Complete] MCP client connected successfully: %s", mcpClientConfig.ID)) SendJSON(ctx, map[string]any{ "status": "success", "message": "MCP client connected successfully with OAuth", diff --git a/transports/bifrost-http/handlers/oauth2.go b/transports/bifrost-http/handlers/oauth2.go index faf5c86654..7e47330c0a 100644 --- a/transports/bifrost-http/handlers/oauth2.go +++ b/transports/bifrost-http/handlers/oauth2.go @@ -223,17 +223,23 @@ func (h *OAuthHandler) InitiateOAuthFlow(ctx context.Context, req OAuthInitiatio return h.oauthProvider.InitiateOAuthFlow(ctx, config) } -// StorePendingMCPClient stores an MCP client config in OAuth provider memory while waiting for OAuth completion -func (h *OAuthHandler) StorePendingMCPClient(mcpClientID string, mcpClientConfig schemas.MCPClientConfig) { - h.oauthProvider.StorePendingMCPClient(mcpClientID, mcpClientConfig) +// StorePendingMCPClient stores an MCP client config in the database while waiting for OAuth completion +// This supports multi-instance deployments where OAuth callback may hit a different server instance +func (h *OAuthHandler) StorePendingMCPClient(oauthConfigID string, mcpClientConfig schemas.MCPClientConfig) error { + return h.oauthProvider.StorePendingMCPClient(oauthConfigID, mcpClientConfig) } -// GetPendingMCPClient retrieves a pending MCP client config by mcp_client_id -func (h *OAuthHandler) GetPendingMCPClient(mcpClientID string) *schemas.MCPClientConfig { - return h.oauthProvider.GetPendingMCPClient(mcpClientID) +// GetPendingMCPClient retrieves a pending MCP client config by oauth_config_id +func (h *OAuthHandler) GetPendingMCPClient(oauthConfigID string) (*schemas.MCPClientConfig, error) { + return h.oauthProvider.GetPendingMCPClient(oauthConfigID) +} + +// GetPendingMCPClientByState retrieves a pending MCP client config by OAuth state token +func (h *OAuthHandler) GetPendingMCPClientByState(state string) (*schemas.MCPClientConfig, string, error) { + return h.oauthProvider.GetPendingMCPClientByState(state) } // RemovePendingMCPClient removes a pending MCP client after OAuth completion -func (h *OAuthHandler) RemovePendingMCPClient(mcpClientID string) { - h.oauthProvider.RemovePendingMCPClient(mcpClientID) +func (h *OAuthHandler) RemovePendingMCPClient(oauthConfigID string) error { + return h.oauthProvider.RemovePendingMCPClient(oauthConfigID) } diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 6300ce423e..e774c52a5d 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -90,7 +90,7 @@ type ConfigData struct { AuthConfig *configstore.AuthConfig `json:"auth_config,omitempty"` Providers map[string]configstore.ProviderConfig `json:"providers"` FrameworkConfig *framework.FrameworkConfig `json:"framework,omitempty"` - MCP *configstoreTables.MCPConfig `json:"mcp,omitempty"` + MCP *schemas.MCPConfig `json:"mcp,omitempty"` Governance *configstore.GovernanceConfig `json:"governance,omitempty"` VectorStoreConfig *vectorstore.Config `json:"vector_store,omitempty"` ConfigStoreConfig *configstore.Config `json:"config_store,omitempty"` @@ -109,7 +109,7 @@ func (cd *ConfigData) UnmarshalJSON(data []byte) error { EncryptionKey string `json:"encryption_key"` AuthConfig *configstore.AuthConfig `json:"auth_config,omitempty"` Providers map[string]configstore.ProviderConfig `json:"providers"` - MCP *configstoreTables.MCPConfig `json:"mcp,omitempty"` + MCP *schemas.MCPConfig `json:"mcp,omitempty"` Governance *configstore.GovernanceConfig `json:"governance,omitempty"` VectorStoreConfig json.RawMessage `json:"vector_store,omitempty"` ConfigStoreConfig json.RawMessage `json:"config_store,omitempty"` @@ -245,7 +245,7 @@ type Config struct { // In-memory storage ClientConfig configstore.ClientConfig Providers map[schemas.ModelProvider]configstore.ProviderConfig - MCPConfig *configstoreTables.MCPConfig + MCPConfig *schemas.MCPConfig GovernanceConfig *configstore.GovernanceConfig FrameworkConfig *framework.FrameworkConfig ProxyConfig *configstoreTables.GlobalProxyConfig @@ -268,8 +268,8 @@ type Config struct { pluginStatusMu sync.RWMutex pluginStatus map[string]schemas.PluginStatus // name -> status - OAuthProvider *oauth2.OAuth2Provider - TokenRefreshWorker *oauth2.TokenRefreshWorker + OAuthProvider *oauth2.OAuth2Provider + TokenRefreshWorker *oauth2.TokenRefreshWorker // Catalog managers ModelCatalog *modelcatalog.ModelCatalog @@ -806,6 +806,12 @@ func reconcileProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []sc // loadMCPConfigFromFile loads and merges MCP config from file func loadMCPConfigFromFile(ctx context.Context, config *Config, configData *ConfigData) { + if config.ConfigStore == nil { + if configData.MCP != nil && len(configData.MCP.ClientConfigs) > 0 { + logger.Warn("config store is disabled - MCP manager will not be initialized. MCP clients require config store for persistence.") + } + return + } if config.ConfigStore != nil { logger.Debug("getting MCP config from store") tableMCPConfig, err := config.ConfigStore.GetMCPConfig(ctx) @@ -828,11 +834,8 @@ func loadMCPConfigFromFile(ctx context.Context, config *Config, configData *Conf if config.ConfigStore != nil && config.MCPConfig != nil { logger.Debug("updating MCP config in store") for _, clientConfig := range config.MCPConfig.ClientConfigs { - schemasClientConfig := configstore.ConvertTableMCPConfigToSchemas(&configstoreTables.MCPConfig{ - ClientConfigs: []configstoreTables.TableMCPClient{clientConfig}, - }, config.ClientConfig.MCPToolSyncInterval) - if schemasClientConfig != nil && len(schemasClientConfig.ClientConfigs) > 0 { - if err := config.ConfigStore.CreateMCPClientConfig(ctx, schemasClientConfig.ClientConfigs[0]); err != nil { + if clientConfig != nil { + if err := config.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { logger.Warn("failed to create MCP client config: %v", err) } } @@ -842,7 +845,7 @@ func loadMCPConfigFromFile(ctx context.Context, config *Config, configData *Conf } // mergeMCPConfig merges MCP config from file with store -func mergeMCPConfig(ctx context.Context, config *Config, configData *ConfigData, mcpConfig *configstoreTables.MCPConfig) { +func mergeMCPConfig(ctx context.Context, config *Config, configData *ConfigData, mcpConfig *schemas.MCPConfig) { logger.Debug("merging MCP config from config file with store") if configData.MCP == nil { @@ -851,11 +854,11 @@ func mergeMCPConfig(ctx context.Context, config *Config, configData *ConfigData, tempMCPConfig := configData.MCP config.MCPConfig = tempMCPConfig // Merge ClientConfigs arrays by ClientID or Name - clientConfigsToAdd := make([]configstoreTables.TableMCPClient, 0) + clientConfigsToAdd := make([]*schemas.MCPClientConfig, 0) for _, newClientConfig := range tempMCPConfig.ClientConfigs { found := false for _, existingClientConfig := range mcpConfig.ClientConfigs { - if (newClientConfig.ClientID != "" && existingClientConfig.ClientID == newClientConfig.ClientID) || + if (newClientConfig.ID != "" && existingClientConfig.ID == newClientConfig.ID) || (newClientConfig.Name != "" && existingClientConfig.Name == newClientConfig.Name) { found = true break @@ -871,11 +874,8 @@ func mergeMCPConfig(ctx context.Context, config *Config, configData *ConfigData, if config.ConfigStore != nil && len(clientConfigsToAdd) > 0 { logger.Debug("updating MCP config in store with %d new client configs", len(clientConfigsToAdd)) for _, clientConfig := range clientConfigsToAdd { - schemasClientConfig := configstore.ConvertTableMCPConfigToSchemas(&configstoreTables.MCPConfig{ - ClientConfigs: []configstoreTables.TableMCPClient{clientConfig}, - }, config.ClientConfig.MCPToolSyncInterval) - if schemasClientConfig != nil && len(schemasClientConfig.ClientConfigs) > 0 { - if err := config.ConfigStore.CreateMCPClientConfig(ctx, schemasClientConfig.ClientConfigs[0]); err != nil { + if clientConfig != nil { + if err := config.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { logger.Warn("failed to create MCP client config: %v", err) } } @@ -1486,8 +1486,8 @@ func mergePluginsFromFile(ctx context.Context, config *Config, configData *Confi } // convertSchemasMCPClientConfigToTable converts schemas.MCPClientConfig to tables.TableMCPClient -func convertSchemasMCPClientConfigToTable(clientConfig schemas.MCPClientConfig) configstoreTables.TableMCPClient { - return configstoreTables.TableMCPClient{ +func convertSchemasMCPClientConfigToTable(clientConfig *schemas.MCPClientConfig) *configstoreTables.TableMCPClient { + return &configstoreTables.TableMCPClient{ ClientID: clientConfig.ID, Name: clientConfig.Name, IsCodeModeClient: clientConfig.IsCodeModeClient, @@ -1860,11 +1860,8 @@ func loadDefaultMCPConfig(ctx context.Context, config *Config) error { if tableMCPConfig == nil { if config.MCPConfig != nil { for _, clientConfig := range config.MCPConfig.ClientConfigs { - schemasClientConfig := configstore.ConvertTableMCPConfigToSchemas(&configstoreTables.MCPConfig{ - ClientConfigs: []configstoreTables.TableMCPClient{clientConfig}, - }, config.ClientConfig.MCPToolSyncInterval) - if schemasClientConfig != nil && len(schemasClientConfig.ClientConfigs) > 0 { - if err := config.ConfigStore.CreateMCPClientConfig(ctx, schemasClientConfig.ClientConfigs[0]); err != nil { + if clientConfig != nil { + if err := config.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { logger.Warn("failed to create MCP client config: %v", err) continue } @@ -2414,11 +2411,11 @@ func (c *Config) GetPluginStatusByName(name string) (schemas.PluginStatus, bool) return status, ok } -// RegisterPlugin adds or updates a plugin in the registry +// ReloadPlugin adds or updates a plugin in the registry // This is the single entry point for all plugin additions/updates // If a plugin with the same name exists, it will be replaced (atomic find-and-replace) // If no plugin exists with that name, it will be added -func (c *Config) RegisterPlugin(plugin schemas.BasePlugin) error { +func (c *Config) ReloadPlugin(plugin schemas.BasePlugin) error { c.pluginsMu.Lock() defer c.pluginsMu.Unlock() @@ -2775,13 +2772,8 @@ func (c *Config) GetMCPClient(id string) (*schemas.MCPClientConfig, error) { } for _, clientConfig := range c.MCPConfig.ClientConfigs { - if clientConfig.ClientID == id { - schemasConfig := configstore.ConvertTableMCPConfigToSchemas(&configstoreTables.MCPConfig{ - ClientConfigs: []configstoreTables.TableMCPClient{clientConfig}, - }, c.ClientConfig.MCPToolSyncInterval) - if schemasConfig != nil && len(schemasConfig.ClientConfigs) > 0 { - return &schemasConfig.ClientConfigs[0], nil - } + if clientConfig.ID == id { + return clientConfig, nil } } @@ -2795,17 +2787,17 @@ func (c *Config) GetMCPClient(id string) (*schemas.MCPClientConfig, error) { // - Validates that the MCP client doesn't already exist // - Processes environment variables in the MCP client configuration // - Stores the processed configuration in memory -func (c *Config) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error { +func (c *Config) AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { if c.client == nil { return fmt.Errorf("bifrost client not set") } c.muMCP.Lock() defer c.muMCP.Unlock() if c.MCPConfig == nil { - c.MCPConfig = &configstoreTables.MCPConfig{} + c.MCPConfig = &schemas.MCPConfig{} } // Track new environment variables - c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs, convertSchemasMCPClientConfigToTable(clientConfig)) + c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs, clientConfig) // Config with processed env vars if err := c.client.AddMCPClient(clientConfig); err != nil { c.MCPConfig.ClientConfigs = c.MCPConfig.ClientConfigs[:len(c.MCPConfig.ClientConfigs)-1] @@ -2813,10 +2805,17 @@ func (c *Config) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClien } // Updating in config store if c.ConfigStore != nil { - if err := c.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { - return fmt.Errorf("failed to create MCP client config in store: %w", err) + skipDBUpdate := false + if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil { + if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok { + skipDBUpdate = skip + } + } + if !skipDBUpdate { + if err := c.ConfigStore.CreateMCPClientConfig(ctx, clientConfig); err != nil { + return fmt.Errorf("failed to create MCP client config in store: %w", err) + } } - // Update MCP catalog pricing data for the new client if c.MCPCatalog != nil { // Get the created client config from store to get tool_pricing @@ -2862,10 +2861,10 @@ func (c *Config) RemoveMCPClient(ctx context.Context, id string) error { } } // Find and remove client from in-memory config - var removedClientConfig *configstoreTables.TableMCPClient + var removedClientConfig *schemas.MCPClientConfig for i, clientConfig := range c.MCPConfig.ClientConfigs { - if clientConfig.ClientID == id { - removedClientConfig = &clientConfig + if clientConfig.ID == id { + removedClientConfig = clientConfig c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs[:i], c.MCPConfig.ClientConfigs[i+1:]...) break } @@ -2886,13 +2885,13 @@ func (c *Config) RemoveMCPClient(ctx context.Context, id string) error { return nil } -// EditMCPClient edits an MCP client configuration. +// UpdateMCPClient edits an MCP client configuration. // This allows for dynamic MCP client management at runtime with proper env var handling. // // Parameters: // - id: ID of the client to edit // - updatedConfig: Updated MCP client configuration -func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig configstoreTables.TableMCPClient) error { +func (c *Config) UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error { if c.client == nil { return fmt.Errorf("bifrost client not set") } @@ -2903,11 +2902,11 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig con return fmt.Errorf("no MCP config found") } // Find the existing client config - var oldConfig configstoreTables.TableMCPClient + var oldConfig *schemas.MCPClientConfig var found bool var configIndex int for i, clientConfig := range c.MCPConfig.ClientConfigs { - if clientConfig.ClientID == id { + if clientConfig.ID == id { oldConfig = clientConfig configIndex = i found = true @@ -2917,27 +2916,11 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig con if !found { return fmt.Errorf("MCP client '%s' not found", id) } - - // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) - // Use oldConfig for connection info since those fields are read-only and not sent in update request - schemasConfig := schemas.MCPClientConfig{ - ID: updatedConfig.ClientID, - Name: updatedConfig.Name, - IsCodeModeClient: updatedConfig.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(oldConfig.ConnectionType), - ConnectionString: oldConfig.ConnectionString, - StdioConfig: oldConfig.StdioConfig, - ToolsToExecute: updatedConfig.ToolsToExecute, - ToolsToAutoExecute: updatedConfig.ToolsToAutoExecute, - Headers: updatedConfig.Headers, - IsPingAvailable: updatedConfig.IsPingAvailable, - } - // Check if client is registered in Bifrost (can be not registered if client initialization failed) if clients, err := c.client.GetMCPClients(); err == nil && len(clients) > 0 { for _, client := range clients { if client.Config.ID == id { - if err := c.client.EditMCPClient(id, schemasConfig); err != nil { + if err := c.client.EditMCPClient(id, updatedConfig); err != nil { // Rollback in-memory changes c.MCPConfig.ClientConfigs[configIndex] = oldConfig return fmt.Errorf("failed to edit MCP client: %w", err) @@ -2946,66 +2929,28 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig con } } } - return nil -} - -// RemoveMCPClient removes an MCP client from the configuration. -// This method is called when an MCP client is removed via the HTTP API. -// -// The method: -// - Validates that the MCP client exists -// - Removes the MCP client from the configuration -// - Removes the MCP client from the Bifrost client -func (c *Config) RemoveMCPClient(ctx context.Context, id string) error { - if c.client == nil { - return fmt.Errorf("bifrost client not set") - } - c.muMCP.Lock() - defer c.muMCP.Unlock() - if c.MCPConfig == nil { - return fmt.Errorf("no MCP config found") - } - // Check if client is registered in Bifrost (can be not registered if client initialization failed) - if clients, err := c.client.GetMCPClients(); err == nil && len(clients) > 0 { - for _, client := range clients { - if client.Config.ID == id { - if err := c.client.RemoveMCPClient(id); err != nil { - return fmt.Errorf("failed to remove MCP client: %w", err) - } - break + // Update MCP catalog pricing data for the edited client + if c.MCPCatalog != nil { + // If the client name has changed, delete all old pricing entries under the old name + if updatedConfig.Name != oldConfig.Name { + for toolName := range oldConfig.ToolPricing { + c.MCPCatalog.DeletePricingData(oldConfig.Name, toolName) } - } - } - for i, clientConfig := range c.MCPConfig.ClientConfigs { - if clientConfig.ID == id { - c.MCPConfig.ClientConfigs = append(c.MCPConfig.ClientConfigs[:i], c.MCPConfig.ClientConfigs[i+1:]...) - break - } - - // Update MCP catalog pricing data for the edited client - if c.MCPCatalog != nil { - // If the client name has changed, delete all old pricing entries under the old name - if updatedConfig.Name != oldConfig.Name { - for toolName := range oldConfig.ToolPricing { - c.MCPCatalog.DeletePricingData(oldConfig.Name, toolName) - } - logger.Debug("deleted old MCP catalog pricing for renamed client: %s -> %s (%d tools)", oldConfig.Name, updatedConfig.Name, len(oldConfig.ToolPricing)) - } else { - // If name hasn't changed, remove pricing entries that were deleted - for toolName := range oldConfig.ToolPricing { - if _, exists := updatedConfig.ToolPricing[toolName]; !exists { - c.MCPCatalog.DeletePricingData(updatedConfig.Name, toolName) - } + logger.Debug("deleted old MCP catalog pricing for renamed client: %s -> %s (%d tools)", oldConfig.Name, updatedConfig.Name, len(oldConfig.ToolPricing)) + } else { + // If name hasn't changed, remove pricing entries that were deleted + for toolName := range oldConfig.ToolPricing { + if _, exists := updatedConfig.ToolPricing[toolName]; !exists { + c.MCPCatalog.DeletePricingData(updatedConfig.Name, toolName) } } - // Then, add or update pricing entries from the new config (with new name if changed) - for toolName, costPerExecution := range updatedConfig.ToolPricing { - c.MCPCatalog.UpdatePricingData(updatedConfig.Name, toolName, costPerExecution) - } - logger.Debug("updated MCP catalog pricing for client: %s (%d tools)", updatedConfig.Name, len(updatedConfig.ToolPricing)) } + // Then, add or update pricing entries from the new config (with new name if changed) + for toolName, costPerExecution := range updatedConfig.ToolPricing { + c.MCPCatalog.UpdatePricingData(updatedConfig.Name, toolName, costPerExecution) + } + logger.Debug("updated MCP catalog pricing for client: %s (%d tools)", updatedConfig.Name, len(updatedConfig.ToolPricing)) } - // Update the in-memory configuration with only the fields that were changed // Preserve connection info (connection_type, connection_string, stdio_config) from oldConfig // as these are read-only and not sent in the update request @@ -3019,11 +2964,12 @@ func (c *Config) RemoveMCPClient(ctx context.Context, id string) error { return nil } -// RedactTableMCPClient creates a redacted copy of a TableMCPClient configuration. +// RedactMCPClientConfig creates a redacted copy of a MCPClientConfig configuration. // Connection strings and headers are redacted for safe external exposure. -func (c *Config) RedactTableMCPClient(config configstoreTables.TableMCPClient) configstoreTables.TableMCPClient { - // Create a shallow copy - configCopy := config +func (c *Config) RedactMCPClientConfig(config *schemas.MCPClientConfig) *schemas.MCPClientConfig { + // Create an actual copy of the struct (not just a pointer copy) + // This prevents modifying the original config when redacting + configCopy := *config // Redact connection string if present if config.ConnectionString != nil { @@ -3038,7 +2984,7 @@ func (c *Config) RedactTableMCPClient(config configstoreTables.TableMCPClient) c } } - return configCopy + return &configCopy } // autoDetectProviders automatically detects common environment variables and sets up providers diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index d26cf3df94..0f93abb401 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -350,7 +350,7 @@ import ( type MockConfigStore struct { clientConfig *configstore.ClientConfig providers map[schemas.ModelProvider]configstore.ProviderConfig - mcpConfig *tables.MCPConfig + mcpConfig *schemas.MCPConfig governanceConfig *configstore.GovernanceConfig authConfig *configstore.AuthConfig frameworkConfig *tables.TableFrameworkConfig @@ -361,7 +361,7 @@ type MockConfigStore struct { // Track update calls for verification clientConfigUpdated bool providersConfigUpdated bool - mcpConfigsCreated []schemas.MCPClientConfig + mcpConfigsCreated []*schemas.MCPClientConfig mcpClientConfigUpdates []struct { ID string Config tables.TableMCPClient @@ -438,7 +438,7 @@ func (m *MockConfigStore) DeleteProvider(ctx context.Context, provider schemas.M } // MCP config -func (m *MockConfigStore) GetMCPConfig(ctx context.Context) (*tables.MCPConfig, error) { +func (m *MockConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { return m.mcpConfig, nil } @@ -446,25 +446,8 @@ func (m *MockConfigStore) GetMCPClientByName(ctx context.Context, name string) ( return nil, nil } -func (m *MockConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig schemas.MCPClientConfig) error { - if m.mcpConfig == nil { - m.mcpConfig = &tables.MCPConfig{ - ClientConfigs: []tables.TableMCPClient{}, - } - } - // Convert schemas.MCPClientConfig to tables.TableMCPClient - tableClient := tables.TableMCPClient{ - ClientID: clientConfig.ID, - Name: clientConfig.Name, - IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: string(clientConfig.ConnectionType), - ConnectionString: clientConfig.ConnectionString, - StdioConfig: clientConfig.StdioConfig, - ToolsToExecute: clientConfig.ToolsToExecute, - ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, - Headers: clientConfig.Headers, - } - m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, tableClient) +func (m *MockConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { + m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, clientConfig) m.mcpConfigsCreated = append(m.mcpConfigsCreated, clientConfig) return nil } @@ -480,20 +463,20 @@ func (m *MockConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, // Initialize m.mcpConfig if nil (same pattern as CreateMCPClientConfig) if m.mcpConfig == nil { - m.mcpConfig = &tables.MCPConfig{ - ClientConfigs: []tables.TableMCPClient{}, + m.mcpConfig = &schemas.MCPConfig{ + ClientConfigs: []*schemas.MCPClientConfig{}, } } // Update the in-memory state to ensure GetMCPConfig returns updated data for i := range m.mcpConfig.ClientConfigs { - if m.mcpConfig.ClientConfigs[i].ClientID == id { + if m.mcpConfig.ClientConfigs[i].ID == id { // Found the entry, update it with the new config - m.mcpConfig.ClientConfigs[i] = tables.TableMCPClient{ - ClientID: clientConfig.ClientID, + m.mcpConfig.ClientConfigs[i] = &schemas.MCPClientConfig{ + ID: clientConfig.ClientID, Name: clientConfig.Name, IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: clientConfig.ConnectionType, + ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), ConnectionString: clientConfig.ConnectionString, StdioConfig: clientConfig.StdioConfig, Headers: clientConfig.Headers, @@ -504,11 +487,11 @@ func (m *MockConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, } } // If not found, create a new entry (similar to CreateMCPClientConfig behavior) - m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, tables.TableMCPClient{ - ClientID: clientConfig.ClientID, + m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, &schemas.MCPClientConfig{ + ID: clientConfig.ClientID, Name: clientConfig.Name, IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: clientConfig.ConnectionType, + ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), ConnectionString: clientConfig.ConnectionString, StdioConfig: clientConfig.StdioConfig, Headers: clientConfig.Headers, @@ -1431,7 +1414,7 @@ func TestLoadConfig_Providers_Merge(t *testing.T) { func TestLoadConfig_MCP_Merge(t *testing.T) { // Setup DB MCP config dbMCPConfig := &schemas.MCPConfig{ - ClientConfigs: []schemas.MCPClientConfig{ + ClientConfigs: []*schemas.MCPClientConfig{ { ID: "mcp-1", Name: "db-client-1", @@ -1447,7 +1430,7 @@ func TestLoadConfig_MCP_Merge(t *testing.T) { // Setup file MCP config with some overlapping and some new fileMCPConfig := &schemas.MCPConfig{ - ClientConfigs: []schemas.MCPClientConfig{ + ClientConfigs: []*schemas.MCPClientConfig{ { ID: "mcp-1", // Same ID - should be skipped Name: "different-name", @@ -1467,7 +1450,7 @@ func TestLoadConfig_MCP_Merge(t *testing.T) { } // Simulate merge logic - clientConfigsToAdd := make([]schemas.MCPClientConfig, 0) + clientConfigsToAdd := make([]*schemas.MCPClientConfig, 0) for _, newClientConfig := range fileMCPConfig.ClientConfigs { found := false for _, existingClientConfig := range dbMCPConfig.ClientConfigs { @@ -9001,7 +8984,7 @@ func TestSQLite_VirtualKey_WithMCPConfigs(t *testing.T) { Name: "test-mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClientConfig) if err != nil { t.Fatalf("Failed to create MCP client: %v", err) } @@ -9090,7 +9073,7 @@ func TestSQLite_VKMCPConfig_Reconciliation(t *testing.T) { Name: "mcp-client-1", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig1) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClientConfig1) if err != nil { t.Fatalf("Failed to create MCP client 1: %v", err) } @@ -9100,7 +9083,7 @@ func TestSQLite_VKMCPConfig_Reconciliation(t *testing.T) { Name: "mcp-client-2", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig2) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClientConfig2) if err != nil { t.Fatalf("Failed to create MCP client 2: %v", err) } @@ -9412,7 +9395,7 @@ func TestSQLite_VirtualKey_DashboardMCPConfig_DeletedOnFileChange(t *testing.T) Name: "mcp-client-1", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClient1Config) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClient1Config) if err != nil { t.Fatalf("Failed to create MCP client 1: %v", err) } @@ -9422,7 +9405,7 @@ func TestSQLite_VirtualKey_DashboardMCPConfig_DeletedOnFileChange(t *testing.T) Name: "mcp-client-2", ConnectionType: schemas.MCPConnectionTypeHTTP, } - err = config1.ConfigStore.CreateMCPClientConfig(ctx, mcpClient2Config) + err = config1.ConfigStore.CreateMCPClientConfig(ctx, &mcpClient2Config) if err != nil { t.Fatalf("Failed to create MCP client 2: %v", err) } @@ -9577,8 +9560,8 @@ func TestSQLite_VKMCPConfig_AddRemove(t *testing.T) { } // Create MCP clients - config1.ConfigStore.CreateMCPClientConfig(ctx, schemas.MCPClientConfig{ID: "mcp-1", Name: "mcp-1", ConnectionType: schemas.MCPConnectionTypeHTTP}) - config1.ConfigStore.CreateMCPClientConfig(ctx, schemas.MCPClientConfig{ID: "mcp-2", Name: "mcp-2", ConnectionType: schemas.MCPConnectionTypeHTTP}) + config1.ConfigStore.CreateMCPClientConfig(ctx, &schemas.MCPClientConfig{ID: "mcp-1", Name: "mcp-1", ConnectionType: schemas.MCPConnectionTypeHTTP}) + config1.ConfigStore.CreateMCPClientConfig(ctx, &schemas.MCPClientConfig{ID: "mcp-2", Name: "mcp-2", ConnectionType: schemas.MCPConnectionTypeHTTP}) mcpClient1, _ := config1.ConfigStore.GetMCPClientByName(ctx, "mcp-1") mcpClient2, _ := config1.ConfigStore.GetMCPClientByName(ctx, "mcp-2") @@ -9699,7 +9682,7 @@ func TestSQLite_VKMCPConfig_UpdateTools(t *testing.T) { } // Create MCP client - config1.ConfigStore.CreateMCPClientConfig(ctx, schemas.MCPClientConfig{ID: "mcp-client", Name: "mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP}) + config1.ConfigStore.CreateMCPClientConfig(ctx, &schemas.MCPClientConfig{ID: "mcp-client", Name: "mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP}) mcpClient, _ := config1.ConfigStore.GetMCPClientByName(ctx, "mcp-client") // Create VK with MCP config @@ -9793,7 +9776,7 @@ func TestSQLite_VK_ProviderAndMCPConfigs_Combined(t *testing.T) { } // Create MCP client - config1.ConfigStore.CreateMCPClientConfig(ctx, schemas.MCPClientConfig{ID: "mcp-client", Name: "mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP}) + config1.ConfigStore.CreateMCPClientConfig(ctx, &schemas.MCPClientConfig{ID: "mcp-client", Name: "mcp-client", ConnectionType: schemas.MCPConnectionTypeHTTP}) mcpClient, _ := config1.ConfigStore.GetMCPClientByName(ctx, "mcp-client") config1.ConfigStore.Close(ctx) diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index 0019b8ebfd..15ee838836 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -17,10 +17,9 @@ import ( "github.com/maximhq/bifrost/transports/bifrost-http/lib" ) -// getPluginTypes determines which interface types a plugin implements -func getPluginTypes(plugin schemas.BasePlugin) []schemas.PluginType { +// InferPluginTypes determines which interface types a plugin implements +func InferPluginTypes(plugin schemas.BasePlugin) []schemas.PluginType { var types []schemas.PluginType - if _, ok := plugin.(schemas.LLMPlugin); ok { types = append(types, schemas.PluginTypeLLM) } @@ -30,7 +29,6 @@ func getPluginTypes(plugin schemas.BasePlugin) []schemas.PluginType { if _, ok := plugin.(schemas.HTTPTransportPlugin); ok { types = append(types, schemas.PluginTypeHTTP) } - return types } @@ -122,17 +120,16 @@ func loadCustomPlugin(ctx context.Context, path *string, pluginConfig any, bifro // InstantiatePlugins loads all plugins from configuration // This is called once during Bootstrap -func (s *BifrostHTTPServer) InstantiatePlugins(ctx context.Context) error { + +func (s *BifrostHTTPServer) LoadPlugins(ctx context.Context) error { // Load built-in plugins first (order matters) if err := s.loadBuiltinPlugins(ctx); err != nil { return err } - // Load custom plugins from config if err := s.loadCustomPlugins(ctx); err != nil { return err } - return nil } @@ -173,7 +170,6 @@ func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error { if isBuiltinPlugin(cfg.Name) { continue } - // Handle disabled plugins if !cfg.Enabled { // For custom plugins with a path, verify to get the real plugin name @@ -218,9 +214,9 @@ func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error { } // Register enabled plugin and mark as active - s.Config.RegisterPlugin(plugin) + s.Config.ReloadPlugin(plugin) s.Config.UpdatePluginOverallStatus(plugin.GetName(), cfg.Name, schemas.PluginStatusActive, - []string{fmt.Sprintf("plugin %s initialized successfully", cfg.Name)}, getPluginTypes(plugin)) + []string{fmt.Sprintf("plugin %s initialized successfully", cfg.Name)}, InferPluginTypes(plugin)) } return nil } diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index a41d9b8283..d51aed898b 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -47,33 +47,43 @@ var enterprisePlugins = []string{ // ServerCallbacks is a interface that defines the callbacks for the server. type ServerCallbacks interface { + // Plugins callbacks ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error RemovePlugin(ctx context.Context, name string) error GetPluginStatus(ctx context.Context) map[string]schemas.PluginStatus - GetModelsForProvider(provider schemas.ModelProvider) []string + // Auth related callbacks UpdateAuthConfig(ctx context.Context, authConfig *configstore.AuthConfig) error ReloadClientConfigFromConfigStore(ctx context.Context) error + // Pricing related callbacks ReloadPricingManager(ctx context.Context) error ForceReloadPricing(ctx context.Context) error + // Proxy related callbacks ReloadProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error + // Client config related callbacks ReloadHeaderFilterConfig(ctx context.Context, config *tables.GlobalHeaderFilterConfig) error UpdateDropExcessRequests(ctx context.Context, value bool) - UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error + // Governance related callbacks + GetGovernanceData() *governance.GovernanceData ReloadTeam(ctx context.Context, id string) (*tables.TableTeam, error) RemoveTeam(ctx context.Context, id string) error ReloadCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) RemoveCustomer(ctx context.Context, id string) error + // Virtual key related callbacks ReloadVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) RemoveVirtualKey(ctx context.Context, id string) error + // Provider related callbacks + GetModelsForProvider(provider schemas.ModelProvider) []string ReloadModelConfig(ctx context.Context, id string) (*tables.TableModelConfig, error) RemoveModelConfig(ctx context.Context, id string) error ReloadProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) RemoveProvider(ctx context.Context, provider schemas.ModelProvider) error - GetGovernanceData() *governance.GovernanceData - ReconnectMCPClient(ctx context.Context, id string) error - AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error + // MCP related callbacks + AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error RemoveMCPClient(ctx context.Context, id string) error - EditMCPClient(ctx context.Context, id string, updatedConfig tables.TableMCPClient) error + UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error + ReconnectMCPClient(ctx context.Context, id string) error + // Logging related callbacks NewLogEntryAdded(ctx context.Context, logEntry *logstore.Log) error } @@ -139,7 +149,7 @@ func (s *GovernanceInMemoryStore) GetConfiguredProviders() map[schemas.ModelProv } // AddMCPClient adds a new MCP client to the in-memory store -func (s *BifrostHTTPServer) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error { +func (s *BifrostHTTPServer) AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { if err := s.Config.AddMCPClient(ctx, clientConfig); err != nil { return err } @@ -149,9 +159,36 @@ func (s *BifrostHTTPServer) AddMCPClient(ctx context.Context, clientConfig schem return nil } -// EditMCPClient edits an MCP client in the in-memory store -func (s *BifrostHTTPServer) EditMCPClient(ctx context.Context, id string, updatedConfig tables.TableMCPClient) error { - if err := s.Config.EditMCPClient(ctx, id, updatedConfig); err != nil { +// ReconnectMCPClient reconnects an MCP client to the in-memory store +func (s *BifrostHTTPServer) ReconnectMCPClient(ctx context.Context, id string) error { + // Check if client is registered in Bifrost (can be not registered if client initialization failed) + if clients, err := s.Client.GetMCPClients(); err == nil && len(clients) > 0 { + for _, client := range clients { + if client.Config.ID == id { + if err := s.Client.ReconnectMCPClient(id); err != nil { + return err + } + return nil + } + } + } + // Config exists in store, but not in Bifrost (can happen if client initialization failed) + clientConfig, err := s.Config.GetMCPClient(id) + if err != nil { + return err + } + if err := s.Client.AddMCPClient(clientConfig); err != nil { + return err + } + if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { + logger.Warn("failed to sync MCP servers after adding client: %v", err) + } + return nil +} + +// UpdateMCPClient updates an MCP client in the in-memory store +func (s *BifrostHTTPServer) UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error { + if err := s.Config.UpdateMCPClient(ctx, id, updatedConfig); err != nil { return err } if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { @@ -502,7 +539,7 @@ func (s *BifrostHTTPServer) ReloadClientConfigFromConfigStore(ctx context.Contex account := lib.NewBaseAccount(s.Config) var mcpConfig *schemas.MCPConfig if s.Config.MCPConfig != nil { - mcpConfig = configstore.ConvertTableMCPConfigToSchemas(s.Config.MCPConfig, s.Config.ClientConfig.MCPToolSyncInterval) + mcpConfig = s.Config.MCPConfig } s.Client.ReloadConfig(schemas.BifrostConfig{ Account: account, @@ -564,28 +601,9 @@ func (s *BifrostHTTPServer) UpdateMCPToolManagerConfig(ctx context.Context, maxA return s.Client.UpdateToolManagerConfig(maxAgentDepth, toolExecutionTimeoutInSeconds, codeModeBindingLevel) } -// reloadBifrostPlugins syncs Config plugins to Bifrost client -func (s *BifrostHTTPServer) reloadBifrostPlugins() error { - account := lib.NewBaseAccount(s.Config) - var mcpConfig *schemas.MCPConfig - if s.Config.MCPConfig != nil { - mcpConfig = configstore.ConvertTableMCPConfigToSchemas(s.Config.MCPConfig, s.Config.ClientConfig.MCPToolSyncInterval) - } - - return s.Client.ReloadConfig(schemas.BifrostConfig{ - Account: account, - InitialPoolSize: s.Config.ClientConfig.InitialPoolSize, - DropExcessRequests: s.Config.ClientConfig.DropExcessRequests, - LLMPlugins: s.Config.GetLoadedLLMPlugins(), - MCPPlugins: s.Config.GetLoadedMCPPlugins(), - MCPConfig: mcpConfig, - Logger: logger, - }) -} - // reloadObservabilityPlugins reloads all observability plugins in the tracing middleware func (s *BifrostHTTPServer) reloadObservabilityPlugins() { - observabilityPlugins := s.collectObservabilityPlugins() + observabilityPlugins := s.CollectObservabilityPlugins() // Always update the tracing middleware, even with empty slice, to clear stale plugins s.TracingMiddleware.SetObservabilityPlugins(observabilityPlugins) } @@ -651,50 +669,50 @@ func (s *BifrostHTTPServer) GetPluginStatus(ctx context.Context) map[string]sche return s.Config.GetPluginStatus() } -// ReloadPlugin reloads a plugin with new instance and updates Bifrost core. -// The plugin is checked for LLM and MCP interfaces independently and registered -// to the appropriate arrays based on which interfaces it implements. -func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error { - logger.Debug("reloading plugin %s", name) - - // Helper to update error status - updateError := func(step string, err error) error { - if err := s.Config.UpdatePluginStatus(name, schemas.PluginStatusError); err != nil { - return err - } - if err := s.Config.AppendPluginStateLogs(name, []string{fmt.Sprintf("error %s plugin %s: %v", step, name, err)}); err != nil { - return err - } +// Helper to update error status +func (s *BifrostHTTPServer) updatePluginErrorStatus(name, step string, err error) error { + if err := s.Config.UpdatePluginStatus(name, schemas.PluginStatusError); err != nil { return err } - - // 1. Instantiate new version - plugin, err := InstantiatePlugin(ctx, name, path, pluginConfig, s.Config) - if err != nil { - return updateError("loading", err) + if err := s.Config.AppendPluginStateLogs(name, []string{fmt.Sprintf("error %s plugin %s: %v", step, name, err)}); err != nil { + return err } + return err +} +// SyncLoadedPlugin syncs a loaded plugin to the Bifrost client and updates the plugin status +func (s *BifrostHTTPServer) SyncLoadedPlugin(ctx context.Context, name string, plugin schemas.BasePlugin) error { // 2. Register (replaces old version atomically) - if err := s.Config.RegisterPlugin(plugin); err != nil { - return updateError("registering", err) + if err := s.Config.ReloadPlugin(plugin); err != nil { + return s.updatePluginErrorStatus(plugin.GetName(), "registering", err) } - // 3. Update Bifrost client - if err := s.reloadBifrostPlugins(); err != nil { - return updateError("reloading bifrost config for", err) + if err := s.Client.ReloadPlugin(plugin, InferPluginTypes(plugin)); err != nil { + return s.updatePluginErrorStatus(plugin.GetName(), "reloading bifrost config for", err) } - // 4. Special handling for observability plugins if _, ok := plugin.(schemas.ObservabilityPlugin); ok { s.reloadObservabilityPlugins() } - // 5. Update plugin status s.Config.UpdatePluginOverallStatus(plugin.GetName(), name, schemas.PluginStatusActive, - []string{fmt.Sprintf("plugin %s reloaded successfully", name)}, getPluginTypes(plugin)) + []string{fmt.Sprintf("plugin %s reloaded successfully", name)}, InferPluginTypes(plugin)) return nil } +// ReloadPlugin reloads a plugin with new instance and updates Bifrost core. +// The plugin is checked for LLM and MCP interfaces independently and registered +// to the appropriate arrays based on which interfaces it implements. +func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error { + logger.Debug("reloading plugin %s", name) + // 1. Instantiate new version + plugin, err := InstantiatePlugin(ctx, name, path, pluginConfig, s.Config) + if err != nil { + return s.updatePluginErrorStatus(name, "loading", err) + } + return s.SyncLoadedPlugin(ctx, name, plugin) +} + // RemovePlugin removes a plugin from the server. // The plugin is removed from both LLM and MCP arrays independently if it exists in them. func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, displayName string) error { @@ -706,7 +724,9 @@ func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, displayName string // Check if plugin implements ObservabilityPlugin before removal var isObservability bool - if plugin, err := s.Config.FindPluginByName(name); err == nil { + var err error + var plugin schemas.BasePlugin + if plugin, err = s.Config.FindPluginByName(name); err == nil { _, isObservability = plugin.(schemas.ObservabilityPlugin) } @@ -716,7 +736,7 @@ func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, displayName string } // 2. Update Bifrost client - if err := s.reloadBifrostPlugins(); err != nil { + if err := s.Client.RemovePlugin(name, InferPluginTypes(plugin)); err != nil { logger.Warn("failed to reload bifrost config after plugin removal: %v", err) } @@ -958,14 +978,14 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { } } // Load all plugins - if err := s.InstantiatePlugins(ctx); err != nil { + if err := s.LoadPlugins(ctx); err != nil { return fmt.Errorf("failed to instantiate plugins: %v", err) } tableMCPConfig := s.Config.MCPConfig var mcpConfig *schemas.MCPConfig if tableMCPConfig != nil { - mcpConfig = configstore.ConvertTableMCPConfigToSchemas(tableMCPConfig, s.Config.ClientConfig.MCPToolSyncInterval) + mcpConfig = s.Config.MCPConfig if mcpConfig != nil { mcpConfig.FetchNewRequestIDFunc = func(ctx *schemas.BifrostContext) string { return uuid.New().String() @@ -1033,7 +1053,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { // Registering inference middlewares inferenceMiddlewares = append([]schemas.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config)}, inferenceMiddlewares...) // Curating observability plugins - observabilityPlugins := s.collectObservabilityPlugins() + observabilityPlugins := s.CollectObservabilityPlugins() // This enables the central streaming accumulator for both use cases // Initializing tracer with embedded streaming accumulator traceStore := tracing.NewTraceStore(60*time.Minute, logger) diff --git a/transports/bifrost-http/server/utils.go b/transports/bifrost-http/server/utils.go index 4f2c939784..4abd274e85 100644 --- a/transports/bifrost-http/server/utils.go +++ b/transports/bifrost-http/server/utils.go @@ -93,14 +93,14 @@ func (s *BifrostHTTPServer) registerPluginWithStatus(ctx context.Context, name s return nil } - s.Config.RegisterPlugin(plugin) + s.Config.ReloadPlugin(plugin) s.Config.UpdatePluginOverallStatus(name, name, schemas.PluginStatusActive, - []string{fmt.Sprintf("%s plugin initialized successfully", name)}, getPluginTypes(plugin)) + []string{fmt.Sprintf("%s plugin initialized successfully", name)}, InferPluginTypes(plugin)) return nil } -// collectObservabilityPlugins gathers all loaded plugins that implement ObservabilityPlugin interface -func (s *BifrostHTTPServer) collectObservabilityPlugins() []schemas.ObservabilityPlugin { +// CollectObservabilityPlugins gathers all loaded plugins that implement ObservabilityPlugin interface +func (s *BifrostHTTPServer) CollectObservabilityPlugins() []schemas.ObservabilityPlugin { var observabilityPlugins []schemas.ObservabilityPlugin // Check LLM plugins diff --git a/ui/app/workspace/mcp-registry/views/oauth2Authorizer.tsx b/ui/app/workspace/mcp-registry/views/oauth2Authorizer.tsx index a69d4e5017..2416537a42 100644 --- a/ui/app/workspace/mcp-registry/views/oauth2Authorizer.tsx +++ b/ui/app/workspace/mcp-registry/views/oauth2Authorizer.tsx @@ -51,8 +51,9 @@ export const OAuth2Authorizer: React.FC = ({ } // Call complete-oauth endpoint using RTK Query mutation + // Use oauthConfigId instead of mcpClientId for multi-instance support try { - await completeOAuth(mcpClientId).unwrap() + await completeOAuth(oauthConfigId).unwrap() setStatus("success") onSuccess() setTimeout(() => { @@ -64,7 +65,7 @@ export const OAuth2Authorizer: React.FC = ({ setErrorMessage(errMsg) onError(errMsg) } - }, [mcpClientId, completeOAuth, onSuccess, onClose, onError]) + }, [oauthConfigId, completeOAuth, onSuccess, onClose, onError]) // Handle OAuth failure const handleOAuthFailed = useCallback((reason: string) => { diff --git a/ui/components/ui/entityAssociationSelect.tsx b/ui/components/ui/entityAssociationSelect.tsx new file mode 100644 index 0000000000..fb3f1f98af --- /dev/null +++ b/ui/components/ui/entityAssociationSelect.tsx @@ -0,0 +1,262 @@ +"use client" + +import { useCallback, useMemo } from "react" +import { AsyncMultiSelect } from "./asyncMultiselect" +import { Option } from "./multiselectUtils" +import { cn } from "./utils" + +// Entity types supported by this component +export type EntityType = "virtualKey" | "team" | "customer" | "user" | "provider" | "apiKey" + +// Generic entity option +export interface EntityOption { + id: string | number + label: string + description?: string + metadata?: Record +} + +// Meta type for AsyncMultiSelect +interface EntityOptionMeta { + id: string | number + description?: string + metadata?: Record +} + +interface EntityAssociationSelectProps { + /** The type of entity being selected */ + entityType: EntityType + /** Currently selected entity IDs */ + value: (string | number)[] + /** Callback when selection changes */ + onChange: (ids: (string | number)[]) => void + /** Placeholder text */ + placeholder?: string + /** Whether the component is disabled */ + disabled?: boolean + /** Additional CSS classes */ + className?: string + /** + * Custom reload function for fetching options. + * If provided, this will be used instead of the default behavior. + * The function should call the callback with filtered options. + */ + customReload?: (query: string, callback: (options: Option[]) => void) => void + /** + * Static options to use when customReload is not provided. + * This is useful for simple use cases where all options are already available. + */ + options?: EntityOption[] + /** + * Whether to allow creating new options + */ + isCreatable?: boolean + /** + * Callback when a new option is created + */ + onCreateOption?: (value: string) => void + /** + * Format function for creating new option labels + */ + formatCreateLabel?: (inputValue: string) => string + /** + * Message to display when no options are available + */ + noOptionsMessage?: () => string +} + +// Default placeholder text for each entity type +const defaultPlaceholders: Record = { + virtualKey: "Add virtual key names...", + team: "Add team names...", + customer: "Add customer names...", + user: "Add user names...", + provider: "Add provider names...", + apiKey: "Add API key names...", +} + +// Default no options messages for each entity type +const defaultNoOptionsMessages: Record = { + virtualKey: "No virtual keys found", + team: "No teams found", + customer: "No customers found", + user: "No users found", + provider: "No providers found", + apiKey: "No API keys found", +} + +// Label text for each entity type +export const entityTypeLabels: Record = { + virtualKey: "Virtual Keys", + team: "Teams", + customer: "Customers", + user: "Users", + provider: "Providers", + apiKey: "API Keys", +} + +export function EntityAssociationSelect({ + entityType, + value, + onChange, + placeholder, + disabled = false, + className, + customReload, + options = [], + isCreatable = false, + onCreateOption, + formatCreateLabel, + noOptionsMessage, +}: EntityAssociationSelectProps) { + // Convert static options to AsyncMultiSelect format using meta for complex data + const defaultOptions = useMemo((): Option[] => { + return options.map((opt) => ({ + label: opt.label, + value: String(opt.id), + meta: { + id: opt.id, + description: opt.description, + metadata: opt.metadata, + }, + })) + }, [options]) + + // Convert selected IDs to Option format + const selectedValues = useMemo((): Option[] => { + return value.map((id) => { + // Try to find the option in the provided options + const existingOption = options.find((opt) => opt.id === id) + if (existingOption) { + return { + label: existingOption.label, + value: String(existingOption.id), + meta: { + id: existingOption.id, + description: existingOption.description, + metadata: existingOption.metadata, + }, + } + } + // If not found, create a basic option + return { + label: String(id), + value: String(id), + meta: { + id, + }, + } + }) + }, [value, options]) + + // Filter options based on query + const filterOptions = useCallback( + (query: string, callback: (options: Option[]) => void) => { + const lowerQuery = query.toLowerCase() + const filtered = defaultOptions.filter( + (opt) => + opt.label.toLowerCase().includes(lowerQuery) || + opt.meta?.description?.toLowerCase().includes(lowerQuery) + ) + callback(filtered) + }, + [defaultOptions] + ) + + // Use custom reload if provided, otherwise use local filter + const reload = customReload || filterOptions + + // Handle selection change + const handleChange = useCallback( + (selected: Option[]) => { + const ids = selected.map((opt) => opt.meta?.id ?? opt.value) + onChange(ids) + }, + [onChange] + ) + + // Handle creating new option + const handleCreateOption = useCallback( + (inputValue: string) => { + if (onCreateOption) { + onCreateOption(inputValue) + } + }, + [onCreateOption] + ) + + return ( +
+ + placeholder={placeholder || defaultPlaceholders[entityType]} + disabled={disabled} + defaultOptions={defaultOptions} + reload={reload} + debounce={200} + onChange={handleChange} + value={selectedValues} + isClearable + closeMenuOnSelect={false} + hideSelectedOptions={false} + isCreatable={isCreatable} + onCreateOption={handleCreateOption} + formatCreateLabel={formatCreateLabel || ((value) => `Add "${value}"`)} + noOptionsMessage={noOptionsMessage || (() => defaultNoOptionsMessages[entityType])} + views={{ + option: (props) => { + // Access data as Option since that's the actual runtime type + const data = props.data as unknown as Option + return ( +
props.selectOption(props.data)} + > +
+ + {data.label} + + {props.isSelected && ( + Selected + )} +
+ {data.meta?.description && ( + + {data.meta.description} + + )} +
+ ) + }, + }} + /> +
+ ) +} + +/** + * Helper function to create EntityOption from simple ID list + * Useful when you just have IDs without full entity info + */ +export function createOptionsFromIds(ids: (string | number)[]): EntityOption[] { + return ids.map((id) => ({ + id, + label: String(id), + })) +} + +/** + * Helper function to create EntityOption with label mapping + */ +export function createOptionsWithLabels( + items: { id: string | number; name?: string; label?: string; description?: string }[] +): EntityOption[] { + return items.map((item) => ({ + id: item.id, + label: item.name || item.label || String(item.id), + description: item.description, + })) +} diff --git a/ui/components/ui/mcpServerSelector.tsx b/ui/components/ui/mcpServerSelector.tsx new file mode 100644 index 0000000000..361502bd34 --- /dev/null +++ b/ui/components/ui/mcpServerSelector.tsx @@ -0,0 +1,241 @@ +"use client" + +import { ExternalLink, X } from "lucide-react" +import { useCallback, useMemo } from "react" +import { AsyncMultiSelect } from "./asyncMultiselect" +import { Badge } from "./badge" +import { Button } from "./button" +import { Option } from "./multiselectUtils" +import { cn } from "./utils" + +// Types +export interface MCPServerInfo { + config: { + client_id: string + name: string + connection_type?: string + } + tools?: { name: string }[] + state?: string +} + +interface ServerOptionMeta { + clientId: string + name: string + connectionType?: string + toolCount: number + state?: string +} + +interface MCPServerSelectorProps { + value: string[] + onChange: (serverIds: string[]) => void + mcpClients: MCPServerInfo[] + placeholder?: string + disabled?: boolean + className?: string + /** Base URL path for the MCP registry page (default: /workspace/mcp-registry) */ + registryPath?: string +} + +export function MCPServerSelector({ + value, + onChange, + mcpClients, + placeholder = "Search and select MCP servers...", + disabled = false, + className, + registryPath = "/workspace/mcp-registry", +}: MCPServerSelectorProps) { + // Create options from MCP clients using meta for complex data + const allServerOptions = useMemo((): Option[] => { + return mcpClients.map((client) => ({ + label: client.config.name, + value: client.config.client_id, + meta: { + clientId: client.config.client_id, + name: client.config.name, + connectionType: client.config.connection_type, + toolCount: client.tools?.length || 0, + state: client.state, + }, + })) + }, [mcpClients]) + + // Get full server info for selected servers + const selectedServersWithInfo = useMemo(() => { + return value.map((serverId) => { + const client = mcpClients.find((c) => c.config.client_id === serverId) + return { + clientId: serverId, + name: client?.config.name || serverId, + connectionType: client?.config.connection_type, + toolCount: client?.tools?.length || 0, + state: client?.state, + } + }) + }, [value, mcpClients]) + + // Filter out already selected servers from options + const availableOptions = useMemo(() => { + const selectedSet = new Set(value) + return allServerOptions.filter((opt) => !selectedSet.has(opt.value)) + }, [allServerOptions, value]) + + const handleSelectServer = useCallback( + (selected: Option[]) => { + if (selected.length === 0) return + + const newServer = selected[selected.length - 1] + if (!newServer?.meta) return + + // Check if already selected + if (!value.includes(newServer.meta.clientId)) { + onChange([...value, newServer.meta.clientId]) + } + }, + [value, onChange] + ) + + const handleRemoveServer = useCallback( + (serverId: string) => { + onChange(value.filter((id) => id !== serverId)) + }, + [value, onChange] + ) + + const reload = useCallback( + (query: string, callback: (options: Option[]) => void) => { + const lowerQuery = query.toLowerCase() + const filtered = availableOptions.filter((opt) => + opt.label.toLowerCase().includes(lowerQuery) + ) + callback(filtered) + }, + [availableOptions] + ) + + const getConnectionTypeBadgeVariant = (type?: string) => { + switch (type) { + case "http": + return "default" + case "stdio": + return "secondary" + case "sse": + return "outline" + default: + return "outline" + } + } + + const getStateBadgeVariant = (state?: string) => { + switch (state) { + case "connected": + return "success" + case "error": + return "destructive" + default: + return "secondary" + } + } + + return ( +
+ {/* Search dropdown */} + + placeholder={placeholder} + disabled={disabled} + defaultOptions={availableOptions} + reload={reload} + debounce={150} + onChange={handleSelectServer} + value={[]} + isClearable={false} + closeMenuOnSelect={true} + hideSelectedOptions={true} + controlShouldRenderValue={false} + noOptionsMessage={() => "No MCP servers found"} + views={{ + option: (props) => { + // Access data as Option since that's the actual runtime type + const data = props.data as unknown as Option; + return ( +
props.selectOption(props.data)} + > +
+ {data.meta?.name} + {data.meta?.connectionType && ( + + {data.meta.connectionType} + + )} +
+ {data.meta?.toolCount} tools +
+ ); + }, + }} + /> + + {/* Selected servers list */} + {selectedServersWithInfo.length > 0 && ( +
+ {selectedServersWithInfo.map((server) => ( +
+
+ {server.name} + {server.connectionType && ( + + {server.connectionType} + + )} + {server.state && ( + + {server.state} + + )} + {server.toolCount} tools +
+
+ + +
+
+ ))} +
+ )} + + {/* Empty state */} + {selectedServersWithInfo.length === 0 && ( +
+ No servers selected. Use the search above to add MCP servers. +
+ )} +
+ ); +} diff --git a/ui/components/ui/mcpToolSelector.tsx b/ui/components/ui/mcpToolSelector.tsx new file mode 100644 index 0000000000..e54b85517a --- /dev/null +++ b/ui/components/ui/mcpToolSelector.tsx @@ -0,0 +1,323 @@ +"use client" + +import { CodeEditor } from "@/app/workspace/logs/views/codeEditor"; +import { ChevronDown, ChevronRight, X } from "lucide-react" +import { useCallback, useMemo, useState } from "react" +import { components, OptionProps } from "react-select"; +import { AsyncMultiSelect } from "./asyncMultiselect" +import { Badge } from "./badge" +import { Button } from "./button" +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "./collapsible" +import { Option } from "./multiselectUtils"; +import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "./table"; +import { cn } from "./utils" + +// Types +export interface SelectedTool { + mcpClientId: string + toolName: string +} + +export interface ToolFunction { + name: string + description?: string + // Parameters can be any object type to support various schema formats + parameters?: Record | object + strict?: boolean +} + +export interface MCPClientInfo { + config: { + client_id: string + name: string + connection_type?: string + } + tools: ToolFunction[] + state?: string +} + +interface ToolOptionMeta { + mcpClientId: string + mcpClientName: string + toolName: string + description?: string + parameters?: Record | object +} + +interface MCPToolSelectorProps { + value: SelectedTool[] + onChange: (tools: SelectedTool[]) => void + mcpClients: MCPClientInfo[] + placeholder?: string + disabled?: boolean + className?: string +} + +export function MCPToolSelector({ + value, + onChange, + mcpClients, + placeholder = "Search and select tools...", + disabled = false, + className, +}: MCPToolSelectorProps) { + const [expandedTools, setExpandedTools] = useState>(new Set()) + + // Flatten all tools from all MCP clients into searchable options + // Using meta field for complex data as per Option type definition + const allToolOptions = useMemo(() => { + const options: Option[] = [] + + for (const client of mcpClients) { + if (!client.tools) continue + + for (const tool of client.tools) { + const key = `${client.config.client_id}:${tool.name}` + + options.push({ + label: `${client.config.name} / ${tool.name}`, + value: key, + meta: { + mcpClientId: client.config.client_id, + mcpClientName: client.config.name, + toolName: tool.name, + description: tool.description, + parameters: tool.parameters, + }, + }) + } + } + + return options + }, [mcpClients]) + + // Get full tool info for selected tools + const selectedToolsWithInfo = useMemo(() => { + return value.map((selected) => { + const client = mcpClients.find((c) => c.config.client_id === selected.mcpClientId) + const tool = client?.tools?.find((t) => t.name === selected.toolName) + return { + ...selected, + mcpClientName: client?.config.name || selected.mcpClientId, + description: tool?.description, + parameters: tool?.parameters, + } + }) + }, [value, mcpClients]) + + // Filter out already selected tools from options + const availableOptions = useMemo(() => { + const selectedKeys = new Set( + value.map((t) => `${t.mcpClientId}:${t.toolName}`) + ) + return allToolOptions.filter((opt) => !selectedKeys.has(opt.value)) + }, [allToolOptions, value]) + + const toggleExpanded = useCallback((key: string) => { + setExpandedTools((prev) => { + const next = new Set(prev) + if (next.has(key)) { + next.delete(key) + } else { + next.add(key) + } + return next + }) + }, []) + + const handleSelectTool = useCallback( + (selected: Option[]) => { + if (selected.length === 0) return + + const newTool = selected[selected.length - 1] + if (!newTool?.meta) return + + const newSelectedTool: SelectedTool = { + mcpClientId: newTool.meta.mcpClientId, + toolName: newTool.meta.toolName, + } + + // Check if already selected + const exists = value.some( + (t) => t.mcpClientId === newSelectedTool.mcpClientId && t.toolName === newSelectedTool.toolName + ) + + if (!exists) { + onChange([...value, newSelectedTool]) + } + }, + [value, onChange] + ) + + const handleRemoveTool = useCallback( + (mcpClientId: string, toolName: string) => { + onChange(value.filter((t) => !(t.mcpClientId === mcpClientId && t.toolName === toolName))) + }, + [value, onChange] + ) + + const reload = useCallback( + (query: string, callback: (options: Option[]) => void) => { + const lowerQuery = query.toLowerCase() + const filtered = availableOptions.filter( + (opt) => + opt.label.toLowerCase().includes(lowerQuery) || + opt.meta?.description?.toLowerCase().includes(lowerQuery) + ) + callback(filtered) + }, + [availableOptions] + ) + + return ( +
+ {/* Search dropdown */} + + placeholder={placeholder} + disabled={disabled} + defaultOptions={availableOptions} + reload={reload} + debounce={150} + onChange={handleSelectTool} + value={[]} + isClearable={false} + closeMenuOnSelect={true} + hideSelectedOptions={true} + controlShouldRenderValue={false} + noOptionsMessage={() => "No results found"} + views={{ + option: (optionProps: OptionProps) => { + const { Option } = components; + // Access data as Option since that's the actual runtime type + const data = optionProps.data as unknown as Option; + return ( + + ); + }, + }} + /> + + {/* Selected tools table */} + {selectedToolsWithInfo.length > 0 && ( +
+ + + + + Tool + Server + + + + + {selectedToolsWithInfo.map((tool) => { + const key = `${tool.mcpClientId}:${tool.toolName}`; + const isExpanded = expandedTools.has(key); + + return ( + toggleExpanded(key)} asChild> + <> + + + + + + + +
+
{tool.toolName}
+ {tool.description &&

{tool.description}

} +
+
+ + {tool.mcpClientName} + + + + +
+ +
+ + + + + + ); + })} + +
+
+
Parameters Schema
+ {tool.parameters ? ( + + ) : ( +
No parameters defined
+ )} +
+
+
+ )} + + {/* Empty state */} + {selectedToolsWithInfo.length === 0 && ( +
+ No tools selected. Use the search above to add tools. +
+ )} +
+ ); +} diff --git a/ui/lib/store/apis/baseApi.ts b/ui/lib/store/apis/baseApi.ts index d7cbf0748e..7491481c17 100644 --- a/ui/lib/store/apis/baseApi.ts +++ b/ui/lib/store/apis/baseApi.ts @@ -171,6 +171,7 @@ export const baseApi = createApi({ "Permissions", "APIKeys", "OAuth2Config", + "MCPToolGroups", ], endpoints: () => ({}), }); diff --git a/ui/lib/store/apis/mcpApi.ts b/ui/lib/store/apis/mcpApi.ts index e342f2cd64..291e2dcece 100644 --- a/ui/lib/store/apis/mcpApi.ts +++ b/ui/lib/store/apis/mcpApi.ts @@ -57,8 +57,8 @@ export const mcpApi = baseApi.injectEndpoints({ // Complete OAuth flow for MCP client completeOAuthFlow: builder.mutation<{ status: string; message: string }, string>({ - query: (mcpClientId) => ({ - url: `/mcp/client/${mcpClientId}/complete-oauth`, + query: (oauthConfigId) => ({ + url: `/mcp/client/${oauthConfigId}/complete-oauth`, method: "POST", }), invalidatesTags: ["MCPClients"], diff --git a/ui/lib/types/mcp.ts b/ui/lib/types/mcp.ts index a31cb27cb5..43b3c8a45e 100644 --- a/ui/lib/types/mcp.ts +++ b/ui/lib/types/mcp.ts @@ -91,3 +91,15 @@ export interface UpdateMCPClientRequest { tool_pricing?: Record; tool_sync_interval?: number; // Per-client override in minutes (0 = use global, -1 = disabled) } + +// Types for MCP Tool Selector component +export interface SelectedTool { + mcpClientId: string; + toolName: string; +} + +// MCP Tool Spec for tool groups (matches backend schema) +export interface MCPToolSpec { + mcp_client_id: string; + tool_names: string[]; +}