Skip to content

Commit 8d7440b

Browse files
committed
add streamable http transport
1 parent 37ac814 commit 8d7440b

File tree

2 files changed

+795
-0
lines changed

2 files changed

+795
-0
lines changed
Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
package transport
2+
3+
import (
4+
"bufio"
5+
"bytes"
6+
"context"
7+
"encoding/json"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"net/url"
12+
"strings"
13+
"sync"
14+
"sync/atomic"
15+
"time"
16+
17+
"github.com/mark3labs/mcp-go/mcp"
18+
)
19+
20+
type StreamableHTTPCOption func(*StreamableHTTP)
21+
22+
func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
23+
return func(sc *StreamableHTTP) {
24+
sc.headers = headers
25+
}
26+
}
27+
28+
// WithHTTPTimeout sets the whole timeout for the HTTP request and stream.
29+
func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
30+
return func(sc *StreamableHTTP) {
31+
sc.httpClient.Timeout = timeout
32+
}
33+
}
34+
35+
// StreamableHTTP implements the transport.Interface using Streamable HTTP transport.
36+
// https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/transports/#streamable-http
37+
type StreamableHTTP struct {
38+
baseURL *url.URL
39+
httpClient *http.Client
40+
headers map[string]string
41+
42+
sessionID atomic.Value // string
43+
44+
notificationHandler func(mcp.JSONRPCNotification)
45+
notifyMu sync.RWMutex
46+
47+
closed chan struct{}
48+
}
49+
50+
// NewStreamableHTTP creates a new Streamable HTTP transport with the given base URL.
51+
// Returns an error if the URL is invalid.
52+
func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) {
53+
parsedURL, err := url.Parse(baseURL)
54+
if err != nil {
55+
return nil, fmt.Errorf("invalid URL: %w", err)
56+
}
57+
58+
smc := &StreamableHTTP{
59+
baseURL: parsedURL,
60+
httpClient: &http.Client{Timeout: 60 * time.Second},
61+
headers: make(map[string]string),
62+
closed: make(chan struct{}),
63+
}
64+
smc.sessionID.Store("") // set initial value to simplify later usage
65+
66+
for _, opt := range options {
67+
opt(smc)
68+
}
69+
70+
return smc, nil
71+
}
72+
73+
// Start initiates the HTTP connection to the server.
74+
func (c *StreamableHTTP) Start(ctx context.Context) error {
75+
// For Streamable HTTP, we don't need to establish a persistent connection
76+
return nil
77+
}
78+
79+
// Close closes the all the HTTP connections to the server.
80+
func (c *StreamableHTTP) Close() error {
81+
select {
82+
case <-c.closed:
83+
return nil
84+
default:
85+
}
86+
// Cancel all in-flight requests
87+
close(c.closed)
88+
89+
sessionId := c.sessionID.Load().(string)
90+
if sessionId != "" {
91+
c.sessionID.Store("")
92+
93+
// notify server session closed
94+
go func() {
95+
req, err := http.NewRequest(http.MethodDelete, c.baseURL.String(), nil)
96+
if err != nil {
97+
fmt.Printf("failed to create close request\n: %v", err)
98+
return
99+
}
100+
req.Header.Set(headerKeySessionID, sessionId)
101+
res, err := c.httpClient.Do(req)
102+
if err != nil {
103+
fmt.Printf("failed to send close request\n: %v", err)
104+
return
105+
}
106+
res.Body.Close()
107+
}()
108+
}
109+
110+
return nil
111+
}
112+
113+
const (
114+
initializeMethod = "initialize"
115+
headerKeySessionID = "Mcp-Session-Id"
116+
)
117+
118+
// sendRequest sends a JSON-RPC request to the server and waits for a response.
119+
// Returns the raw JSON response message or an error if the request fails.
120+
func (c *StreamableHTTP) SendRequest(
121+
ctx context.Context,
122+
request JSONRPCRequest,
123+
) (*JSONRPCResponse, error) {
124+
125+
// Create a combined context that could be canceled when the client is closed
126+
var cancelRequest context.CancelFunc
127+
ctx, cancelRequest = context.WithCancel(ctx)
128+
defer cancelRequest()
129+
go func() {
130+
select {
131+
case <-c.closed:
132+
cancelRequest()
133+
case <-ctx.Done():
134+
// The original context was canceled, no need to do anything
135+
}
136+
}()
137+
138+
// Marshal request
139+
requestBody, err := json.Marshal(request)
140+
if err != nil {
141+
return nil, fmt.Errorf("failed to marshal request: %w", err)
142+
}
143+
144+
// Create HTTP request
145+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody))
146+
if err != nil {
147+
return nil, fmt.Errorf("failed to create request: %w", err)
148+
}
149+
150+
// Set headers
151+
req.Header.Set("Content-Type", "application/json")
152+
req.Header.Set("Accept", "application/json, text/event-stream")
153+
sessionID := c.sessionID.Load()
154+
if sessionID != "" {
155+
req.Header.Set(headerKeySessionID, sessionID.(string))
156+
}
157+
for k, v := range c.headers {
158+
req.Header.Set(k, v)
159+
}
160+
161+
// Send request
162+
resp, err := c.httpClient.Do(req)
163+
if err != nil {
164+
return nil, fmt.Errorf("failed to send request: %w", err)
165+
}
166+
defer resp.Body.Close()
167+
168+
// Check if we got an error response
169+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
170+
// handle session closed
171+
if resp.StatusCode == http.StatusNotFound {
172+
c.sessionID.CompareAndSwap(sessionID, "")
173+
return nil, fmt.Errorf("session terminated (404)")
174+
}
175+
176+
// handle error response
177+
var errResponse JSONRPCResponse
178+
body, _ := io.ReadAll(resp.Body)
179+
if err := json.Unmarshal(body, &errResponse); err == nil {
180+
return &errResponse, nil
181+
}
182+
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
183+
}
184+
185+
if request.Method == initializeMethod {
186+
// saved the received session ID in the response
187+
// empty session ID is allowed
188+
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
189+
c.sessionID.Store(sessionID)
190+
}
191+
}
192+
193+
// Handle different response types
194+
switch resp.Header.Get("Content-Type") {
195+
case "application/json":
196+
// Single response
197+
var response JSONRPCResponse
198+
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
199+
return nil, fmt.Errorf("failed to decode response: %w", err)
200+
}
201+
202+
// should not be a notification
203+
if response.ID == nil {
204+
return nil, fmt.Errorf("response should contain RPC id: %v", response)
205+
}
206+
207+
return &response, nil
208+
209+
case "text/event-stream":
210+
// Server is using SSE for streaming responses
211+
return c.handleSSEResponse(ctx, resp.Body)
212+
213+
default:
214+
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
215+
}
216+
}
217+
218+
// handleSSEResponse processes an SSE stream for a specific request.
219+
// It returns the final result for the request once received, or an error.
220+
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
221+
222+
// Create a channel for this specific request
223+
responseChan := make(chan *JSONRPCResponse, 1)
224+
defer close(responseChan)
225+
226+
ctx, cancel := context.WithCancel(ctx)
227+
defer cancel()
228+
229+
// Start a goroutine to process the SSE stream
230+
go c.readSSE(ctx, reader, func(event, data string) {
231+
232+
// unsupported
233+
// - batching
234+
// - server -> client request
235+
236+
var message JSONRPCResponse
237+
if err := json.Unmarshal([]byte(data), &message); err != nil {
238+
fmt.Printf("failed to unmarshal message: %v", err)
239+
return
240+
}
241+
242+
// Handle notification
243+
if message.ID == nil {
244+
var notification mcp.JSONRPCNotification
245+
if err := json.Unmarshal([]byte(data), &notification); err != nil {
246+
fmt.Printf("failed to unmarshal notification: %v", err)
247+
return
248+
}
249+
c.notifyMu.RLock()
250+
if c.notificationHandler != nil {
251+
c.notificationHandler(notification)
252+
}
253+
c.notifyMu.RUnlock()
254+
return
255+
}
256+
257+
responseChan <- &message
258+
})
259+
260+
// Wait for the response or context cancellation
261+
select {
262+
case response := <-responseChan:
263+
if response == nil {
264+
return nil, fmt.Errorf("unexpected nil response")
265+
}
266+
return response, nil
267+
case <-ctx.Done():
268+
return nil, ctx.Err()
269+
}
270+
}
271+
272+
// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
273+
// It will end when the reader is closed (or the context is done).
274+
func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
275+
defer reader.Close()
276+
277+
br := bufio.NewReader(reader)
278+
var event, data string
279+
280+
for {
281+
select {
282+
case <-ctx.Done():
283+
return
284+
default:
285+
line, err := br.ReadString('\n')
286+
if err != nil {
287+
if err == io.EOF {
288+
// Process any pending event before exit
289+
if event != "" && data != "" {
290+
handler(event, data)
291+
}
292+
return
293+
}
294+
select {
295+
case <-ctx.Done():
296+
return
297+
default:
298+
fmt.Printf("SSE stream error: %v\n", err)
299+
return
300+
}
301+
}
302+
303+
// Remove only newline markers
304+
line = strings.TrimRight(line, "\r\n")
305+
if line == "" {
306+
// Empty line means end of event
307+
if event != "" && data != "" {
308+
handler(event, data)
309+
event = ""
310+
data = ""
311+
}
312+
continue
313+
}
314+
315+
if strings.HasPrefix(line, "event:") {
316+
event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
317+
} else if strings.HasPrefix(line, "data:") {
318+
data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
319+
}
320+
}
321+
}
322+
}
323+
324+
func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
325+
326+
// Marshal request
327+
requestBody, err := json.Marshal(notification)
328+
if err != nil {
329+
return fmt.Errorf("failed to marshal notification: %w", err)
330+
}
331+
332+
// Create HTTP request
333+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody))
334+
if err != nil {
335+
return fmt.Errorf("failed to create request: %w", err)
336+
}
337+
338+
// Set headers
339+
req.Header.Set("Content-Type", "application/json")
340+
if sessionID := c.sessionID.Load(); sessionID != "" {
341+
req.Header.Set(headerKeySessionID, sessionID.(string))
342+
}
343+
for k, v := range c.headers {
344+
req.Header.Set(k, v)
345+
}
346+
347+
// Send request
348+
resp, err := c.httpClient.Do(req)
349+
if err != nil {
350+
return fmt.Errorf("failed to send request: %w", err)
351+
}
352+
defer resp.Body.Close()
353+
354+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
355+
body, _ := io.ReadAll(resp.Body)
356+
return fmt.Errorf(
357+
"notification failed with status %d: %s",
358+
resp.StatusCode,
359+
body,
360+
)
361+
}
362+
363+
return nil
364+
}
365+
366+
func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
367+
c.notifyMu.Lock()
368+
defer c.notifyMu.Unlock()
369+
c.notificationHandler = handler
370+
}

0 commit comments

Comments
 (0)