Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions core/schemas/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import (

// OAuth-related errors
var (
ErrOAuth2ConfigNotFound = errors.New("oauth2 config not found")
ErrOAuth2ProviderNotAvailable = errors.New("oauth2 provider not available")
ErrOAuth2TokenExpired = errors.New("oauth2 token expired")
ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid")
ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed")
ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session")
ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server")
ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired")
ErrOAuth2ConfigNotFound = errors.New("oauth2 config not found")
ErrOAuth2ProviderNotAvailable = errors.New("oauth2 provider not available")
ErrOAuth2TokenExpired = errors.New("oauth2 token expired")
ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid")
ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed")
ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session")
ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server")
ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired")
)

// MCPUserOAuthRequiredError is returned when a per-user OAuth MCP server requires
Expand Down Expand Up @@ -119,14 +119,14 @@ type MCPClientConfig struct {
// - nil/omitted => treated as [] (no tools)
// - ["tool1", "tool2"] => auto-execute only the specified tools
// Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped.
IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true.
ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled)
ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution)
ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized)
AllowOnAllVirtualKeys bool `json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys
IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true.
ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled)
ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution)
ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized)
AllowOnAllVirtualKeys bool `json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys

// Discovered tools for per-user OAuth clients (persisted so they survive restart)
DiscoveredTools map[string]ChatTool `json:"-"` // Discovered tool schemas keyed by prefixed name
DiscoveredTools map[string]ChatTool `json:"-"` // Discovered tool schemas keyed by prefixed name
DiscoveredToolNameMapping map[string]string `json:"-"` // Mapping from sanitized tool names to original MCP names
}

Expand Down
82 changes: 47 additions & 35 deletions framework/configstore/rdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,35 @@ func (s *RDBConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tabl
return &mcpClient, nil
}

// GetMCPClientConfigByID retrieves an MCP client by ID and converts it to a schemas.MCPClientConfig.
// Unlike GetMCPClientByID, this includes DiscoveredTools and DiscoveredToolNameMapping.
func (s *RDBConfigStore) GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error) {
dbClient, err := s.GetMCPClientByID(ctx, id)
if err != nil {
return nil, err
}
return &schemas.MCPClientConfig{
ID: dbClient.ClientID,
Name: dbClient.Name,
IsCodeModeClient: dbClient.IsCodeModeClient,
ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType),
ConnectionString: dbClient.ConnectionString,
StdioConfig: dbClient.StdioConfig,
AuthType: schemas.MCPAuthType(dbClient.AuthType),
OauthConfigID: dbClient.OauthConfigID,
ToolsToExecute: dbClient.ToolsToExecute,
ToolsToAutoExecute: dbClient.ToolsToAutoExecute,
Headers: dbClient.Headers,
AllowedExtraHeaders: dbClient.AllowedExtraHeaders,
IsPingAvailable: dbClient.IsPingAvailable,
ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute,
AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys,
ToolPricing: dbClient.ToolPricing,
DiscoveredTools: dbClient.DiscoveredTools,
DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping,
}, nil
}

// GetMCPClientByName retrieves an MCP client by name from the database.
func (s *RDBConfigStore) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) {
var mcpClient tables.TableMCPClient
Expand All @@ -1282,21 +1311,24 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig
}
// Create new client
dbClient := tables.TableMCPClient{
ClientID: clientConfigCopy.ID,
Name: clientConfigCopy.Name,
IsCodeModeClient: clientConfigCopy.IsCodeModeClient,
ConnectionType: string(clientConfigCopy.ConnectionType),
ConnectionString: clientConfigCopy.ConnectionString,
StdioConfig: clientConfigCopy.StdioConfig,
AuthType: string(clientConfigCopy.AuthType),
OauthConfigID: clientConfigCopy.OauthConfigID,
ToolsToExecute: clientConfigCopy.ToolsToExecute,
ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute,
Headers: clientConfigCopy.Headers,
AllowedExtraHeaders: clientConfigCopy.AllowedExtraHeaders,
IsPingAvailable: clientConfigCopy.IsPingAvailable,
ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()),
AllowOnAllVirtualKeys: clientConfigCopy.AllowOnAllVirtualKeys,
ClientID: clientConfigCopy.ID,
Name: clientConfigCopy.Name,
IsCodeModeClient: clientConfigCopy.IsCodeModeClient,
ConnectionType: string(clientConfigCopy.ConnectionType),
ConnectionString: clientConfigCopy.ConnectionString,
StdioConfig: clientConfigCopy.StdioConfig,
AuthType: string(clientConfigCopy.AuthType),
OauthConfigID: clientConfigCopy.OauthConfigID,
ToolsToExecute: clientConfigCopy.ToolsToExecute,
ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute,
Headers: clientConfigCopy.Headers,
AllowedExtraHeaders: clientConfigCopy.AllowedExtraHeaders,
IsPingAvailable: clientConfigCopy.IsPingAvailable,
ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()),
AllowOnAllVirtualKeys: clientConfigCopy.AllowOnAllVirtualKeys,
// DiscoveredTools has json:"-" so deepCopy loses it; use original clientConfig
DiscoveredTools: clientConfig.DiscoveredTools,
DiscoveredToolNameMapping: clientConfig.DiscoveredToolNameMapping,
}
if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil {
return s.parseGormError(err)
Expand Down Expand Up @@ -1411,26 +1443,6 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c
})
}

