Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,13 +403,33 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) {
s.promptHandlers[prompt.Name] = handler
s.promptsMu.Unlock()

// When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification.
// When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification.
if s.capabilities.prompts.listChanged {
// Send notification to all initialized sessions
s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil)
}
}

// DeletePrompts removes prompts from the server
func (s *MCPServer) DeletePrompts(names ...string) {
s.promptsMu.Lock()
var exists bool
for _, name := range names {
if _, ok := s.prompts[name]; ok {
delete(s.prompts, name)
delete(s.promptHandlers, name)
exists = true
}
}
s.promptsMu.Unlock()

// Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt
if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged {
// Send notification to all initialized sessions
s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil)
}
}

// AddTool registers a new tool and its handler
func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) {
s.AddTools(ServerTool{Tool: tool, Handler: handler})
Expand Down Expand Up @@ -460,7 +480,7 @@ func (s *MCPServer) SetTools(tools ...ServerTool) {
s.AddTools(tools...)
}

// DeleteTools removes a tool from the server
// DeleteTools removes tools from the server
func (s *MCPServer) DeleteTools(names ...string) {
s.toolsMu.Lock()
var exists bool
Expand Down
11 changes: 11 additions & 0 deletions server/server_race_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ func TestRaceConditions(t *testing.T) {
})
})

runConcurrentOperation(&wg, testDuration, "delete-prompts", func() {
name := fmt.Sprintf("delete-prompt-%d", time.Now().UnixNano())
srv.AddPrompt(mcp.Prompt{
Name: name,
Description: "Temporary prompt",
}, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
return &mcp.GetPromptResult{}, nil
})
srv.DeletePrompts(name)
})

runConcurrentOperation(&wg, testDuration, "add-tools", func() {
name := fmt.Sprintf("tool-%d", time.Now().UnixNano())
srv.AddTool(mcp.Tool{
Expand Down
188 changes: 188 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,194 @@ func TestMCPServer_PromptHandling(t *testing.T) {
}
}

func TestMCPServer_Prompts(t *testing.T) {
tests := []struct {
name string
action func(*testing.T, *MCPServer, chan mcp.JSONRPCNotification)
expectedNotifications int
validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage)
}{
{
name: "DeletePrompts sends single notifications/prompts/list_changed",
action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) {
err := server.RegisterSession(context.TODO(), &fakeSession{
sessionID: "test",
notificationChannel: notificationChannel,
initialized: true,
})
require.NoError(t, err)
server.AddPrompt(
mcp.Prompt{
Name: "test-prompt-1",
Description: "A test prompt",
Arguments: []mcp.PromptArgument{
{
Name: "arg1",
Description: "First argument",
},
},
},
nil,
)
server.DeletePrompts("test-prompt-1")
},
expectedNotifications: 2,
validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) {
// One for AddPrompt
assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method)
// One for DeletePrompts
assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method)

// Expect a successful response with an empty list of prompts
resp, ok := promptsList.(mcp.JSONRPCResponse)
assert.True(t, ok, "Expected JSONRPCResponse, got %T", promptsList)

result, ok := resp.Result.(mcp.ListPromptsResult)
assert.True(t, ok, "Expected ListPromptsResult, got %T", resp.Result)

assert.Empty(t, result.Prompts, "Expected empty prompts list")
},
},
{
name: "DeletePrompts removes the first prompt and retains the other",
action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) {
err := server.RegisterSession(context.TODO(), &fakeSession{
sessionID: "test",
notificationChannel: notificationChannel,
initialized: true,
})
require.NoError(t, err)
server.AddPrompt(
mcp.Prompt{
Name: "test-prompt-1",
Description: "A test prompt",
Arguments: []mcp.PromptArgument{
{
Name: "arg1",
Description: "First argument",
},
},
},
nil,
)
server.AddPrompt(
mcp.Prompt{
Name: "test-prompt-2",
Description: "A test prompt",
Arguments: []mcp.PromptArgument{
{
Name: "arg1",
Description: "First argument",
},
},
},
nil,
)
// Remove non-existing prompts
server.DeletePrompts("test-prompt-1")
},
expectedNotifications: 3,
validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) {
// first notification expected for AddPrompt test-prompt-1
assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method)
// second notification expected for AddPrompt test-prompt-2
assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method)
// second notification expected for DeletePrompts test-prompt-1
assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[2].Method)

// Confirm the prompt list does not change
prompts := promptsList.(mcp.JSONRPCResponse).Result.(mcp.ListPromptsResult).Prompts
assert.Len(t, prompts, 1)
assert.Equal(t, "test-prompt-2", prompts[0].Name)
},
},
{
name: "DeletePrompts with non-existent prompts does nothing and not receives notifications from MCPServer",
action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) {
err := server.RegisterSession(context.TODO(), &fakeSession{
sessionID: "test",
notificationChannel: notificationChannel,
initialized: true,
})
require.NoError(t, err)
server.AddPrompt(
mcp.Prompt{
Name: "test-prompt-1",
Description: "A test prompt",
Arguments: []mcp.PromptArgument{
{
Name: "arg1",
Description: "First argument",
},
},
},
nil,
)
server.AddPrompt(
mcp.Prompt{
Name: "test-prompt-2",
Description: "A test prompt",
Arguments: []mcp.PromptArgument{
{
Name: "arg1",
Description: "First argument",
},
},
},
nil,
)
// Remove non-existing prompts
server.DeletePrompts("test-prompt-3", "test-prompt-4")
},
expectedNotifications: 2,
validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) {
// first notification expected for AddPrompt test-prompt-1
assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method)
// second notification expected for AddPrompt test-prompt-2
assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method)

// Confirm the prompt list does not change
prompts := promptsList.(mcp.JSONRPCResponse).Result.(mcp.ListPromptsResult).Prompts
assert.Len(t, prompts, 2)
assert.Equal(t, "test-prompt-1", prompts[0].Name)
assert.Equal(t, "test-prompt-2", prompts[1].Name)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true))
_ = server.HandleMessage(ctx, []byte(`{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
}`))
notificationChannel := make(chan mcp.JSONRPCNotification, 100)
notifications := make([]mcp.JSONRPCNotification, 0)
tt.action(t, server, notificationChannel)
for done := false; !done; {
select {
case serverNotification := <-notificationChannel:
notifications = append(notifications, serverNotification)
if len(notifications) == tt.expectedNotifications {
done = true
}
case <-time.After(1 * time.Second):
done = true
}
}
assert.Len(t, notifications, tt.expectedNotifications)
promptsList := server.HandleMessage(ctx, []byte(`{
"jsonrpc": "2.0",
"id": 1,
"method": "prompts/list"
}`))
tt.validate(t, notifications, promptsList)
})
}
}

func TestMCPServer_HandleInvalidMessages(t *testing.T) {
var errs []error
hooks := &Hooks{}
Expand Down