diff --git a/lib/utils/mcputils/event.go b/lib/utils/mcputils/event.go new file mode 100644 index 0000000000000..9d9301c8367a5 --- /dev/null +++ b/lib/utils/mcputils/event.go @@ -0,0 +1,146 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// All content in this file is copied from the official SDK without +// modifications: +// https://github.com/modelcontextprotocol/go-sdk/blob/b4f957ff3c279051f9bcc88aa08e897add012a95/mcp/event.go + +package mcputils + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "iter" + "net/http" + "strings" +) + +// An Event is a server-sent event. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. +type Event struct { + Name string // the "event" field + ID string // the "id" field + Data []byte // the "data" field + Retry string // the "retry" field +} + +// Empty reports whether the Event is empty. +func (e Event) Empty() bool { + return e.Name == "" && e.ID == "" && len(e.Data) == 0 && e.Retry == "" +} + +// writeEvent writes the event to w, and flushes. +func writeEvent(w io.Writer, evt Event) (int, error) { + var b bytes.Buffer + if evt.Name != "" { + fmt.Fprintf(&b, "event: %s\n", evt.Name) + } + if evt.ID != "" { + fmt.Fprintf(&b, "id: %s\n", evt.ID) + } + if evt.Retry != "" { + fmt.Fprintf(&b, "retry: %s\n", evt.Retry) + } + fmt.Fprintf(&b, "data: %s\n\n", string(evt.Data)) + n, err := w.Write(b.Bytes()) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return n, err +} + +// scanEvents iterates SSE events in the given scanner. The iterated error is +// terminal: if encountered, the stream is corrupt or broken and should no +// longer be used. +// +// TODO(rfindley): consider a different API here that makes failure modes more +// apparent. +func scanEvents(r io.Reader) iter.Seq2[Event, error] { + scanner := bufio.NewScanner(r) + const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size + scanner.Buffer(nil, maxTokenSize) + + // TODO: investigate proper behavior when events are out of order, or have + // non-standard names. + var ( + eventKey = []byte("event") + idKey = []byte("id") + dataKey = []byte("data") + retryKey = []byte("retry") + ) + + return func(yield func(Event, error) bool) { + // iterate event from the wire. + // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples + // + // - `key: value` line records. + // - Consecutive `data: ...` fields are joined with newlines. + // - Unrecognized fields are ignored. Since we only care about 'event', 'id', and + // 'data', these are the only three we consider. + // - Lines starting with ":" are ignored. + // - Records are terminated with two consecutive newlines. + var ( + evt Event + dataBuf *bytes.Buffer // if non-nil, preceding field was also data + ) + flushData := func() { + if dataBuf != nil { + evt.Data = dataBuf.Bytes() + dataBuf = nil + } + } + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + flushData() + // \n\n is the record delimiter + if !evt.Empty() && !yield(evt, nil) { + return + } + evt = Event{} + continue + } + before, after, found := bytes.Cut(line, []byte{':'}) + if !found { + yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) + return + } + if !bytes.Equal(before, dataKey) { + flushData() + } + switch { + case bytes.Equal(before, eventKey): + evt.Name = strings.TrimSpace(string(after)) + case bytes.Equal(before, idKey): + evt.ID = strings.TrimSpace(string(after)) + case bytes.Equal(before, retryKey): + evt.Retry = strings.TrimSpace(string(after)) + case bytes.Equal(before, dataKey): + data := bytes.TrimSpace(after) + if dataBuf != nil { + dataBuf.WriteByte('\n') + dataBuf.Write(data) + } else { + dataBuf = new(bytes.Buffer) + dataBuf.Write(data) + } + } + } + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) + } + if !yield(Event{}, err) { + return + } + } + flushData() + if !evt.Empty() { + yield(evt, nil) + } + } +} diff --git a/lib/utils/mcputils/event_test.go b/lib/utils/mcputils/event_test.go new file mode 100644 index 0000000000000..2ecdc50781b9b --- /dev/null +++ b/lib/utils/mcputils/event_test.go @@ -0,0 +1,103 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// All content in this file is copied from the official SDK without +// modifications: +// https://github.com/modelcontextprotocol/go-sdk/blob/b4f957ff3c279051f9bcc88aa08e897add012a95/mcp/event_test.go + +package mcputils + +import ( + "strings" + "testing" +) + +func TestScanEvents(t *testing.T) { + tests := []struct { + name string + input string + want []Event + wantErr string + }{ + { + name: "simple event", + input: "event: message\nid: 1\ndata: hello\n\n", + want: []Event{ + {Name: "message", ID: "1", Data: []byte("hello")}, + }, + }, + { + name: "multiple data lines", + input: "data: line 1\ndata: line 2\n\n", + want: []Event{ + {Data: []byte("line 1\nline 2")}, + }, + }, + { + name: "multiple events", + input: "data: first\n\nevent: second\ndata: second\n\n", + want: []Event{ + {Data: []byte("first")}, + {Name: "second", Data: []byte("second")}, + }, + }, + { + name: "no trailing newline", + input: "data: hello", + want: []Event{ + {Data: []byte("hello")}, + }, + }, + { + name: "malformed line", + input: "invalid line\n\n", + wantErr: "malformed line", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := strings.NewReader(tt.input) + var got []Event + var err error + for e, err2 := range scanEvents(r) { + if err2 != nil { + err = err2 + break + } + got = append(got, e) + } + + if tt.wantErr != "" { + if err == nil { + t.Fatalf("scanEvents() got nil error, want error containing %q", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("scanEvents() error = %q, want containing %q", err, tt.wantErr) + } + return + } + + if err != nil { + t.Fatalf("scanEvents() returned unexpected error: %v", err) + } + + if len(got) != len(tt.want) { + t.Fatalf("scanEvents() got %d events, want %d", len(got), len(tt.want)) + } + + for i := range got { + if g, w := got[i].Name, tt.want[i].Name; g != w { + t.Errorf("event %d: name = %q, want %q", i, g, w) + } + if g, w := got[i].ID, tt.want[i].ID; g != w { + t.Errorf("event %d: id = %q, want %q", i, g, w) + } + if g, w := string(got[i].Data), string(tt.want[i].Data); g != w { + t.Errorf("event %d: data = %q, want %q", i, g, w) + } + } + }) + } +} diff --git a/lib/utils/mcputils/http.go b/lib/utils/mcputils/http.go index 638d3c0ca2d22..ec257c50ef67b 100644 --- a/lib/utils/mcputils/http.go +++ b/lib/utils/mcputils/http.go @@ -22,6 +22,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "log/slog" "mime" @@ -120,7 +121,9 @@ func (r *httpSSEResponseReplacer) Read(p []byte) (int, error) { msg, err := r.ReadMessage(r.ctx) if err != nil { - if utils.IsOKNetworkError(err) { + // Note that the underlying connection may be canceled by connection + // monitoring. + if utils.IsOKNetworkError(err) || errors.Is(err, context.Canceled) { return 0, io.EOF } return 0, trace.Wrap(err) @@ -147,11 +150,15 @@ func (r *httpSSEResponseReplacer) Read(p []byte) (int, error) { } // Convert to SSE. - e := event{ - name: sseEventMessage, - data: respToSendAsBody, + e := Event{ + Name: sseEventMessage, + Data: respToSendAsBody, } - r.buf = e.marshal() + var buf bytes.Buffer + if _, err := writeEvent(&buf, e); err != nil { + return 0, trace.Wrap(err) + } + r.buf = buf.Bytes() return r.Read(p) } diff --git a/lib/utils/mcputils/sse.go b/lib/utils/mcputils/sse.go index 1b9720aa8980b..7b1033da5d8a5 100644 --- a/lib/utils/mcputils/sse.go +++ b/lib/utils/mcputils/sse.go @@ -19,21 +19,20 @@ package mcputils import ( - "bufio" "bytes" "context" "encoding/json" - "fmt" "io" + "iter" "net/http" "net/url" - "slices" - "strings" + "sync" "github.com/gravitational/trace" "github.com/mark3labs/mcp-go/mcp" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/utils" ) // ConnectSSEServer establishes an SSE stream with the MCP server and finds the @@ -131,15 +130,26 @@ func (w *SSERequestWriter) WriteMessage(ctx context.Context, msg mcp.JSONRPCMess // SSE stream with the MCP server. type SSEResponseReader struct { io.Closer - scanner *bufio.Scanner + nextEvent func() (Event, error, bool) } // NewSSEResponseReader creates a new SSEResponseReader. Input reader is usually the // http body used for SSE stream. func NewSSEResponseReader(reader io.ReadCloser) *SSEResponseReader { + var mu sync.Mutex + nextEvent, stopFunc := iter.Pull2(scanEvents(reader)) return &SSEResponseReader{ - Closer: reader, - scanner: bufio.NewScanner(reader), + Closer: utils.CloseFunc(func() error { + mu.Lock() + stopFunc() + mu.Unlock() + return reader.Close() + }), + nextEvent: func() (Event, error, bool) { + mu.Lock() + defer mu.Unlock() + return nextEvent() + }, } } @@ -147,14 +157,16 @@ func NewSSEResponseReader(reader io.ReadCloser) *SSEResponseReader { // This should be the first event after connecting to SSE server, and any error // is critical. func (r *SSEResponseReader) ReadEndpoint(ctx context.Context, baseURL *url.URL) (*url.URL, error) { - evt, err := r.nextEvent() - if err != nil { + evt, err, ok := r.nextEvent() + if !ok { + return nil, trace.Wrap(io.EOF, "reading SSE server message") + } else if err != nil { return nil, trace.Wrap(err, "reading SSE server message") } - if evt.name != sseEventEndpoint { - return nil, trace.BadParameter("expecting endpoint event, got %s", evt.name) + if evt.Name != sseEventEndpoint { + return nil, trace.BadParameter("expecting endpoint event, got %s", evt.Name) } - endpointURI, err := baseURL.Parse(string(evt.data)) + endpointURI, err := baseURL.Parse(string(evt.Data)) if err != nil { return nil, trace.Wrap(err, "parsing endpoint data") } @@ -163,14 +175,16 @@ func (r *SSEResponseReader) ReadEndpoint(ctx context.Context, baseURL *url.URL) // ReadMessage reads the next SSE message event from SSE stream. func (r *SSEResponseReader) ReadMessage(ctx context.Context) (string, error) { - evt, err := r.nextEvent() - if err != nil { - return "", trace.Wrap(err) + evt, err, ok := r.nextEvent() + if !ok { + return "", trace.Wrap(io.EOF, "reading SSE server message") + } else if err != nil { + return "", trace.Wrap(err, "reading SSE server message") } - if evt.name != sseEventMessage { - return "", newReaderParseError(trace.BadParameter("unexpected event type %s", evt.name)) + if evt.Name != sseEventMessage { + return "", newReaderParseError(trace.BadParameter("unexpected event type %s", evt.Name)) } - return string(evt.data), nil + return string(evt.Data), nil } // Type returns "SSE". @@ -182,57 +196,3 @@ const ( sseEventEndpoint string = "endpoint" sseEventMessage string = "message" ) - -// event is an event is a server-sent event. -type event struct { - name string - data []byte -} - -func (e event) marshal() []byte { - return fmt.Appendf(nil, "event: %s\ndata: %s\n\n", e.name, e.data) -} - -// nextEvent reads one sse event from the wire. -// -// Logic is copied from golang internal mcp lib which might get released -// officially someday: -// https://cs.opensource.google/go/x/tools/+/refs/tags/v0.34.0:internal/mcp/sse.go. -// -// Original comment from above go source: -// https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples -// - `key: value` line records. -// - Consecutive `data: ...` fields are joined with newlines. -// - Unrecognized fields are ignored. Since we only care about 'event' and -// 'data', these are the only two we consider. -// - Lines starting with ":" are ignored. -// - Records are terminated with two consecutive newlines. -func (r *SSEResponseReader) nextEvent() (event, error) { - var ( - evt event - lastWasData bool // if set, preceding data field was also data - ) - for r.scanner.Scan() { - line := r.scanner.Bytes() - if len(line) == 0 && (evt.name != "" || len(evt.data) > 0) { - return evt, nil - } - before, after, found := bytes.Cut(line, []byte{':'}) - if !found { - return evt, fmt.Errorf("malformed line in SSE stream: %q", string(line)) - } - switch { - case bytes.Equal(before, []byte("event")): - evt.name = strings.TrimSpace(string(after)) - case bytes.Equal(before, []byte("data")): - data := bytes.TrimSpace(after) - if lastWasData { - evt.data = slices.Concat(evt.data, []byte{'\n'}, data) - } else { - evt.data = data - } - lastWasData = true - } - } - return evt, io.EOF -} diff --git a/lib/utils/mcputils/sse_test.go b/lib/utils/mcputils/sse_test.go index 26a846a65b618..ee07811694585 100644 --- a/lib/utils/mcputils/sse_test.go +++ b/lib/utils/mcputils/sse_test.go @@ -66,11 +66,3 @@ func TestConnectSSEServer(t *testing.T) { require.NoError(t, err) require.Equal(t, "test-server", initResult.ServerInfo.Name) } - -func TestEventMarshal(t *testing.T) { - e := event{ - name: sseEventMessage, - data: []byte("hello"), - } - require.Equal(t, "event: message\ndata: hello\n\n", string(e.marshal())) -}