Skip to content

Commit 5307cdc

Browse files
octopottekkat
andauthored
feat(mcptest): Change Server.Start to accept a context.Context. (#339)
* feat(mcptest): Change `Server.Start` to accept a `context.Context`. Previously, `mcptest` used `context.TODO()` as the context with this server code was executed (with a comment to upgrade to `testing.T.Context()` eventually. This effectively prevents tests from preparing a context to be used with server code. This is imporant, because `WithHTTPContextFunc` and friends appear to be the intended way to extract things like authentication information from the raw HTTP request and pass it to the server code. Without this change, such code is effectively untestable, at least with this package. The `NewServer` convenience method has been left unchanged (i.e. it does not accept a context) to avoid breaking users taking the happy path. Chaning the signature of `Start()` is however a breaking change. * test(mcptest): Update tests to use the new `Start()` signature. --------- Co-authored-by: Navendu Pottekkat <[email protected]>
1 parent 4295cec commit 5307cdc

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

mcptest/mcptest.go

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ type Server struct {
2424
prompts []server.ServerPrompt
2525
resources []server.ServerResource
2626

27-
ctx context.Context
2827
cancel func()
2928

3029
serverReader *io.PipeReader
@@ -45,7 +44,8 @@ func NewServer(t *testing.T, tools ...server.ServerTool) (*Server, error) {
4544
server := NewUnstartedServer(t)
4645
server.AddTools(tools...)
4746

48-
if err := server.Start(); err != nil {
47+
// TODO: use t.Context() once go.mod is upgraded to go 1.24+
48+
if err := server.Start(context.TODO()); err != nil {
4949
return nil, err
5050
}
5151

@@ -59,12 +59,6 @@ func NewUnstartedServer(t *testing.T) *Server {
5959
name: t.Name(),
6060
}
6161

62-
// Use t.Context() once we switch to go >= 1.24
63-
ctx := context.TODO()
64-
65-
// Set up context with cancellation, used to stop the server
66-
server.ctx, server.cancel = context.WithCancel(ctx)
67-
6862
// Set up pipes for client-server communication
6963
server.serverReader, server.clientWriter = io.Pipe()
7064
server.clientReader, server.serverWriter = io.Pipe()
@@ -114,9 +108,11 @@ func (s *Server) AddResources(resources ...server.ServerResource) {
114108

115109
// Start starts the server in a goroutine. Make sure to defer Close() after Start().
116110
// When using NewServer(), the returned server is already started.
117-
func (s *Server) Start() error {
111+
func (s *Server) Start(ctx context.Context) error {
118112
s.wg.Add(1)
119113

114+
ctx, s.cancel = context.WithCancel(ctx)
115+
120116
// Start the MCP server in a goroutine
121117
go func() {
122118
defer s.wg.Done()
@@ -132,21 +128,21 @@ func (s *Server) Start() error {
132128
stdioServer := server.NewStdioServer(mcpServer)
133129
stdioServer.SetErrorLogger(logger)
134130

135-
if err := stdioServer.Listen(s.ctx, s.serverReader, s.serverWriter); err != nil {
131+
if err := stdioServer.Listen(ctx, s.serverReader, s.serverWriter); err != nil {
136132
logger.Println("StdioServer.Listen failed:", err)
137133
}
138134
}()
139135

140136
s.transport = transport.NewIO(s.clientReader, s.clientWriter, io.NopCloser(&s.logBuffer))
141-
if err := s.transport.Start(s.ctx); err != nil {
137+
if err := s.transport.Start(ctx); err != nil {
142138
return fmt.Errorf("transport.Start(): %w", err)
143139
}
144140

145141
s.client = client.NewClient(s.transport)
146142

147143
var initReq mcp.InitializeRequest
148144
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
149-
if _, err := s.client.Initialize(s.ctx, initReq); err != nil {
145+
if _, err := s.client.Initialize(ctx, initReq); err != nil {
150146
return fmt.Errorf("client.Initialize(): %w", err)
151147
}
152148

mcptest/mcptest_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func TestServerWithPrompt(t *testing.T) {
109109

110110
srv.AddPrompt(prompt, handler)
111111

112-
err := srv.Start()
112+
err := srv.Start(ctx)
113113
if err != nil {
114114
t.Fatal(err)
115115
}
@@ -164,7 +164,7 @@ func TestServerWithResource(t *testing.T) {
164164

165165
srv.AddResource(resource, handler)
166166

167-
err := srv.Start()
167+
err := srv.Start(ctx)
168168
if err != nil {
169169
t.Fatal(err)
170170
}

0 commit comments

Comments
 (0)