diff --git a/server/server.go b/server/server.go index 6005738b..b913b1f7 100644 --- a/server/server.go +++ b/server/server.go @@ -403,13 +403,33 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { s.promptHandlers[prompt.Name] = handler s.promptsMu.Unlock() - // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification. + // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification. if s.capabilities.prompts.listChanged { // Send notification to all initialized sessions s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) } } +// DeletePrompts removes prompts from the server +func (s *MCPServer) DeletePrompts(names ...string) { + s.promptsMu.Lock() + var exists bool + for _, name := range names { + if _, ok := s.prompts[name]; ok { + delete(s.prompts, name) + delete(s.promptHandlers, name) + exists = true + } + } + s.promptsMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt + if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + } +} + // AddTool registers a new tool and its handler func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { s.AddTools(ServerTool{Tool: tool, Handler: handler}) @@ -460,7 +480,7 @@ func (s *MCPServer) SetTools(tools ...ServerTool) { s.AddTools(tools...) } -// DeleteTools removes a tool from the server +// DeleteTools removes tools from the server func (s *MCPServer) DeleteTools(names ...string) { s.toolsMu.Lock() var exists bool diff --git a/server/server_race_test.go b/server/server_race_test.go index b5593c81..4e0be43a 100644 --- a/server/server_race_test.go +++ b/server/server_race_test.go @@ -44,6 +44,17 @@ func TestRaceConditions(t *testing.T) { }) }) + runConcurrentOperation(&wg, testDuration, "delete-prompts", func() { + name := fmt.Sprintf("delete-prompt-%d", time.Now().UnixNano()) + srv.AddPrompt(mcp.Prompt{ + Name: name, + Description: "Temporary prompt", + }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{}, nil + }) + srv.DeletePrompts(name) + }) + runConcurrentOperation(&wg, testDuration, "add-tools", func() { name := fmt.Sprintf("tool-%d", time.Now().UnixNano()) srv.AddTool(mcp.Tool{ diff --git a/server/server_test.go b/server/server_test.go index 25daeb80..c0ececc9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -809,6 +809,194 @@ func TestMCPServer_PromptHandling(t *testing.T) { } } +func TestMCPServer_Prompts(t *testing.T) { + tests := []struct { + name string + action func(*testing.T, *MCPServer, chan mcp.JSONRPCNotification) + expectedNotifications int + validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage) + }{ + { + name: "DeletePrompts sends single notifications/prompts/list_changed", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + server.AddPrompt( + mcp.Prompt{ + Name: "test-prompt-1", + Description: "A test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "arg1", + Description: "First argument", + }, + }, + }, + nil, + ) + server.DeletePrompts("test-prompt-1") + }, + expectedNotifications: 2, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { + // One for AddPrompt + assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + // One for DeletePrompts + assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method) + + // Expect a successful response with an empty list of prompts + resp, ok := promptsList.(mcp.JSONRPCResponse) + assert.True(t, ok, "Expected JSONRPCResponse, got %T", promptsList) + + result, ok := resp.Result.(mcp.ListPromptsResult) + assert.True(t, ok, "Expected ListPromptsResult, got %T", resp.Result) + + assert.Empty(t, result.Prompts, "Expected empty prompts list") + }, + }, + { + name: "DeletePrompts removes the first prompt and retains the other", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + server.AddPrompt( + mcp.Prompt{ + Name: "test-prompt-1", + Description: "A test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "arg1", + Description: "First argument", + }, + }, + }, + nil, + ) + server.AddPrompt( + mcp.Prompt{ + Name: "test-prompt-2", + Description: "A test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "arg1", + Description: "First argument", + }, + }, + }, + nil, + ) + // Remove non-existing prompts + server.DeletePrompts("test-prompt-1") + }, + expectedNotifications: 3, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { + // first notification expected for AddPrompt test-prompt-1 + assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + // second notification expected for AddPrompt test-prompt-2 + assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method) + // second notification expected for DeletePrompts test-prompt-1 + assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[2].Method) + + // Confirm the prompt list does not change + prompts := promptsList.(mcp.JSONRPCResponse).Result.(mcp.ListPromptsResult).Prompts + assert.Len(t, prompts, 1) + assert.Equal(t, "test-prompt-2", prompts[0].Name) + }, + }, + { + name: "DeletePrompts with non-existent prompts does nothing and not receives notifications from MCPServer", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + server.AddPrompt( + mcp.Prompt{ + Name: "test-prompt-1", + Description: "A test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "arg1", + Description: "First argument", + }, + }, + }, + nil, + ) + server.AddPrompt( + mcp.Prompt{ + Name: "test-prompt-2", + Description: "A test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "arg1", + Description: "First argument", + }, + }, + }, + nil, + ) + // Remove non-existing prompts + server.DeletePrompts("test-prompt-3", "test-prompt-4") + }, + expectedNotifications: 2, + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { + // first notification expected for AddPrompt test-prompt-1 + assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + // second notification expected for AddPrompt test-prompt-2 + assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method) + + // Confirm the prompt list does not change + prompts := promptsList.(mcp.JSONRPCResponse).Result.(mcp.ListPromptsResult).Prompts + assert.Len(t, prompts, 2) + assert.Equal(t, "test-prompt-1", prompts[0].Name) + assert.Equal(t, "test-prompt-2", prompts[1].Name) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) + _ = server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + notificationChannel := make(chan mcp.JSONRPCNotification, 100) + notifications := make([]mcp.JSONRPCNotification, 0) + tt.action(t, server, notificationChannel) + for done := false; !done; { + select { + case serverNotification := <-notificationChannel: + notifications = append(notifications, serverNotification) + if len(notifications) == tt.expectedNotifications { + done = true + } + case <-time.After(1 * time.Second): + done = true + } + } + assert.Len(t, notifications, tt.expectedNotifications) + promptsList := server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/list" + }`)) + tt.validate(t, notifications, promptsList) + }) + } +} + func TestMCPServer_HandleInvalidMessages(t *testing.T) { var errs []error hooks := &Hooks{}