Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 87 additions & 80 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ type Bifrost struct {
oauth2Provider schemas.OAuth2Provider // OAuth provider instance
logger schemas.Logger // logger instance, default logger is used if not provided
tracer atomic.Value // tracer for distributed tracing (stores schemas.Tracer, NoOpTracer if not configured)
mcpManager *mcp.MCPManager // MCP integration manager (nil if MCP not configured)
McpManager mcp.MCPManagerInterface // MCP integration manager (nil if MCP not configured)
mcpInitOnce sync.Once // Ensures MCP manager is initialized only once
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
keySelector schemas.KeySelector // Custom key selector function
Expand Down Expand Up @@ -186,6 +186,12 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
logger: config.Logger,
}
bifrost.tracer.Store(&tracerWrapper{tracer: tracer})
if config.LLMPlugins == nil {
config.LLMPlugins = make([]schemas.LLMPlugin, 0)
}
if config.MCPPlugins == nil {
config.MCPPlugins = make([]schemas.MCPPlugin, 0)
}
bifrost.llmPlugins.Store(&config.LLMPlugins)
bifrost.mcpPlugins.Store(&config.MCPPlugins)

Expand Down Expand Up @@ -280,7 +286,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
}
}
codeMode := starlark.NewStarlarkCodeMode(codeModeConfig)
bifrost.mcpManager = mcp.NewMCPManager(bifrostCtx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode)
bifrost.McpManager = mcp.NewMCPManager(bifrostCtx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode)
bifrost.logger.Info("MCP integration initialized successfully")
})
}
Expand Down Expand Up @@ -338,21 +344,6 @@ func (bifrost *Bifrost) getTracer() schemas.Tracer {
// We will keep on adding other aspects as required
func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error {
bifrost.dropExcessRequests.Store(config.DropExcessRequests)

// Update LLM plugins atomically
if config.LLMPlugins != nil {
llmPluginsCopy := make([]schemas.LLMPlugin, len(config.LLMPlugins))
copy(llmPluginsCopy, config.LLMPlugins)
bifrost.llmPlugins.Store(&llmPluginsCopy)
}

// Update MCP plugins atomically
if config.MCPPlugins != nil {
mcpPluginsCopy := make([]schemas.MCPPlugin, len(config.MCPPlugins))
copy(mcpPluginsCopy, config.MCPPlugins)
bifrost.mcpPlugins.Store(&mcpPluginsCopy)
}

return nil
}

Expand Down Expand Up @@ -722,8 +713,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx *schemas.BifrostContext, req *
}

