diff --git a/client/client.go b/client/client.go index ba03b5828..60fe0cbfe 100644 --- a/client/client.go +++ b/client/client.go @@ -1,124 +1,390 @@ -// Package client provides MCP (Model Control Protocol) client implementations. package client import ( "context" "encoding/json" + "errors" "fmt" + "sync" + "sync/atomic" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) -// MCPClient represents an MCP client interface -type MCPClient interface { - // Initialize sends the initial connection request to the server - Initialize( - ctx context.Context, - request mcp.InitializeRequest, - ) (*mcp.InitializeResult, error) - - // Ping checks if the server is alive - Ping(ctx context.Context) error - - // ListResourcesByPage manually list resources by page. - ListResourcesByPage( - ctx context.Context, - request mcp.ListResourcesRequest, - ) (*mcp.ListResourcesResult, error) - - // ListResources requests a list of available resources from the server - ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, - ) (*mcp.ListResourcesResult, error) - - // ListResourceTemplatesByPage manually list resource templates by page. - ListResourceTemplatesByPage( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, - ) (*mcp.ListResourceTemplatesResult, - error) - - // ListResourceTemplates requests a list of available resource templates from the server - ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, - ) (*mcp.ListResourceTemplatesResult, - error) - - // ReadResource reads a specific resource from the server - ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, - ) (*mcp.ReadResourceResult, error) - - // Subscribe requests notifications for changes to a specific resource - Subscribe(ctx context.Context, request mcp.SubscribeRequest) error - - // Unsubscribe cancels notifications for a specific resource - Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error - - // ListPromptsByPage manually list prompts by page. - ListPromptsByPage( - ctx context.Context, - request mcp.ListPromptsRequest, - ) (*mcp.ListPromptsResult, error) - - // ListPrompts requests a list of available prompts from the server - ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, - ) (*mcp.ListPromptsResult, error) - - // GetPrompt retrieves a specific prompt from the server - GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, - ) (*mcp.GetPromptResult, error) - - // ListToolsByPage manually list tools by page. - ListToolsByPage( - ctx context.Context, - request mcp.ListToolsRequest, - ) (*mcp.ListToolsResult, error) - - // ListTools requests a list of available tools from the server - ListTools( - ctx context.Context, - request mcp.ListToolsRequest, - ) (*mcp.ListToolsResult, error) - - // CallTool invokes a specific tool on the server - CallTool( - ctx context.Context, - request mcp.CallToolRequest, - ) (*mcp.CallToolResult, error) - - // SetLevel sets the logging level for the server - SetLevel(ctx context.Context, request mcp.SetLevelRequest) error - - // Complete requests completion options for a given argument - Complete( - ctx context.Context, - request mcp.CompleteRequest, - ) (*mcp.CompleteResult, error) - - // Close client connection and cleanup resources - Close() error - - // OnNotification registers a handler for notifications - OnNotification(handler func(notification mcp.JSONRPCNotification)) -} - -type mcpClient interface { - MCPClient - - sendRequest(ctx context.Context, method string, params interface{}) (*json.RawMessage, error) +// Client implements the MCP client. +type Client struct { + transport transport.Interface + + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + requestID atomic.Int64 + capabilities mcp.ServerCapabilities +} + +// NewClient creates a new MCP client with the given transport. +// Usage: +// +// stdio := transport.NewStdio("./mcp_server", nil, "xxx") +// client, err := NewClient(stdio) +// if err != nil { +// log.Fatalf("Failed to create client: %v", err) +// } +func NewClient(transport transport.Interface) *Client { + return &Client{ + transport: transport, + } +} + +// Start initiates the connection to the server. +// Must be called before using the client. +func (c *Client) Start(ctx context.Context) error { + if c.transport == nil { + return fmt.Errorf("transport is nil") + } + err := c.transport.Start(ctx) + if err != nil { + return err + } + + c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + c.notifyMu.RLock() + defer c.notifyMu.RUnlock() + for _, handler := range c.notifications { + handler(notification) + } + }) + return nil +} + +// Close shuts down the client and closes the transport. +func (c *Client) Close() error { + return c.transport.Close() +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *Client) OnNotification( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notifications = append(c.notifications, handler) +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *Client) sendRequest( + ctx context.Context, + method string, + params interface{}, +) (*json.RawMessage, error) { + if !c.initialized && method != "initialize" { + return nil, fmt.Errorf("client not initialized") + } + + id := c.requestID.Add(1) + + request := transport.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: method, + Params: params, + } + + response, err := c.transport.SendRequest(ctx, request) + if err != nil { + return nil, fmt.Errorf("transport error: %w", err) + } + + if response.Error != nil { + return nil, errors.New(response.Error.Message) + } + + return &response.Result, nil +} + +// Initialize negotiates with the server. +// Must be called after Start, and before any request methods. +func (c *Client) Initialize( + ctx context.Context, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, error) { + // Ensure we send a params object with all required fields + params := struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + }{ + ProtocolVersion: request.Params.ProtocolVersion, + ClientInfo: request.Params.ClientInfo, + Capabilities: request.Params.Capabilities, // Will be empty struct if not set + } + + response, err := c.sendRequest(ctx, "initialize", params) + if err != nil { + return nil, err + } + + var result mcp.InitializeResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Store capabilities + c.capabilities = result.Capabilities + + // Send initialized notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: "notifications/initialized", + }, + } + + err = c.transport.SendNotification(ctx, notification) + if err != nil { + return nil, fmt.Errorf( + "failed to send initialized notification: %w", + err, + ) + } + + c.initialized = true + return &result, nil +} + +func (c *Client) Ping(ctx context.Context) error { + _, err := c.sendRequest(ctx, "ping", nil) + return err +} + +// ListResourcesByPage manually list resources by page. +func (c *Client) ListResourcesByPage( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, error) { + result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list") + if err != nil { + return nil, err + } + return result, nil +} + +func (c *Client) ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, error) { + result, err := c.ListResourcesByPage(ctx, request) + if err != nil { + return nil, err + } + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListResourcesByPage(ctx, request) + if err != nil { + return nil, err + } + result.Resources = append(result.Resources, newPageRes.Resources...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil +} + +func (c *Client) ListResourceTemplatesByPage( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, error) { + result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list") + if err != nil { + return nil, err + } + return result, nil +} + +func (c *Client) ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, error) { + result, err := c.ListResourceTemplatesByPage(ctx, request) + if err != nil { + return nil, err + } + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListResourceTemplatesByPage(ctx, request) + if err != nil { + return nil, err + } + result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil +} + +func (c *Client) ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, error) { + response, err := c.sendRequest(ctx, "resources/read", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseReadResourceResult(response) +} + +func (c *Client) Subscribe( + ctx context.Context, + request mcp.SubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + return err +} + +func (c *Client) Unsubscribe( + ctx context.Context, + request mcp.UnsubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + return err +} + +func (c *Client) ListPromptsByPage( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list") + if err != nil { + return nil, err + } + return result, nil +} + +func (c *Client) ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + result, err := c.ListPromptsByPage(ctx, request) + if err != nil { + return nil, err + } + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListPromptsByPage(ctx, request) + if err != nil { + return nil, err + } + result.Prompts = append(result.Prompts, newPageRes.Prompts...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil +} + +func (c *Client) GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, error) { + response, err := c.sendRequest(ctx, "prompts/get", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseGetPromptResult(response) +} + +func (c *Client) ListToolsByPage( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list") + if err != nil { + return nil, err + } + return result, nil +} + +func (c *Client) ListTools( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + result, err := c.ListToolsByPage(ctx, request) + if err != nil { + return nil, err + } + for result.NextCursor != "" { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + request.Params.Cursor = result.NextCursor + newPageRes, err := c.ListToolsByPage(ctx, request) + if err != nil { + return nil, err + } + result.Tools = append(result.Tools, newPageRes.Tools...) + result.NextCursor = newPageRes.NextCursor + } + } + return result, nil +} + +func (c *Client) CallTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + response, err := c.sendRequest(ctx, "tools/call", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseCallToolResult(response) +} + +func (c *Client) SetLevel( + ctx context.Context, + request mcp.SetLevelRequest, +) error { + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + return err +} + +func (c *Client) Complete( + ctx context.Context, + request mcp.CompleteRequest, +) (*mcp.CompleteResult, error) { + response, err := c.sendRequest(ctx, "completion/complete", request.Params) + if err != nil { + return nil, err + } + + var result mcp.CompleteResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil } func listByPage[T any]( ctx context.Context, - client mcpClient, + client *Client, request mcp.PaginatedRequest, method string, ) (*T, error) { @@ -132,3 +398,11 @@ func listByPage[T any]( } return &result, nil } + +// Helper methods + +// GetTransport gives access to the underlying transport layer. +// Cast it to the specific transport type and obtain the other helper methods. +func (c *Client) GetTransport() transport.Interface { + return c.transport +} diff --git a/client/interface.go b/client/interface.go new file mode 100644 index 000000000..ea7f4d1fb --- /dev/null +++ b/client/interface.go @@ -0,0 +1,109 @@ +// Package client provides MCP (Model Control Protocol) client implementations. +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// MCPClient represents an MCP client interface +type MCPClient interface { + // Initialize sends the initial connection request to the server + Initialize( + ctx context.Context, + request mcp.InitializeRequest, + ) (*mcp.InitializeResult, error) + + // Ping checks if the server is alive + Ping(ctx context.Context) error + + // ListResourcesByPage manually list resources by page. + ListResourcesByPage( + ctx context.Context, + request mcp.ListResourcesRequest, + ) (*mcp.ListResourcesResult, error) + + // ListResources requests a list of available resources from the server + ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, + ) (*mcp.ListResourcesResult, error) + + // ListResourceTemplatesByPage manually list resource templates by page. + ListResourceTemplatesByPage( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, + ) (*mcp.ListResourceTemplatesResult, + error) + + // ListResourceTemplates requests a list of available resource templates from the server + ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, + ) (*mcp.ListResourceTemplatesResult, + error) + + // ReadResource reads a specific resource from the server + ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, + ) (*mcp.ReadResourceResult, error) + + // Subscribe requests notifications for changes to a specific resource + Subscribe(ctx context.Context, request mcp.SubscribeRequest) error + + // Unsubscribe cancels notifications for a specific resource + Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error + + // ListPromptsByPage manually list prompts by page. + ListPromptsByPage( + ctx context.Context, + request mcp.ListPromptsRequest, + ) (*mcp.ListPromptsResult, error) + + // ListPrompts requests a list of available prompts from the server + ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, + ) (*mcp.ListPromptsResult, error) + + // GetPrompt retrieves a specific prompt from the server + GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, + ) (*mcp.GetPromptResult, error) + + // ListToolsByPage manually list tools by page. + ListToolsByPage( + ctx context.Context, + request mcp.ListToolsRequest, + ) (*mcp.ListToolsResult, error) + + // ListTools requests a list of available tools from the server + ListTools( + ctx context.Context, + request mcp.ListToolsRequest, + ) (*mcp.ListToolsResult, error) + + // CallTool invokes a specific tool on the server + CallTool( + ctx context.Context, + request mcp.CallToolRequest, + ) (*mcp.CallToolResult, error) + + // SetLevel sets the logging level for the server + SetLevel(ctx context.Context, request mcp.SetLevelRequest) error + + // Complete requests completion options for a given argument + Complete( + ctx context.Context, + request mcp.CompleteRequest, + ) (*mcp.CompleteResult, error) + + // Close client connection and cleanup resources + Close() error + + // OnNotification registers a handler for notifications + OnNotification(handler func(notification mcp.JSONRPCNotification)) +} diff --git a/client/sse.go b/client/sse.go index e7aaaa494..c26744a39 100644 --- a/client/sse.go +++ b/client/sse.go @@ -1,653 +1,32 @@ package client import ( - "bufio" - "bytes" - "context" - "encoding/json" - "errors" "fmt" - "io" - "net/http" + "github.com/mark3labs/mcp-go/client/transport" "net/url" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/mark3labs/mcp-go/mcp" ) -// SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE). -// It maintains a persistent HTTP connection to receive server-pushed events -// while sending requests over regular HTTP POST calls. The client handles -// automatic reconnection and message routing between requests and responses. -type SSEMCPClient struct { - baseURL *url.URL - endpoint *url.URL - httpClient *http.Client - requestID atomic.Int64 - responses map[int64]chan RPCResponse - mu sync.RWMutex - done chan struct{} - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - endpointChan chan struct{} - capabilities mcp.ServerCapabilities - headers map[string]string -} - -type ClientOption func(*SSEMCPClient) - -func WithHeaders(headers map[string]string) ClientOption { - return func(sc *SSEMCPClient) { - sc.headers = headers - } +func WithHeaders(headers map[string]string) transport.ClientOption { + return transport.WithHeaders(headers) } // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. -func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) { - parsedURL, err := url.Parse(baseURL) - if err != nil { - return nil, fmt.Errorf("invalid URL: %w", err) - } - - smc := &SSEMCPClient{ - baseURL: parsedURL, - httpClient: &http.Client{}, - responses: make(map[int64]chan RPCResponse), - done: make(chan struct{}), - endpointChan: make(chan struct{}), - headers: make(map[string]string), - } - - for _, opt := range options { - opt(smc) - } - - return smc, nil -} - -// Start initiates the SSE connection to the server and waits for the endpoint information. -// Returns an error if the connection fails or times out waiting for the endpoint. -func (c *SSEMCPClient) Start(ctx context.Context) error { - - req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) - - if err != nil { - - return fmt.Errorf("failed to create request: %w", err) - - } - - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - for k, v := range c.headers { - req.Header.Set(k, v) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("failed to connect to SSE stream: %w", err) - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - go c.readSSE(resp.Body) - - // Wait for the endpoint to be received - - select { - case <-c.endpointChan: - // Endpoint received, proceed - case <-ctx.Done(): - return fmt.Errorf("context cancelled while waiting for endpoint") - case <-time.After(30 * time.Second): // Add a timeout - return fmt.Errorf("timeout waiting for endpoint") - } - - return nil -} - -// readSSE continuously reads the SSE stream and processes events. -// It runs until the connection is closed or an error occurs. -func (c *SSEMCPClient) readSSE(reader io.ReadCloser) { - defer reader.Close() - - br := bufio.NewReader(reader) - var event, data string - - for { - select { - case <-c.done: - return - default: - line, err := br.ReadString('\n') - if err != nil { - if err == io.EOF { - // Process any pending event before exit - if event != "" && data != "" { - c.handleSSEEvent(event, data) - } - break - } - select { - case <-c.done: - return - default: - fmt.Printf("SSE stream error: %v\n", err) - return - } - } - - // Remove only newline markers - line = strings.TrimRight(line, "\r\n") - if line == "" { - // Empty line means end of event - if event != "" && data != "" { - c.handleSSEEvent(event, data) - event = "" - data = "" - } - continue - } - - if strings.HasPrefix(line, "event:") { - event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) - } else if strings.HasPrefix(line, "data:") { - data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) - } - } - } -} - -// handleSSEEvent processes SSE events based on their type. -// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. -func (c *SSEMCPClient) handleSSEEvent(event, data string) { - switch event { - case "endpoint": - endpoint, err := c.baseURL.Parse(data) - if err != nil { - fmt.Printf("Error parsing endpoint URL: %v\n", err) - return - } - if endpoint.Host != c.baseURL.Host { - fmt.Printf("Endpoint origin does not match connection origin\n") - return - } - c.endpoint = endpoint - close(c.endpointChan) - - case "message": - var baseMessage struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { - fmt.Printf("Error unmarshaling message: %v\n", err) - return - } - - // Handle notification - if baseMessage.ID == nil { - var notification mcp.JSONRPCNotification - if err := json.Unmarshal([]byte(data), ¬ification); err != nil { - return - } - c.notifyMu.RLock() - for _, handler := range c.notifications { - handler(notification) - } - c.notifyMu.RUnlock() - return - } - - c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] - c.mu.RUnlock() - - if ok { - if baseMessage.Error != nil { - ch <- RPCResponse{ - Error: &baseMessage.Error.Message, - } - } else { - ch <- RPCResponse{ - Response: &baseMessage.Result, - } - } - c.mu.Lock() - delete(c.responses, *baseMessage.ID) - c.mu.Unlock() - } - } -} - -// OnNotification registers a handler function to be called when notifications are received. -// Multiple handlers can be registered and will be called in the order they were added. -func (c *SSEMCPClient) OnNotification( - handler func(notification mcp.JSONRPCNotification), -) { - c.notifyMu.Lock() - defer c.notifyMu.Unlock() - c.notifications = append(c.notifications, handler) -} - -// sendRequest sends a JSON-RPC request to the server and waits for a response. -// Returns the raw JSON response message or an error if the request fails. -func (c *SSEMCPClient) sendRequest( - ctx context.Context, - method string, - params interface{}, -) (*json.RawMessage, error) { - if !c.initialized && method != "initialize" { - return nil, fmt.Errorf("client not initialized") - } - - if c.endpoint == nil { - return nil, fmt.Errorf("endpoint not received") - } - - id := c.requestID.Add(1) - - request := mcp.JSONRPCRequest{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Request: mcp.Request{ - Method: method, - }, - Params: params, - } - - requestBytes, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - responseChan := make(chan RPCResponse, 1) - c.mu.Lock() - c.responses[id] = responseChan - c.mu.Unlock() - - req, err := http.NewRequestWithContext( - ctx, - "POST", - c.endpoint.String(), - bytes.NewReader(requestBytes), - ) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - // set custom http headers - for k, v := range c.headers { - req.Header.Set(k, v) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK && - resp.StatusCode != http.StatusAccepted { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf( - "request failed with status %d: %s", - resp.StatusCode, - body, - ) - } - - select { - case <-ctx.Done(): - c.mu.Lock() - delete(c.responses, id) - c.mu.Unlock() - return nil, ctx.Err() - case response := <-responseChan: - if response.Error != nil { - return nil, errors.New(*response.Error) - } - return response.Response, nil - } -} - -func (c *SSEMCPClient) Initialize( - ctx context.Context, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, error) { - // Ensure we send a params object with all required fields - params := struct { - ProtocolVersion string `json:"protocolVersion"` - ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - }{ - ProtocolVersion: request.Params.ProtocolVersion, - ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set - } - - response, err := c.sendRequest(ctx, "initialize", params) - if err != nil { - return nil, err - } - - var result mcp.InitializeResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - // Store capabilities - c.capabilities = result.Capabilities +func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) { - // Send initialized notification - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: "notifications/initialized", - }, - } - - notificationBytes, err := json.Marshal(notification) - if err != nil { - return nil, fmt.Errorf( - "failed to marshal initialized notification: %w", - err, - ) - } - - req, err := http.NewRequestWithContext( - ctx, - "POST", - c.endpoint.String(), - bytes.NewReader(notificationBytes), - ) - if err != nil { - return nil, fmt.Errorf("failed to create notification request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf( - "failed to send initialized notification: %w", - err, - ) - } - defer resp.Body.Close() - - c.initialized = true - return &result, nil -} - -func (c *SSEMCPClient) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) - return err -} - -// ListResourcesByPage manually list resources by page. -func (c *SSEMCPClient) ListResourcesByPage( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *SSEMCPClient) ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - result, err := c.ListResourcesByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListResourcesByPage(ctx, request) - if err != nil { - return nil, err - } - result.Resources = append(result.Resources, newPageRes.Resources...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *SSEMCPClient) ListResourceTemplatesByPage( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *SSEMCPClient) ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - result, err := c.ListResourceTemplatesByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListResourceTemplatesByPage(ctx, request) - if err != nil { - return nil, err - } - result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *SSEMCPClient) ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseReadResourceResult(response) -} - -func (c *SSEMCPClient) Subscribe( - ctx context.Context, - request mcp.SubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) - return err -} - -func (c *SSEMCPClient) Unsubscribe( - ctx context.Context, - request mcp.UnsubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) - return err -} - -func (c *SSEMCPClient) ListPromptsByPage( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *SSEMCPClient) ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - result, err := c.ListPromptsByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListPromptsByPage(ctx, request) - if err != nil { - return nil, err - } - result.Prompts = append(result.Prompts, newPageRes.Prompts...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *SSEMCPClient) GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) + sseTransport, err := transport.NewSSE(baseURL, options...) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create SSE transport: %w", err) } - return mcp.ParseGetPromptResult(response) + return NewClient(sseTransport), nil } -func (c *SSEMCPClient) ListToolsByPage( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *SSEMCPClient) ListTools( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - result, err := c.ListToolsByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListToolsByPage(ctx, request) - if err != nil { - return nil, err - } - result.Tools = append(result.Tools, newPageRes.Tools...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *SSEMCPClient) CallTool( - ctx context.Context, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseCallToolResult(response) -} - -func (c *SSEMCPClient) SetLevel( - ctx context.Context, - request mcp.SetLevelRequest, -) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) - return err -} - -func (c *SSEMCPClient) Complete( - ctx context.Context, - request mcp.CompleteRequest, -) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) - if err != nil { - return nil, err - } - - var result mcp.CompleteResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil -} - -// Helper methods - // GetEndpoint returns the current endpoint URL for the SSE connection. -func (c *SSEMCPClient) GetEndpoint() *url.URL { - return c.endpoint -} - -// Close shuts down the SSE client connection and cleans up any pending responses. -// Returns an error if the shutdown process fails. -func (c *SSEMCPClient) Close() error { - select { - case <-c.done: - return nil // Already closed - default: - close(c.done) - } - - // Clean up any pending responses - c.mu.Lock() - for _, ch := range c.responses { - close(ch) - } - c.responses = make(map[int64]chan RPCResponse) - c.mu.Unlock() - - return nil +// +// Note: This method only works with SSE transport, or it will panic. +func GetEndpoint(c *Client) *url.URL { + t := c.GetTransport() + sse := t.(*transport.SSE) + return sse.GetEndpoint() } diff --git a/client/sse_test.go b/client/sse_test.go index 366fbc517..7308d043f 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "github.com/mark3labs/mcp-go/client/transport" "testing" "time" @@ -46,7 +47,8 @@ func TestSSEMCPClient(t *testing.T) { } defer client.Close() - if client.baseURL == nil { + sseTransport := client.GetTransport().(*transport.SSE) + if sseTransport.GetBaseURL() == nil { t.Error("Base URL should not be nil") } }) diff --git a/client/stdio.go b/client/stdio.go index c92334928..a25f6d19d 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -1,524 +1,40 @@ package client import ( - "bufio" "context" - "encoding/json" - "errors" "fmt" "io" - "os" - "os/exec" - "sync" - "sync/atomic" - "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/client/transport" ) -// StdioMCPClient implements the MCPClient interface using stdio communication. -// It launches a subprocess and communicates with it via standard input/output streams -// using JSON-RPC messages. The client handles message routing between requests and -// responses, and supports asynchronous notifications. -type StdioMCPClient struct { - cmd *exec.Cmd - stdin io.WriteCloser - stdout *bufio.Reader - stderr io.ReadCloser - requestID atomic.Int64 - responses map[int64]chan RPCResponse - mu sync.RWMutex - done chan struct{} - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - capabilities mcp.ServerCapabilities -} - // NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess. // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. // Returns an error if the subprocess cannot be started or the pipes cannot be created. +// +// NOTICE: NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. +// This is for backward compatibility. func NewStdioMCPClient( command string, env []string, args ...string, -) (*StdioMCPClient, error) { - cmd := exec.Command(command, args...) - - mergedEnv := os.Environ() - mergedEnv = append(mergedEnv, env...) - - cmd.Env = mergedEnv - - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdin pipe: %w", err) - } - - stdout, err := cmd.StdoutPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdout pipe: %w", err) - } +) (*Client, error) { - stderr, err := cmd.StderrPipe() + stdioTransport := transport.NewStdio(command, env, args...) + err := stdioTransport.Start(context.Background()) if err != nil { - return nil, fmt.Errorf("failed to create stderr pipe: %w", err) - } - - client := &StdioMCPClient{ - cmd: cmd, - stdin: stdin, - stderr: stderr, - stdout: bufio.NewReader(stdout), - responses: make(map[int64]chan RPCResponse), - done: make(chan struct{}), + return nil, fmt.Errorf("failed to start stdio transport: %w", err) } - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start command: %w", err) - } - - // Start reading responses in a goroutine and wait for it to be ready - ready := make(chan struct{}) - go func() { - close(ready) - client.readResponses() - }() - <-ready - - return client, nil -} - -// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. -// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. -func (c *StdioMCPClient) Close() error { - close(c.done) - if err := c.stdin.Close(); err != nil { - return fmt.Errorf("failed to close stdin: %w", err) - } - if err := c.stderr.Close(); err != nil { - return fmt.Errorf("failed to close stderr: %w", err) - } - return c.cmd.Wait() + return NewClient(stdioTransport), nil } -// Stderr returns a reader for the stderr output of the subprocess. +// GetStderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. -func (c *StdioMCPClient) Stderr() io.Reader { - return c.stderr -} - -// OnNotification registers a handler function to be called when notifications are received. -// Multiple handlers can be registered and will be called in the order they were added. -func (c *StdioMCPClient) OnNotification( - handler func(notification mcp.JSONRPCNotification), -) { - c.notifyMu.Lock() - defer c.notifyMu.Unlock() - c.notifications = append(c.notifications, handler) -} - -// readResponses continuously reads and processes responses from the server's stdout. -// It handles both responses to requests and notifications, routing them appropriately. -// Runs until the done channel is closed or an error occurs reading from stdout. -func (c *StdioMCPClient) readResponses() { - for { - select { - case <-c.done: - return - default: - line, err := c.stdout.ReadString('\n') - if err != nil { - if err != io.EOF { - fmt.Printf("Error reading response: %v\n", err) - } - return - } - - var baseMessage struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Method string `json:"method,omitempty"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { - continue - } - - // Handle notification - if baseMessage.ID == nil { - var notification mcp.JSONRPCNotification - if err := json.Unmarshal([]byte(line), ¬ification); err != nil { - continue - } - c.notifyMu.RLock() - for _, handler := range c.notifications { - handler(notification) - } - c.notifyMu.RUnlock() - continue - } - - c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] - c.mu.RUnlock() - - if ok { - if baseMessage.Error != nil { - ch <- RPCResponse{ - Error: &baseMessage.Error.Message, - } - } else { - ch <- RPCResponse{ - Response: &baseMessage.Result, - } - } - c.mu.Lock() - delete(c.responses, *baseMessage.ID) - c.mu.Unlock() - } - } - } -} - -// sendRequest sends a JSON-RPC request to the server and waits for a response. -// It creates a unique request ID, sends the request over stdin, and waits for -// the corresponding response or context cancellation. -// Returns the raw JSON response message or an error if the request fails. -func (c *StdioMCPClient) sendRequest( - ctx context.Context, - method string, - params interface{}, -) (*json.RawMessage, error) { - if !c.initialized && method != "initialize" { - return nil, fmt.Errorf("client not initialized") - } - - id := c.requestID.Add(1) - - // Create the complete request structure - request := mcp.JSONRPCRequest{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Request: mcp.Request{ - Method: method, - }, - Params: params, - } - - responseChan := make(chan RPCResponse, 1) - c.mu.Lock() - c.responses[id] = responseChan - c.mu.Unlock() - - requestBytes, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - requestBytes = append(requestBytes, '\n') - - if _, err := c.stdin.Write(requestBytes); err != nil { - return nil, fmt.Errorf("failed to write request: %w", err) - } - - select { - case <-ctx.Done(): - c.mu.Lock() - delete(c.responses, id) - c.mu.Unlock() - return nil, ctx.Err() - case response := <-responseChan: - if response.Error != nil { - return nil, errors.New(*response.Error) - } - return response.Response, nil - } -} - -func (c *StdioMCPClient) Ping(ctx context.Context) error { - _, err := c.sendRequest(ctx, "ping", nil) - return err -} - -func (c *StdioMCPClient) Initialize( - ctx context.Context, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, error) { - // This structure ensures Capabilities is always included in JSON - params := struct { - ProtocolVersion string `json:"protocolVersion"` - ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - }{ - ProtocolVersion: request.Params.ProtocolVersion, - ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set - } - - response, err := c.sendRequest(ctx, "initialize", params) - if err != nil { - return nil, err - } - - var result mcp.InitializeResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - // Store capabilities - c.capabilities = result.Capabilities - - // Send initialized notification - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: "notifications/initialized", - }, - } - - notificationBytes, err := json.Marshal(notification) - if err != nil { - return nil, fmt.Errorf( - "failed to marshal initialized notification: %w", - err, - ) - } - notificationBytes = append(notificationBytes, '\n') - - if _, err := c.stdin.Write(notificationBytes); err != nil { - return nil, fmt.Errorf( - "failed to send initialized notification: %w", - err, - ) - } - - c.initialized = true - return &result, nil -} - -// ListResourcesByPage manually list resources by page. -func (c *StdioMCPClient) ListResourcesByPage( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *StdioMCPClient) ListResources( - ctx context.Context, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, error) { - result, err := c.ListResourcesByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListResourcesByPage(ctx, request) - if err != nil { - return nil, err - } - result.Resources = append(result.Resources, newPageRes.Resources...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *StdioMCPClient) ListResourceTemplatesByPage( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *StdioMCPClient) ListResourceTemplates( - ctx context.Context, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, error) { - result, err := c.ListResourceTemplatesByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListResourceTemplatesByPage(ctx, request) - if err != nil { - return nil, err - } - result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *StdioMCPClient) ReadResource( - ctx context.Context, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, - error) { - response, err := c.sendRequest(ctx, "resources/read", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseReadResourceResult(response) -} - -func (c *StdioMCPClient) Subscribe( - ctx context.Context, - request mcp.SubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) - return err -} - -func (c *StdioMCPClient) Unsubscribe( - ctx context.Context, - request mcp.UnsubscribeRequest, -) error { - _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) - return err -} - -func (c *StdioMCPClient) ListPromptsByPage( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *StdioMCPClient) ListPrompts( - ctx context.Context, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, error) { - result, err := c.ListPromptsByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListPromptsByPage(ctx, request) - if err != nil { - return nil, err - } - result.Prompts = append(result.Prompts, newPageRes.Prompts...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *StdioMCPClient) GetPrompt( - ctx context.Context, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, error) { - response, err := c.sendRequest(ctx, "prompts/get", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseGetPromptResult(response) -} - -func (c *StdioMCPClient) ListToolsByPage( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list") - if err != nil { - return nil, err - } - return result, nil -} - -func (c *StdioMCPClient) ListTools( - ctx context.Context, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, error) { - result, err := c.ListToolsByPage(ctx, request) - if err != nil { - return nil, err - } - for result.NextCursor != "" { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - request.Params.Cursor = result.NextCursor - newPageRes, err := c.ListToolsByPage(ctx, request) - if err != nil { - return nil, err - } - result.Tools = append(result.Tools, newPageRes.Tools...) - result.NextCursor = newPageRes.NextCursor - } - } - return result, nil -} - -func (c *StdioMCPClient) CallTool( - ctx context.Context, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, error) { - response, err := c.sendRequest(ctx, "tools/call", request.Params) - if err != nil { - return nil, err - } - - return mcp.ParseCallToolResult(response) -} - -func (c *StdioMCPClient) SetLevel( - ctx context.Context, - request mcp.SetLevelRequest, -) error { - _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) - return err -} - -func (c *StdioMCPClient) Complete( - ctx context.Context, - request mcp.CompleteRequest, -) (*mcp.CompleteResult, error) { - response, err := c.sendRequest(ctx, "completion/complete", request.Params) - if err != nil { - return nil, err - } - - var result mcp.CompleteResult - if err := json.Unmarshal(*response, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return &result, nil +// +// Note: This method only works with stdio transport, or it will panic. +func GetStderr(c *Client) io.Reader { + t := c.GetTransport() + stdio := t.(*transport.Stdio) + return stdio.Stderr() } diff --git a/client/stdio_test.go b/client/stdio_test.go index df69b46a3..94da0b541 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -47,7 +47,7 @@ func TestStdioMCPClient(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - dec := json.NewDecoder(client.Stderr()) + dec := json.NewDecoder(GetStderr(client)) for { var record map[string]any if err := dec.Decode(&record); err != nil { diff --git a/client/transport/interface.go b/client/transport/interface.go new file mode 100644 index 000000000..8ac75d746 --- /dev/null +++ b/client/transport/interface.go @@ -0,0 +1,45 @@ +package transport + +import ( + "context" + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Interface for the transport layer. +type Interface interface { + // Start the connection. Start should only be called once. + Start(ctx context.Context) error + + // SendRequest sends a json RPC request and returns the response synchronously. + SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) + + // SendNotification sends a json RPC Notification to the server. + SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error + + // SetNotificationHandler sets the handler for notifications. + // Any notification before the handler is set will be discarded. + SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) + + // Close the connection. + Close() error +} + +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID *int64 `json:"id"` + Result json.RawMessage `json:"result"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + } `json:"error"` +} diff --git a/client/transport/sse.go b/client/transport/sse.go new file mode 100644 index 000000000..a515ae760 --- /dev/null +++ b/client/transport/sse.go @@ -0,0 +1,376 @@ +package transport + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE). +// It maintains a persistent HTTP connection to receive server-pushed events +// while sending requests over regular HTTP POST calls. The client handles +// automatic reconnection and message routing between requests and responses. +type SSE struct { + baseURL *url.URL + endpoint *url.URL + httpClient *http.Client + responses map[int64]chan *JSONRPCResponse + mu sync.RWMutex + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + endpointChan chan struct{} + headers map[string]string + + started atomic.Bool + closed atomic.Bool + cancelSSEStream context.CancelFunc +} + +type ClientOption func(*SSE) + +func WithHeaders(headers map[string]string) ClientOption { + return func(sc *SSE) { + sc.headers = headers + } +} + +// NewSSE creates a new SSE-based MCP client with the given base URL. +// Returns an error if the URL is invalid. +func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + smc := &SSE{ + baseURL: parsedURL, + httpClient: &http.Client{}, + responses: make(map[int64]chan *JSONRPCResponse), + endpointChan: make(chan struct{}), + headers: make(map[string]string), + } + + for _, opt := range options { + opt(smc) + } + + return smc, nil +} + +// Start initiates the SSE connection to the server and waits for the endpoint information. +// Returns an error if the connection fails or times out waiting for the endpoint. +func (c *SSE) Start(ctx context.Context) error { + + if c.started.Load() { + return fmt.Errorf("has already started") + } + + ctx, cancel := context.WithCancel(ctx) + c.cancelSSEStream = cancel + + req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) + + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + // set custom http headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to connect to SSE stream: %w", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + go c.readSSE(resp.Body) + + // Wait for the endpoint to be received + timeout := time.NewTimer(30 * time.Second) + defer timeout.Stop() + select { + case <-c.endpointChan: + // Endpoint received, proceed + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for endpoint") + case <-timeout.C: // Add a timeout + cancel() + return fmt.Errorf("timeout waiting for endpoint") + } + + c.started.Store(true) + return nil +} + +// readSSE continuously reads the SSE stream and processes events. +// It runs until the connection is closed or an error occurs. +func (c *SSE) readSSE(reader io.ReadCloser) { + defer reader.Close() + + br := bufio.NewReader(reader) + var event, data string + + for { + // when close or start's ctx cancel, the reader will be closed + // and the for loop will break. + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit + if event != "" && data != "" { + c.handleSSEEvent(event, data) + } + break + } + if !c.closed.Load() { + fmt.Printf("SSE stream error: %v\n", err) + } + return + } + + // Remove only newline markers + line = strings.TrimRight(line, "\r\n") + if line == "" { + // Empty line means end of event + if event != "" && data != "" { + c.handleSSEEvent(event, data) + event = "" + data = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } +} + +// handleSSEEvent processes SSE events based on their type. +// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. +func (c *SSE) handleSSEEvent(event, data string) { + switch event { + case "endpoint": + endpoint, err := c.baseURL.Parse(data) + if err != nil { + fmt.Printf("Error parsing endpoint URL: %v\n", err) + return + } + if endpoint.Host != c.baseURL.Host { + fmt.Printf("Endpoint origin does not match connection origin\n") + return + } + c.endpoint = endpoint + close(c.endpointChan) + + case "message": + var baseMessage JSONRPCResponse + if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { + fmt.Printf("Error unmarshaling message: %v\n", err) + return + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(data), ¬ification); err != nil { + return + } + c.notifyMu.RLock() + if c.onNotification != nil { + c.onNotification(notification) + } + c.notifyMu.RUnlock() + return + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + ch <- &baseMessage + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } +} + +func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.onNotification = handler +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *SSE) SendRequest( + ctx context.Context, + request JSONRPCRequest, +) (*JSONRPCResponse, error) { + + if !c.started.Load() { + return nil, fmt.Errorf("transport not started yet") + } + if c.closed.Load() { + return nil, fmt.Errorf("transport has been closed") + } + if c.endpoint == nil { + return nil, fmt.Errorf("endpoint not received") + } + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + responseChan := make(chan *JSONRPCResponse, 1) + c.mu.Lock() + c.responses[request.ID] = responseChan + c.mu.Unlock() + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(requestBytes), + ) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + // set custom http headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf( + "request failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, request.ID) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + return response, nil + } +} + +// Close shuts down the SSE client connection and cleans up any pending responses. +// Returns an error if the shutdown process fails. +func (c *SSE) Close() error { + if !c.closed.CompareAndSwap(false, true) { + return nil // Already closed + } + + if c.cancelSSEStream != nil { + // It could stop the sse stream body, to quit the readSSE loop immediately + // Also, it could quit start() immediately if not receiving the endpoint + c.cancelSSEStream() + } + + // Clean up any pending responses + c.mu.Lock() + for _, ch := range c.responses { + close(ch) + } + c.responses = make(map[int64]chan *JSONRPCResponse) + c.mu.Unlock() + + return nil +} + +// SendNotification sends a JSON-RPC notification to the server without expecting a response. +func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + if c.endpoint == nil { + return fmt.Errorf("endpoint not received") + } + + notificationBytes, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(notificationBytes), + ) + if err != nil { + return fmt.Errorf("failed to create notification request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + // Set custom HTTP headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send notification: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf( + "notification failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + return nil +} + +// GetEndpoint returns the current endpoint URL for the SSE connection. +func (c *SSE) GetEndpoint() *url.URL { + return c.endpoint +} + +// GetBaseURL returns the base URL set in the SSE constructor. +func (c *SSE) GetBaseURL() *url.URL { + return c.baseURL +} diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go new file mode 100644 index 000000000..0c4dff6a2 --- /dev/null +++ b/client/transport/sse_test.go @@ -0,0 +1,480 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "fmt" + "net/http" + "net/http/httptest" + + "github.com/mark3labs/mcp-go/mcp" +) + +// startMockSSEEchoServer starts a test HTTP server that implements +// a minimal SSE-based echo server for testing purposes. +// It returns the server URL and a function to close the server. +func startMockSSEEchoServer() (string, func()) { + // Create handler for SSE endpoint + var sseWriter http.ResponseWriter + var flush func() + var mu sync.Mutex + sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Setup SSE headers + defer func() { + mu.Lock() // for passing race test + sseWriter = nil + flush = nil + mu.Unlock() + fmt.Printf("SSEHandler ends: %v\n", r.Context().Err()) + }() + + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + mu.Lock() + sseWriter = w + flush = flusher.Flush + mu.Unlock() + + // Send initial endpoint event with message endpoint URL + mu.Lock() + fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", "/message") + flusher.Flush() + mu.Unlock() + + // Keep connection open + <-r.Context().Done() + }) + + // Create handler for message endpoint + messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle only POST requests + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse incoming JSON-RPC request + var request map[string]interface{} + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&request); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + // Echo back the request as the response result + response := map[string]interface{}{ + "jsonrpc": "2.0", + "id": request["id"], + "result": request, + } + + method := request["method"] + switch method { + case "debug/echo": + response["result"] = request + case "debug/echo_notification": + response["result"] = request + // send notification to client + responseBytes, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "method": "debug/test", + "params": request, + }) + mu.Lock() + fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", responseBytes) + flush() + mu.Unlock() + case "debug/echo_error_string": + data, _ := json.Marshal(request) + response["error"] = map[string]interface{}{ + "code": -1, + "message": string(data), + } + } + + // Set response headers + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + + go func() { + data, _ := json.Marshal(response) + mu.Lock() + defer mu.Unlock() + if sseWriter != nil && flush != nil { + fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", data) + flush() + } + }() + + }) + + // Create a router to handle different endpoints + mux := http.NewServeMux() + mux.Handle("/", sseHandler) + mux.Handle("/message", messageHandler) + + // Start test server + testServer := httptest.NewServer(mux) + + return testServer.URL, testServer.Close +} + +func TestSSE(t *testing.T) { + // Compile mock server + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + // Start the transport + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + err = trans.Start(ctx) + if err != nil { + t.Fatalf("Failed to start transport: %v", err) + } + defer trans.Close() + + t.Run("SendRequest", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]interface{}{ + "string": "hello world", + "array": []interface{}{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo", + Params: params, + } + + // Send the request + response, err := trans.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + // Verify response data matches what was sent + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + if result.ID != 1 { + t.Errorf("Expected ID 1, got %d", result.ID) + } + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "hello world" { + t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + + t.Run("SendRequestWithTimeout", func(t *testing.T) { + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "debug/echo", + } + + // The request should fail because the context is canceled + _, err := trans.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected context canceled error, got nil") + } else if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + }) + + t.Run("SendNotification & NotificationHandler", func(t *testing.T) { + + var wg sync.WaitGroup + notificationChan := make(chan mcp.JSONRPCNotification, 1) + + // Set notification handler + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationChan <- notification + }) + + // Send a notification + // This would trigger a notification from the server + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + notification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "debug/echo_notification", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{"test": "value"}, + }, + }, + } + err := trans.SendNotification(ctx, notification) + if err != nil { + t.Fatalf("SendNotification failed: %v", err) + } + + wg.Add(1) + go func() { + defer wg.Done() + select { + case nt := <-notificationChan: + // We received a notification + responseJson, _ := json.Marshal(nt.Params.AdditionalFields) + requestJson, _ := json.Marshal(notification) + if string(responseJson) != string(requestJson) { + t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + } + + case <-time.After(1 * time.Second): + t.Errorf("Expected notification, got none") + } + }() + + wg.Wait() + }) + + t.Run("MultipleRequests", func(t *testing.T) { + var wg sync.WaitGroup + const numRequests = 5 + + // Send multiple requests concurrently + mu := sync.Mutex{} + responses := make([]*JSONRPCResponse, numRequests) + errors := make([]error, numRequests) + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each request has a unique ID and payload + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: int64(100 + idx), + Method: "debug/echo", + Params: map[string]interface{}{ + "requestIndex": idx, + "timestamp": time.Now().UnixNano(), + }, + } + + resp, err := trans.SendRequest(ctx, request) + mu.Lock() + responses[idx] = resp + errors[idx] = err + mu.Unlock() + }(i) + } + + wg.Wait() + + // Check results + for i := 0; i < numRequests; i++ { + if errors[i] != nil { + t.Errorf("Request %d failed: %v", i, errors[i]) + continue + } + + if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { + t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + continue + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(responses[i].Result, &result); err != nil { + t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) + continue + } + + // Verify data matches what was sent + if result.ID != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + } + + if result.Method != "debug/echo" { + t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) + } + + // Verify the requestIndex parameter + if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { + t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) + } + } + }) + + t.Run("ResponseError", func(t *testing.T) { + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 100, + Method: "debug/echo_error_string", + } + + // The request should fail because the context is canceled + reps, err := trans.SendRequest(ctx, request) + if err != nil { + t.Errorf("SendRequest failed: %v", err) + } + + if reps.Error == nil { + t.Errorf("Expected error, got nil") + } + + var responseError JSONRPCRequest + if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + } + + if responseError.Method != "debug/echo_error_string" { + t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) + } + if responseError.ID != 100 { + t.Errorf("Expected ID 100, got %d", responseError.ID) + } + if responseError.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) + } + }) + +} + +func TestSSEErrors(t *testing.T) { + t.Run("InvalidURL", func(t *testing.T) { + // Create a new SSE transport with an invalid URL + _, err := NewSSE("://invalid-url") + if err == nil { + t.Errorf("Expected error when creating with invalid URL, got nil") + } + }) + + t.Run("NonExistentURL", func(t *testing.T) { + // Create a new SSE transport with a non-existent URL + sse, err := NewSSE("http://localhost:1") + if err != nil { + t.Fatalf("Failed to create SSE transport: %v", err) + } + + // Start should fail + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = sse.Start(ctx) + if err == nil { + t.Errorf("Expected error when starting with non-existent URL, got nil") + sse.Close() + } + }) + + t.Run("RequestBeforeStart", func(t *testing.T) { + url, closeF := startMockSSEEchoServer() + defer closeF() + + // Create a new SSE instance without calling Start method + sse, err := NewSSE(url) + if err != nil { + t.Fatalf("Failed to create SSE transport: %v", err) + } + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 99, + Method: "ping", + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + _, err = sse.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected SendRequest to fail before Start(), but it didn't") + } + }) + + t.Run("RequestAfterClose", func(t *testing.T) { + // Start a mock server + url, closeF := startMockSSEEchoServer() + defer closeF() + + // Create a new SSE transport + sse, err := NewSSE(url) + if err != nil { + t.Fatalf("Failed to create SSE transport: %v", err) + } + + // Start the transport + ctx := context.Background() + if err := sse.Start(ctx); err != nil { + t.Fatalf("Failed to start SSE transport: %v", err) + } + + // Close the transport + sse.Close() + + // Wait a bit to ensure connection has closed + time.Sleep(100 * time.Millisecond) + + // Try to send a request after close + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "ping", + } + + _, err = sse.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected error when sending request after close, got nil") + } + }) + +} diff --git a/client/transport/stdio.go b/client/transport/stdio.go new file mode 100644 index 000000000..85a300a15 --- /dev/null +++ b/client/transport/stdio.go @@ -0,0 +1,234 @@ +package transport + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "sync" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Stdio implements the transport layer of the MCP protocol using stdio communication. +// It launches a subprocess and communicates with it via standard input/output streams +// using JSON-RPC messages. The client handles message routing between requests and +// responses, and supports asynchronous notifications. +type Stdio struct { + command string + args []string + env []string + + cmd *exec.Cmd + stdin io.WriteCloser + stdout *bufio.Reader + stderr io.ReadCloser + responses map[int64]chan *JSONRPCResponse + mu sync.RWMutex + done chan struct{} + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex +} + +// NewStdio creates a new stdio transport to communicate with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +func NewStdio( + command string, + env []string, + args ...string, +) *Stdio { + + client := &Stdio{ + command: command, + args: args, + env: env, + + responses: make(map[int64]chan *JSONRPCResponse), + done: make(chan struct{}), + } + + return client +} + +func (c *Stdio) Start(ctx context.Context) error { + cmd := exec.CommandContext(ctx, c.command, c.args...) + + mergedEnv := os.Environ() + mergedEnv = append(mergedEnv, c.env...) + + cmd.Env = mergedEnv + + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + + c.cmd = cmd + c.stdin = stdin + c.stderr = stderr + c.stdout = bufio.NewReader(stdout) + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start command: %w", err) + } + + // Start reading responses in a goroutine and wait for it to be ready + ready := make(chan struct{}) + go func() { + close(ready) + c.readResponses() + }() + <-ready + + return nil +} + +// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. +// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. +func (c *Stdio) Close() error { + close(c.done) + if err := c.stdin.Close(); err != nil { + return fmt.Errorf("failed to close stdin: %w", err) + } + if err := c.stderr.Close(); err != nil { + return fmt.Errorf("failed to close stderr: %w", err) + } + return c.cmd.Wait() +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *Stdio) SetNotificationHandler( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.onNotification = handler +} + +// readResponses continuously reads and processes responses from the server's stdout. +// It handles both responses to requests and notifications, routing them appropriately. +// Runs until the done channel is closed or an error occurs reading from stdout. +func (c *Stdio) readResponses() { + for { + select { + case <-c.done: + return + default: + line, err := c.stdout.ReadString('\n') + if err != nil { + if err != io.EOF { + fmt.Printf("Error reading response: %v\n", err) + } + return + } + + var baseMessage JSONRPCResponse + if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { + continue + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(line), ¬ification); err != nil { + continue + } + c.notifyMu.RLock() + if c.onNotification != nil { + c.onNotification(notification) + } + c.notifyMu.RUnlock() + continue + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + ch <- &baseMessage + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } + } +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// It creates a unique request ID, sends the request over stdin, and waits for +// the corresponding response or context cancellation. +// Returns the raw JSON response message or an error if the request fails. +func (c *Stdio) SendRequest( + ctx context.Context, + request JSONRPCRequest, +) (*JSONRPCResponse, error) { + if c.stdin == nil { + return nil, fmt.Errorf("stdio client not started") + } + + // Create the complete request structure + responseChan := make(chan *JSONRPCResponse, 1) + c.mu.Lock() + c.responses[request.ID] = responseChan + c.mu.Unlock() + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := c.stdin.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, request.ID) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + return response, nil + } +} + +// SendNotification sends a json RPC Notification to the server. +func (c *Stdio) SendNotification( + ctx context.Context, + notification mcp.JSONRPCNotification, +) error { + notificationBytes, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + notificationBytes = append(notificationBytes, '\n') + + if _, err := c.stdin.Write(notificationBytes); err != nil { + return fmt.Errorf("failed to write notification: %w", err) + } + + return nil +} + +// Stderr returns a reader for the stderr output of the subprocess. +// This can be used to capture error messages or logs from the subprocess. +func (c *Stdio) Stderr() io.Reader { + return c.stderr +} diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go new file mode 100644 index 000000000..445ba07ea --- /dev/null +++ b/client/transport/stdio_test.go @@ -0,0 +1,367 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +func compileTestServer(outputPath string) error { + cmd := exec.Command( + "go", + "build", + "-o", + outputPath, + "../../testdata/mockstdio_server.go", + ) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) + } + return nil +} + +func TestStdio(t *testing.T) { + // Compile mock server + mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + if err := compileTestServer(mockServerPath); err != nil { + t.Fatalf("Failed to compile mock server: %v", err) + } + defer os.Remove(mockServerPath) + + // Create a new Stdio transport + stdio := NewStdio(mockServerPath, nil) + + // Start the transport + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := stdio.Start(ctx) + if err != nil { + t.Fatalf("Failed to start Stdio transport: %v", err) + } + defer stdio.Close() + + t.Run("SendRequest", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5000000000*time.Second) + defer cancel() + + params := map[string]interface{}{ + "string": "hello world", + "array": []interface{}{1, 2, 3}, + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "debug/echo", + Params: params, + } + + // Send the request + response, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + // Verify response data matches what was sent + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + if result.ID != 1 { + t.Errorf("Expected ID 1, got %d", result.ID) + } + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "hello world" { + t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + + t.Run("SendRequestWithTimeout", func(t *testing.T) { + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 3, + Method: "debug/echo", + } + + // The request should fail because the context is canceled + _, err := stdio.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected context canceled error, got nil") + } else if err != context.Canceled { + t.Errorf("Expected context.Canceled error, got: %v", err) + } + }) + + t.Run("SendNotification & NotificationHandler", func(t *testing.T) { + + var wg sync.WaitGroup + notificationChan := make(chan mcp.JSONRPCNotification, 1) + + // Set notification handler + stdio.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationChan <- notification + }) + + // Send a notification + // This would trigger a notification from the server + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + notification := mcp.JSONRPCNotification{ + JSONRPC: "2.0", + Notification: mcp.Notification{ + Method: "debug/echo_notification", + Params: mcp.NotificationParams{ + AdditionalFields: map[string]interface{}{"test": "value"}, + }, + }, + } + err := stdio.SendNotification(ctx, notification) + if err != nil { + t.Fatalf("SendNotification failed: %v", err) + } + + wg.Add(1) + go func() { + defer wg.Done() + select { + case nt := <-notificationChan: + // We received a notification + responseJson, _ := json.Marshal(nt.Params.AdditionalFields) + requestJson, _ := json.Marshal(notification) + if string(responseJson) != string(requestJson) { + t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + } + + case <-time.After(1 * time.Second): + t.Errorf("Expected notification, got none") + } + }() + + wg.Wait() + }) + + t.Run("MultipleRequests", func(t *testing.T) { + var wg sync.WaitGroup + const numRequests = 5 + + // Send multiple requests concurrently + responses := make([]*JSONRPCResponse, numRequests) + errors := make([]error, numRequests) + mu := sync.Mutex{} + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Each request has a unique ID and payload + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: int64(100 + idx), + Method: "debug/echo", + Params: map[string]interface{}{ + "requestIndex": idx, + "timestamp": time.Now().UnixNano(), + }, + } + + resp, err := stdio.SendRequest(ctx, request) + mu.Lock() + responses[idx] = resp + errors[idx] = err + mu.Unlock() + }(i) + } + + wg.Wait() + + // Check results + for i := 0; i < numRequests; i++ { + if errors[i] != nil { + t.Errorf("Request %d failed: %v", i, errors[i]) + continue + } + + if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { + t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + continue + } + + // Parse the result to verify echo + var result struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + } + + if err := json.Unmarshal(responses[i].Result, &result); err != nil { + t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) + continue + } + + // Verify data matches what was sent + if result.ID != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + } + + if result.Method != "debug/echo" { + t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) + } + + // Verify the requestIndex parameter + if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { + t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) + } + } + }) + + t.Run("ResponseError", func(t *testing.T) { + + // Prepare a request + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 100, + Method: "debug/echo_error_string", + } + + // The request should fail because the context is canceled + reps, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Errorf("SendRequest failed: %v", err) + } + + if reps.Error == nil { + t.Errorf("Expected error, got nil") + } + + var responseError JSONRPCRequest + if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { + t.Errorf("Failed to unmarshal result: %v", err) + } + + if responseError.Method != "debug/echo_error_string" { + t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) + } + if responseError.ID != 100 { + t.Errorf("Expected ID 100, got %d", responseError.ID) + } + if responseError.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) + } + }) + +} + +func TestStdioErrors(t *testing.T) { + t.Run("InvalidCommand", func(t *testing.T) { + // Create a new Stdio transport with a non-existent command + stdio := NewStdio("non_existent_command", nil) + + // Start should fail + ctx := context.Background() + err := stdio.Start(ctx) + if err == nil { + t.Errorf("Expected error when starting with invalid command, got nil") + stdio.Close() + } + }) + + t.Run("RequestBeforeStart", func(t *testing.T) { + // 创建一个新的 Stdio 实例但不调用 Start 方法 + mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + if err := compileTestServer(mockServerPath); err != nil { + t.Fatalf("Failed to compile mock server: %v", err) + } + defer os.Remove(mockServerPath) + + uninitiatedStdio := NewStdio(mockServerPath, nil) + + // 准备一个请求 + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 99, + Method: "ping", + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + _, err := uninitiatedStdio.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected SendRequest to panic before Start(), but it didn't") + } else if err.Error() != "stdio client not started" { + t.Errorf("Expected error 'stdio client not started', got: %v", err) + } + }) + + t.Run("RequestAfterClose", func(t *testing.T) { + // Compile mock server + mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + if err := compileTestServer(mockServerPath); err != nil { + t.Fatalf("Failed to compile mock server: %v", err) + } + defer os.Remove(mockServerPath) + + // Create a new Stdio transport + stdio := NewStdio(mockServerPath, nil) + + // Start the transport + ctx := context.Background() + if err := stdio.Start(ctx); err != nil { + t.Fatalf("Failed to start Stdio transport: %v", err) + } + + // Close the transport - ignore errors like "broken pipe" since the process might exit already + stdio.Close() + + // Wait a bit to ensure process has exited + time.Sleep(100 * time.Millisecond) + + // Try to send a request after close + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "ping", + } + + _, err := stdio.SendRequest(ctx, request) + if err == nil { + t.Errorf("Expected error when sending request after close, got nil") + } + }) + +} diff --git a/client/types.go b/client/types.go deleted file mode 100644 index 4402bd024..000000000 --- a/client/types.go +++ /dev/null @@ -1,8 +0,0 @@ -package client - -import "encoding/json" - -type RPCResponse struct { - Error *string - Response *json.RawMessage -} diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index 3100c5a2a..9f13d5547 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -10,14 +10,14 @@ import ( type JSONRPCRequest struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID *int64 `json:"id,omitempty"` Method string `json:"method"` Params json.RawMessage `json:"params"` } type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID *int64 `json:"id,omitempty"` Result interface{} `json:"result,omitempty"` Error *struct { Code int `json:"code"` @@ -138,6 +138,30 @@ func handleRequest(request JSONRPCRequest) JSONRPCResponse { "values": []string{"test completion"}, }, } + + // Debug methods for testing transport. + case "debug/echo": + response.Result = request + case "debug/echo_notification": + response.Result = request + + // send notification to client + responseBytes, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "method": "debug/test", + "params": request, + }) + fmt.Fprintf(os.Stdout, "%s\n", responseBytes) + + case "debug/echo_error_string": + all, _ := json.Marshal(request) + response.Error = &struct { + Code int `json:"code"` + Message string `json:"message"` + }{ + Code: -32601, + Message: string(all), + } default: response.Error = &struct { Code int `json:"code"`