Skip to content

Commit f8b7dce

Browse files
committed
fix race
1 parent 32f36b9 commit f8b7dce

File tree

1 file changed

+47
-19
lines changed

1 file changed

+47
-19
lines changed

client/http_test.go

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
"github.com/mark3labs/mcp-go/server"
1313
)
1414

15+
// SafeMap is a thread-safe map wrapper
16+
1517
func TestHTTPClient(t *testing.T) {
1618
hooks := &server.Hooks{}
1719
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
@@ -87,12 +89,9 @@ func TestHTTPClient(t *testing.T) {
8789
return
8890
}
8991

90-
notificationNum := make(map[string]int)
91-
notificationNumMutex := sync.Mutex{}
92+
notificationNum := NewSafeMap()
9293
client.OnNotification(func(notification mcp.JSONRPCNotification) {
93-
notificationNumMutex.Lock()
94-
notificationNum[notification.Method] += 1
95-
notificationNumMutex.Unlock()
94+
notificationNum.Increment(notification.Method)
9695
})
9796

9897
ctx := context.Background()
@@ -120,19 +119,21 @@ func TestHTTPClient(t *testing.T) {
120119
t.Errorf("Expected 1 content item, got %d", len(result.Content))
121120
}
122121

123-
if n := notificationNum["notifications/progress"]; n != 1 {
122+
if n := notificationNum.Get("notifications/progress"); n != 1 {
124123
t.Errorf("Expected 1 progross notification item, got %d", n)
125124
}
126-
if n := len(notificationNum); n != 1 {
125+
if n := notificationNum.Len(); n != 1 {
127126
t.Errorf("Expected 1 type of notification, got %d", n)
128127
}
129128
})
130129

131-
t.Run("Cannot receive global notifications from server by default", func(t *testing.T) {
130+
t.Run("Can not receive global notifications from server by default", func(t *testing.T) {
132131
addServerToolfunc("hello1")
133132
time.Sleep(time.Millisecond * 50)
134-
if n := notificationNum["hello1"]; n != 0 {
135-
t.Errorf("Expected 0 notification item, got %d", n)
133+
134+
helloNotifications := notificationNum.Get("hello1")
135+
if helloNotifications != 0 {
136+
t.Errorf("Expected 0 notification item, got %d", helloNotifications)
136137
}
137138
})
138139

@@ -146,13 +147,9 @@ func TestHTTPClient(t *testing.T) {
146147
}
147148
defer client.Close()
148149

149-
notificationNum := make(map[string]int)
150-
notificationNumMutex := sync.Mutex{}
150+
notificationNum := NewSafeMap()
151151
client.OnNotification(func(notification mcp.JSONRPCNotification) {
152-
notificationNumMutex.Lock()
153-
println(notification.Method)
154-
notificationNum[notification.Method] += 1
155-
notificationNumMutex.Unlock()
152+
notificationNum.Increment(notification.Method)
156153
})
157154

158155
ctx := context.Background()
@@ -176,20 +173,51 @@ func TestHTTPClient(t *testing.T) {
176173
t.Fatalf("CallTool failed: %v", err)
177174
}
178175

179-
if n := notificationNum["notifications/progress"]; n != 1 {
176+
if n := notificationNum.Get("notifications/progress"); n != 1 {
180177
t.Errorf("Expected 1 progross notification item, got %d", n)
181178
}
182-
if n := len(notificationNum); n != 1 {
179+
if n := notificationNum.Len(); n != 1 {
183180
t.Errorf("Expected 1 type of notification, got %d", n)
184181
}
185182

186183
// can receive global notification
187184
addServerToolfunc("hello2")
188185
time.Sleep(time.Millisecond * 50) // wait for the notification to be sent as upper action is async
189-
if n := notificationNum["notifications/tools/list_changed"]; n != 1 {
186+
187+
n := notificationNum.Get("notifications/tools/list_changed")
188+
if n != 1 {
190189
t.Errorf("Expected 1 notification item, got %d, %v", n, notificationNum)
191190
}
192191
})
193192

194193
})
195194
}
195+
196+
type SafeMap struct {
197+
mu sync.RWMutex
198+
data map[string]int
199+
}
200+
201+
func NewSafeMap() *SafeMap {
202+
return &SafeMap{
203+
data: make(map[string]int),
204+
}
205+
}
206+
207+
func (sm *SafeMap) Increment(key string) {
208+
sm.mu.Lock()
209+
defer sm.mu.Unlock()
210+
sm.data[key]++
211+
}
212+
213+
func (sm *SafeMap) Get(key string) int {
214+
sm.mu.RLock()
215+
defer sm.mu.RUnlock()
216+
return sm.data[key]
217+
}
218+
219+
func (sm *SafeMap) Len() int {
220+
sm.mu.RLock()
221+
defer sm.mu.RUnlock()
222+
return len(sm.data)
223+
}

0 commit comments

Comments
 (0)