diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a85227116..60b74f2bf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,4 +12,4 @@ jobs: - uses: actions/setup-go@v5 with: go-version-file: 'go.mod' - - run: go test ./... + - run: go test ./... -race diff --git a/server/server.go b/server/server.go index b3d738177..464365fe6 100644 --- a/server/server.go +++ b/server/server.go @@ -6,6 +6,9 @@ import ( "encoding/json" "fmt" "regexp" + "sort" + "sync" + "sync/atomic" "github.com/mark3labs/mcp-go/mcp" ) @@ -61,6 +64,7 @@ type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCN // MCPServer implements a Model Control Protocol server that can handle various types of requests // including resources, prompts, and tools. type MCPServer struct { + mu sync.RWMutex // Add mutex for protecting shared resources name string version string resources map[string]resourceEntry @@ -71,8 +75,9 @@ type MCPServer struct { notificationHandlers map[string]NotificationHandlerFunc capabilities serverCapabilities notifications chan ServerNotification + clientMu sync.Mutex // Separate mutex for client context currentClient NotificationContext - initialized bool + initialized atomic.Bool // Use atomic for the initialized flag } // serverKey is the context key for storing the server instance @@ -91,7 +96,9 @@ func (s *MCPServer) WithContext( ctx context.Context, notifCtx NotificationContext, ) context.Context { + s.clientMu.Lock() s.currentClient = notifCtx + s.clientMu.Unlock() return ctx } @@ -104,6 +111,10 @@ func (s *MCPServer) SendNotificationToClient( return fmt.Errorf("notification channel not initialized") } + s.clientMu.Lock() + clientContext := s.currentClient + s.clientMu.Unlock() + notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ @@ -116,7 +127,7 @@ func (s *MCPServer) SendNotificationToClient( select { case s.notifications <- ServerNotification{ - Context: s.currentClient, + Context: clientContext, Notification: notification, }: return nil @@ -394,6 +405,8 @@ func (s *MCPServer) AddResource( if s.capabilities.resources == nil { panic("Resource capabilities not enabled") } + s.mu.Lock() + defer s.mu.Unlock() s.resources[resource.URI] = resourceEntry{ resource: resource, handler: handler, @@ -408,6 +421,8 @@ func (s *MCPServer) AddResourceTemplate( if s.capabilities.resources == nil { panic("Resource capabilities not enabled") } + s.mu.Lock() + defer s.mu.Unlock() s.resourceTemplates[template.URITemplate] = resourceTemplateEntry{ template: template, handler: handler, @@ -419,6 +434,8 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { if s.capabilities.prompts == nil { panic("Prompt capabilities not enabled") } + s.mu.Lock() + defer s.mu.Unlock() s.prompts[prompt.Name] = prompt s.promptHandlers[prompt.Name] = handler } @@ -430,12 +447,15 @@ func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { // AddTools registers multiple tools at once func (s *MCPServer) AddTools(tools ...ServerTool) { + s.mu.Lock() for _, entry := range tools { s.tools[entry.Tool.Name] = entry } + initialized := s.initialized.Load() + s.mu.Unlock() // Send notification if server is already initialized - if s.initialized { + if initialized { if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil { // We can't return the error, but in a future version we could log it } @@ -444,18 +464,23 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { // SetTools replaces all existing tools with the provided list func (s *MCPServer) SetTools(tools ...ServerTool) { + s.mu.Lock() s.tools = make(map[string]ServerTool) + s.mu.Unlock() s.AddTools(tools...) } // DeleteTools removes a tool from the server func (s *MCPServer) DeleteTools(names ...string) { + s.mu.Lock() for _, name := range names { delete(s.tools, name) } + initialized := s.initialized.Load() + s.mu.Unlock() // Send notification if server is already initialized - if s.initialized { + if initialized { if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil { // We can't return the error, but in a future version we could log it } @@ -467,6 +492,8 @@ func (s *MCPServer) AddNotificationHandler( method string, handler NotificationHandlerFunc, ) { + s.mu.Lock() + defer s.mu.Unlock() s.notificationHandlers[method] = handler } @@ -510,7 +537,7 @@ func (s *MCPServer) handleInitialize( Capabilities: capabilities, } - s.initialized = true + s.initialized.Store(true) return createResponse(id, result) } @@ -527,10 +554,12 @@ func (s *MCPServer) handleListResources( id interface{}, request mcp.ListResourcesRequest, ) mcp.JSONRPCMessage { + s.mu.RLock() resources := make([]mcp.Resource, 0, len(s.resources)) for _, entry := range s.resources { resources = append(resources, entry.resource) } + s.mu.RUnlock() result := mcp.ListResourcesResult{ Resources: resources, @@ -546,10 +575,12 @@ func (s *MCPServer) handleListResourceTemplates( id interface{}, request mcp.ListResourceTemplatesRequest, ) mcp.JSONRPCMessage { + s.mu.RLock() templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates)) for _, entry := range s.resourceTemplates { templates = append(templates, entry.template) } + s.mu.RUnlock() result := mcp.ListResourceTemplatesResult{ ResourceTemplates: templates, @@ -565,9 +596,12 @@ func (s *MCPServer) handleReadResource( id interface{}, request mcp.ReadResourceRequest, ) mcp.JSONRPCMessage { + s.mu.RLock() // First try direct resource handlers if entry, ok := s.resources[request.Params.URI]; ok { - contents, err := entry.handler(ctx, request) + handler := entry.handler + s.mu.RUnlock() + contents, err := handler(ctx, request) if err != nil { return createErrorResponse(id, mcp.INTERNAL_ERROR, err.Error()) } @@ -575,18 +609,27 @@ func (s *MCPServer) handleReadResource( } // If no direct handler found, try matching against templates + var matchedHandler ResourceTemplateHandlerFunc + var matched bool for uriTemplate, entry := range s.resourceTemplates { if matchesTemplate(request.Params.URI, uriTemplate) { - contents, err := entry.handler(ctx, request) - if err != nil { - return createErrorResponse(id, mcp.INTERNAL_ERROR, err.Error()) - } - return createResponse( - id, - mcp.ReadResourceResult{Contents: contents}, - ) + matchedHandler = entry.handler + matched = true + break } } + s.mu.RUnlock() + + if matched { + contents, err := matchedHandler(ctx, request) + if err != nil { + return createErrorResponse(id, mcp.INTERNAL_ERROR, err.Error()) + } + return createResponse( + id, + mcp.ReadResourceResult{Contents: contents}, + ) + } return createErrorResponse( id, @@ -617,10 +660,12 @@ func (s *MCPServer) handleListPrompts( id interface{}, request mcp.ListPromptsRequest, ) mcp.JSONRPCMessage { + s.mu.RLock() prompts := make([]mcp.Prompt, 0, len(s.prompts)) for _, prompt := range s.prompts { prompts = append(prompts, prompt) } + s.mu.RUnlock() result := mcp.ListPromptsResult{ Prompts: prompts, @@ -636,7 +681,10 @@ func (s *MCPServer) handleGetPrompt( id interface{}, request mcp.GetPromptRequest, ) mcp.JSONRPCMessage { + s.mu.RLock() handler, ok := s.promptHandlers[request.Params.Name] + s.mu.RUnlock() + if !ok { return createErrorResponse( id, @@ -658,10 +706,23 @@ func (s *MCPServer) handleListTools( id interface{}, request mcp.ListToolsRequest, ) mcp.JSONRPCMessage { + s.mu.RLock() tools := make([]mcp.Tool, 0, len(s.tools)) + + // Get all tool names for consistent ordering + toolNames := make([]string, 0, len(s.tools)) for name := range s.tools { + toolNames = append(toolNames, name) + } + + // Sort the tool names for consistent ordering + sort.Strings(toolNames) + + // Add tools in sorted order + for _, name := range toolNames { tools = append(tools, s.tools[name].Tool) } + s.mu.RUnlock() result := mcp.ListToolsResult{ Tools: tools, @@ -677,7 +738,10 @@ func (s *MCPServer) handleToolCall( id interface{}, request mcp.CallToolRequest, ) mcp.JSONRPCMessage { + s.mu.RLock() tool, ok := s.tools[request.Params.Name] + s.mu.RUnlock() + if !ok { return createErrorResponse( id, @@ -698,7 +762,11 @@ func (s *MCPServer) handleNotification( ctx context.Context, notification mcp.JSONRPCNotification, ) mcp.JSONRPCMessage { - if handler, ok := s.notificationHandlers[notification.Method]; ok { + s.mu.RLock() + handler, ok := s.notificationHandlers[notification.Method] + s.mu.RUnlock() + + if ok { handler(ctx, notification) } return nil diff --git a/server/sse.go b/server/sse.go index ee7db2e99..dd197b026 100644 --- a/server/sse.go +++ b/server/sse.go @@ -23,9 +23,10 @@ type SSEServer struct { // sseSession represents an active SSE connection. type sseSession struct { - writer http.ResponseWriter - flusher http.Flusher - done chan struct{} + writer http.ResponseWriter + flusher http.Flusher + done chan struct{} + eventQueue chan string // Channel for queuing events } // NewSSEServer creates a new SSE server instance with the given MCP server and base URL. @@ -112,9 +113,10 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { sessionID := uuid.New().String() session := &sseSession{ - writer: w, - flusher: flusher, - done: make(chan struct{}), + writer: w, + flusher: flusher, + done: make(chan struct{}), + eventQueue: make(chan string, 100), // Buffer for events } s.sessions.Store(sessionID, session) @@ -127,10 +129,15 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { case serverNotification := <-s.server.notifications: // Only forward notifications meant for this session if serverNotification.Context.SessionID == sessionID { - s.SendEventToSession( - sessionID, - serverNotification.Notification, - ) + eventData, err := json.Marshal(serverNotification.Notification) + if err == nil { + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + // Event queued successfully + case <-session.done: + return + } + } } case <-session.done: return @@ -145,11 +152,23 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { s.baseURL, sessionID, ) + + // Send the initial endpoint event fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", messageEndpoint) flusher.Flush() - <-r.Context().Done() - close(session.done) + // Main event loop - this runs in the HTTP handler goroutine + for { + select { + case event := <-session.eventQueue: + // Write the event to the response + fmt.Fprint(w, event) + flusher.Flush() + case <-r.Context().Done(): + close(session.done) + return + } + } } // handleMessage processes incoming JSON-RPC messages from clients and sends responses @@ -192,8 +211,16 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { // Only send response if there is one (not for notifications) if response != nil { eventData, _ := json.Marshal(response) - fmt.Fprintf(session.writer, "event: message\ndata: %s\n\n", eventData) - session.flusher.Flush() + + // Queue the event for sending via SSE + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + // Event queued successfully + case <-session.done: + // Session is closed, don't try to queue + default: + // Queue is full, could log this + } // Send HTTP response w.Header().Set("Content-Type", "application/json") @@ -235,12 +262,13 @@ func (s *SSEServer) SendEventToSession( return err } + // Queue the event for sending via SSE select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + return nil case <-session.done: return fmt.Errorf("session closed") default: - fmt.Fprintf(session.writer, "event: message\ndata: %s\n\n", eventData) - session.flusher.Flush() - return nil + return fmt.Errorf("event queue full") } }