// UpdateMCPClientDiscoveredTools persists discovered tools for a per-user OAuth MCP client.
func (s *RDBConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) error {
toolsJSON, err := json.Marshal(tools)
if err != nil {
return fmt.Errorf("failed to marshal discovered tools: %w", err)
}
mappingJSON, err := json.Marshal(toolNameMapping)
if err != nil {
return fmt.Errorf("failed to marshal tool name mapping: %w", err)
}
return s.DB().WithContext(ctx).
Model(&tables.TableMCPClient{}).
Where("client_id = ?", clientID).
Updates(map[string]interface{}{
"discovered_tools_json": string(toolsJSON),
"tool_name_mapping_json": string(mappingJSON),
"updated_at": time.Now(),
}).Error
}

// DeleteMCPClientConfig deletes an MCP client configuration from the database.
func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error {
return s.DB().Transaction(func(tx *gorm.DB) error {
Expand Down
2 changes: 1 addition & 1 deletion framework/configstore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ type ConfigStore interface {
// MCP config CRUD
GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error)
GetMCPClientByID(ctx context.Context, id string) (*tables.TableMCPClient, error)
GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error)
GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error)
GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error)
CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error
UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error
UpdateMCPClientDiscoveredTools(ctx context.Context, clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) error
DeleteMCPClientConfig(ctx context.Context, id string) error

// Vector store config CRUD
Expand Down
13 changes: 5 additions & 8 deletions transports/bifrost-http/handlers/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,11 @@ func (h *MCPHandler) completeMCPClientOAuth(ctx *fasthttp.RequestCtx) {
return
}

// Persist MCP client config in config store
// Attach discovered tools before persisting so the DB row includes them from the start.
mcpClientConfig.DiscoveredTools = tools
mcpClientConfig.DiscoveredToolNameMapping = toolNameMapping

// Persist MCP client config in config store (BeforeSave hook serializes DiscoveredTools)
if h.store.ConfigStore != nil {
if err := h.store.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig); err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create MCP config: %v", err))
Expand All @@ -1122,13 +1126,6 @@ func (h *MCPHandler) completeMCPClientOAuth(ctx *fasthttp.RequestCtx) {
// Set discovered tools on the client
h.mcpManager.SetClientTools(mcpClientConfig.ID, tools, toolNameMapping)

// Persist discovered tools to DB so they survive restart
if h.store.ConfigStore != nil {
if err := h.store.ConfigStore.UpdateMCPClientDiscoveredTools(ctx, mcpClientConfig.ID, tools, toolNameMapping); err != nil {
logger.Warn(fmt.Sprintf("[OAuth Complete] Failed to persist discovered tools for %s: %v", mcpClientConfig.ID, err))
}
}

logger.Debug(fmt.Sprintf("[OAuth Complete] Per-user OAuth MCP client verified and created: %s (%d tools)", mcpClientConfig.ID, len(tools)))
SendJSON(ctx, map[string]any{
"status": "success",
Expand Down
8 changes: 4 additions & 4 deletions transports/bifrost-http/lib/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,10 @@ func (m *MockConfigStore) GetMCPClientByID(ctx context.Context, id string) (*tab
return nil, nil
}

func (m *MockConfigStore) GetMCPClientConfigByID(ctx context.Context, id string) (*schemas.MCPClientConfig, error) {
return nil, nil
}

func (m *MockConfigStore) GetMCPClientByName(ctx context.Context, name string) (*tables.TableMCPClient, error) {
return nil, nil
}
Expand Down Expand Up @@ -613,10 +617,6 @@ func (m *MockConfigStore) GetMCPClientsPaginated(ctx context.Context, params con
return nil, 0, nil
}

func (m *MockConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) error {
return nil
}

func (m *MockConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error {
return nil
}
Expand Down
Loading