Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (
// ToolHandlerFunc handles tool calls with given arguments.
type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)

// ServerTool combines a Tool with its ToolHandlerFunc.
type ServerTool struct {
Tool mcp.Tool
Handler ToolHandlerFunc
}

// NotificationContext provides client identification for notifications
type NotificationContext struct {
ClientID string
Expand All @@ -61,8 +67,7 @@ type MCPServer struct {
resourceTemplates map[string]resourceTemplateEntry
prompts map[string]mcp.Prompt
promptHandlers map[string]PromptHandlerFunc
tools map[string]mcp.Tool
toolHandlers map[string]ToolHandlerFunc
tools map[string]ServerTool
notificationHandlers map[string]NotificationHandlerFunc
capabilities serverCapabilities
notifications chan ServerNotification
Expand Down Expand Up @@ -174,8 +179,7 @@ func NewMCPServer(
resourceTemplates: make(map[string]resourceTemplateEntry),
prompts: make(map[string]mcp.Prompt),
promptHandlers: make(map[string]PromptHandlerFunc),
tools: make(map[string]mcp.Tool),
toolHandlers: make(map[string]ToolHandlerFunc),
tools: make(map[string]ServerTool),
name: name,
version: version,
notificationHandlers: make(map[string]NotificationHandlerFunc),
Expand Down Expand Up @@ -421,8 +425,34 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) {

// AddTool registers a new tool and its handler
func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) {
s.tools[tool.Name] = tool
s.toolHandlers[tool.Name] = handler
s.AddTools(ServerTool{Tool: tool, Handler: handler})
}

// AddTools registers multiple tools at once
func (s *MCPServer) AddTools(tools ...ServerTool) {
for _, entry := range tools {
s.tools[entry.Tool.Name] = entry
}

// Send notification if server is already initialized
if s.initialized {
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
// We can't return the error, but in a future version we could log it
}
}
}

// SetTools replaces all existing tools with the provided list
func (s *MCPServer) SetTools(tools ...ServerTool) {
s.tools = make(map[string]ServerTool)
s.AddTools(tools...)
}

// DeleteTools removes a tool from the server
func (s *MCPServer) DeleteTools(names ...string) {
for _, name := range names {
delete(s.tools, name)
}

// Send notification if server is already initialized
if s.initialized {
Expand Down Expand Up @@ -630,7 +660,7 @@ func (s *MCPServer) handleListTools(
) mcp.JSONRPCMessage {
tools := make([]mcp.Tool, 0, len(s.tools))
for name := range s.tools {
tools = append(tools, s.tools[name])
tools = append(tools, s.tools[name].Tool)
}

result := mcp.ListToolsResult{
Expand All @@ -647,7 +677,7 @@ func (s *MCPServer) handleToolCall(
id interface{},
request mcp.CallToolRequest,
) mcp.JSONRPCMessage {
handler, ok := s.toolHandlers[request.Params.Name]
tool, ok := s.tools[request.Params.Name]
if !ok {
return createErrorResponse(
id,
Expand All @@ -656,7 +686,7 @@ func (s *MCPServer) handleToolCall(
)
}

result, err := handler(ctx, request)
result, err := tool.Handler(ctx, request)
if err != nil {
return createErrorResponse(id, mcp.INTERNAL_ERROR, err.Error())
}
Expand Down
107 changes: 105 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package server
import (
"context"
"encoding/json"
"testing"

"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"testing"
"time"
)

func TestMCPServer_NewMCPServer(t *testing.T) {
Expand Down Expand Up @@ -105,6 +105,109 @@ func TestMCPServer_Capabilities(t *testing.T) {
})
}
}
func TestMCPServer_Tools(t *testing.T) {
tests := []struct {
name string
action func(*MCPServer)
expectedNotifications int
validate func(*testing.T, []ServerNotification, mcp.JSONRPCMessage)
}{
{
name: "SetTools sends single notifications/tools/list_changed",
action: func(server *MCPServer) {
server.SetTools(ServerTool{
Tool: mcp.NewTool("test-tool-1"),
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{}, nil
},
}, ServerTool{
Tool: mcp.NewTool("test-tool-2"),
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{}, nil
},
})
},
expectedNotifications: 1,
validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) {
assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method)
tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools
assert.Len(t, tools, 2)
assert.Equal(t, "test-tool-1", tools[0].Name)
assert.Equal(t, "test-tool-2", tools[1].Name)
},
},
{
name: "AddTool sends multiple notifications/tools/list_changed",
action: func(server *MCPServer) {
server.AddTool(mcp.NewTool("test-tool-1"),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{}, nil
})
server.AddTool(mcp.NewTool("test-tool-2"),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{}, nil
})
},
expectedNotifications: 2,
validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) {
assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method)
tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools
assert.Len(t, tools, 2)
assert.Equal(t, "test-tool-1", tools[0].Name)
assert.Equal(t, "test-tool-2", tools[1].Name)
},
},
{
name: "DeleteTools sends single notifications/tools/list_changed",
action: func(server *MCPServer) {
server.SetTools(
ServerTool{Tool: mcp.NewTool("test-tool-1")},
ServerTool{Tool: mcp.NewTool("test-tool-2")})
server.DeleteTools("test-tool-1", "test-tool-2")
},
expectedNotifications: 2,
validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) {
// One for SetTools
assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method)
// One for DeleteTools
assert.Equal(t, "notifications/tools/list_changed", notifications[1].Notification.Method)
assert.Equal(t, "Tools not supported", toolsList.(mcp.JSONRPCError).Error.Message)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
server := NewMCPServer("test-server", "1.0.0")
_ = server.HandleMessage(ctx, []byte(`{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
}`))
notifications := make([]ServerNotification, 0)
tt.action(server)
for done := false; !done; {
select {
case serverNotification := <-server.notifications:
notifications = append(notifications, serverNotification)
if len(notifications) == tt.expectedNotifications {
done = true
}
case <-time.After(1 * time.Second):
done = true
}
}
assert.Len(t, notifications, tt.expectedNotifications)
toolsList := server.HandleMessage(ctx, []byte(`{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list"
}`))
tt.validate(t, notifications, toolsList.(mcp.JSONRPCMessage))
})

}
}

func TestMCPServer_HandleValidMessages(t *testing.T) {
server := NewMCPServer("test-server", "1.0.0",
Expand Down