diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 445ba07ea..e6cdb30db 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "sync" "testing" "time" @@ -31,6 +32,10 @@ func compileTestServer(outputPath string) error { func TestStdio(t *testing.T) { // Compile mock server mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } if err := compileTestServer(mockServerPath); err != nil { t.Fatalf("Failed to compile mock server: %v", err) } @@ -302,8 +307,11 @@ func TestStdioErrors(t *testing.T) { }) t.Run("RequestBeforeStart", func(t *testing.T) { - // 创建一个新的 Stdio 实例但不调用 Start 方法 mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } if err := compileTestServer(mockServerPath); err != nil { t.Fatalf("Failed to compile mock server: %v", err) } @@ -311,7 +319,7 @@ func TestStdioErrors(t *testing.T) { uninitiatedStdio := NewStdio(mockServerPath, nil) - // 准备一个请求 + // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", ID: 99, @@ -331,6 +339,10 @@ func TestStdioErrors(t *testing.T) { t.Run("RequestAfterClose", func(t *testing.T) { // Compile mock server mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + mockServerPath += ".exe" + } if err := compileTestServer(mockServerPath); err != nil { t.Fatalf("Failed to compile mock server: %v", err) } diff --git a/server/sse.go b/server/sse.go index f69451c6d..5dc9adaf1 100644 --- a/server/sse.go +++ b/server/sse.go @@ -65,6 +65,8 @@ type SSEServer struct { keepAlive bool keepAliveInterval time.Duration + + mu sync.RWMutex } // SSEOption defines a function type for configuring SSEServer @@ -189,10 +191,12 @@ func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { // Start begins serving SSE connections on the specified address. // It sets up HTTP handlers for SSE and message endpoints. func (s *SSEServer) Start(addr string) error { + s.mu.Lock() s.srv = &http.Server{ Addr: addr, Handler: s, } + s.mu.Unlock() return s.srv.ListenAndServe() } @@ -200,7 +204,11 @@ func (s *SSEServer) Start(addr string) error { // Shutdown gracefully stops the SSE server, closing all active sessions // and shutting down the HTTP server. func (s *SSEServer) Shutdown(ctx context.Context) error { - if s.srv != nil { + s.mu.RLock() + srv := s.srv + s.mu.RUnlock() + + if srv != nil { s.sessions.Range(func(key, value interface{}) bool { if session, ok := value.(*sseSession); ok { close(session.done) @@ -209,7 +217,7 @@ func (s *SSEServer) Shutdown(ctx context.Context) error { return true }) - return s.srv.Shutdown(ctx) + return srv.Shutdown(ctx) } return nil } @@ -335,7 +343,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") return } - sessionI, ok := s.sessions.Load(sessionID) if !ok { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")