Skip to content

Commit c9443eb

Browse files
authored
Merge pull request #24 from marcnuri-forks/feat/tools-handling
feat(server): enhance tool management with ServerTool struct and related methods
2 parents 35e60f7 + 6d9c794 commit c9443eb

File tree

2 files changed

+144
-11
lines changed

2 files changed

+144
-11
lines changed

server/server.go

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (
3737
// ToolHandlerFunc handles tool calls with given arguments.
3838
type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
3939

40+
// ServerTool combines a Tool with its ToolHandlerFunc.
41+
type ServerTool struct {
42+
Tool mcp.Tool
43+
Handler ToolHandlerFunc
44+
}
45+
4046
// NotificationContext provides client identification for notifications
4147
type NotificationContext struct {
4248
ClientID string
@@ -61,8 +67,7 @@ type MCPServer struct {
6167
resourceTemplates map[string]resourceTemplateEntry
6268
prompts map[string]mcp.Prompt
6369
promptHandlers map[string]PromptHandlerFunc
64-
tools map[string]mcp.Tool
65-
toolHandlers map[string]ToolHandlerFunc
70+
tools map[string]ServerTool
6671
notificationHandlers map[string]NotificationHandlerFunc
6772
capabilities serverCapabilities
6873
notifications chan ServerNotification
@@ -174,8 +179,7 @@ func NewMCPServer(
174179
resourceTemplates: make(map[string]resourceTemplateEntry),
175180
prompts: make(map[string]mcp.Prompt),
176181
promptHandlers: make(map[string]PromptHandlerFunc),
177-
tools: make(map[string]mcp.Tool),
178-
toolHandlers: make(map[string]ToolHandlerFunc),
182+
tools: make(map[string]ServerTool),
179183
name: name,
180184
version: version,
181185
notificationHandlers: make(map[string]NotificationHandlerFunc),
@@ -421,8 +425,34 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) {
421425

422426
// AddTool registers a new tool and its handler
423427
func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) {
424-
s.tools[tool.Name] = tool
425-
s.toolHandlers[tool.Name] = handler
428+
s.AddTools(ServerTool{Tool: tool, Handler: handler})
429+
}
430+
431+
// AddTools registers multiple tools at once
432+
func (s *MCPServer) AddTools(tools ...ServerTool) {
433+
for _, entry := range tools {
434+
s.tools[entry.Tool.Name] = entry
435+
}
436+
437+
// Send notification if server is already initialized
438+
if s.initialized {
439+
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
440+
// We can't return the error, but in a future version we could log it
441+
}
442+
}
443+
}
444+
445+
// SetTools replaces all existing tools with the provided list
446+
func (s *MCPServer) SetTools(tools ...ServerTool) {
447+
s.tools = make(map[string]ServerTool)
448+
s.AddTools(tools...)
449+
}
450+
451+
// DeleteTools removes a tool from the server
452+
func (s *MCPServer) DeleteTools(names ...string) {
453+
for _, name := range names {
454+
delete(s.tools, name)
455+
}
426456

427457
// Send notification if server is already initialized
428458
if s.initialized {
@@ -630,7 +660,7 @@ func (s *MCPServer) handleListTools(
630660
) mcp.JSONRPCMessage {
631661
tools := make([]mcp.Tool, 0, len(s.tools))
632662
for name := range s.tools {
633-
tools = append(tools, s.tools[name])
663+
tools = append(tools, s.tools[name].Tool)
634664
}
635665

636666
result := mcp.ListToolsResult{
@@ -647,7 +677,7 @@ func (s *MCPServer) handleToolCall(
647677
id interface{},
648678
request mcp.CallToolRequest,
649679
) mcp.JSONRPCMessage {
650-
handler, ok := s.toolHandlers[request.Params.Name]
680+
tool, ok := s.tools[request.Params.Name]
651681
if !ok {
652682
return createErrorResponse(
653683
id,
@@ -656,7 +686,7 @@ func (s *MCPServer) handleToolCall(
656686
)
657687
}
658688

659-
result, err := handler(ctx, request)
689+
result, err := tool.Handler(ctx, request)
660690
if err != nil {
661691
return createErrorResponse(id, mcp.INTERNAL_ERROR, err.Error())
662692
}

server/server_test.go

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ package server
33
import (
44
"context"
55
"encoding/json"
6-
"testing"
7-
86
"github.com/mark3labs/mcp-go/mcp"
97
"github.com/stretchr/testify/assert"
8+
"testing"
9+
"time"
1010
)
1111

1212
func TestMCPServer_NewMCPServer(t *testing.T) {
@@ -105,6 +105,109 @@ func TestMCPServer_Capabilities(t *testing.T) {
105105
})
106106
}
107107
}
108+
func TestMCPServer_Tools(t *testing.T) {
109+
tests := []struct {
110+
name string
111+
action func(*MCPServer)
112+
expectedNotifications int
113+
validate func(*testing.T, []ServerNotification, mcp.JSONRPCMessage)
114+
}{
115+
{
116+
name: "SetTools sends single notifications/tools/list_changed",
117+
action: func(server *MCPServer) {
118+
server.SetTools(ServerTool{
119+
Tool: mcp.NewTool("test-tool-1"),
120+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
121+
return &mcp.CallToolResult{}, nil
122+
},
123+
}, ServerTool{
124+
Tool: mcp.NewTool("test-tool-2"),
125+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
126+
return &mcp.CallToolResult{}, nil
127+
},
128+
})
129+
},
130+
expectedNotifications: 1,
131+
validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) {
132+
assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method)
133+
tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools
134+
assert.Len(t, tools, 2)
135+
assert.Equal(t, "test-tool-1", tools[0].Name)
136+
assert.Equal(t, "test-tool-2", tools[1].Name)
137+
},
138+
},
139+
{
140+
name: "AddTool sends multiple notifications/tools/list_changed",
141+
action: func(server *MCPServer) {
142+
server.AddTool(mcp.NewTool("test-tool-1"),
143+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
144+
return &mcp.CallToolResult{}, nil
145+
})
146+
server.AddTool(mcp.NewTool("test-tool-2"),
147+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
148+
return &mcp.CallToolResult{}, nil
149+
})
150+
},
151+
expectedNotifications: 2,
152+
validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) {
153+
assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method)
154+
tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools
155+
assert.Len(t, tools, 2)
156+
assert.Equal(t, "test-tool-1", tools[0].Name)
157+
assert.Equal(t, "test-tool-2", tools[1].Name)
158+
},
159+
},
160+
{
161+
name: "DeleteTools sends single notifications/tools/list_changed",
162+
action: func(server *MCPServer) {
163+
server.SetTools(
164+
ServerTool{Tool: mcp.NewTool("test-tool-1")},
165+
ServerTool{Tool: mcp.NewTool("test-tool-2")})
166+
server.DeleteTools("test-tool-1", "test-tool-2")
167+
},
168+
expectedNotifications: 2,
169+
validate: func(t *testing.T, notifications []ServerNotification, toolsList mcp.JSONRPCMessage) {
170+
// One for SetTools
171+
assert.Equal(t, "notifications/tools/list_changed", notifications[0].Notification.Method)
172+
// One for DeleteTools
173+
assert.Equal(t, "notifications/tools/list_changed", notifications[1].Notification.Method)
174+
assert.Equal(t, "Tools not supported", toolsList.(mcp.JSONRPCError).Error.Message)
175+
},
176+
},
177+
}
178+
for _, tt := range tests {
179+
t.Run(tt.name, func(t *testing.T) {
180+
ctx := context.Background()
181+
server := NewMCPServer("test-server", "1.0.0")
182+
_ = server.HandleMessage(ctx, []byte(`{
183+
"jsonrpc": "2.0",
184+
"id": 1,
185+
"method": "initialize"
186+
}`))
187+
notifications := make([]ServerNotification, 0)
188+
tt.action(server)
189+
for done := false; !done; {
190+
select {
191+
case serverNotification := <-server.notifications:
192+
notifications = append(notifications, serverNotification)
193+
if len(notifications) == tt.expectedNotifications {
194+
done = true
195+
}
196+
case <-time.After(1 * time.Second):
197+
done = true
198+
}
199+
}
200+
assert.Len(t, notifications, tt.expectedNotifications)
201+
toolsList := server.HandleMessage(ctx, []byte(`{
202+
"jsonrpc": "2.0",
203+
"id": 1,
204+
"method": "tools/list"
205+
}`))
206+
tt.validate(t, notifications, toolsList.(mcp.JSONRPCMessage))
207+
})
208+
209+
}
210+
}
108211

109212
func TestMCPServer_HandleValidMessages(t *testing.T) {
110213
server := NewMCPServer("test-server", "1.0.0",

0 commit comments

Comments
 (0)