Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions client/transport/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/util"
)

// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
Expand All @@ -33,6 +34,7 @@ type SSE struct {
endpointChan chan struct{}
headers map[string]string
headerFunc HTTPHeaderFunc
logger util.Logger

started atomic.Bool
closed atomic.Bool
Expand All @@ -45,6 +47,13 @@ type SSE struct {

type ClientOption func(*SSE)

// WithSSELogger sets a custom logger for the SSE client.
func WithSSELogger(logger util.Logger) ClientOption {
return func(sc *SSE) {
sc.logger = logger
}
}

func WithHeaders(headers map[string]string) ClientOption {
return func(sc *SSE) {
sc.headers = headers
Expand Down Expand Up @@ -83,6 +92,7 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
responses: make(map[string]chan *JSONRPCResponse),
endpointChan: make(chan struct{}),
headers: make(map[string]string),
logger: util.DefaultLogger(),
}

for _, opt := range options {
Expand All @@ -102,7 +112,6 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
// Start initiates the SSE connection to the server and waits for the endpoint information.
// Returns an error if the connection fails or times out waiting for the endpoint.
func (c *SSE) Start(ctx context.Context) error {

if c.started.Load() {
return fmt.Errorf("has already started")
}
Expand All @@ -111,7 +120,6 @@ func (c *SSE) Start(ctx context.Context) error {
c.cancelSSEStream = cancel

req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)

if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
Expand Down Expand Up @@ -205,7 +213,7 @@ func (c *SSE) readSSE(reader io.ReadCloser) {
break
}
if !c.closed.Load() {
fmt.Printf("SSE stream error: %v\n", err)
c.logger.Errorf("SSE stream error: %v", err)
}
return
}
Expand Down Expand Up @@ -241,11 +249,11 @@ func (c *SSE) handleSSEEvent(event, data string) {
case "endpoint":
endpoint, err := c.baseURL.Parse(data)
if err != nil {
fmt.Printf("Error parsing endpoint URL: %v\n", err)
c.logger.Errorf("Error parsing endpoint URL: %v", err)
return
}
if endpoint.Host != c.baseURL.Host {
fmt.Printf("Endpoint origin does not match connection origin\n")
c.logger.Errorf("Endpoint origin does not match connection origin")
return
}
c.endpoint = endpoint
Expand All @@ -254,7 +262,7 @@ func (c *SSE) handleSSEEvent(event, data string) {
case "message":
var baseMessage JSONRPCResponse
if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
fmt.Printf("Error unmarshaling message: %v\n", err)
c.logger.Errorf("Error unmarshaling message: %v", err)
return
}

Expand Down Expand Up @@ -300,7 +308,6 @@ func (c *SSE) SendRequest(
ctx context.Context,
request JSONRPCRequest,
) (*JSONRPCResponse, error) {

if !c.started.Load() {
return nil, fmt.Errorf("transport not started yet")
}
Expand Down
8 changes: 7 additions & 1 deletion client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,18 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
}
}

func WithLogger(logger util.Logger) StreamableHTTPCOption {
// WithHTTPLogger sets a custom logger for the StreamableHTTP transport.
func WithHTTPLogger(logger util.Logger) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
sc.logger = logger
}
}

// Deprecated: Use [WithHTTPLogger] instead.
func WithLogger(logger util.Logger) StreamableHTTPCOption {
return WithHTTPLogger(logger)
}

// WithSession creates a client with a pre-configured session
func WithSession(sessionID string) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
Expand Down