Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 265 additions & 80 deletions core/bifrost.go

Large diffs are not rendered by default.

147 changes: 95 additions & 52 deletions core/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ type MCPManager struct {

// MCPClient represents a connected MCP client with its configuration and tools.
type MCPClient struct {
Name string // Unique name for this client
Conn *client.Client // Active MCP client connection
ExecutionConfig schemas.MCPClientConfig // Tool filtering settings
ToolMap map[string]schemas.Tool // Available tools mapped by name
ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management
cancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized)
Name string // Unique name for this client
Conn *client.Client // Active MCP client connection
ExecutionConfig schemas.MCPClientConfig // Tool filtering settings
ToolMap map[string]schemas.ChatTool // Available tools mapped by name
ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management
cancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized)
}

// MCPClientConnectionInfo stores metadata about how a client is connected.
Expand Down Expand Up @@ -173,7 +173,7 @@ func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error {
m.clientMap[config.Name] = &MCPClient{
Name: config.Name,
ExecutionConfig: config,
ToolMap: make(map[string]schemas.Tool),
ToolMap: make(map[string]schemas.ChatTool),
}

// Temporarily unlock for the connection attempt
Expand Down Expand Up @@ -228,7 +228,7 @@ func (m *MCPManager) removeClientUnsafe(name string) error {
}

// Clear client tool map
client.ToolMap = make(map[string]schemas.Tool)
client.ToolMap = make(map[string]schemas.ChatTool)

delete(m.clientMap, name)
return nil
Expand Down Expand Up @@ -256,7 +256,7 @@ func (m *MCPManager) EditClientTools(name string, toolsToAdd []string, toolsToRe
client.ExecutionConfig = config

// Clear current tool map
client.ToolMap = make(map[string]schemas.Tool)
client.ToolMap = make(map[string]schemas.ChatTool)

// Temporarily unlock for the network call
m.mu.Unlock()
Expand Down Expand Up @@ -288,7 +288,7 @@ func (m *MCPManager) EditClientTools(name string, toolsToAdd []string, toolsToRe

// getAvailableTools returns all tools from connected MCP clients.
// Applies client filtering if specified in the context.
func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.Tool {
func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.ChatTool {
m.mu.RLock()
defer m.mu.RUnlock()

Expand All @@ -303,7 +303,7 @@ func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.Tool {
excludeClients = existingExcludeClients
}

tools := make([]schemas.Tool, 0)
tools := make([]schemas.ChatTool, 0)
for clientName, client := range m.clientMap {
// Apply client filtering logic
if !m.shouldIncludeClient(clientName, includeClients, excludeClients) {
Expand Down Expand Up @@ -348,7 +348,7 @@ func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.Tool {
// func(args EchoArgs) (string, error) {
// return args.Message, nil
// }, toolSchema)
func (m *MCPManager) registerTool(name, description string, handler MCPToolHandler[any], toolSchema schemas.Tool) error {
func (m *MCPManager) registerTool(name, description string, handler MCPToolHandler[any], toolSchema schemas.ChatTool) error {
// Ensure local server is set up
if err := m.setupLocalHost(); err != nil {
return fmt.Errorf("failed to setup local host: %w", err)
Expand Down Expand Up @@ -453,7 +453,7 @@ func (m *MCPManager) createLocalMCPClient() (*MCPClient, error) {
ExecutionConfig: schemas.MCPClientConfig{
Name: BifrostMCPClientName,
},
ToolMap: make(map[string]schemas.Tool),
ToolMap: make(map[string]schemas.ChatTool),
ConnectionInfo: MCPClientConnectionInfo{
Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport
},
Expand Down Expand Up @@ -524,9 +524,9 @@ func (m *MCPManager) startLocalMCPServer() error {
// - toolCall: The tool call to execute (from assistant message)
//
// Returns:
// - schemas.BifrostMessage: Tool message with execution result
// - schemas.ChatMessage: Tool message with execution result
// - error: Any execution error
func (m *MCPManager) executeTool(ctx context.Context, toolCall schemas.ToolCall) (*schemas.BifrostMessage, error) {
func (m *MCPManager) executeTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) {
if toolCall.Function.Name == nil {
return nil, fmt.Errorf("tool call missing function name")
}
Expand Down Expand Up @@ -600,7 +600,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error {
m.clientMap[config.Name] = &MCPClient{
Name: config.Name,
ExecutionConfig: config,
ToolMap: make(map[string]schemas.Tool),
ToolMap: make(map[string]schemas.ChatTool),
ConnectionInfo: MCPClientConnectionInfo{
Type: config.ConnectionType,
},
Expand Down Expand Up @@ -679,7 +679,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error {
if err != nil {
m.logger.Warn(fmt.Sprintf("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err))
// Continue with connection even if tool retrieval fails
tools = make(map[string]schemas.Tool)
tools = make(map[string]schemas.ChatTool)
}

// Second lock: Update client with final connection details and tools
Expand Down Expand Up @@ -711,7 +711,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error {
}

// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks.
func (m *MCPManager) retrieveExternalTools(ctx context.Context, client *client.Client, config schemas.MCPClientConfig) (map[string]schemas.Tool, error) {
func (m *MCPManager) retrieveExternalTools(ctx context.Context, client *client.Client, config schemas.MCPClientConfig) (map[string]schemas.ChatTool, error) {
// Get available tools from external server
listRequest := mcp.ListToolsRequest{
PaginatedRequest: mcp.PaginatedRequest{
Expand All @@ -727,10 +727,10 @@ func (m *MCPManager) retrieveExternalTools(ctx context.Context, client *client.C
}

if toolsResponse == nil {
return make(map[string]schemas.Tool), nil // No tools available
return make(map[string]schemas.ChatTool), nil // No tools available
}

tools := make(map[string]schemas.Tool)
tools := make(map[string]schemas.ChatTool)

// toolsResponse is already a ListToolsResult
for _, mcpTool := range toolsResponse.Tools {
Expand Down Expand Up @@ -796,13 +796,13 @@ func (m *MCPManager) shouldSkipToolForRequest(toolName string, ctx context.Conte
}

// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format.
func (m *MCPManager) convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.Tool {
return schemas.Tool{
Type: "function",
Function: schemas.Function{
func (m *MCPManager) convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.ChatTool {
return schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: mcpTool.Name,
Description: mcpTool.Description,
Parameters: schemas.FunctionParameters{
Description: Ptr(mcpTool.Description),
Parameters: &schemas.ToolFunctionParameters{
Type: mcpTool.InputSchema.Type,
Properties: mcpTool.InputSchema.Properties,
Required: mcpTool.InputSchema.Required,
Expand Down Expand Up @@ -852,13 +852,13 @@ func (m *MCPManager) extractTextFromMCPResponse(toolResponse *mcp.CallToolResult
}

// createToolResponseMessage creates a tool response message with the execution result.
func (m *MCPManager) createToolResponseMessage(toolCall schemas.ToolCall, responseText string) *schemas.BifrostMessage {
return &schemas.BifrostMessage{
Role: schemas.ModelChatMessageRoleTool,
Content: schemas.MessageContent{
func (m *MCPManager) createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage {
return &schemas.ChatMessage{
Role: schemas.ChatMessageRoleTool,
Content: schemas.ChatMessageContent{
ContentStr: &responseText,
},
ToolMessage: &schemas.ToolMessage{
ChatToolMessage: &schemas.ChatToolMessage{
ToolCallID: toolCall.ID,
},
}
Expand All @@ -867,31 +867,74 @@ func (m *MCPManager) createToolResponseMessage(toolCall schemas.ToolCall, respon
func (m *MCPManager) addMCPToolsToBifrostRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest {
mcpTools := m.getAvailableTools(ctx)
if len(mcpTools) > 0 {
// Initialize tools array if needed
if req.Params == nil {
req.Params = &schemas.ModelParameters{}
}
if req.Params.Tools == nil {
req.Params.Tools = &[]schemas.Tool{}
}
tools := *req.Params.Tools
switch req.RequestType {
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
// Only allocate new Params if it's nil to preserve caller-supplied settings
if req.ChatRequest.Params == nil {
req.ChatRequest.Params = &schemas.ChatParameters{}
}

// Create a map of existing tool names for O(1) lookup
existingToolsMap := make(map[string]bool)
for _, tool := range tools {
existingToolsMap[tool.Function.Name] = true
}
tools := req.ChatRequest.Params.Tools

// Add MCP tools that are not already present
for _, mcpTool := range mcpTools {
if !existingToolsMap[mcpTool.Function.Name] {
tools = append(tools, mcpTool)
// Update the map to prevent duplicates within MCP tools as well
existingToolsMap[mcpTool.Function.Name] = true
// Create a map of existing tool names for O(1) lookup
existingToolsMap := make(map[string]bool)
for _, tool := range tools {
if tool.Function != nil && tool.Function.Name != "" {
existingToolsMap[tool.Function.Name] = true
}
}

// Add MCP tools that are not already present
for _, mcpTool := range mcpTools {
// Skip tools with nil Function or empty Name
if mcpTool.Function == nil || mcpTool.Function.Name == "" {
continue
}

if !existingToolsMap[mcpTool.Function.Name] {
tools = append(tools, mcpTool)
// Update the map to prevent duplicates within MCP tools as well
existingToolsMap[mcpTool.Function.Name] = true
}
}
req.ChatRequest.Params.Tools = tools
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest:
// Only allocate new Params if it's nil to preserve caller-supplied settings
if req.ResponsesRequest.Params == nil {
req.ResponsesRequest.Params = &schemas.ResponsesParameters{}
}
}
req.Params.Tools = &tools

tools := req.ResponsesRequest.Params.Tools

// Create a map of existing tool names for O(1) lookup
existingToolsMap := make(map[string]bool)
for _, tool := range tools {
if tool.Name != nil {
existingToolsMap[*tool.Name] = true
}
}

// Add MCP tools that are not already present
for _, mcpTool := range mcpTools {
// Skip tools with nil Function or empty Name
if mcpTool.Function == nil || mcpTool.Function.Name == "" {
continue
}

if !existingToolsMap[mcpTool.Function.Name] {
responsesTool := mcpTool.ToResponsesTool()
// Skip if the converted tool has nil Name
if responsesTool.Name == nil {
continue
}

tools = append(tools, *responsesTool)
// Update the map to prevent duplicates within MCP tools as well
existingToolsMap[*responsesTool.Name] = true
}
}
req.ResponsesRequest.Params.Tools = tools
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
return req
}
Expand Down
Loading