diff --git a/api/utils/grpc/stream/stream.go b/api/utils/grpc/stream/stream.go index 7fe4694da7954..fecd9d2b7d1f6 100644 --- a/api/utils/grpc/stream/stream.go +++ b/api/utils/grpc/stream/stream.go @@ -47,17 +47,45 @@ type ReadWriter struct { wLock sync.Mutex rLock sync.Mutex rBytes []byte + + options *Options +} + +// Options is NewReadWriter config options. +type Options struct { + // DisableChunking disables automatic splitting of data messages + // that exceed MaxChunkSize during writes. + // This is useful when the receiver does not support chunked reads. + DisableChunking bool +} + +// Option allows setting options as functional arguments to NewReadWriter. +type Option func(s *Options) + +// WithDisabledChunking disables automatic splitting of data messages +// that exceed MaxChunkSize during writes. +// This is useful when the receiver does not support chunked reads. +func WithDisabledChunking() Option { + return func(s *Options) { + s.DisableChunking = true + } } // NewReadWriter creates a new ReadWriter that leverages the provided // source to retrieve data from and write data to. -func NewReadWriter(source Source) (*ReadWriter, error) { +func NewReadWriter(source Source, opts ...Option) (*ReadWriter, error) { if source == nil { return nil, trace.BadParameter("parameter source required") } + options := &Options{} + for _, opt := range opts { + opt(options) + } + return &ReadWriter{ - source: source, + source: source, + options: options, }, nil } @@ -102,7 +130,7 @@ func (c *ReadWriter) Read(b []byte) (n int, err error) { // the grpc stream. To prevent exhausting the stream all // sends on the stream are limited to be at most MaxChunkSize. // If the data exceeds the MaxChunkSize it will be sent in -// batches. +// batches. This behavior can be disabled by using WithDisabledChunking. func (c *ReadWriter) Write(b []byte) (int, error) { c.wLock.Lock() defer c.wLock.Unlock() @@ -110,7 +138,7 @@ func (c *ReadWriter) Write(b []byte) (int, error) { var sent int for len(b) > 0 { chunk := b - if len(chunk) > MaxChunkSize { + if !c.options.DisableChunking && len(chunk) > MaxChunkSize { chunk = chunk[:MaxChunkSize] } diff --git a/api/utils/grpc/stream/stream_test.go b/api/utils/grpc/stream/stream_test.go index 35dfb7070b53c..6658e4f97daf6 100644 --- a/api/utils/grpc/stream/stream_test.go +++ b/api/utils/grpc/stream/stream_test.go @@ -52,7 +52,7 @@ func (m *mockStream) Recv() ([]byte, error) { return b[:n], err } -func newStreamPipe(t *testing.T) (*ReadWriter, net.Conn) { +func newStreamPipe(t *testing.T, opts ...Option) (*ReadWriter, net.Conn) { local, remote := net.Pipe() stream := newMockStream(context.Background(), remote) @@ -63,7 +63,7 @@ func newStreamPipe(t *testing.T) (*ReadWriter, net.Conn) { require.NoError(t, remote.SetReadDeadline(timeout)) require.NoError(t, remote.SetWriteDeadline(timeout)) - streamConn, err := NewReadWriter(stream) + streamConn, err := NewReadWriter(stream, opts...) require.NoError(t, err) return streamConn, local @@ -122,6 +122,30 @@ func TestReadWriter_WriteChunk(t *testing.T) { wg.Wait() } +func TestReadWriter_WriteDisabledChunk(t *testing.T) { + streamConn, local := newStreamPipe(t, WithDisabledChunking()) + wg := &sync.WaitGroup{} + wg.Add(2) + + data := make([]byte, MaxChunkSize+1) + go func() { + defer wg.Done() + n, err := streamConn.Write(data) + assert.NoError(t, err) + assert.Len(t, data, n) + }() + go func() { + defer wg.Done() + b := make([]byte, 2*MaxChunkSize) + n, err := local.Read(b) + assert.NoError(t, err) + assert.Len(t, data, n) + assert.Equal(t, data[:n], b[:n]) + }() + + wg.Wait() +} + func TestReadWriter_Read(t *testing.T) { streamConn, local := newStreamPipe(t) wg := &sync.WaitGroup{} diff --git a/lib/teleterm/services/desktop/desktop.go b/lib/teleterm/services/desktop/desktop.go index 2ae2e1cc065f4..57a192eb76a90 100644 --- a/lib/teleterm/services/desktop/desktop.go +++ b/lib/teleterm/services/desktop/desktop.go @@ -130,9 +130,12 @@ func (s *Session) Start(ctx context.Context, stream grpc.BidiStreamingServer[api return trace.Wrap(err) } - downstreamRW, err := streamutils.NewReadWriter(&clientStream{ - stream: stream, - }) + downstreamRW, err := streamutils.NewReadWriter( + &clientStream{ + stream: stream, + }, + streamutils.WithDisabledChunking(), + ) if err != nil { return trace.Wrap(err) }