diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 85a300a15..a90605c04 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -33,6 +33,20 @@ type Stdio struct { notifyMu sync.RWMutex } +// NewIO returns a new stdio-based transport using existing input, output, and +// logging streams instead of spawning a subprocess. +// This is useful for testing and simulating client behavior. +func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio { + return &Stdio{ + stdin: output, + stdout: bufio.NewReader(input), + stderr: logging, + + responses: make(map[int64]chan *JSONRPCResponse), + done: make(chan struct{}), + } +} + // NewStdio creates a new stdio transport to communicate with a subprocess. // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. // Returns an error if the subprocess cannot be started or the pipes cannot be created. @@ -55,6 +69,26 @@ func NewStdio( } func (c *Stdio) Start(ctx context.Context) error { + if err := c.spawnCommand(ctx); err != nil { + return err + } + + ready := make(chan struct{}) + go func() { + close(ready) + c.readResponses() + }() + <-ready + + return nil +} + +// spawnCommand spawns a new process running c.command. +func (c *Stdio) spawnCommand(ctx context.Context) error { + if c.command == "" { + return nil + } + cmd := exec.CommandContext(ctx, c.command, c.args...) mergedEnv := os.Environ() @@ -86,14 +120,6 @@ func (c *Stdio) Start(ctx context.Context) error { return fmt.Errorf("failed to start command: %w", err) } - // Start reading responses in a goroutine and wait for it to be ready - ready := make(chan struct{}) - go func() { - close(ready) - c.readResponses() - }() - <-ready - return nil } @@ -107,7 +133,12 @@ func (c *Stdio) Close() error { if err := c.stderr.Close(); err != nil { return fmt.Errorf("failed to close stderr: %w", err) } - return c.cmd.Wait() + + if c.cmd != nil { + return c.cmd.Wait() + } + + return nil } // OnNotification registers a handler function to be called when notifications are received. diff --git a/mcptest/mcptest.go b/mcptest/mcptest.go new file mode 100644 index 000000000..11dcad986 --- /dev/null +++ b/mcptest/mcptest.go @@ -0,0 +1,154 @@ +// Package mcptest implements helper functions for testing MCP servers. +package mcptest + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "sync" + "testing" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// Server encapsulates an MCP server and manages resources like pipes and context. +type Server struct { + name string + tools []server.ServerTool + + ctx context.Context + cancel func() + + serverReader *io.PipeReader + serverWriter *io.PipeWriter + clientReader *io.PipeReader + clientWriter *io.PipeWriter + + logBuffer bytes.Buffer + + transport transport.Interface + client *client.Client + + wg sync.WaitGroup +} + +// NewServer starts a new MCP server with the provided tools and returns the server instance. +func NewServer(t *testing.T, tools ...server.ServerTool) (*Server, error) { + server := NewUnstartedServer(t) + server.AddTools(tools...) + + if err := server.Start(); err != nil { + return nil, err + } + + return server, nil +} + +// NewUnstartedServer creates a new MCP server instance with the given name, but does not start the server. +// Useful for tests where you need to add tools before starting the server. +func NewUnstartedServer(t *testing.T) *Server { + server := &Server{ + name: t.Name(), + } + + // Use t.Context() once we switch to go >= 1.24 + ctx := context.TODO() + + // Set up context with cancellation, used to stop the server + server.ctx, server.cancel = context.WithCancel(ctx) + + // Set up pipes for client-server communication + server.serverReader, server.clientWriter = io.Pipe() + server.clientReader, server.serverWriter = io.Pipe() + + // Return the configured server + return server +} + +// AddTools adds multiple tools to an unstarted server. +func (s *Server) AddTools(tools ...server.ServerTool) { + s.tools = append(s.tools, tools...) +} + +// AddTool adds a tool to an unstarted server. +func (s *Server) AddTool(tool mcp.Tool, handler server.ToolHandlerFunc) { + s.tools = append(s.tools, server.ServerTool{ + Tool: tool, + Handler: handler, + }) +} + +// Start starts the server in a goroutine. Make sure to defer Close() after Start(). +// When using NewServer(), the returned server is already started. +func (s *Server) Start() error { + s.wg.Add(1) + + // Start the MCP server in a goroutine + go func() { + defer s.wg.Done() + + mcpServer := server.NewMCPServer(s.name, "1.0.0") + + mcpServer.AddTools(s.tools...) + + logger := log.New(&s.logBuffer, "", 0) + + stdioServer := server.NewStdioServer(mcpServer) + stdioServer.SetErrorLogger(logger) + + if err := stdioServer.Listen(s.ctx, s.serverReader, s.serverWriter); err != nil { + logger.Println("StdioServer.Listen failed:", err) + } + }() + + s.transport = transport.NewIO(s.clientReader, s.clientWriter, io.NopCloser(&s.logBuffer)) + if err := s.transport.Start(s.ctx); err != nil { + return fmt.Errorf("transport.Start(): %w", err) + } + + s.client = client.NewClient(s.transport) + + var initReq mcp.InitializeRequest + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + if _, err := s.client.Initialize(s.ctx, initReq); err != nil { + return fmt.Errorf("client.Initialize(): %w", err) + } + + return nil +} + +// Close stops the server and cleans up resources like temporary directories. +func (s *Server) Close() { + if s.transport != nil { + s.transport.Close() + s.transport = nil + s.client = nil + } + + if s.cancel != nil { + s.cancel() + s.cancel = nil + } + + // Wait for server goroutine to finish + s.wg.Wait() + + s.serverWriter.Close() + s.serverReader.Close() + s.serverReader, s.serverWriter = nil, nil + + s.clientWriter.Close() + s.clientReader.Close() + s.clientReader, s.clientWriter = nil, nil +} + +// Client returns an MCP client connected to the server. +// The client is already initialized, i.e. you do _not_ need to call Client.Initialize(). +func (s *Server) Client() *client.Client { + return s.client +} diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go new file mode 100644 index 000000000..3fa8af5ba --- /dev/null +++ b/mcptest/mcptest_test.go @@ -0,0 +1,79 @@ +package mcptest_test + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/mcptest" + "github.com/mark3labs/mcp-go/server" +) + +func TestServer(t *testing.T) { + ctx := context.Background() + + srv, err := mcptest.NewServer(t, server.ServerTool{ + Tool: mcp.NewTool("hello", + mcp.WithDescription("Says hello to the provided name, or world."), + mcp.WithString("name", mcp.Description("The name to say hello to.")), + ), + Handler: helloWorldHandler, + }) + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + client := srv.Client() + + var req mcp.CallToolRequest + req.Params.Name = "hello" + req.Params.Arguments = map[string]any{ + "name": "Claude", + } + + result, err := client.CallTool(ctx, req) + if err != nil { + t.Fatal("CallTool:", err) + } + + got, err := resultToString(result) + if err != nil { + t.Fatal(err) + } + + want := "Hello, Claude!" + if got != want { + t.Errorf("Got %q, want %q", got, want) + } +} + +func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract name from request arguments + name, ok := request.Params.Arguments["name"].(string) + if !ok { + name = "World" + } + + return mcp.NewToolResultText(fmt.Sprintf("Hello, %s!", name)), nil +} + +func resultToString(result *mcp.CallToolResult) (string, error) { + var b strings.Builder + + for _, content := range result.Content { + text, ok := content.(mcp.TextContent) + if !ok { + return "", fmt.Errorf("unsupported content type: %T", content) + } + b.WriteString(text.Text) + } + + if result.IsError { + return "", fmt.Errorf("%s", b.String()) + } + + return b.String(), nil +}