diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 5485eff2..8529be5b 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1725,3 +1725,36 @@ func TestPointerArgEquivalence(t *testing.T) { func ptr[T any](v T) *T { return &v } + +func TestComplete(t *testing.T) { + completionValues := []string{"python", "pytorch", "pyside"} + + serverOpts := &ServerOptions{ + CompletionHandler: func(_ context.Context, request *CompleteRequest) (*CompleteResult, error) { + return &CompleteResult{ + Completion: CompletionResultDetails{ + Values: completionValues, + }, + }, nil + }, + } + server := NewServer(testImpl, serverOpts) + cs, _ := basicClientServerConnection(t, nil, server, func(s *Server) {}) + result, err := cs.Complete(context.Background(), &CompleteParams{ + Argument: CompleteParamsArgument{ + Name: "language", + Value: "py", + }, + Ref: &CompleteReference{ + Type: "ref/prompt", + Name: "code_review", + }, + }) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(completionValues, result.Completion.Values); diff != "" { + t.Errorf("Complete() mismatch (-want +got):\n%s", diff) + } +} diff --git a/mcp/server.go b/mcp/server.go index c1440339..dd5d807c 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -407,7 +407,7 @@ func (s *Server) capabilities() *ServerCapabilities { return caps } -func (s *Server) complete(ctx context.Context, req *CompleteRequest) (Result, error) { +func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteResult, error) { if s.opts.CompletionHandler == nil { return nil, jsonrpc2.ErrMethodNotFound }