Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 138 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,19 @@ type MCPServer struct {
paginationLimit *int
sessions sync.Map
hooks *Hooks

// custom handlers for basic methods
InitializeHandler func(ctx context.Context, request mcp.InitializeRequest) (*mcp.InitializeResult, error)
PingHandler func(ctx context.Context, request mcp.PingRequest) (*mcp.EmptyResult, error)
ListResourcesHandler func(ctx context.Context, request mcp.ListResourcesRequest) (*mcp.ListResourcesResult, error)
ListResourceTemplatesHandler func(ctx context.Context, request mcp.ListResourceTemplatesRequest) (*mcp.ListResourceTemplatesResult, error)
ReadResourceHandler func(ctx context.Context, request mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error)
ListPromptsHandler func(ctx context.Context, request mcp.ListPromptsRequest) (*mcp.ListPromptsResult, error)
GetPromptHandler func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error)
ListToolsHandler func(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
CallToolHandler func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
SetLevelHandler func(ctx context.Context, request mcp.SetLevelRequest) (*mcp.EmptyResult, error)
NotificationHandler func(ctx context.Context, notification mcp.JSONRPCNotification)
}
Comment on lines +169 to 182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add GoDoc comments for exported handler fields.

All 11 exported handler fields lack documentation. Per coding guidelines, exported identifiers must have GoDoc comments starting with the identifier name. Users need to understand when these handlers are invoked, how errors are handled, and that they bypass default behavior (including middleware, session-specific resources/tools, and filters).

Example documentation pattern:

+	// InitializeHandler, if set, replaces the default initialization logic.
+	// It is invoked during the initialize request and bypasses all default capability
+	// negotiation. Errors are wrapped as INTERNAL_ERROR responses.
 	InitializeHandler            func(ctx context.Context, request mcp.InitializeRequest) (*mcp.InitializeResult, error)
+	// PingHandler, if set, replaces the default ping logic. Errors are wrapped as INTERNAL_ERROR responses.
 	PingHandler                  func(ctx context.Context, request mcp.PingRequest) (*mcp.EmptyResult, error)

As per coding guidelines.

Committable suggestion skipped: line range outside the PR's diff.


// WithPaginationLimit sets the pagination limit for the server.
Expand Down Expand Up @@ -650,9 +663,21 @@ func (s *MCPServer) AddNotificationHandler(

func (s *MCPServer) handleInitialize(
ctx context.Context,
_ any,
id any,
request mcp.InitializeRequest,
) (*mcp.InitializeResult, *requestError) {
if s.InitializeHandler != nil {
result, err := s.InitializeHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

capabilities := mcp.ServerCapabilities{}

// Only add resource capabilities if they're configured
Expand Down Expand Up @@ -736,10 +761,21 @@ func (s *MCPServer) protocolVersion(clientVersion string) string {
}

func (s *MCPServer) handlePing(
_ context.Context,
_ any,
_ mcp.PingRequest,
ctx context.Context,
id any,
request mcp.PingRequest,
) (*mcp.EmptyResult, *requestError) {
if s.PingHandler != nil {
result, err := s.PingHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}
return &mcp.EmptyResult{}, nil
}

Expand All @@ -748,6 +784,18 @@ func (s *MCPServer) handleSetLevel(
id any,
request mcp.SetLevelRequest,
) (*mcp.EmptyResult, *requestError) {
if s.SetLevelHandler != nil {
result, err := s.SetLevelHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

clientSession := ClientSessionFromContext(ctx)
if clientSession == nil || !clientSession.Initialized() {
return nil, &requestError{
Expand Down Expand Up @@ -827,6 +875,18 @@ func (s *MCPServer) handleListResources(
id any,
request mcp.ListResourcesRequest,
) (*mcp.ListResourcesResult, *requestError) {
if s.ListResourcesHandler != nil {
result, err := s.ListResourcesHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.resourcesMu.RLock()
resourceMap := make(map[string]mcp.Resource, len(s.resources))
for uri, entry := range s.resources {
Expand Down Expand Up @@ -880,6 +940,18 @@ func (s *MCPServer) handleListResourceTemplates(
id any,
request mcp.ListResourceTemplatesRequest,
) (*mcp.ListResourceTemplatesResult, *requestError) {
if s.ListResourceTemplatesHandler != nil {
result, err := s.ListResourceTemplatesHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.resourcesMu.RLock()
templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates))
for _, entry := range s.resourceTemplates {
Expand Down Expand Up @@ -916,6 +988,18 @@ func (s *MCPServer) handleReadResource(
id any,
request mcp.ReadResourceRequest,
) (*mcp.ReadResourceResult, *requestError) {
if s.ReadResourceHandler != nil {
result, err := s.ReadResourceHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.resourcesMu.RLock()

// First check session-specific resources
Expand Down Expand Up @@ -1030,6 +1114,18 @@ func (s *MCPServer) handleListPrompts(
id any,
request mcp.ListPromptsRequest,
) (*mcp.ListPromptsResult, *requestError) {
if s.ListPromptsHandler != nil {
result, err := s.ListPromptsHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.promptsMu.RLock()
prompts := make([]mcp.Prompt, 0, len(s.prompts))
for _, prompt := range s.prompts {
Expand Down Expand Up @@ -1068,6 +1164,18 @@ func (s *MCPServer) handleGetPrompt(
id any,
request mcp.GetPromptRequest,
) (*mcp.GetPromptResult, *requestError) {
if s.GetPromptHandler != nil {
result, err := s.GetPromptHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}

s.promptsMu.RLock()
handler, ok := s.promptHandlers[request.Params.Name]
s.promptsMu.RUnlock()
Expand Down Expand Up @@ -1097,6 +1205,17 @@ func (s *MCPServer) handleListTools(
id any,
request mcp.ListToolsRequest,
) (*mcp.ListToolsResult, *requestError) {
if s.ListToolsHandler != nil {
result, err := s.ListToolsHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}
// Get the base tools from the server
s.toolsMu.RLock()
tools := make([]mcp.Tool, 0, len(s.tools))
Expand Down Expand Up @@ -1187,6 +1306,17 @@ func (s *MCPServer) handleToolCall(
id any,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, *requestError) {
if s.CallToolHandler != nil {
result, err := s.CallToolHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
code: mcp.INTERNAL_ERROR,
err: err,
}
}
return result, nil
}
// First check session-specific tools
var tool ServerTool
var ok bool
Expand Down Expand Up @@ -1246,6 +1376,10 @@ func (s *MCPServer) handleNotification(
ctx context.Context,
notification mcp.JSONRPCNotification,
) mcp.JSONRPCMessage {
if s.NotificationHandler != nil {
s.NotificationHandler(ctx, notification)
return nil
}
s.notificationHandlersMu.RLock()
handler, ok := s.notificationHandlers[notification.Method]
s.notificationHandlersMu.RUnlock()
Expand Down
8 changes: 6 additions & 2 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,12 @@ func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID strin
if s.useFullURLForMessageEndpoint && s.baseURL != "" {
endpointPath = s.baseURL + endpointPath
}

return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID)
if strings.Contains(endpointPath, "?") {
endpointPath += "&"
} else {
endpointPath += "?"
}
return fmt.Sprintf("%ssessionId=%s", endpointPath, sessionID)
}

// handleMessage processes incoming JSON-RPC messages from clients and sends responses
Expand Down