diff --git a/mcptest/mcptest.go b/mcptest/mcptest.go index 19f492cd..232eac5d 100644 --- a/mcptest/mcptest.go +++ b/mcptest/mcptest.go @@ -24,7 +24,6 @@ type Server struct { prompts []server.ServerPrompt resources []server.ServerResource - ctx context.Context cancel func() serverReader *io.PipeReader @@ -45,7 +44,8 @@ func NewServer(t *testing.T, tools ...server.ServerTool) (*Server, error) { server := NewUnstartedServer(t) server.AddTools(tools...) - if err := server.Start(); err != nil { + // TODO: use t.Context() once go.mod is upgraded to go 1.24+ + if err := server.Start(context.TODO()); err != nil { return nil, err } @@ -59,12 +59,6 @@ func NewUnstartedServer(t *testing.T) *Server { name: t.Name(), } - // Use t.Context() once we switch to go >= 1.24 - ctx := context.TODO() - - // Set up context with cancellation, used to stop the server - server.ctx, server.cancel = context.WithCancel(ctx) - // Set up pipes for client-server communication server.serverReader, server.clientWriter = io.Pipe() server.clientReader, server.serverWriter = io.Pipe() @@ -114,9 +108,11 @@ func (s *Server) AddResources(resources ...server.ServerResource) { // Start starts the server in a goroutine. Make sure to defer Close() after Start(). // When using NewServer(), the returned server is already started. -func (s *Server) Start() error { +func (s *Server) Start(ctx context.Context) error { s.wg.Add(1) + ctx, s.cancel = context.WithCancel(ctx) + // Start the MCP server in a goroutine go func() { defer s.wg.Done() @@ -132,13 +128,13 @@ func (s *Server) Start() error { stdioServer := server.NewStdioServer(mcpServer) stdioServer.SetErrorLogger(logger) - if err := stdioServer.Listen(s.ctx, s.serverReader, s.serverWriter); err != nil { + if err := stdioServer.Listen(ctx, s.serverReader, s.serverWriter); err != nil { logger.Println("StdioServer.Listen failed:", err) } }() s.transport = transport.NewIO(s.clientReader, s.clientWriter, io.NopCloser(&s.logBuffer)) - if err := s.transport.Start(s.ctx); err != nil { + if err := s.transport.Start(ctx); err != nil { return fmt.Errorf("transport.Start(): %w", err) } @@ -146,7 +142,7 @@ func (s *Server) Start() error { var initReq mcp.InitializeRequest initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - if _, err := s.client.Initialize(s.ctx, initReq); err != nil { + if _, err := s.client.Initialize(ctx, initReq); err != nil { return fmt.Errorf("client.Initialize(): %w", err) } diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go index 129ef39f..0ab9b276 100644 --- a/mcptest/mcptest_test.go +++ b/mcptest/mcptest_test.go @@ -109,7 +109,7 @@ func TestServerWithPrompt(t *testing.T) { srv.AddPrompt(prompt, handler) - err := srv.Start() + err := srv.Start(ctx) if err != nil { t.Fatal(err) } @@ -164,7 +164,7 @@ func TestServerWithResource(t *testing.T) { srv.AddResource(resource, handler) - err := srv.Start() + err := srv.Start(ctx) if err != nil { t.Fatal(err) }