diff --git a/server/server.go b/server/server.go index 229f8926d..b3d738177 100644 --- a/server/server.go +++ b/server/server.go @@ -37,6 +37,12 @@ type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) ( // ToolHandlerFunc handles tool calls with given arguments. type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) +// ServerTool combines a Tool with its ToolHandlerFunc. +type ServerTool struct { + Tool mcp.Tool + Handler ToolHandlerFunc +} + // NotificationContext provides client identification for notifications type NotificationContext struct { ClientID string @@ -61,8 +67,7 @@ type MCPServer struct { resourceTemplates map[string]resourceTemplateEntry prompts map[string]mcp.Prompt promptHandlers map[string]PromptHandlerFunc - tools map[string]mcp.Tool - toolHandlers map[string]ToolHandlerFunc + tools map[string]ServerTool notificationHandlers map[string]NotificationHandlerFunc capabilities serverCapabilities notifications chan ServerNotification @@ -174,8 +179,7 @@ func NewMCPServer( resourceTemplates: make(map[string]resourceTemplateEntry), prompts: make(map[string]mcp.Prompt), promptHandlers: make(map[string]PromptHandlerFunc), - tools: make(map[string]mcp.Tool), - toolHandlers: make(map[string]ToolHandlerFunc), + tools: make(map[string]ServerTool), name: name, version: version, notificationHandlers: make(map[string]NotificationHandlerFunc), @@ -421,8 +425,34 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { // AddTool registers a new tool and its handler func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { - s.tools[tool.Name] = tool - s.toolHandlers[tool.Name] = handler + s.AddTools(ServerTool{Tool: tool, Handler: handler}) +} + +// AddTools registers multiple tools at once +func (s *MCPServer) AddTools(tools ...ServerTool) { + for _, entry := range tools { + s.tools[entry.Tool.Name] = entry + } + + // Send notification if server is already initialized + if s.initialized { + if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil { + // We can't return the error, but in a future version we could log it + } + } +} + +// SetTools replaces all existing tools with the provided list +func (s *MCPServer) SetTools(tools ...ServerTool) { + s.tools = make(map[string]ServerTool) + s.AddTools(tools...) +} + +// DeleteTools removes a tool from the server +func (s *MCPServer) DeleteTools(names ...string) { + for _, name := range names { + delete(s.tools, name) + } // Send notification if server is already initialized if s.initialized { @@ -630,7 +660,7 @@ func (s *MCPServer) handleListTools( ) mcp.JSONRPCMessage { tools := make([]mcp.Tool, 0, len(s.tools)) for name := range s.tools { - tools = append(tools, s.tools[name]) + tools = append(tools, s.tools[name].Tool) } result := mcp.ListToolsResult{ @@ -647,7 +677,7 @@ func (s *MCPServer) handleToolCall( id interface{}, request mcp.CallToolRequest, ) mcp.JSONRPCMessage { - handler, ok := s.toolHandlers[request.Params.Name] + tool, ok := s.tools[request.Params.Name] if !ok { return createErrorResponse( id, @@ -656,7 +686,7 @@ func (s *MCPServer) handleToolCall( ) } - result, err := handler(ctx, request) + result, err := tool.Handler(ctx, request) if err != nil { return createErrorResponse(id, mcp.INTERNAL_ERROR, err.Error()) } diff --git a/server/server_test.go b/server/server_test.go index 30d3c56b1..ff2bf299a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -3,10 +3,10 @@ package server import ( "context" "encoding/json" - "testing" - "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" + "testing" + "time" ) func TestMCPServer_NewMCPServer(t *testing.T) { @@ -105,6 +105,109 @@ func TestMCPServer_Capabilities(t *testing.T) { }) } } +func TestMCPServer_Tools(t *testing.T) { + tests := []struct { + name string + action func(*MCPServer) + expectedNotifications int + validate func(*testing.T, []ServerNotification, mcp.JSONRPCMessage) + }{ + { + name: "SetTools sends single notifications/tools/list_changed", + action: func(server *MCPServer) { + server.SetTools(ServerTool{ + Tool: mcp.NewTool("test-tool-1"), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }, + }, ServerTool{ + Tool: mcp.NewTool("test-tool-2"), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }, + }) + }, + expectedNotifications: 1, + validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) { + assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method) + tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools + assert.Len(t, tools, 2) + assert.Equal(t, "test-tool-1", tools[0].Name) + assert.Equal(t, "test-tool-2", tools[1].Name) + }, + }, + { + name: "AddTool sends multiple notifications/tools/list_changed", + action: func(server *MCPServer) { + server.AddTool(mcp.NewTool("test-tool-1"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + server.AddTool(mcp.NewTool("test-tool-2"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + }, + expectedNotifications: 2, + validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) { + assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method) + tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools + assert.Len(t, tools, 2) + assert.Equal(t, "test-tool-1", tools[0].Name) + assert.Equal(t, "test-tool-2", tools[1].Name) + }, + }, + { + name: "DeleteTools sends single notifications/tools/list_changed", + action: func(server *MCPServer) { + server.SetTools( + ServerTool{Tool: mcp.NewTool("test-tool-1")}, + ServerTool{Tool: mcp.NewTool("test-tool-2")}) + server.DeleteTools("test-tool-1", "test-tool-2") + }, + expectedNotifications: 2, + validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) { + // One for SetTools + assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method) + // One for DeleteTools + assert.Equal(t, "notifications/tools/list_changed", notifications[1].Notification.Method) + assert.Equal(t, "Tools not supported", toolsList.(mcp.JSONRPCError).Error.Message) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + server := NewMCPServer("test-server", "1.0.0") + _ = server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + notifications := make([]ServerNotification, 0) + tt.action(server) + for done := false; !done; { + select { + case serverNotification := <-server.notifications: + notifications = append(notifications, serverNotification) + if len(notifications) == tt.expectedNotifications { + done = true + } + case <-time.After(1 * time.Second): + done = true + } + } + assert.Len(t, notifications, tt.expectedNotifications) + toolsList := server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + }`)) + tt.validate(t, notifications, toolsList.(mcp.JSONRPCMessage)) + }) + + } +} func TestMCPServer_HandleValidMessages(t *testing.T) { server := NewMCPServer("test-server", "1.0.0",