@@ -12,6 +12,8 @@ import (
1212	"github.com/mark3labs/mcp-go/server" 
1313)
1414
15+ // SafeMap is a thread-safe map wrapper 
16+ 
1517func  TestHTTPClient (t  * testing.T ) {
1618	hooks  :=  & server.Hooks {}
1719	hooks .AddAfterCallTool (func (ctx  context.Context , id  any , message  * mcp.CallToolRequest , result  * mcp.CallToolResult ) {
@@ -87,12 +89,9 @@ func TestHTTPClient(t *testing.T) {
8789			return 
8890		}
8991
90- 		notificationNum  :=  make (map [string ]int )
91- 		notificationNumMutex  :=  sync.Mutex {}
92+ 		notificationNum  :=  NewSafeMap ()
9293		client .OnNotification (func (notification  mcp.JSONRPCNotification ) {
93- 			notificationNumMutex .Lock ()
94- 			notificationNum [notification .Method ] +=  1 
95- 			notificationNumMutex .Unlock ()
94+ 			notificationNum .Increment (notification .Method )
9695		})
9796
9897		ctx  :=  context .Background ()
@@ -120,19 +119,21 @@ func TestHTTPClient(t *testing.T) {
120119				t .Errorf ("Expected 1 content item, got %d" , len (result .Content ))
121120			}
122121
123- 			if  n  :=  notificationNum [ "notifications/progress" ] ; n  !=  1  {
122+ 			if  n  :=  notificationNum . Get ( "notifications/progress" ) ; n  !=  1  {
124123				t .Errorf ("Expected 1 progross notification item, got %d" , n )
125124			}
126- 			if  n  :=  len ( notificationNum ); n  !=  1  {
125+ 			if  n  :=  notificationNum . Len ( ); n  !=  1  {
127126				t .Errorf ("Expected 1 type of notification, got %d" , n )
128127			}
129128		})
130129
131- 		t .Run ("Cannot  receive global notifications from server by default" , func (t  * testing.T ) {
130+ 		t .Run ("Can not  receive global notifications from server by default" , func (t  * testing.T ) {
132131			addServerToolfunc ("hello1" )
133132			time .Sleep (time .Millisecond  *  50 )
134- 			if  n  :=  notificationNum ["hello1" ]; n  !=  0  {
135- 				t .Errorf ("Expected 0 notification item, got %d" , n )
133+ 
134+ 			helloNotifications  :=  notificationNum .Get ("hello1" )
135+ 			if  helloNotifications  !=  0  {
136+ 				t .Errorf ("Expected 0 notification item, got %d" , helloNotifications )
136137			}
137138		})
138139
@@ -146,13 +147,9 @@ func TestHTTPClient(t *testing.T) {
146147			}
147148			defer  client .Close ()
148149
149- 			notificationNum  :=  make (map [string ]int )
150- 			notificationNumMutex  :=  sync.Mutex {}
150+ 			notificationNum  :=  NewSafeMap ()
151151			client .OnNotification (func (notification  mcp.JSONRPCNotification ) {
152- 				notificationNumMutex .Lock ()
153- 				println (notification .Method )
154- 				notificationNum [notification .Method ] +=  1 
155- 				notificationNumMutex .Unlock ()
152+ 				notificationNum .Increment (notification .Method )
156153			})
157154
158155			ctx  :=  context .Background ()
@@ -176,20 +173,51 @@ func TestHTTPClient(t *testing.T) {
176173				t .Fatalf ("CallTool failed: %v" , err )
177174			}
178175
179- 			if  n  :=  notificationNum [ "notifications/progress" ] ; n  !=  1  {
176+ 			if  n  :=  notificationNum . Get ( "notifications/progress" ) ; n  !=  1  {
180177				t .Errorf ("Expected 1 progross notification item, got %d" , n )
181178			}
182- 			if  n  :=  len ( notificationNum ); n  !=  1  {
179+ 			if  n  :=  notificationNum . Len ( ); n  !=  1  {
183180				t .Errorf ("Expected 1 type of notification, got %d" , n )
184181			}
185182
186183			// can receive global notification 
187184			addServerToolfunc ("hello2" )
188185			time .Sleep (time .Millisecond  *  50 ) // wait for the notification to be sent as upper action is async 
189- 			if  n  :=  notificationNum ["notifications/tools/list_changed" ]; n  !=  1  {
186+ 
187+ 			n  :=  notificationNum .Get ("notifications/tools/list_changed" )
188+ 			if  n  !=  1  {
190189				t .Errorf ("Expected 1 notification item, got %d, %v" , n , notificationNum )
191190			}
192191		})
193192
194193	})
195194}
195+ 
196+ type  SafeMap  struct  {
197+ 	mu    sync.RWMutex 
198+ 	data  map [string ]int 
199+ }
200+ 
201+ func  NewSafeMap () * SafeMap  {
202+ 	return  & SafeMap {
203+ 		data : make (map [string ]int ),
204+ 	}
205+ }
206+ 
207+ func  (sm  * SafeMap ) Increment (key  string ) {
208+ 	sm .mu .Lock ()
209+ 	defer  sm .mu .Unlock ()
210+ 	sm .data [key ]++ 
211+ }
212+ 
213+ func  (sm  * SafeMap ) Get (key  string ) int  {
214+ 	sm .mu .RLock ()
215+ 	defer  sm .mu .RUnlock ()
216+ 	return  sm .data [key ]
217+ }
218+ 
219+ func  (sm  * SafeMap ) Len () int  {
220+ 	sm .mu .RLock ()
221+ 	defer  sm .mu .RUnlock ()
222+ 	return  len (sm .data )
223+ }
0 commit comments