From 39ee6ec2f0c0783700e097f613068973f8d2418e Mon Sep 17 00:00:00 2001 From: lbbniu Date: Fri, 12 Sep 2025 16:14:05 +0800 Subject: [PATCH 1/2] feat: add custom handler support for all MCP server methods - Add custom handler fields to MCPServer struct for all basic MCP methods - Implement custom handler logic in all handle* methods with proper error handling - Support custom handlers for: Initialize, Ping, SetLevel, ListResources, ListResourceTemplates, ReadResource, ListPrompts, GetPrompt, ListTools, CallTool, and Notification methods - Maintain backward compatibility by falling back to default behavior when custom handlers are not set - Enable more flexible server customization and middleware integration --- server/server.go | 142 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 138 insertions(+), 4 deletions(-) diff --git a/server/server.go b/server/server.go index f45c0353..29dca0df 100644 --- a/server/server.go +++ b/server/server.go @@ -166,6 +166,19 @@ type MCPServer struct { paginationLimit *int sessions sync.Map hooks *Hooks + + // custom handlers for basic methods + InitializeHandler func(ctx context.Context, request mcp.InitializeRequest) (*mcp.InitializeResult, error) + PingHandler func(ctx context.Context, request mcp.PingRequest) (*mcp.EmptyResult, error) + ListResourcesHandler func(ctx context.Context, request mcp.ListResourcesRequest) (*mcp.ListResourcesResult, error) + ListResourceTemplatesHandler func(ctx context.Context, request mcp.ListResourceTemplatesRequest) (*mcp.ListResourceTemplatesResult, error) + ReadResourceHandler func(ctx context.Context, request mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) + ListPromptsHandler func(ctx context.Context, request mcp.ListPromptsRequest) (*mcp.ListPromptsResult, error) + GetPromptHandler func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) + ListToolsHandler func(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error) + CallToolHandler func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) + SetLevelHandler func(ctx context.Context, request mcp.SetLevelRequest) (*mcp.EmptyResult, error) + NotificationHandler func(ctx context.Context, notification mcp.JSONRPCNotification) } // WithPaginationLimit sets the pagination limit for the server. @@ -650,9 +663,21 @@ func (s *MCPServer) AddNotificationHandler( func (s *MCPServer) handleInitialize( ctx context.Context, - _ any, + id any, request mcp.InitializeRequest, ) (*mcp.InitializeResult, *requestError) { + if s.InitializeHandler != nil { + result, err := s.InitializeHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + capabilities := mcp.ServerCapabilities{} // Only add resource capabilities if they're configured @@ -736,10 +761,21 @@ func (s *MCPServer) protocolVersion(clientVersion string) string { } func (s *MCPServer) handlePing( - _ context.Context, - _ any, - _ mcp.PingRequest, + ctx context.Context, + id any, + request mcp.PingRequest, ) (*mcp.EmptyResult, *requestError) { + if s.PingHandler != nil { + result, err := s.PingHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } return &mcp.EmptyResult{}, nil } @@ -748,6 +784,18 @@ func (s *MCPServer) handleSetLevel( id any, request mcp.SetLevelRequest, ) (*mcp.EmptyResult, *requestError) { + if s.SetLevelHandler != nil { + result, err := s.SetLevelHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + clientSession := ClientSessionFromContext(ctx) if clientSession == nil || !clientSession.Initialized() { return nil, &requestError{ @@ -827,6 +875,18 @@ func (s *MCPServer) handleListResources( id any, request mcp.ListResourcesRequest, ) (*mcp.ListResourcesResult, *requestError) { + if s.ListResourcesHandler != nil { + result, err := s.ListResourcesHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.resourcesMu.RLock() resourceMap := make(map[string]mcp.Resource, len(s.resources)) for uri, entry := range s.resources { @@ -880,6 +940,18 @@ func (s *MCPServer) handleListResourceTemplates( id any, request mcp.ListResourceTemplatesRequest, ) (*mcp.ListResourceTemplatesResult, *requestError) { + if s.ListResourceTemplatesHandler != nil { + result, err := s.ListResourceTemplatesHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.resourcesMu.RLock() templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates)) for _, entry := range s.resourceTemplates { @@ -916,6 +988,18 @@ func (s *MCPServer) handleReadResource( id any, request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, *requestError) { + if s.ReadResourceHandler != nil { + result, err := s.ReadResourceHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.resourcesMu.RLock() // First check session-specific resources @@ -1030,6 +1114,18 @@ func (s *MCPServer) handleListPrompts( id any, request mcp.ListPromptsRequest, ) (*mcp.ListPromptsResult, *requestError) { + if s.ListPromptsHandler != nil { + result, err := s.ListPromptsHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.promptsMu.RLock() prompts := make([]mcp.Prompt, 0, len(s.prompts)) for _, prompt := range s.prompts { @@ -1068,6 +1164,18 @@ func (s *MCPServer) handleGetPrompt( id any, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, *requestError) { + if s.GetPromptHandler != nil { + result, err := s.GetPromptHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } + s.promptsMu.RLock() handler, ok := s.promptHandlers[request.Params.Name] s.promptsMu.RUnlock() @@ -1097,6 +1205,17 @@ func (s *MCPServer) handleListTools( id any, request mcp.ListToolsRequest, ) (*mcp.ListToolsResult, *requestError) { + if s.ListToolsHandler != nil { + result, err := s.ListToolsHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } // Get the base tools from the server s.toolsMu.RLock() tools := make([]mcp.Tool, 0, len(s.tools)) @@ -1187,6 +1306,17 @@ func (s *MCPServer) handleToolCall( id any, request mcp.CallToolRequest, ) (*mcp.CallToolResult, *requestError) { + if s.CallToolHandler != nil { + result, err := s.CallToolHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return result, nil + } // First check session-specific tools var tool ServerTool var ok bool @@ -1246,6 +1376,10 @@ func (s *MCPServer) handleNotification( ctx context.Context, notification mcp.JSONRPCNotification, ) mcp.JSONRPCMessage { + if s.NotificationHandler != nil { + s.NotificationHandler(ctx, notification) + return nil + } s.notificationHandlersMu.RLock() handler, ok := s.notificationHandlers[notification.Method] s.notificationHandlersMu.RUnlock() From 6b9c30fbaeb854f801d462bd6d443a18c8af4756 Mon Sep 17 00:00:00 2001 From: lbbniu Date: Fri, 12 Sep 2025 23:21:38 +0800 Subject: [PATCH 2/2] fix: message endpoint generate --- server/sse.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/sse.go b/server/sse.go index 250141ce..36316ea5 100644 --- a/server/sse.go +++ b/server/sse.go @@ -496,8 +496,12 @@ func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID strin if s.useFullURLForMessageEndpoint && s.baseURL != "" { endpointPath = s.baseURL + endpointPath } - - return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID) + if strings.Contains(endpointPath, "?") { + endpointPath += "&" + } else { + endpointPath += "?" + } + return fmt.Sprintf("%ssessionId=%s", endpointPath, sessionID) } // handleMessage processes incoming JSON-RPC messages from clients and sends responses