From bee9f90bab8622796cb7e9348acdaaebcc3cd7ed Mon Sep 17 00:00:00 2001 From: sunerpy Date: Wed, 9 Jul 2025 10:33:46 +0800 Subject: [PATCH] fix(streamable_http): ensure graceful shutdown to prevent close request errors Added sync.WaitGroup to wait for asynchronous session cleanup goroutine to finish wg.Add(1) is used before spawning the goroutine; wg.Done() is called upon completion wg.Wait() ensures Close() blocks until the cleanup finishes, preserving correct shutdown order Fixes an issue in tests where closing the client and server in quick succession could lead to connection refused errors due to the client's asynchronous close attempting to reach a server that has already shut down Removed redundant blank lines to streamline the code structure --- client/transport/streamable_http.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index e358751b3..8ceb84208 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -114,6 +114,7 @@ type StreamableHTTP struct { // OAuth support oauthHandler *OAuthHandler + wg sync.WaitGroup } // NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL. @@ -182,9 +183,10 @@ func (c *StreamableHTTP) Close() error { sessionId := c.sessionID.Load().(string) if sessionId != "" { c.sessionID.Store("") - + c.wg.Add(1) // notify server session closed go func() { + defer c.wg.Done() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil) @@ -201,7 +203,7 @@ func (c *StreamableHTTP) Close() error { res.Body.Close() }() } - + c.wg.Wait() return nil } @@ -231,7 +233,6 @@ func (c *StreamableHTTP) SendRequest( ctx context.Context, request JSONRPCRequest, ) (*JSONRPCResponse, error) { - // Marshal request requestBody, err := json.Marshal(request) if err != nil { @@ -316,7 +317,6 @@ func (c *StreamableHTTP) sendHTTP( body io.Reader, acceptType string, ) (resp *http.Response, err error) { - // Create HTTP request req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body) if err != nil { @@ -374,7 +374,6 @@ func (c *StreamableHTTP) sendHTTP( // It returns the final result for the request once received, or an error. // If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening. func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) { - // Create a channel for this specific request responseChan := make(chan *JSONRPCResponse, 1) @@ -387,7 +386,6 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl defer close(responseChan) c.readSSE(ctx, reader, func(event, data string) { - // (unsupported: batching) var message JSONRPCResponse @@ -490,7 +488,6 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand } func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { - // Marshal request requestBody, err := json.Marshal(notification) if err != nil { @@ -577,7 +574,6 @@ var ( ) func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error { - resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") if err != nil { return fmt.Errorf("failed to send request: %w", err)