Skip to content

Commit c04f42a

Browse files
committed
[chore][client] Add ability to override the http.Client
1 parent 37ac814 commit c04f42a

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

client/sse.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@ package client
22

33
import (
44
"fmt"
5-
"github.com/mark3labs/mcp-go/client/transport"
5+
"net/http"
66
"net/url"
7+
8+
"github.com/mark3labs/mcp-go/client/transport"
79
)
810

911
func WithHeaders(headers map[string]string) transport.ClientOption {
1012
return transport.WithHeaders(headers)
1113
}
1214

15+
func WithHTTPClient(httpClient *http.Client) transport.ClientOption {
16+
return transport.WithHTTPClient(httpClient)
17+
}
18+
1319
// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
1420
// Returns an error if the URL is invalid.
1521
func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) {

client/transport/sse.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ func WithHeaders(headers map[string]string) ClientOption {
4545
}
4646
}
4747

48+
func WithHTTPClient(httpClient *http.Client) ClientOption {
49+
return func(sc *SSE) {
50+
sc.httpClient = httpClient
51+
}
52+
}
53+
4854
// NewSSE creates a new SSE-based MCP client with the given base URL.
4955
// Returns an error if the URL is invalid.
5056
func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {

client/transport/sse_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,31 @@ func TestSSEErrors(t *testing.T) {
415415
}
416416
})
417417

418+
t.Run("WithHTTPClient", func(t *testing.T) {
419+
// Create a custom client with a very short timeout
420+
customClient := &http.Client{Timeout: 1 * time.Nanosecond}
421+
422+
url, closeF := startMockSSEEchoServer()
423+
defer closeF()
424+
// Initialize SSE transport with the custom HTTP client
425+
trans, err := NewSSE(url, WithHTTPClient(customClient))
426+
if err != nil {
427+
t.Fatalf("Failed to create SSE with custom client: %v", err)
428+
}
429+
430+
// Starting should immediately error due to timeout
431+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
432+
defer cancel()
433+
err = trans.Start(ctx)
434+
if err == nil {
435+
t.Error("Expected Start to fail with custom timeout, got nil")
436+
}
437+
if !errors.Is(err, context.DeadlineExceeded) {
438+
t.Errorf("Expected error 'context deadline exceeded', got '%s'", err.Error())
439+
}
440+
trans.Close()
441+
})
442+
418443
t.Run("RequestBeforeStart", func(t *testing.T) {
419444
url, closeF := startMockSSEEchoServer()
420445
defer closeF()

0 commit comments

Comments
 (0)