From 21882fa4d25d7ce977e61da7304da57b6dc22cc7 Mon Sep 17 00:00:00 2001 From: Hengkang Qiao <2468439195@qq.com> Date: Fri, 18 Apr 2025 11:30:04 +0800 Subject: [PATCH 1/5] Add mutex to avoid data race bewteen Start and Shutdown in SSEServer struct --- server/sse.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/server/sse.go b/server/sse.go index f69451c6d..876841914 100644 --- a/server/sse.go +++ b/server/sse.go @@ -65,6 +65,8 @@ type SSEServer struct { keepAlive bool keepAliveInterval time.Duration + + mu sync.RWMutex } // SSEOption defines a function type for configuring SSEServer @@ -189,10 +191,12 @@ func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { // Start begins serving SSE connections on the specified address. // It sets up HTTP handlers for SSE and message endpoints. func (s *SSEServer) Start(addr string) error { + s.mu.Lock() s.srv = &http.Server{ Addr: addr, Handler: s, } + s.mu.Unlock() return s.srv.ListenAndServe() } @@ -200,7 +204,12 @@ func (s *SSEServer) Start(addr string) error { // Shutdown gracefully stops the SSE server, closing all active sessions // and shutting down the HTTP server. func (s *SSEServer) Shutdown(ctx context.Context) error { - if s.srv != nil { + s.mu.RLock() + srv := s.srv + s.mu.RUnlock() + + if srv != nil { + // 关闭所有会话 s.sessions.Range(func(key, value interface{}) bool { if session, ok := value.(*sseSession); ok { close(session.done) @@ -209,7 +218,7 @@ func (s *SSEServer) Shutdown(ctx context.Context) error { return true }) - return s.srv.Shutdown(ctx) + return srv.Shutdown(ctx) } return nil } @@ -336,7 +345,9 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { return } + s.mu.RLock() sessionI, ok := s.sessions.Load(sessionID) + s.mu.RUnlock() if !ok { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") return From ef892e3e6d4bbab6234644a06d22a595b27fd909 Mon Sep 17 00:00:00 2001 From: Hengkang Qiao <2468439195@qq.com> Date: Fri, 18 Apr 2025 12:43:20 +0800 Subject: [PATCH 2/5] delete the mutex for session --- server/sse.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/server/sse.go b/server/sse.go index 876841914..5e9161d0f 100644 --- a/server/sse.go +++ b/server/sse.go @@ -344,10 +344,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") return } - - s.mu.RLock() sessionI, ok := s.sessions.Load(sessionID) - s.mu.RUnlock() if !ok { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") return From 31c2eb9dc4470db07d493c4c6b34188b820be851 Mon Sep 17 00:00:00 2001 From: Hengkang Qiao <2468439195@qq.com> Date: Fri, 18 Apr 2025 13:21:24 +0800 Subject: [PATCH 3/5] Added checks for Windows --- client/transport/stdio_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 445ba07ea..f40c4b571 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "sync" "testing" "time" @@ -31,6 +32,10 @@ func compileTestServer(outputPath string) error { func TestStdio(t *testing.T) { // Compile mock server mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } if err := compileTestServer(mockServerPath); err != nil { t.Fatalf("Failed to compile mock server: %v", err) } @@ -304,6 +309,10 @@ func TestStdioErrors(t *testing.T) { t.Run("RequestBeforeStart", func(t *testing.T) { // 创建一个新的 Stdio 实例但不调用 Start 方法 mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } if err := compileTestServer(mockServerPath); err != nil { t.Fatalf("Failed to compile mock server: %v", err) } @@ -331,6 +340,10 @@ func TestStdioErrors(t *testing.T) { t.Run("RequestAfterClose", func(t *testing.T) { // Compile mock server mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } if err := compileTestServer(mockServerPath); err != nil { t.Fatalf("Failed to compile mock server: %v", err) } From fceedb66c4b5209dcf949961a9e891988d18d7e9 Mon Sep 17 00:00:00 2001 From: WoodQ <2468439195@qq.com> Date: Fri, 18 Apr 2025 23:28:50 +0800 Subject: [PATCH 4/5] Update sse.go --- server/sse.go | 1 - 1 file changed, 1 deletion(-) diff --git a/server/sse.go b/server/sse.go index 5e9161d0f..5dc9adaf1 100644 --- a/server/sse.go +++ b/server/sse.go @@ -209,7 +209,6 @@ func (s *SSEServer) Shutdown(ctx context.Context) error { s.mu.RUnlock() if srv != nil { - // 关闭所有会话 s.sessions.Range(func(key, value interface{}) bool { if session, ok := value.(*sseSession); ok { close(session.done) From ea1e9301435b40654c8bfd809ce5a2f08810bb9d Mon Sep 17 00:00:00 2001 From: WoodQ <2468439195@qq.com> Date: Fri, 18 Apr 2025 23:33:02 +0800 Subject: [PATCH 5/5] Update stdio_test.go --- client/transport/stdio_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index f40c4b571..e6cdb30db 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -307,7 +307,6 @@ func TestStdioErrors(t *testing.T) { }) t.Run("RequestBeforeStart", func(t *testing.T) { - // 创建一个新的 Stdio 实例但不调用 Start 方法 mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") // Add .exe suffix on Windows if runtime.GOOS == "windows" { @@ -320,7 +319,7 @@ func TestStdioErrors(t *testing.T) { uninitiatedStdio := NewStdio(mockServerPath, nil) - // 准备一个请求 + // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", ID: 99,