@@ -666,7 +666,8 @@ func TestSSEServer(t *testing.T) {
666666 t .Fatalf ("Failed to marshal tool request: %v" , err )
667667 }
668668
669- req , err := http .NewRequest (http .MethodPost , messageURL , bytes .NewBuffer (requestBody ))
669+ var req * http.Request
670+ req , err = http .NewRequest (http .MethodPost , messageURL , bytes .NewBuffer (requestBody ))
670671 if err != nil {
671672 t .Fatalf ("Failed to create tool request: %v" , err )
672673 }
@@ -1129,6 +1130,116 @@ func TestSSEServer(t *testing.T) {
11291130 })
11301131 }
11311132 })
1133+
1134+ t .Run ("SessionWithTools implementation" , func (t * testing.T ) {
1135+ // Create hooks to track sessions
1136+ hooks := & Hooks {}
1137+ var registeredSession * sseSession
1138+ hooks .AddOnRegisterSession (func (ctx context.Context , session ClientSession ) {
1139+ if s , ok := session .(* sseSession ); ok {
1140+ registeredSession = s
1141+ }
1142+ })
1143+
1144+ mcpServer := NewMCPServer ("test" , "1.0.0" , WithHooks (hooks ))
1145+ testServer := NewTestServer (mcpServer )
1146+ defer testServer .Close ()
1147+
1148+ // Connect to SSE endpoint
1149+ sseResp , err := http .Get (fmt .Sprintf ("%s/sse" , testServer .URL ))
1150+ if err != nil {
1151+ t .Fatalf ("Failed to connect to SSE endpoint: %v" , err )
1152+ }
1153+ defer sseResp .Body .Close ()
1154+
1155+ // Read the endpoint event to ensure session is established
1156+ _ , err = readSeeEvent (sseResp )
1157+ if err != nil {
1158+ t .Fatalf ("Failed to read SSE response: %v" , err )
1159+ }
1160+
1161+ // Verify we got a session
1162+ if registeredSession == nil {
1163+ t .Fatal ("Session was not registered via hook" )
1164+ }
1165+
1166+ // Test setting and getting tools
1167+ tools := map [string ]ServerTool {
1168+ "test_tool" : {
1169+ Tool : mcp.Tool {
1170+ Name : "test_tool" ,
1171+ Description : "A test tool" ,
1172+ Annotations : mcp.ToolAnnotation {
1173+ Title : "Test Tool" ,
1174+ },
1175+ },
1176+ Handler : func (ctx context.Context , request mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1177+ return mcp .NewToolResultText ("test" ), nil
1178+ },
1179+ },
1180+ }
1181+
1182+ // Test SetSessionTools
1183+ registeredSession .SetSessionTools (tools )
1184+
1185+ // Test GetSessionTools
1186+ retrievedTools := registeredSession .GetSessionTools ()
1187+ if len (retrievedTools ) != 1 {
1188+ t .Errorf ("Expected 1 tool, got %d" , len (retrievedTools ))
1189+ }
1190+ if tool , exists := retrievedTools ["test_tool" ]; ! exists {
1191+ t .Error ("Expected test_tool to exist" )
1192+ } else if tool .Tool .Name != "test_tool" {
1193+ t .Errorf ("Expected tool name test_tool, got %s" , tool .Tool .Name )
1194+ }
1195+
1196+ // Test concurrent access
1197+ var wg sync.WaitGroup
1198+ for i := 0 ; i < 10 ; i ++ {
1199+ wg .Add (2 )
1200+ go func (i int ) {
1201+ defer wg .Done ()
1202+ tools := map [string ]ServerTool {
1203+ fmt .Sprintf ("tool_%d" , i ): {
1204+ Tool : mcp.Tool {
1205+ Name : fmt .Sprintf ("tool_%d" , i ),
1206+ Description : fmt .Sprintf ("Tool %d" , i ),
1207+ Annotations : mcp.ToolAnnotation {
1208+ Title : fmt .Sprintf ("Tool %d" , i ),
1209+ },
1210+ },
1211+ },
1212+ }
1213+ registeredSession .SetSessionTools (tools )
1214+ }(i )
1215+ go func () {
1216+ defer wg .Done ()
1217+ _ = registeredSession .GetSessionTools ()
1218+ }()
1219+ }
1220+ wg .Wait ()
1221+
1222+ // Verify we can still get and set tools after concurrent access
1223+ finalTools := map [string ]ServerTool {
1224+ "final_tool" : {
1225+ Tool : mcp.Tool {
1226+ Name : "final_tool" ,
1227+ Description : "Final Tool" ,
1228+ Annotations : mcp.ToolAnnotation {
1229+ Title : "Final Tool" ,
1230+ },
1231+ },
1232+ },
1233+ }
1234+ registeredSession .SetSessionTools (finalTools )
1235+ retrievedTools = registeredSession .GetSessionTools ()
1236+ if len (retrievedTools ) != 1 {
1237+ t .Errorf ("Expected 1 tool, got %d" , len (retrievedTools ))
1238+ }
1239+ if _ , exists := retrievedTools ["final_tool" ]; ! exists {
1240+ t .Error ("Expected final_tool to exist" )
1241+ }
1242+ })
11321243}
11331244
11341245func readSeeEvent (sseResp * http.Response ) (string , error ) {
0 commit comments