diff --git a/server/server.go b/server/server.go index e0cbf3065..5b2d739dc 100644 --- a/server/server.go +++ b/server/server.go @@ -141,7 +141,14 @@ type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCN // MCPServer implements a Model Control Protocol server that can handle various types of requests // including resources, prompts, and tools. type MCPServer struct { - mu sync.RWMutex // Add mutex for protecting shared resources + // Separate mutexes for different resource types + resourcesMu sync.RWMutex + promptsMu sync.RWMutex + toolsMu sync.RWMutex + middlewareMu sync.RWMutex + notificationHandlersMu sync.RWMutex + capabilitiesMu sync.RWMutex + name string version string instructions string @@ -301,7 +308,9 @@ func WithToolHandlerMiddleware( toolHandlerMiddleware ToolHandlerMiddleware, ) ServerOption { return func(s *MCPServer) { + s.middlewareMu.Lock() s.toolHandlerMiddlewares = append(s.toolHandlerMiddlewares, toolHandlerMiddleware) + s.middlewareMu.Unlock() } } @@ -396,11 +405,14 @@ func (s *MCPServer) AddResource( resource mcp.Resource, handler ResourceHandlerFunc, ) { + s.capabilitiesMu.Lock() if s.capabilities.resources == nil { s.capabilities.resources = &resourceCapabilities{} } - s.mu.Lock() - defer s.mu.Unlock() + s.capabilitiesMu.Unlock() + + s.resourcesMu.Lock() + defer s.resourcesMu.Unlock() s.resources[resource.URI] = resourceEntry{ resource: resource, handler: handler, @@ -412,11 +424,14 @@ func (s *MCPServer) AddResourceTemplate( template mcp.ResourceTemplate, handler ResourceTemplateHandlerFunc, ) { + s.capabilitiesMu.Lock() if s.capabilities.resources == nil { s.capabilities.resources = &resourceCapabilities{} } - s.mu.Lock() - defer s.mu.Unlock() + s.capabilitiesMu.Unlock() + + s.resourcesMu.Lock() + defer s.resourcesMu.Unlock() s.resourceTemplates[template.URITemplate.Raw()] = resourceTemplateEntry{ template: template, handler: handler, @@ -425,11 +440,14 @@ func (s *MCPServer) AddResourceTemplate( // AddPrompt registers a new prompt handler with the given name func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { + s.capabilitiesMu.Lock() if s.capabilities.prompts == nil { s.capabilities.prompts = &promptCapabilities{} } - s.mu.Lock() - defer s.mu.Unlock() + s.capabilitiesMu.Unlock() + + s.promptsMu.Lock() + defer s.promptsMu.Unlock() s.prompts[prompt.Name] = prompt s.promptHandlers[prompt.Name] = handler } @@ -441,14 +459,17 @@ func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { // AddTools registers multiple tools at once func (s *MCPServer) AddTools(tools ...ServerTool) { + s.capabilitiesMu.Lock() if s.capabilities.tools == nil { s.capabilities.tools = &toolCapabilities{} } - s.mu.Lock() + s.capabilitiesMu.Unlock() + + s.toolsMu.Lock() for _, entry := range tools { s.tools[entry.Tool.Name] = entry } - s.mu.Unlock() + s.toolsMu.Unlock() // Send notification to all initialized sessions s.sendNotificationToAllClients("notifications/tools/list_changed", nil) @@ -456,19 +477,19 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { // SetTools replaces all existing tools with the provided list func (s *MCPServer) SetTools(tools ...ServerTool) { - s.mu.Lock() + s.toolsMu.Lock() s.tools = make(map[string]ServerTool) - s.mu.Unlock() + s.toolsMu.Unlock() s.AddTools(tools...) } // DeleteTools removes a tool from the server func (s *MCPServer) DeleteTools(names ...string) { - s.mu.Lock() + s.toolsMu.Lock() for _, name := range names { delete(s.tools, name) } - s.mu.Unlock() + s.toolsMu.Unlock() // Send notification to all initialized sessions s.sendNotificationToAllClients("notifications/tools/list_changed", nil) @@ -479,8 +500,8 @@ func (s *MCPServer) AddNotificationHandler( method string, handler NotificationHandlerFunc, ) { - s.mu.Lock() - defer s.mu.Unlock() + s.notificationHandlersMu.Lock() + defer s.notificationHandlersMu.Unlock() s.notificationHandlers[method] = handler } @@ -589,12 +610,12 @@ func (s *MCPServer) handleListResources( id interface{}, request mcp.ListResourcesRequest, ) (*mcp.ListResourcesResult, *requestError) { - s.mu.RLock() + s.resourcesMu.RLock() resources := make([]mcp.Resource, 0, len(s.resources)) for _, entry := range s.resources { resources = append(resources, entry.resource) } - s.mu.RUnlock() + s.resourcesMu.RUnlock() // Sort the resources by name sort.Slice(resources, func(i, j int) bool { @@ -622,12 +643,12 @@ func (s *MCPServer) handleListResourceTemplates( id interface{}, request mcp.ListResourceTemplatesRequest, ) (*mcp.ListResourceTemplatesResult, *requestError) { - s.mu.RLock() + s.resourcesMu.RLock() templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates)) for _, entry := range s.resourceTemplates { templates = append(templates, entry.template) } - s.mu.RUnlock() + s.resourcesMu.RUnlock() sort.Slice(templates, func(i, j int) bool { return templates[i].Name < templates[j].Name }) @@ -653,11 +674,11 @@ func (s *MCPServer) handleReadResource( id interface{}, request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, *requestError) { - s.mu.RLock() + s.resourcesMu.RLock() // First try direct resource handlers if entry, ok := s.resources[request.Params.URI]; ok { handler := entry.handler - s.mu.RUnlock() + s.resourcesMu.RUnlock() contents, err := handler(ctx, request) if err != nil { return nil, &requestError{ @@ -686,7 +707,7 @@ func (s *MCPServer) handleReadResource( break } } - s.mu.RUnlock() + s.resourcesMu.RUnlock() if matched { contents, err := matchedHandler(ctx, request) @@ -717,12 +738,12 @@ func (s *MCPServer) handleListPrompts( id interface{}, request mcp.ListPromptsRequest, ) (*mcp.ListPromptsResult, *requestError) { - s.mu.RLock() + s.promptsMu.RLock() prompts := make([]mcp.Prompt, 0, len(s.prompts)) for _, prompt := range s.prompts { prompts = append(prompts, prompt) } - s.mu.RUnlock() + s.promptsMu.RUnlock() // sort prompts by name sort.Slice(prompts, func(i, j int) bool { @@ -750,9 +771,9 @@ func (s *MCPServer) handleGetPrompt( id interface{}, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, *requestError) { - s.mu.RLock() + s.promptsMu.RLock() handler, ok := s.promptHandlers[request.Params.Name] - s.mu.RUnlock() + s.promptsMu.RUnlock() if !ok { return nil, &requestError{ @@ -779,7 +800,7 @@ func (s *MCPServer) handleListTools( id interface{}, request mcp.ListToolsRequest, ) (*mcp.ListToolsResult, *requestError) { - s.mu.RLock() + s.toolsMu.RLock() tools := make([]mcp.Tool, 0, len(s.tools)) // Get all tool names for consistent ordering @@ -795,6 +816,8 @@ func (s *MCPServer) handleListTools( for _, name := range toolNames { tools = append(tools, s.tools[name].Tool) } + s.toolsMu.RUnlock() + toolsToReturn, nextCursor, err := listByPagination[mcp.Tool](ctx, s, request.Params.Cursor, tools) if err != nil { return nil, &requestError{ @@ -817,9 +840,9 @@ func (s *MCPServer) handleToolCall( id interface{}, request mcp.CallToolRequest, ) (*mcp.CallToolResult, *requestError) { - s.mu.RLock() + s.toolsMu.RLock() tool, ok := s.tools[request.Params.Name] - s.mu.RUnlock() + s.toolsMu.RUnlock() if !ok { return nil, &requestError{ @@ -830,9 +853,16 @@ func (s *MCPServer) handleToolCall( } finalHandler := tool.Handler - for i := len(s.toolHandlerMiddlewares) - 1; i >= 0; i-- { - finalHandler = s.toolHandlerMiddlewares[i](finalHandler) + + s.middlewareMu.RLock() + mw := s.toolHandlerMiddlewares + s.middlewareMu.RUnlock() + + // Apply middlewares in reverse order + for i := len(mw) - 1; i >= 0; i-- { + finalHandler = mw[i](finalHandler) } + result, err := finalHandler(ctx, request) if err != nil { return nil, &requestError{ @@ -849,9 +879,9 @@ func (s *MCPServer) handleNotification( ctx context.Context, notification mcp.JSONRPCNotification, ) mcp.JSONRPCMessage { - s.mu.RLock() + s.notificationHandlersMu.RLock() handler, ok := s.notificationHandlers[notification.Method] - s.mu.RUnlock() + s.notificationHandlersMu.RUnlock() if ok { handler(ctx, notification) diff --git a/server/server_race_test.go b/server/server_race_test.go new file mode 100644 index 000000000..8cc29476c --- /dev/null +++ b/server/server_race_test.go @@ -0,0 +1,190 @@ +package server + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestRaceConditions attempts to trigger race conditions by performing +// concurrent operations on different resources of the MCPServer. +func TestRaceConditions(t *testing.T) { + // Create a server with all capabilities + srv := NewMCPServer("test-server", "1.0.0", + WithResourceCapabilities(true, true), + WithPromptCapabilities(true), + WithToolCapabilities(true), + WithLogging(), + WithRecovery(), + ) + + // Create a context + ctx := context.Background() + + // Create a sync.WaitGroup to coordinate test goroutines + var wg sync.WaitGroup + + // Define test duration + testDuration := 300 * time.Millisecond + + // Start goroutines to perform concurrent operations + runConcurrentOperation(&wg, testDuration, "add-prompts", func() { + name := fmt.Sprintf("prompt-%d", time.Now().UnixNano()) + srv.AddPrompt(mcp.Prompt{ + Name: name, + Description: "Test prompt", + }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{}, nil + }) + }) + + runConcurrentOperation(&wg, testDuration, "add-tools", func() { + name := fmt.Sprintf("tool-%d", time.Now().UnixNano()) + srv.AddTool(mcp.Tool{ + Name: name, + Description: "Test tool", + }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + }) + + runConcurrentOperation(&wg, testDuration, "delete-tools", func() { + name := fmt.Sprintf("delete-tool-%d", time.Now().UnixNano()) + // Add and immediately delete + srv.AddTool(mcp.Tool{ + Name: name, + Description: "Temporary tool", + }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + srv.DeleteTools(name) + }) + + runConcurrentOperation(&wg, testDuration, "add-middleware", func() { + middleware := func(next ToolHandlerFunc) ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return next(ctx, req) + } + } + WithToolHandlerMiddleware(middleware)(srv) + }) + + runConcurrentOperation(&wg, testDuration, "list-tools", func() { + result, reqErr := srv.handleListTools(ctx, "123", mcp.ListToolsRequest{}) + require.Nil(t, reqErr, "List tools operation should not return an error") + require.NotNil(t, result, "List tools result should not be nil") + }) + + runConcurrentOperation(&wg, testDuration, "list-prompts", func() { + result, reqErr := srv.handleListPrompts(ctx, "123", mcp.ListPromptsRequest{}) + require.Nil(t, reqErr, "List prompts operation should not return an error") + require.NotNil(t, result, "List prompts result should not be nil") + }) + + // Add a persistent tool for testing tool calls + srv.AddTool(mcp.Tool{ + Name: "persistent-tool", + Description: "Test tool that always exists", + }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + + runConcurrentOperation(&wg, testDuration, "call-tools", func() { + req := mcp.CallToolRequest{} + req.Params.Name = "persistent-tool" + req.Params.Arguments = map[string]interface{}{"param": "test"} + result, reqErr := srv.handleToolCall(ctx, "123", req) + require.Nil(t, reqErr, "Tool call operation should not return an error") + require.NotNil(t, result, "Tool call result should not be nil") + }) + + runConcurrentOperation(&wg, testDuration, "add-resources", func() { + uri := fmt.Sprintf("resource-%d", time.Now().UnixNano()) + srv.AddResource(mcp.Resource{ + URI: uri, + Name: uri, + Description: "Test resource", + }, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: uri, + Text: "Test content", + }, + }, nil + }) + }) + + // Wait for all operations to complete + wg.Wait() + t.Log("No race conditions detected") +} + +// Helper function to run an operation concurrently for a specified duration +func runConcurrentOperation(wg *sync.WaitGroup, duration time.Duration, name string, operation func()) { + wg.Add(1) + go func() { + defer wg.Done() + + done := time.After(duration) + for { + select { + case <-done: + return + default: + operation() + } + } + }() +} + +// TestConcurrentPromptAdd specifically tests for the deadlock scenario where adding a prompt +// from a goroutine can cause a deadlock +func TestConcurrentPromptAdd(t *testing.T) { + srv := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) + ctx := context.Background() + + // Add a prompt with a handler that adds another prompt in a goroutine + srv.AddPrompt(mcp.Prompt{ + Name: "initial-prompt", + Description: "Initial prompt", + }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + go func() { + srv.AddPrompt(mcp.Prompt{ + Name: fmt.Sprintf("new-prompt-%d", time.Now().UnixNano()), + Description: "Added from handler", + }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{}, nil + }) + }() + return &mcp.GetPromptResult{}, nil + }) + + // Create request and channel to track completion + req := mcp.GetPromptRequest{} + req.Params.Name = "initial-prompt" + done := make(chan struct{}) + + // Try to get the prompt - this would deadlock with a single mutex + go func() { + result, reqErr := srv.handleGetPrompt(ctx, "123", req) + require.Nil(t, reqErr, "Get prompt operation should not return an error") + require.NotNil(t, result, "Get prompt result should not be nil") + close(done) + }() + + // Assert the operation completes without deadlock + assert.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, 1*time.Second, 10*time.Millisecond, "Deadlock detected: operation did not complete in time") +}