diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index af87cdc743..0f08fcaf91 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -15,14 +15,14 @@ import ( // OAuth-related errors var ( - ErrOAuth2ConfigNotFound = errors.New("oauth2 config not found") - ErrOAuth2ProviderNotAvailable = errors.New("oauth2 provider not available") - ErrOAuth2TokenExpired = errors.New("oauth2 token expired") - ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid") - ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed") - ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session") - ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server") - ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired") + ErrOAuth2ConfigNotFound = errors.New("oauth2 config not found") + ErrOAuth2ProviderNotAvailable = errors.New("oauth2 provider not available") + ErrOAuth2TokenExpired = errors.New("oauth2 token expired") + ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid") + ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed") + ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session") + ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server") + ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired") ) // MCPUserOAuthRequiredError is returned when a per-user OAuth MCP server requires @@ -119,14 +119,14 @@ type MCPClientConfig struct { // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => auto-execute only the specified tools // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. - IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true. - ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) - ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) - ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) - AllowOnAllVirtualKeys bool `json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys + IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true. + ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) + ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) + ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + AllowOnAllVirtualKeys bool `json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys // Discovered tools for per-user OAuth clients (persisted so they survive restart) - DiscoveredTools map[string]ChatTool `json:"-"` // Discovered tool schemas keyed by prefixed name + DiscoveredTools map[string]ChatTool `json:"-"` // Discovered tool schemas keyed by prefixed name DiscoveredToolNameMapping map[string]string `json:"-"` // Mapping from sanitized tool names to original MCP names } diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 0eff1506a9..9030a2f774 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -1256,6 +1256,35 @@ func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tabl return &mcpClient, nil } +// GetMCPClientConfigByID retrieves an MCP client by ID and converts it to a schemas.MCPClientConfig. +// Unlike GetMCPClientByID, this includes DiscoveredTools and DiscoveredToolNameMapping. +func (s *RDBConfigStore) GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error) { + dbClient, err := s.GetMCPClientByID(ctx, id) + if err != nil { + return nil, err + } + return &schemas.MCPClientConfig{ + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: dbClient.ConnectionString, + StdioConfig: dbClient.StdioConfig, + AuthType: schemas.MCPAuthType(dbClient.AuthType), + OauthConfigID: dbClient.OauthConfigID, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: dbClient.Headers, + AllowedExtraHeaders: dbClient.AllowedExtraHeaders, + IsPingAvailable: dbClient.IsPingAvailable, + ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys, + ToolPricing: dbClient.ToolPricing, + DiscoveredTools: dbClient.DiscoveredTools, + DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping, + }, nil +} + // GetMCPClientByName retrieves an MCP client by name from the database. func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) { var mcpClient tables.TableMCPClient @@ -1282,21 +1311,24 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig } // Create new client dbClient := tables.TableMCPClient{ - ClientID: clientConfigCopy.ID, - Name: clientConfigCopy.Name, - IsCodeModeClient: clientConfigCopy.IsCodeModeClient, - ConnectionType: string(clientConfigCopy.ConnectionType), - ConnectionString: clientConfigCopy.ConnectionString, - StdioConfig: clientConfigCopy.StdioConfig, - AuthType: string(clientConfigCopy.AuthType), - OauthConfigID: clientConfigCopy.OauthConfigID, - ToolsToExecute: clientConfigCopy.ToolsToExecute, - ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, - Headers: clientConfigCopy.Headers, - AllowedExtraHeaders: clientConfigCopy.AllowedExtraHeaders, - IsPingAvailable: clientConfigCopy.IsPingAvailable, - ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()), - AllowOnAllVirtualKeys: clientConfigCopy.AllowOnAllVirtualKeys, + ClientID: clientConfigCopy.ID, + Name: clientConfigCopy.Name, + IsCodeModeClient: clientConfigCopy.IsCodeModeClient, + ConnectionType: string(clientConfigCopy.ConnectionType), + ConnectionString: clientConfigCopy.ConnectionString, + StdioConfig: clientConfigCopy.StdioConfig, + AuthType: string(clientConfigCopy.AuthType), + OauthConfigID: clientConfigCopy.OauthConfigID, + ToolsToExecute: clientConfigCopy.ToolsToExecute, + ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, + Headers: clientConfigCopy.Headers, + AllowedExtraHeaders: clientConfigCopy.AllowedExtraHeaders, + IsPingAvailable: clientConfigCopy.IsPingAvailable, + ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()), + AllowOnAllVirtualKeys: clientConfigCopy.AllowOnAllVirtualKeys, + // DiscoveredTools has json:"-" so deepCopy loses it; use original clientConfig + DiscoveredTools: clientConfig.DiscoveredTools, + DiscoveredToolNameMapping: clientConfig.DiscoveredToolNameMapping, } if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil { return s.parseGormError(err) @@ -1411,26 +1443,6 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c }) } -// UpdateMCPClientDiscoveredTools persists discovered tools for a per-user OAuth MCP client. -func (s *RDBConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) error { - toolsJSON, err := json.Marshal(tools) - if err != nil { - return fmt.Errorf("failed to marshal discovered tools: %w", err) - } - mappingJSON, err := json.Marshal(toolNameMapping) - if err != nil { - return fmt.Errorf("failed to marshal tool name mapping: %w", err) - } - return s.DB().WithContext(ctx). - Model(&tables.TableMCPClient{}). - Where("client_id = ?", clientID). - Updates(map[string]interface{}{ - "discovered_tools_json": string(toolsJSON), - "tool_name_mapping_json": string(mappingJSON), - "updated_at": time.Now(), - }).Error -} - // DeleteMCPClientConfig deletes an MCP client configuration from the database. func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error { return s.DB().Transaction(func(tx *gorm.DB) error { diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 16cedc6b6a..ee5b1f0eda 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -115,11 +115,11 @@ type ConfigStore interface { // MCP config CRUD GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error) + GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error - UpdateMCPClientDiscoveredTools(ctx context.Context, clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) error DeleteMCPClientConfig(ctx context.Context, id string) error // Vector store config CRUD diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 64ce383eb9..e380082920 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -1097,7 +1097,11 @@ func (h *MCPHandler) completeMCPClientOAuth(ctx *fasthttp.RequestCtx) { return } - // Persist MCP client config in config store + // Attach discovered tools before persisting so the DB row includes them from the start. + mcpClientConfig.DiscoveredTools = tools + mcpClientConfig.DiscoveredToolNameMapping = toolNameMapping + + // Persist MCP client config in config store (BeforeSave hook serializes DiscoveredTools) if h.store.ConfigStore != nil { if err := h.store.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create MCP config: %v", err)) @@ -1122,13 +1126,6 @@ func (h *MCPHandler) completeMCPClientOAuth(ctx *fasthttp.RequestCtx) { // Set discovered tools on the client h.mcpManager.SetClientTools(mcpClientConfig.ID, tools, toolNameMapping) - // Persist discovered tools to DB so they survive restart - if h.store.ConfigStore != nil { - if err := h.store.ConfigStore.UpdateMCPClientDiscoveredTools(ctx, mcpClientConfig.ID, tools, toolNameMapping); err != nil { - logger.Warn(fmt.Sprintf("[OAuth Complete] Failed to persist discovered tools for %s: %v", mcpClientConfig.ID, err)) - } - } - logger.Debug(fmt.Sprintf("[OAuth Complete] Per-user OAuth MCP client verified and created: %s (%d tools)", mcpClientConfig.ID, len(tools))) SendJSON(ctx, map[string]any{ "status": "success", diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 3835c9f192..fae2b334be 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -547,6 +547,10 @@ func (m *MockConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tab return nil, nil } +func (m *MockConfigStore) GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error) { + return nil, nil +} + func (m *MockConfigStore) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) { return nil, nil } @@ -613,10 +617,6 @@ func (m *MockConfigStore) GetMCPClientsPaginated(ctx context.Context, params con return nil, 0, nil } -func (m *MockConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) error { - return nil -} - func (m *MockConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error { return nil }