Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 164 additions & 76 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
Expand All @@ -16,10 +17,24 @@ import (
"time"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/util"
)

type StreamableHTTPCOption func(*StreamableHTTP)

// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification),
// you should enable this option.
//
// It will establish a standalone long-live GET HTTP connection to the server.
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
// NOTICE: Even enabled, the server may not support this feature.
func WithContinuousListening() StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
sc.getListeningEnabled = true
}
}

func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
sc.headers = headers
Expand All @@ -39,6 +54,12 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
}
}

func WithLogger(logger util.Logger) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
sc.logger = logger
}
}

// StreamableHTTP implements Streamable HTTP transport.
//
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
Expand All @@ -49,18 +70,19 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
//
// The current implementation does not support the following features:
// - batching
// - continuously listening for server notifications when no request is in flight
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
// - resuming stream
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
// - server -> client request
type StreamableHTTP struct {
baseURL *url.URL
httpClient *http.Client
headers map[string]string
headerFunc HTTPHeaderFunc
baseURL *url.URL
httpClient *http.Client
headers map[string]string
headerFunc HTTPHeaderFunc
logger util.Logger
getListeningEnabled bool

sessionID atomic.Value // string
initialized chan struct{}
sessionID atomic.Value // string

notificationHandler func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
Expand All @@ -77,10 +99,12 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea
}

smc := &StreamableHTTP{
baseURL: parsedURL,
httpClient: &http.Client{},
headers: make(map[string]string),
closed: make(chan struct{}),
baseURL: parsedURL,
httpClient: &http.Client{},
headers: make(map[string]string),
closed: make(chan struct{}),
logger: util.DefaultLogger(),
initialized: make(chan struct{}),
}
smc.sessionID.Store("") // set initial value to simplify later usage

Expand All @@ -93,7 +117,14 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea

// Start initiates the HTTP connection to the server.
func (c *StreamableHTTP) Start(ctx context.Context) error {
// For Streamable HTTP, we don't need to establish a persistent connection
// For Streamable HTTP, we don't need to establish a persistent connection by default
if c.getListeningEnabled {
go func() {
<-c.initialized
c.listenForever()
}()
}

return nil
}

Expand Down Expand Up @@ -144,61 +175,20 @@ func (c *StreamableHTTP) SendRequest(
request JSONRPCRequest,
) (*JSONRPCResponse, error) {

// Create a combined context that could be canceled when the client is closed
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-c.closed:
cancel()
case <-newCtx.Done():
// The original context was canceled, no need to do anything
}
}()
ctx = newCtx

// Marshal request
requestBody, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
sessionID := c.sessionID.Load()
if sessionID != "" {
req.Header.Set(headerKeySessionID, sessionID.(string))
}
for k, v := range c.headers {
req.Header.Set(k, v)
}
if c.headerFunc != nil {
for k, v := range c.headerFunc(ctx) {
req.Header.Set(k, v)
}
}

// Send request
resp, err := c.httpClient.Do(req)
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

// Check if we got an error response
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
// handle session closed
if resp.StatusCode == http.StatusNotFound {
c.sessionID.CompareAndSwap(sessionID, "")
return nil, fmt.Errorf("session terminated (404). need to re-initialize")
}

// handle error response
var errResponse JSONRPCResponse
Expand All @@ -215,6 +205,8 @@ func (c *StreamableHTTP) SendRequest(
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
c.sessionID.Store(sessionID)
}

close(c.initialized)
}

// Handle different response types
Expand Down Expand Up @@ -243,6 +235,62 @@ func (c *StreamableHTTP) SendRequest(
}
}

func (c *StreamableHTTP) sendHTTP(
ctx context.Context,
method string,
body io.Reader,
acceptType string,
) (resp *http.Response, err error) {
// Create a combined context that could be canceled when the client is closed
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-c.closed:
cancel()
case <-newCtx.Done():
// The original context was canceled, no need to do anything
}
}()
ctx = newCtx

// Create HTTP request
req, err := http.NewRequestWithContext(ctx, method, c.baseURL.String(), body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", acceptType)
sessionID := c.sessionID.Load().(string)
if sessionID != "" {
req.Header.Set(headerKeySessionID, sessionID)
}
for k, v := range c.headers {
req.Header.Set(k, v)
}
if c.headerFunc != nil {
for k, v := range c.headerFunc(ctx) {
req.Header.Set(k, v)
}
}

// Send request
resp, err = c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}

// universal handling for session terminated
if resp.StatusCode == http.StatusNotFound {
c.sessionID.CompareAndSwap(sessionID, "")
return nil, fmt.Errorf("session terminated (404). need to re-initialize")
}

return resp, nil
}

// handleSSEResponse processes an SSE stream for a specific request.
// It returns the final result for the request once received, or an error.
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When readSSE#ReadString occur error, reader will be closed and won't receive notification anymore, but handleSSEResponse will wait ctx.Done(). Now, client will ignore all the notification, am I right?
image

image

Copy link
Contributor Author

@leavez leavez Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image when readSSE is done, the responseChan will be closed, and goes image

then the method returns

Expand Down Expand Up @@ -360,28 +408,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
}

// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
if sessionID := c.sessionID.Load(); sessionID != "" {
req.Header.Set(headerKeySessionID, sessionID.(string))
}
for k, v := range c.headers {
req.Header.Set(k, v)
}
if c.headerFunc != nil {
for k, v := range c.headerFunc(ctx) {
req.Header.Set(k, v)
}
}

// Send request
resp, err := c.httpClient.Do(req)
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
Expand All @@ -408,3 +435,64 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica
func (c *StreamableHTTP) GetSessionId() string {
return c.sessionID.Load().(string)
}

func (c *StreamableHTTP) listenForever() {
c.logger.Infof("listening to server forever")
for {
err := c.createGETConnectionToServer()
if errors.Is(err, errGetMethodNotAllowed) {
// server does not support listening
c.logger.Errorf("server does not support listening")
return
}

select {
case <-c.closed:
return
default:
}

if err != nil {
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
}
time.Sleep(retryInterval)
}
}

var (
errGetMethodNotAllowed = fmt.Errorf("GET method not allowed")
retryInterval = 1 * time.Second // a variable is convenient for testing
)

func (c *StreamableHTTP) createGETConnectionToServer() error {

ctx := context.Background() // the sendHTTP will be automatically canceled when the client is closed
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

// Check if we got an error response
if resp.StatusCode == http.StatusMethodNotAllowed {
return errGetMethodNotAllowed
}

if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
}

// handle SSE response
contentType := resp.Header.Get("Content-Type")
if contentType != "text/event-stream" {
return fmt.Errorf("unexpected content type: %s", contentType)
}

_, err = c.handleSSEResponse(ctx, resp.Body)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently ignore the response here.
However, if you refer to the Python SDK implementation, you'll see that server responses are actively written to a read_stream—a memory stream used for receiving messages.

This read_stream is initialized during client setup, and it's the same stream shared by both the GET and POST handlers.

That said, the intended behavior here is still somewhat unclear—the MCP spec doesn't explicitly define whether GET responses must be surfaced to the client, so it's possible the current handling is valid, but worth clarifying.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7dc57e6

The spec is unclear, and to be easier to use, we should be more compatible, however, currently the transport layer is message based, so is no easy way to handle the response messages.

if err != nil {
return fmt.Errorf("failed to handle SSE response: %w", err)
}

return nil
}
Loading
Loading