Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions api/utils/grpc/stream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,45 @@ type ReadWriter struct {
wLock sync.Mutex
rLock sync.Mutex
rBytes []byte

options *Options
}

// Options is NewReadWriter config options.
type Options struct {
// DisableChunking disables automatic splitting of data messages
// that exceed MaxChunkSize during writes.
// This is useful when the receiver does not support chunked reads.
DisableChunking bool
}

// Option allows setting options as functional arguments to NewReadWriter.
type Option func(s *Options)

// WithDisabledChunking disables automatic splitting of data messages
// that exceed MaxChunkSize during writes.
// This is useful when the receiver does not support chunked reads.
func WithDisabledChunking() Option {
return func(s *Options) {
s.DisableChunking = true
}
}

// NewReadWriter creates a new ReadWriter that leverages the provided
// source to retrieve data from and write data to.
func NewReadWriter(source Source) (*ReadWriter, error) {
func NewReadWriter(source Source, opts ...Option) (*ReadWriter, error) {
if source == nil {
return nil, trace.BadParameter("parameter source required")
}

options := &Options{}
for _, opt := range opts {
opt(options)
}

return &ReadWriter{
source: source,
source: source,
options: options,
}, nil
}

Expand Down Expand Up @@ -102,15 +130,15 @@ func (c *ReadWriter) Read(b []byte) (n int, err error) {
// the grpc stream. To prevent exhausting the stream all
// sends on the stream are limited to be at most MaxChunkSize.
// If the data exceeds the MaxChunkSize it will be sent in
// batches.
// batches. This behavior can be disabled by using WithDisabledChunking.
func (c *ReadWriter) Write(b []byte) (int, error) {
c.wLock.Lock()
defer c.wLock.Unlock()

var sent int
for len(b) > 0 {
chunk := b
if len(chunk) > MaxChunkSize {
if !c.options.DisableChunking && len(chunk) > MaxChunkSize {
Comment thread
gzdunek marked this conversation as resolved.
chunk = chunk[:MaxChunkSize]
}

Expand Down
28 changes: 26 additions & 2 deletions api/utils/grpc/stream/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (m *mockStream) Recv() ([]byte, error) {
return b[:n], err
}

func newStreamPipe(t *testing.T) (*ReadWriter, net.Conn) {
func newStreamPipe(t *testing.T, opts ...Option) (*ReadWriter, net.Conn) {
local, remote := net.Pipe()
stream := newMockStream(context.Background(), remote)

Expand All @@ -63,7 +63,7 @@ func newStreamPipe(t *testing.T) (*ReadWriter, net.Conn) {
require.NoError(t, remote.SetReadDeadline(timeout))
require.NoError(t, remote.SetWriteDeadline(timeout))

streamConn, err := NewReadWriter(stream)
streamConn, err := NewReadWriter(stream, opts...)
require.NoError(t, err)

return streamConn, local
Expand Down Expand Up @@ -122,6 +122,30 @@ func TestReadWriter_WriteChunk(t *testing.T) {
wg.Wait()
}

func TestReadWriter_WriteDisabledChunk(t *testing.T) {
streamConn, local := newStreamPipe(t, WithDisabledChunking())
wg := &sync.WaitGroup{}
wg.Add(2)

data := make([]byte, MaxChunkSize+1)
go func() {
defer wg.Done()
n, err := streamConn.Write(data)
assert.NoError(t, err)
assert.Len(t, data, n)
}()
go func() {
defer wg.Done()
b := make([]byte, 2*MaxChunkSize)
n, err := local.Read(b)
assert.NoError(t, err)
assert.Len(t, data, n)
assert.Equal(t, data[:n], b[:n])
Comment thread
zmb3 marked this conversation as resolved.
}()

wg.Wait()
}

func TestReadWriter_Read(t *testing.T) {
streamConn, local := newStreamPipe(t)
wg := &sync.WaitGroup{}
Expand Down
9 changes: 6 additions & 3 deletions lib/teleterm/services/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@ func (s *Session) Start(ctx context.Context, stream grpc.BidiStreamingServer[api
return trace.Wrap(err)
}

downstreamRW, err := streamutils.NewReadWriter(&clientStream{
stream: stream,
})
downstreamRW, err := streamutils.NewReadWriter(
&clientStream{
stream: stream,
},
streamutils.WithDisabledChunking(),
)
if err != nil {
return trace.Wrap(err)
}
Expand Down
Loading