Skip to content

Commit fb3c8e4

Browse files
committed
update implmentaion
1 parent 4419e2a commit fb3c8e4

File tree

1 file changed

+39
-23
lines changed

1 file changed

+39
-23
lines changed

client/transport/streamable_http.go

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,28 @@ func (c *StreamableHTTP) Close() error {
8686
}
8787
// Cancel all in-flight requests
8888
close(c.closed)
89-
c.sessionID.Store("")
89+
90+
sessionId := c.sessionID.Load().(string)
91+
if sessionId != "" {
92+
c.sessionID.Store("")
93+
94+
// notify server session closed
95+
go func() {
96+
req, err := http.NewRequest(http.MethodDelete, c.baseURL.String(), nil)
97+
if err != nil {
98+
fmt.Printf("failed to create close request\n: %v", err)
99+
return
100+
}
101+
req.Header.Set(headerKeySessionID, sessionId)
102+
res, err := c.httpClient.Do(req)
103+
if err != nil {
104+
fmt.Printf("failed to send close request\n: %v", err)
105+
return
106+
}
107+
res.Body.Close()
108+
}()
109+
}
110+
90111
return nil
91112
}
92113

@@ -102,10 +123,6 @@ func (c *StreamableHTTP) SendRequest(
102123
request JSONRPCRequest,
103124
) (*JSONRPCResponse, error) {
104125

105-
if request.Method != initializeMethod && c.sessionID.Load() == nil {
106-
return nil, fmt.Errorf("no session ID. please call initialize first")
107-
}
108-
109126
// Create a combined context that could be canceled when the client is closed
110127
var cancelRequest context.CancelFunc
111128
ctx, cancelRequest = context.WithCancel(ctx)
@@ -120,6 +137,7 @@ func (c *StreamableHTTP) SendRequest(
120137
}()
121138

122139
id := c.requestID.Add(1)
140+
request.ID = id
123141

124142
// Marshal request
125143
requestBody, err := json.Marshal(request)
@@ -136,8 +154,9 @@ func (c *StreamableHTTP) SendRequest(
136154
// Set headers
137155
req.Header.Set("Content-Type", "application/json")
138156
req.Header.Set("Accept", "application/json, text/event-stream")
139-
if v := c.sessionID.Load(); v != "" {
140-
req.Header.Set(headerKeySessionID, v.(string))
157+
sessionID := c.sessionID.Load()
158+
if sessionID != "" {
159+
req.Header.Set(headerKeySessionID, sessionID.(string))
141160
}
142161
for k, v := range c.headers {
143162
req.Header.Set(k, v)
@@ -154,10 +173,8 @@ func (c *StreamableHTTP) SendRequest(
154173
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
155174
// handle session closed
156175
if resp.StatusCode == http.StatusNotFound {
157-
sessionID := req.Header.Get(headerKeySessionID)
158-
if sessionID != "" && c.sessionID.CompareAndSwap(sessionID, "") {
159-
return nil, fmt.Errorf("session ID not found (Session may be closed). initialize needs to be called again")
160-
}
176+
c.sessionID.CompareAndSwap(sessionID, "")
177+
return nil, fmt.Errorf("session terminated (404)")
161178
}
162179

163180
// handle error response
@@ -170,11 +187,10 @@ func (c *StreamableHTTP) SendRequest(
170187
}
171188

172189
if request.Method == initializeMethod {
173-
// Check if we got a session ID in the response
190+
// saved the received session ID in the response
191+
// empty session ID is allowed
174192
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
175193
c.sessionID.Store(sessionID)
176-
} else {
177-
return nil, fmt.Errorf("invalid response: initialize request should return a session ID")
178194
}
179195
}
180196

@@ -196,7 +212,7 @@ func (c *StreamableHTTP) SendRequest(
196212

197213
case "text/event-stream":
198214
// Server is using SSE for streaming responses
199-
return c.handleSSEResponse(ctx, resp.Body, id)
215+
return c.handleSSEResponse(ctx, resp.Body)
200216

201217
default:
202218
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
@@ -205,7 +221,7 @@ func (c *StreamableHTTP) SendRequest(
205221

206222
// handleSSEResponse processes an SSE stream for a specific request.
207223
// It returns the final result for the request once received, or an error.
208-
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, requestID int64) (*JSONRPCResponse, error) {
224+
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
209225

210226
// Create a channel for this specific request
211227
responseChan := make(chan *JSONRPCResponse, 1)
@@ -217,7 +233,10 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
217233
// Start a goroutine to process the SSE stream
218234
go c.readSSE(ctx, reader, func(event, data string) {
219235

220-
// (batch not supported yet)
236+
// unsupported
237+
// - batching
238+
// - server -> client request
239+
221240
var message JSONRPCResponse
222241
if err := json.Unmarshal([]byte(data), &message); err != nil {
223242
fmt.Printf("failed to unmarshal message: %v", err)
@@ -308,11 +327,6 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
308327

309328
func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
310329

311-
sessionID := c.sessionID.Load()
312-
if sessionID == nil {
313-
return fmt.Errorf("no session ID")
314-
}
315-
316330
// Marshal request
317331
requestBody, err := json.Marshal(notification)
318332
if err != nil {
@@ -327,7 +341,9 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
327341

328342
// Set headers
329343
req.Header.Set("Content-Type", "application/json")
330-
req.Header.Set(headerKeySessionID, sessionID.(string))
344+
if sessionID := c.sessionID.Load(); sessionID != "" {
345+
req.Header.Set(headerKeySessionID, sessionID.(string))
346+
}
331347
for k, v := range c.headers {
332348
req.Header.Set(k, v)
333349
}

0 commit comments

Comments
 (0)