Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ type StreamableHTTP struct {

// OAuth support
oauthHandler *OAuthHandler
wg sync.WaitGroup
}

// NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL.
Expand Down Expand Up @@ -182,9 +183,10 @@ func (c *StreamableHTTP) Close() error {
sessionId := c.sessionID.Load().(string)
if sessionId != "" {
c.sessionID.Store("")

c.wg.Add(1)
// notify server session closed
go func() {
defer c.wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
Expand All @@ -201,7 +203,7 @@ func (c *StreamableHTTP) Close() error {
res.Body.Close()
}()
}

c.wg.Wait()
return nil
}

Expand Down Expand Up @@ -231,7 +233,6 @@ func (c *StreamableHTTP) SendRequest(
ctx context.Context,
request JSONRPCRequest,
) (*JSONRPCResponse, error) {

// Marshal request
requestBody, err := json.Marshal(request)
if err != nil {
Expand Down Expand Up @@ -316,7 +317,6 @@ func (c *StreamableHTTP) sendHTTP(
body io.Reader,
acceptType string,
) (resp *http.Response, err error) {

// Create HTTP request
req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
if err != nil {
Expand Down Expand Up @@ -374,7 +374,6 @@ func (c *StreamableHTTP) sendHTTP(
// It returns the final result for the request once received, or an error.
// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening.
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) {

// Create a channel for this specific request
responseChan := make(chan *JSONRPCResponse, 1)

Expand All @@ -387,7 +386,6 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
defer close(responseChan)

c.readSSE(ctx, reader, func(event, data string) {

// (unsupported: batching)

var message JSONRPCResponse
Expand Down Expand Up @@ -490,7 +488,6 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
}

func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {

// Marshal request
requestBody, err := json.Marshal(notification)
if err != nil {
Expand Down Expand Up @@ -577,7 +574,6 @@ var (
)

func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {

resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
Expand Down