Skip to content

Commit 35fc389

Browse files
committed
Format
1 parent 5298f64 commit 35fc389

File tree

9 files changed

+71
-71
lines changed

9 files changed

+71
-71
lines changed

client/transport/sse.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ type SSE struct {
3636
headerFunc HTTPHeaderFunc
3737
logger util.Logger
3838

39-
started atomic.Bool
40-
closed atomic.Bool
41-
cancelSSEStream context.CancelFunc
42-
protocolVersion atomic.Value // string
43-
onConnectionLost func(error)
44-
connectionLostMu sync.RWMutex
39+
started atomic.Bool
40+
closed atomic.Bool
41+
cancelSSEStream context.CancelFunc
42+
protocolVersion atomic.Value // string
43+
onConnectionLost func(error)
44+
connectionLostMu sync.RWMutex
4545

4646
// OAuth support
4747
oauthHandler *OAuthHandler
@@ -220,7 +220,7 @@ func (c *SSE) readSSE(reader io.ReadCloser) {
220220
c.connectionLostMu.RLock()
221221
handler := c.onConnectionLost
222222
c.connectionLostMu.RUnlock()
223-
223+
224224
if handler != nil {
225225
// This is not actually an error - HTTP2 idle timeout disconnection
226226
handler(err)

client/transport/streamable_http.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
605605
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
606606
err := c.createGETConnectionToServer(connectCtx)
607607
cancel()
608-
608+
609609
if errors.Is(err, ErrGetMethodNotAllowed) {
610610
// server does not support listening
611611
c.logger.Errorf("server does not support listening")
@@ -621,7 +621,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
621621
if err != nil {
622622
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
623623
}
624-
624+
625625
// Use context-aware sleep
626626
select {
627627
case <-time.After(retryInterval):
@@ -704,15 +704,15 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON
704704
// Create a new context with timeout for request handling, respecting parent context
705705
requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
706706
defer cancel()
707-
707+
708708
response, err := handler(requestCtx, request)
709709
if err != nil {
710710
c.logger.Errorf("error handling request %s: %v", request.Method, err)
711-
711+
712712
// Determine appropriate JSON-RPC error code based on error type
713713
var errorCode int
714714
var errorMessage string
715-
715+
716716
// Check for specific sampling-related errors
717717
if errors.Is(err, context.Canceled) {
718718
errorCode = -32800 // Request cancelled
@@ -731,7 +731,7 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON
731731
errorMessage = err.Error()
732732
}
733733
}
734-
734+
735735
// Send error response
736736
errorResponse := &JSONRPCResponse{
737737
JSONRPC: "2.0",

client/transport/streamable_http_sampling_test.go

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,27 @@ import (
1616

1717
// TestStreamableHTTP_SamplingFlow tests the complete sampling flow with HTTP transport
1818
func TestStreamableHTTP_SamplingFlow(t *testing.T) {
19-
// Create simple test server
19+
// Create simple test server
2020
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2121
// Just respond OK to any requests
2222
w.WriteHeader(http.StatusOK)
2323
}))
2424
defer server.Close()
25-
25+
2626
// Create HTTP client transport
2727
client, err := NewStreamableHTTP(server.URL)
2828
if err != nil {
2929
t.Fatalf("Failed to create client: %v", err)
3030
}
3131
defer client.Close()
32-
32+
3333
// Set up sampling request handler
3434
var handledRequest *JSONRPCRequest
3535
handlerCalled := make(chan struct{})
3636
client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
3737
handledRequest = &request
3838
close(handlerCalled)
39-
39+
4040
// Simulate sampling handler response
4141
result := map[string]any{
4242
"role": "assistant",
@@ -47,25 +47,25 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) {
4747
"model": "test-model",
4848
"stopReason": "stop_sequence",
4949
}
50-
50+
5151
resultBytes, _ := json.Marshal(result)
52-
52+
5353
return &JSONRPCResponse{
5454
JSONRPC: "2.0",
5555
ID: request.ID,
5656
Result: resultBytes,
5757
}, nil
5858
})
59-
59+
6060
// Start the client
6161
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
6262
defer cancel()
63-
63+
6464
err = client.Start(ctx)
6565
if err != nil {
6666
t.Fatalf("Failed to start client: %v", err)
6767
}
68-
68+
6969
// Test direct request handling (simulating a sampling request)
7070
samplingRequest := JSONRPCRequest{
7171
JSONRPC: "2.0",
@@ -83,23 +83,23 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) {
8383
},
8484
},
8585
}
86-
86+
8787
// Directly test request handling
8888
client.handleIncomingRequest(ctx, samplingRequest)
89-
89+
9090
// Wait for handler to be called
9191
select {
9292
case <-handlerCalled:
9393
// Handler was called
9494
case <-time.After(1 * time.Second):
9595
t.Fatal("Handler was not called within timeout")
9696
}
97-
97+
9898
// Verify the request was handled
9999
if handledRequest == nil {
100100
t.Fatal("Sampling request was not handled")
101101
}
102-
102+
103103
if handledRequest.Method != string(mcp.MethodSamplingCreateMessage) {
104104
t.Errorf("Expected method %s, got %s", mcp.MethodSamplingCreateMessage, handledRequest.Method)
105105
}
@@ -109,7 +109,7 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) {
109109
func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) {
110110
var errorHandled sync.WaitGroup
111111
errorHandled.Add(1)
112-
112+
113113
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
114114
if r.Method == http.MethodPost {
115115
var body map[string]any
@@ -118,7 +118,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) {
118118
w.WriteHeader(http.StatusOK)
119119
return
120120
}
121-
121+
122122
// Check if this is an error response
123123
if errorField, ok := body["error"]; ok {
124124
errorMap := errorField.(map[string]any)
@@ -132,36 +132,36 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) {
132132
w.WriteHeader(http.StatusOK)
133133
}))
134134
defer server.Close()
135-
135+
136136
client, err := NewStreamableHTTP(server.URL)
137137
if err != nil {
138138
t.Fatalf("Failed to create client: %v", err)
139139
}
140140
defer client.Close()
141-
141+
142142
// Set up request handler that returns an error
143143
client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
144144
return nil, fmt.Errorf("sampling failed")
145145
})
146-
146+
147147
// Start the client
148148
ctx := context.Background()
149149
err = client.Start(ctx)
150150
if err != nil {
151151
t.Fatalf("Failed to start client: %v", err)
152152
}
153-
153+
154154
// Simulate incoming sampling request
155155
samplingRequest := JSONRPCRequest{
156156
JSONRPC: "2.0",
157157
ID: mcp.NewRequestId(1),
158158
Method: string(mcp.MethodSamplingCreateMessage),
159159
Params: map[string]any{},
160160
}
161-
161+
162162
// This should trigger error handling
163163
client.handleIncomingRequest(ctx, samplingRequest)
164-
164+
165165
// Wait for error to be handled
166166
errorHandled.Wait()
167167
}
@@ -170,7 +170,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) {
170170
func TestStreamableHTTP_NoSamplingHandler(t *testing.T) {
171171
var errorReceived bool
172172
errorReceivedChan := make(chan struct{})
173-
173+
174174
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
175175
if r.Method == http.MethodPost {
176176
var body map[string]any
@@ -179,12 +179,12 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) {
179179
w.WriteHeader(http.StatusOK)
180180
return
181181
}
182-
182+
183183
// Check if this is an error response with method not found
184184
if errorField, ok := body["error"]; ok {
185185
errorMap := errorField.(map[string]any)
186186
if code, ok := errorMap["code"].(float64); ok && code == -32601 {
187-
if message, ok := errorMap["message"].(string); ok &&
187+
if message, ok := errorMap["message"].(string); ok &&
188188
strings.Contains(message, "no handler configured") {
189189
errorReceived = true
190190
close(errorReceivedChan)
@@ -195,40 +195,40 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) {
195195
w.WriteHeader(http.StatusOK)
196196
}))
197197
defer server.Close()
198-
198+
199199
client, err := NewStreamableHTTP(server.URL)
200200
if err != nil {
201201
t.Fatalf("Failed to create client: %v", err)
202202
}
203203
defer client.Close()
204-
204+
205205
// Don't set any request handler
206-
206+
207207
ctx := context.Background()
208208
err = client.Start(ctx)
209209
if err != nil {
210210
t.Fatalf("Failed to start client: %v", err)
211211
}
212-
212+
213213
// Simulate incoming sampling request
214214
samplingRequest := JSONRPCRequest{
215215
JSONRPC: "2.0",
216216
ID: mcp.NewRequestId(1),
217217
Method: string(mcp.MethodSamplingCreateMessage),
218218
Params: map[string]any{},
219219
}
220-
220+
221221
// This should trigger "method not found" error
222222
client.handleIncomingRequest(ctx, samplingRequest)
223-
223+
224224
// Wait for error to be received
225225
select {
226226
case <-errorReceivedChan:
227227
// Error was received
228228
case <-time.After(1 * time.Second):
229229
t.Fatal("Method not found error was not received within timeout")
230230
}
231-
231+
232232
if !errorReceived {
233233
t.Error("Expected method not found error, but didn't receive it")
234234
}
@@ -241,13 +241,13 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) {
241241
t.Fatalf("Failed to create client: %v", err)
242242
}
243243
defer client.Close()
244-
244+
245245
// Verify it implements BidirectionalInterface
246246
_, ok := any(client).(BidirectionalInterface)
247247
if !ok {
248248
t.Error("StreamableHTTP should implement BidirectionalInterface")
249249
}
250-
250+
251251
// Test SetRequestHandler
252252
handlerSet := false
253253
handlerSetChan := make(chan struct{})
@@ -256,23 +256,23 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) {
256256
close(handlerSetChan)
257257
return nil, nil
258258
})
259-
259+
260260
// Verify handler was set by triggering it
261261
ctx := context.Background()
262262
client.handleIncomingRequest(ctx, JSONRPCRequest{
263263
JSONRPC: "2.0",
264264
ID: mcp.NewRequestId(1),
265265
Method: "test",
266266
})
267-
267+
268268
// Wait for handler to be called
269269
select {
270270
case <-handlerSetChan:
271271
// Handler was called
272272
case <-time.After(1 * time.Second):
273273
t.Fatal("Handler was not called within timeout")
274274
}
275-
275+
276276
if !handlerSet {
277277
t.Error("Request handler was not properly set or called")
278278
}
@@ -315,16 +315,16 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {
315315
// Track which requests have been received and their completion order
316316
var requestOrder []int
317317
var orderMutex sync.Mutex
318-
318+
319319
// Set up request handler that simulates different processing times
320320
client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
321321
// Extract request ID to determine processing time
322322
requestIDValue := request.ID.Value()
323-
323+
324324
var delay time.Duration
325325
var responseText string
326326
var requestNum int
327-
327+
328328
// First request (ID 1) takes longer, second request (ID 2) completes faster
329329
if requestIDValue == int64(1) {
330330
delay = 100 * time.Millisecond
@@ -341,7 +341,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {
341341

342342
// Simulate processing time
343343
time.Sleep(delay)
344-
344+
345345
// Record completion order
346346
orderMutex.Lock()
347347
requestOrder = append(requestOrder, requestNum)
@@ -428,7 +428,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {
428428
// Verify completion order: request 2 should complete first
429429
orderMutex.Lock()
430430
defer orderMutex.Unlock()
431-
431+
432432
if len(requestOrder) != 2 {
433433
t.Fatalf("Expected 2 completed requests, got %d", len(requestOrder))
434434
}
@@ -493,4 +493,4 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {
493493
}
494494
}
495495
}
496-
}
496+
}

0 commit comments

Comments
 (0)