-
Couldn't load subscription status.
- Fork 708
feat: client-side streamable-http transport supports continuously listening #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
472f442
292bcea
b7cf1a3
6449b15
6a05fc6
0a1a9e9
6435882
8d3f236
c0f4403
cc540bb
42ba0ff
928d9ea
21316a0
b6ca548
4313aa1
1f5efb5
a6ad665
50f9c47
32f36b9
f8b7dce
c706c93
c222bab
7dc57e6
4e9411d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ import ( | |
| "bytes" | ||
| "context" | ||
| "encoding/json" | ||
| "errors" | ||
| "fmt" | ||
| "io" | ||
| "mime" | ||
|
|
@@ -16,10 +17,24 @@ import ( | |
| "time" | ||
|
|
||
| "github.com/mark3labs/mcp-go/mcp" | ||
| "github.com/mark3labs/mcp-go/util" | ||
| ) | ||
|
|
||
| type StreamableHTTPCOption func(*StreamableHTTP) | ||
|
|
||
| // WithContinuousListening enables receiving server-to-client notifications when no request is in flight. | ||
| // In particular, if you want to receive global notifications from the server (like ToolListChangedNotification), | ||
| // you should enable this option. | ||
| // | ||
| // It will establish a standalone long-live GET HTTP connection to the server. | ||
| // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server | ||
| // NOTICE: Even enabled, the server may not support this feature. | ||
| func WithContinuousListening() StreamableHTTPCOption { | ||
| return func(sc *StreamableHTTP) { | ||
| sc.getListeningEnabled = true | ||
| } | ||
| } | ||
|
|
||
| func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption { | ||
| return func(sc *StreamableHTTP) { | ||
| sc.headers = headers | ||
|
|
@@ -39,6 +54,12 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { | |
| } | ||
| } | ||
|
|
||
| func WithLogger(logger util.Logger) StreamableHTTPCOption { | ||
| return func(sc *StreamableHTTP) { | ||
| sc.logger = logger | ||
| } | ||
| } | ||
|
|
||
| // StreamableHTTP implements Streamable HTTP transport. | ||
| // | ||
| // It transmits JSON-RPC messages over individual HTTP requests. One message per request. | ||
|
|
@@ -49,18 +70,19 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { | |
| // | ||
| // The current implementation does not support the following features: | ||
| // - batching | ||
| // - continuously listening for server notifications when no request is in flight | ||
| // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) | ||
| // - resuming stream | ||
| // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) | ||
| // - server -> client request | ||
| type StreamableHTTP struct { | ||
| baseURL *url.URL | ||
| httpClient *http.Client | ||
| headers map[string]string | ||
| headerFunc HTTPHeaderFunc | ||
| baseURL *url.URL | ||
| httpClient *http.Client | ||
| headers map[string]string | ||
| headerFunc HTTPHeaderFunc | ||
| logger util.Logger | ||
| getListeningEnabled bool | ||
|
|
||
| sessionID atomic.Value // string | ||
| initialized chan struct{} | ||
| sessionID atomic.Value // string | ||
|
|
||
| notificationHandler func(mcp.JSONRPCNotification) | ||
| notifyMu sync.RWMutex | ||
|
|
@@ -77,10 +99,12 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea | |
| } | ||
|
|
||
| smc := &StreamableHTTP{ | ||
| baseURL: parsedURL, | ||
| httpClient: &http.Client{}, | ||
| headers: make(map[string]string), | ||
| closed: make(chan struct{}), | ||
| baseURL: parsedURL, | ||
| httpClient: &http.Client{}, | ||
| headers: make(map[string]string), | ||
| closed: make(chan struct{}), | ||
| logger: util.DefaultLogger(), | ||
| initialized: make(chan struct{}), | ||
| } | ||
| smc.sessionID.Store("") // set initial value to simplify later usage | ||
|
|
||
|
|
@@ -93,7 +117,14 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea | |
|
|
||
| // Start initiates the HTTP connection to the server. | ||
| func (c *StreamableHTTP) Start(ctx context.Context) error { | ||
| // For Streamable HTTP, we don't need to establish a persistent connection | ||
| // For Streamable HTTP, we don't need to establish a persistent connection by default | ||
| if c.getListeningEnabled { | ||
| go func() { | ||
| <-c.initialized | ||
| c.listenForever() | ||
| }() | ||
| } | ||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return nil | ||
| } | ||
|
|
||
|
|
@@ -144,61 +175,20 @@ func (c *StreamableHTTP) SendRequest( | |
| request JSONRPCRequest, | ||
| ) (*JSONRPCResponse, error) { | ||
|
|
||
| // Create a combined context that could be canceled when the client is closed | ||
| newCtx, cancel := context.WithCancel(ctx) | ||
| defer cancel() | ||
| go func() { | ||
| select { | ||
| case <-c.closed: | ||
| cancel() | ||
| case <-newCtx.Done(): | ||
| // The original context was canceled, no need to do anything | ||
| } | ||
| }() | ||
| ctx = newCtx | ||
|
|
||
| // Marshal request | ||
| requestBody, err := json.Marshal(request) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to marshal request: %w", err) | ||
| } | ||
|
|
||
| // Create HTTP request | ||
| req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to create request: %w", err) | ||
| } | ||
|
|
||
| // Set headers | ||
| req.Header.Set("Content-Type", "application/json") | ||
| req.Header.Set("Accept", "application/json, text/event-stream") | ||
| sessionID := c.sessionID.Load() | ||
| if sessionID != "" { | ||
| req.Header.Set(headerKeySessionID, sessionID.(string)) | ||
| } | ||
| for k, v := range c.headers { | ||
| req.Header.Set(k, v) | ||
| } | ||
| if c.headerFunc != nil { | ||
| for k, v := range c.headerFunc(ctx) { | ||
| req.Header.Set(k, v) | ||
| } | ||
| } | ||
|
|
||
| // Send request | ||
| resp, err := c.httpClient.Do(req) | ||
| resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to send request: %w", err) | ||
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| // Check if we got an error response | ||
| if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { | ||
| // handle session closed | ||
| if resp.StatusCode == http.StatusNotFound { | ||
| c.sessionID.CompareAndSwap(sessionID, "") | ||
| return nil, fmt.Errorf("session terminated (404). need to re-initialize") | ||
| } | ||
|
|
||
| // handle error response | ||
| var errResponse JSONRPCResponse | ||
|
|
@@ -215,6 +205,8 @@ func (c *StreamableHTTP) SendRequest( | |
| if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { | ||
| c.sessionID.Store(sessionID) | ||
| } | ||
|
|
||
| close(c.initialized) | ||
leavez marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| // Handle different response types | ||
|
|
@@ -243,6 +235,62 @@ func (c *StreamableHTTP) SendRequest( | |
| } | ||
| } | ||
|
|
||
| func (c *StreamableHTTP) sendHTTP( | ||
| ctx context.Context, | ||
| method string, | ||
| body io.Reader, | ||
| acceptType string, | ||
| ) (resp *http.Response, err error) { | ||
| // Create a combined context that could be canceled when the client is closed | ||
| newCtx, cancel := context.WithCancel(ctx) | ||
| defer cancel() | ||
| go func() { | ||
| select { | ||
| case <-c.closed: | ||
| cancel() | ||
| case <-newCtx.Done(): | ||
| // The original context was canceled, no need to do anything | ||
| } | ||
| }() | ||
| ctx = newCtx | ||
|
|
||
| // Create HTTP request | ||
| req, err := http.NewRequestWithContext(ctx, method, c.baseURL.String(), body) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to create request: %w", err) | ||
| } | ||
|
|
||
| // Set headers | ||
| req.Header.Set("Content-Type", "application/json") | ||
| req.Header.Set("Accept", acceptType) | ||
| sessionID := c.sessionID.Load().(string) | ||
| if sessionID != "" { | ||
| req.Header.Set(headerKeySessionID, sessionID) | ||
| } | ||
| for k, v := range c.headers { | ||
| req.Header.Set(k, v) | ||
| } | ||
| if c.headerFunc != nil { | ||
| for k, v := range c.headerFunc(ctx) { | ||
| req.Header.Set(k, v) | ||
| } | ||
| } | ||
|
|
||
| // Send request | ||
| resp, err = c.httpClient.Do(req) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to send request: %w", err) | ||
| } | ||
|
|
||
| // universal handling for session terminated | ||
| if resp.StatusCode == http.StatusNotFound { | ||
| c.sessionID.CompareAndSwap(sessionID, "") | ||
| return nil, fmt.Errorf("session terminated (404). need to re-initialize") | ||
| } | ||
leavez marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return resp, nil | ||
| } | ||
|
|
||
| // handleSSEResponse processes an SSE stream for a specific request. | ||
| // It returns the final result for the request once received, or an error. | ||
| func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { | ||
|
||
|
|
@@ -360,28 +408,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. | |
| } | ||
|
|
||
| // Create HTTP request | ||
| req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to create request: %w", err) | ||
| } | ||
|
|
||
| // Set headers | ||
| req.Header.Set("Content-Type", "application/json") | ||
| req.Header.Set("Accept", "application/json, text/event-stream") | ||
| if sessionID := c.sessionID.Load(); sessionID != "" { | ||
| req.Header.Set(headerKeySessionID, sessionID.(string)) | ||
| } | ||
| for k, v := range c.headers { | ||
| req.Header.Set(k, v) | ||
| } | ||
| if c.headerFunc != nil { | ||
| for k, v := range c.headerFunc(ctx) { | ||
| req.Header.Set(k, v) | ||
| } | ||
| } | ||
|
|
||
| // Send request | ||
| resp, err := c.httpClient.Do(req) | ||
| resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") | ||
| if err != nil { | ||
| return fmt.Errorf("failed to send request: %w", err) | ||
| } | ||
|
|
@@ -408,3 +435,64 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica | |
| func (c *StreamableHTTP) GetSessionId() string { | ||
| return c.sessionID.Load().(string) | ||
| } | ||
|
|
||
| func (c *StreamableHTTP) listenForever() { | ||
| c.logger.Infof("listening to server forever") | ||
| for { | ||
| err := c.createGETConnectionToServer() | ||
| if errors.Is(err, errGetMethodNotAllowed) { | ||
| // server does not support listening | ||
| c.logger.Errorf("server does not support listening") | ||
| return | ||
| } | ||
|
|
||
| select { | ||
| case <-c.closed: | ||
| return | ||
| default: | ||
| } | ||
|
|
||
| if err != nil { | ||
| c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) | ||
| } | ||
| time.Sleep(retryInterval) | ||
| } | ||
| } | ||
|
|
||
| var ( | ||
| errGetMethodNotAllowed = fmt.Errorf("GET method not allowed") | ||
| retryInterval = 1 * time.Second // a variable is convenient for testing | ||
| ) | ||
|
|
||
| func (c *StreamableHTTP) createGETConnectionToServer() error { | ||
|
|
||
| ctx := context.Background() // the sendHTTP will be automatically canceled when the client is closed | ||
| resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") | ||
| if err != nil { | ||
| return fmt.Errorf("failed to send request: %w", err) | ||
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| // Check if we got an error response | ||
| if resp.StatusCode == http.StatusMethodNotAllowed { | ||
| return errGetMethodNotAllowed | ||
| } | ||
|
|
||
| if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { | ||
| body, _ := io.ReadAll(resp.Body) | ||
| return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) | ||
| } | ||
|
|
||
| // handle SSE response | ||
| contentType := resp.Header.Get("Content-Type") | ||
| if contentType != "text/event-stream" { | ||
| return fmt.Errorf("unexpected content type: %s", contentType) | ||
| } | ||
|
|
||
| _, err = c.handleSSEResponse(ctx, resp.Body) | ||
|
||
| if err != nil { | ||
| return fmt.Errorf("failed to handle SSE response: %w", err) | ||
| } | ||
|
|
||
| return nil | ||
| } | ||




Uh oh!
There was an error while loading. Please reload this page.