// Check if we should enter agent mode
if bifrost.mcpManager != nil {
return bifrost.mcpManager.CheckAndExecuteAgentForChatRequest(
if bifrost.McpManager != nil {
return bifrost.McpManager.CheckAndExecuteAgentForChatRequest(
ctx,
req,
response,
Expand Down Expand Up @@ -819,8 +810,8 @@ func (bifrost *Bifrost) ResponsesRequest(ctx *schemas.BifrostContext, req *schem
}

// Check if we should enter agent mode
if bifrost.mcpManager != nil {
return bifrost.mcpManager.CheckAndExecuteAgentForResponsesRequest(
if bifrost.McpManager != nil {
return bifrost.McpManager.CheckAndExecuteAgentForResponsesRequest(
ctx,
req,
response,
Expand Down Expand Up @@ -2116,14 +2107,22 @@ func (bifrost *Bifrost) ContainerFileDeleteRequest(ctx *schemas.BifrostContext,
}

// RemovePlugin removes a plugin from the server.
func (bifrost *Bifrost) RemovePlugin(name string, pluginType schemas.PluginType) error {
switch pluginType {
case schemas.PluginTypeLLM:
return bifrost.removeLLMPlugin(name)
case schemas.PluginTypeMCP:
return bifrost.removeMCPPlugin(name)
}
return fmt.Errorf("unsupported plugin type: %v", pluginType)
func (bifrost *Bifrost) RemovePlugin(name string, pluginTypes []schemas.PluginType) error {
for _, pluginType := range pluginTypes {
switch pluginType {
case schemas.PluginTypeLLM:
err := bifrost.removeLLMPlugin(name)
if err != nil {
return err
}
case schemas.PluginTypeMCP:
err := bifrost.removeMCPPlugin(name)
if err != nil {
return err
}
}
}
return nil
Comment on lines +2110 to +2125
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate pluginTypes and reject unknown values (avoid silent no-ops).

Both methods now accept a slice but silently succeed on empty or unsupported types, which can mask misuse. Consider returning a clear error in those cases.

🛠️ Suggested fix
 func (bifrost *Bifrost) RemovePlugin(name string, pluginTypes []schemas.PluginType) error {
+	if len(pluginTypes) == 0 {
+		return fmt.Errorf("pluginTypes is required")
+	}
 	for _, pluginType := range pluginTypes {
 		switch pluginType {
 		case schemas.PluginTypeLLM:
 			err := bifrost.removeLLMPlugin(name)
 			if err != nil {
 				return err
 			}
 		case schemas.PluginTypeMCP:
 			err := bifrost.removeMCPPlugin(name)
 			if err != nil {
 				return err
 			}
+		default:
+			return fmt.Errorf("unsupported plugin type: %s", pluginType)
 		}
 	}
 	return nil
 }
 
 func (bifrost *Bifrost) ReloadPlugin(plugin schemas.BasePlugin, pluginTypes []schemas.PluginType) error {
+	if len(pluginTypes) == 0 {
+		return fmt.Errorf("pluginTypes is required")
+	}
 	for _, pluginType := range pluginTypes {
 		switch pluginType {
 		case schemas.PluginTypeLLM:
 			llmPlugin, ok := plugin.(schemas.LLMPlugin)
 			if !ok {
 				return fmt.Errorf("plugin %s is not an LLMPlugin", plugin.GetName())
 			}
 			err := bifrost.reloadLLMPlugin(llmPlugin)
 			if err != nil {
 				return err
 			}
 		case schemas.PluginTypeMCP:
 			mcpPlugin, ok := plugin.(schemas.MCPPlugin)
 			if !ok {
 				return fmt.Errorf("plugin %s is not an MCPPlugin", plugin.GetName())
 			}
 			err := bifrost.reloadMCPPlugin(mcpPlugin)
 			if err != nil {
 				return err
 			}
+		default:
+			return fmt.Errorf("unsupported plugin type: %s", pluginType)
 		}
 	}
 	return nil
 }

Also applies to: 2196-2219

🤖 Prompt for AI Agents
In `@core/bifrost.go` around lines 2104 - 2119, The RemovePlugin method currently
ignores empty or unsupported pluginTypes which can hide misuse; update
RemovePlugin (and the other similar method around the second occurrence) to
validate pluginTypes: if the slice is empty return an error like "no plugin
types provided", and inside the loop add a default case that returns an error
for unknown schemas.PluginType values instead of silently continuing; reference
the existing helper methods removeLLMPlugin and removeMCPPlugin when describing
allowed types in the error message so callers know valid options.

}

// removeLLMPlugin removes an LLM plugin from the server.
Expand Down Expand Up @@ -2200,22 +2199,30 @@ func (bifrost *Bifrost) removeMCPPlugin(name string) error {

// ReloadPlugin reloads a plugin with new instance
// During the reload - it's stop the world phase where we take a global lock on the plugin mutex
func (bifrost *Bifrost) ReloadPlugin(plugin schemas.BasePlugin, pluginType schemas.PluginType) error {
switch pluginType {
case schemas.PluginTypeLLM:
llmPlugin, ok := plugin.(schemas.LLMPlugin)
if !ok {
return fmt.Errorf("plugin %s is not an LLMPlugin", plugin.GetName())
}
return bifrost.reloadLLMPlugin(llmPlugin)
case schemas.PluginTypeMCP:
mcpPlugin, ok := plugin.(schemas.MCPPlugin)
if !ok {
return fmt.Errorf("plugin %s is not an MCPPlugin", plugin.GetName())
}
return bifrost.reloadMCPPlugin(mcpPlugin)
}
return fmt.Errorf("unsupported plugin type: %v", pluginType)
func (bifrost *Bifrost) ReloadPlugin(plugin schemas.BasePlugin, pluginTypes []schemas.PluginType) error {
for _, pluginType := range pluginTypes {
switch pluginType {
case schemas.PluginTypeLLM:
llmPlugin, ok := plugin.(schemas.LLMPlugin)
if !ok {
return fmt.Errorf("plugin %s is not an LLMPlugin", plugin.GetName())
}
err := bifrost.reloadLLMPlugin(llmPlugin)
if err != nil {
return err
}
case schemas.PluginTypeMCP:
mcpPlugin, ok := plugin.(schemas.MCPPlugin)
if !ok {
return fmt.Errorf("plugin %s is not an MCPPlugin", plugin.GetName())
}
err := bifrost.reloadMCPPlugin(mcpPlugin)
if err != nil {
return err
}
}
}
return nil
}

// reloadLLMPlugin reloads an LLM plugin with new instance
Expand Down Expand Up @@ -2654,11 +2661,11 @@ func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *syn
// return args.Message, nil
// }, toolSchema)
func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.ChatTool) error {
if bifrost.mcpManager == nil {
if bifrost.McpManager == nil {
return fmt.Errorf("MCP is not configured in this Bifrost instance")
}

return bifrost.mcpManager.RegisterTool(name, description, handler, toolSchema)
return bifrost.McpManager.RegisterTool(name, description, handler, toolSchema)
}

// IMPORTANT: Running the MCP client management operations (GetMCPClients, AddMCPClient, RemoveMCPClient, EditMCPClientTools)
Expand All @@ -2672,11 +2679,11 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a
// - []schemas.MCPClient: List of all MCP clients
// - error: Any retrieval error
func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) {
if bifrost.mcpManager == nil {
if bifrost.McpManager == nil {
return nil, fmt.Errorf("MCP is not configured in this Bifrost instance")
}

clients := bifrost.mcpManager.GetClients()
clients := bifrost.McpManager.GetClients()
clientsInConfig := make([]schemas.MCPClient, 0, len(clients))

for _, client := range clients {
Expand Down Expand Up @@ -2714,10 +2721,10 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) {
// Returns:
// - []schemas.ChatTool: List of available tools
func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool {
if bifrost.mcpManager == nil {
if bifrost.McpManager == nil {
return nil
}
return bifrost.mcpManager.GetAvailableTools(ctx)
return bifrost.McpManager.GetAvailableTools(ctx)
}

// AddMCPClient adds a new MCP client to the Bifrost instance.
Expand All @@ -2736,13 +2743,13 @@ func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.Chat
// ConnectionType: schemas.MCPConnectionTypeHTTP,
// ConnectionString: &url,
// })
func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error {
if bifrost.mcpManager == nil {
func (bifrost *Bifrost) AddMCPClient(config *schemas.MCPClientConfig) error {
if bifrost.McpManager == nil {
// Use sync.Once to ensure thread-safe initialization
bifrost.mcpInitOnce.Do(func() {
// Initialize with empty config - client will be added via AddClient below
mcpConfig := schemas.MCPConfig{
ClientConfigs: []schemas.MCPClientConfig{},
ClientConfigs: []*schemas.MCPClientConfig{},
}
// Set up plugin pipeline provider functions for executeCode tool hooks
mcpConfig.PluginPipelineProvider = func() interface{} {
Expand All @@ -2756,16 +2763,16 @@ func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error {
// Create Starlark CodeMode for code execution (with default config)
starlark.SetLogger(bifrost.logger)
codeMode := starlark.NewStarlarkCodeMode(nil)
bifrost.mcpManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode)
bifrost.McpManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode)
})
}

// Handle case where initialization succeeded elsewhere but manager is still nil
if bifrost.mcpManager == nil {
if bifrost.McpManager == nil {
return fmt.Errorf("MCP manager is not initialized")
}

return bifrost.mcpManager.AddClient(config)
return bifrost.McpManager.AddClient(config)
}

// RemoveMCPClient removes an MCP client from the Bifrost instance.
Expand All @@ -2784,23 +2791,23 @@ func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error {
// log.Fatalf("Failed to remove MCP client: %v", err)
// }
func (bifrost *Bifrost) RemoveMCPClient(id string) error {
if bifrost.mcpManager == nil {
if bifrost.McpManager == nil {
return fmt.Errorf("MCP is not configured in this Bifrost instance")
}

return bifrost.mcpManager.RemoveClient(id)
return bifrost.McpManager.RemoveClient(id)
}

// SetMCPManager sets the MCP manager for this Bifrost instance.
// This is primarily used for testing purposes to inject a custom MCP manager.
// This allows injecting a custom MCP manager implementation (e.g., for enterprise features).
//
// Parameters:
// - manager: The MCP manager to set
func (bifrost *Bifrost) SetMCPManager(manager *mcp.MCPManager) {
bifrost.mcpManager = manager
// - manager: The MCP manager to set (must implement MCPManagerInterface)
func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) {
bifrost.McpManager = manager
}

// EditMCPClient edits the tools of an MCP client.
// UpdateMCPClient updates the MCP client.
// This allows for dynamic MCP client tool management at runtime.
//
// Parameters:
Expand All @@ -2812,16 +2819,16 @@ func (bifrost *Bifrost) SetMCPManager(manager *mcp.MCPManager) {
//
// Example:
//
// err := bifrost.EditMCPClient("my-mcp-client-id", schemas.MCPClientConfig{
// err := bifrost.UpdateMCPClient("my-mcp-client-id", schemas.MCPClientConfig{
// Name: "my-mcp-client-name",
// ToolsToExecute: []string{"tool1", "tool2"},
// })
func (bifrost *Bifrost) EditMCPClient(id string, updatedConfig schemas.MCPClientConfig) error {
if bifrost.mcpManager == nil {
func (bifrost *Bifrost) EditMCPClient(id string, updatedConfig *schemas.MCPClientConfig) error {
if bifrost.McpManager == nil {
return fmt.Errorf("MCP is not configured in this Bifrost instance")
}

return bifrost.mcpManager.EditClient(id, updatedConfig)
return bifrost.McpManager.EditClient(id, updatedConfig)
}

// ReconnectMCPClient attempts to reconnect an MCP client if it is disconnected.
Expand All @@ -2832,21 +2839,21 @@ func (bifrost *Bifrost) EditMCPClient(id string, updatedConfig schemas.MCPClient
// Returns:
// - error: Any reconnection error
func (bifrost *Bifrost) ReconnectMCPClient(id string) error {
if bifrost.mcpManager == nil {
if bifrost.McpManager == nil {
return fmt.Errorf("MCP is not configured in this Bifrost instance")
}

return bifrost.mcpManager.ReconnectClient(id)
return bifrost.McpManager.ReconnectClient(id)
}

// UpdateToolManagerConfig updates the tool manager config for the MCP manager.
// This allows for hot-reloading of the tool manager config at runtime.
func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error {
if bifrost.mcpManager == nil {
if bifrost.McpManager == nil {
return fmt.Errorf("MCP is not configured in this Bifrost instance")
}

bifrost.mcpManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{
bifrost.McpManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{
MaxAgentDepth: maxAgentDepth,
ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second,
CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel),
Expand Down Expand Up @@ -3438,8 +3445,8 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif
}

// Add MCP tools to request if MCP is configured and requested
if bifrost.mcpManager != nil {
req = bifrost.mcpManager.AddToolsToRequest(ctx, req)
if bifrost.McpManager != nil {
req = bifrost.McpManager.AddToolsToRequest(ctx, req)
}

tracer := bifrost.getTracer()
Expand Down Expand Up @@ -3628,8 +3635,8 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem
}

// Add MCP tools to request if MCP is configured and requested
if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.mcpManager != nil {
req = bifrost.mcpManager.AddToolsToRequest(ctx, req)
if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.McpManager != nil {
req = bifrost.McpManager.AddToolsToRequest(ctx, req)
}

tracer := bifrost.getTracer()
Expand Down Expand Up @@ -4404,7 +4411,7 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r
// - *schemas.BifrostMCPResponse: The MCP response after all hooks
// - *schemas.BifrostError: Any execution error
func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpRequest *schemas.BifrostMCPRequest, requestType schemas.RequestType) (*schemas.BifrostMCPResponse, *schemas.BifrostError) {
if bifrost.mcpManager == nil {
if bifrost.McpManager == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Expand Down Expand Up @@ -4463,7 +4470,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR
}

// Execute tool with modified request
result, err := bifrost.mcpManager.ExecuteToolCall(ctx, preReq)
result, err := bifrost.McpManager.ExecuteToolCall(ctx, preReq)

// Prepare MCP response and error for post-hooks
var mcpResp *schemas.BifrostMCPResponse
Expand Down Expand Up @@ -5271,8 +5278,8 @@ func (bifrost *Bifrost) Shutdown() {
})

// Cleanup MCP manager
if bifrost.mcpManager != nil {
err := bifrost.mcpManager.Cleanup()
if bifrost.McpManager != nil {
err := bifrost.McpManager.Cleanup()
if err != nil {
bifrost.logger.Warn("Error cleaning up MCP manager: %s", err.Error())
}
Expand Down
4 changes: 2 additions & 2 deletions core/internal/mcptests/agent_filtering_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ func TestAgent_FilteringWithMultipleClients(t *testing.T) {
tempConfig := GetTemperatureMCPClientConfig("")
tempConfig.ToolsToAutoExecute = []string{} // Not auto-executed

err = manager.AddClient(tempConfig)
err = manager.AddClient(&tempConfig)
if err != nil {
t.Skipf("Skipping test - temperature server not available: %v", err)
return
Expand Down Expand Up @@ -620,7 +620,7 @@ func TestAgent_ToolConflictInAgentMode(t *testing.T) {
tempConfig := GetTemperatureMCPClientConfig("")
tempConfig.ToolsToAutoExecute = []string{} // Not auto

err = manager.AddClient(tempConfig)
err = manager.AddClient(&tempConfig)
if err != nil {
t.Skipf("Skipping test - temperature server not available: %v", err)
return
Expand Down
8 changes: 7 additions & 1 deletion core/internal/mcptests/agent_request_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ func setupMCPManagerWithRequestIDFunc(t *testing.T, fetchNewRequestIDFunc func(c

logger := &testLogger{t: t}

// Convert to pointer slice for MCPConfig
clientConfigPtrs := make([]*schemas.MCPClientConfig, len(clientConfigs))
for i := range clientConfigs {
clientConfigPtrs[i] = &clientConfigs[i]
}

// Create MCP config with request ID function
mcpConfig := &schemas.MCPConfig{
ClientConfigs: clientConfigs,
ClientConfigs: clientConfigPtrs,
FetchNewRequestIDFunc: fetchNewRequestIDFunc,
}

Expand Down
Loading