diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 39eb2cc7e0eea..fdfe3e6b890aa 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -2259,7 +2259,7 @@ func TestCloseConnectionsOnLogout(t *testing.T) { _, err = io.WriteString(stream, "expr 137 + 39\r\n") require.NoError(t, err) - // make sure server has replied + // make sure the server has replied out := make([]byte, 100) _, err = stream.Read(out) require.NoError(t, err) @@ -2267,7 +2267,7 @@ func TestCloseConnectionsOnLogout(t *testing.T) { _, err = pack.clt.Delete(s.ctx, pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) - // wait until we timeout or detect that connection has been closed + // wait until timeout or detect that the connection has been closed. after := time.After(5 * time.Second) errC := make(chan error) go func() { @@ -2275,6 +2275,7 @@ func TestCloseConnectionsOnLogout(t *testing.T) { _, err := stream.Read(out) if err != nil { errC <- err + return } } }() diff --git a/lib/web/command.go b/lib/web/command.go index 2bcdd223f8382..cd47e83538bc3 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -674,7 +674,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl return } - if err := t.stream.Close(); err != nil { + if err := t.stream.SendCloseMessage(); err != nil { t.log.WithError(err).Error("Unable to send close event to web client.") return } diff --git a/lib/web/command_test.go b/lib/web/command_test.go index ecff846c09c6e..5ae2d888d25c4 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -38,6 +38,7 @@ import ( "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/timestamppb" @@ -173,43 +174,28 @@ func TestExecuteCommandSummary(t *testing.T) { ws, _, err := s.makeCommand(t, authPack, conversationID) require.NoError(t, err) - // The current Assist execution relies on a hack that multiplexes multiple - // streams over a single one. This causes issues when a stream close as the - // single stream receiver will consider it should close. We work around by - // using a non-closing websocket and initiating a proper stream close. - // Then we reopen a new stream on the same websocket and continue reading - // the summary. - nonClosableWebsocket := &noopCloserWS{ws} - stream := NewWStream(ctx, nonClosableWebsocket, utils.NewLoggerForTests(), nil) + // For simplicity, use simple WS to io.Reader adapter + stream := &wsReader{conn: ws} // Wait for command execution to complete require.NoError(t, waitForCommandOutput(stream, "teleport")) - // Stop the stream consumption - err = stream.Close() - require.NoError(t, err) - - // Start a new stream and process the summary event var env Envelope - stream = NewWStream(ctx, ws, utils.NewLoggerForTests(), nil) dec := json.NewDecoder(stream) err = dec.Decode(&env) require.NoError(t, err) require.Equal(t, envelopeTypeSummary, env.GetType()) require.NotEmpty(t, env.GetPayload()) - // Close the stream, and the underlying websocket - _ = stream.Close() - // Wait for the command execution history to be saved var messages *assist.GetAssistantMessagesResponse - // Command execution history is saved in asynchronusly, so we need to wait for it. + // Command execution history is saved in asynchronously, so we need to wait for it. require.Eventually(t, func() bool { messages, err = clt.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ ConversationId: conversationID.String(), Username: testUser, }) - require.NoError(t, err) + assert.NoError(t, err) return len(messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResultSummary]) == 1 }, 5*time.Second, 100*time.Millisecond) @@ -364,3 +350,29 @@ func messagesByType(messages []*assist.AssistantMessage) map[assistlib.MessageTy } return byType } + +// wsReader implements io.Reader interface over websocket connection +type wsReader struct { + conn *websocket.Conn +} + +// Read reads data from websocket connection. +// The message should be in web.Envelope format and only the payload will be returned. +// It expects that the passed buffer is big enough to fit the whole message. +func (r *wsReader) Read(p []byte) (int, error) { + _, data, err := r.conn.ReadMessage() + if err != nil { + return 0, trace.Wrap(err) + } + + var envelope Envelope + if err := proto.Unmarshal(data, &envelope); err != nil { + return 0, trace.Errorf("Unable to parse message payload %v", err) + } + + if len(envelope.Payload) > len(p) { + return 0, trace.BadParameter("buffer too small") + } + + return copy(p, envelope.Payload), nil +} diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 09d1c3f96b9c1..b2235271b2ad0 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -744,6 +744,11 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor return } + // Send close envelope to web terminal upon exit without an error. + if err := t.stream.SendCloseMessage(); err != nil { + t.log.WithError(err).Error("Unable to send close event to web client.") + } + if err := t.stream.Close(); err != nil { t.log.WithError(err).Error("Unable to send close event to web client.") return @@ -910,7 +915,6 @@ func NewWStream(ctx context.Context, ws WSConn, log logrus.FieldLogger, handlers ws: ws, encoder: unicode.UTF8.NewEncoder(), decoder: unicode.UTF8.NewDecoder(), - completedC: make(chan struct{}), rawC: make(chan Envelope, 100), challengeC: make(chan Envelope, 1), handlers: handlers, @@ -956,7 +960,6 @@ type WSStream struct { once sync.Once challengeC chan Envelope rawC chan Envelope - completedC chan struct{} // buffer is a buffer used to store the remaining payload data if it did not // fit into the buffer provided by the callee to Read method @@ -993,7 +996,7 @@ func (t *WSStream) writeError(msg string) { func (t *WSStream) processMessages(ctx context.Context) { defer func() { - close(t.completedC) + t.close() }() t.ws.SetReadLimit(teleport.MaxHTTPRequestSize) @@ -1187,25 +1190,23 @@ func (t *WSStream) writeChallenge(challenge *client.MFAAuthenticateChallenge, co // readChallengeResponse reads and decodes the challenge response from the // websocket in the correct format. func (t *WSStream) readChallengeResponse(codec mfaCodec) (*authproto.MFAAuthenticateResponse, error) { - select { - case <-t.completedC: + envelope, ok := <-t.challengeC + if !ok { return nil, io.EOF - case envelope := <-t.challengeC: - resp, err := codec.decodeResponse([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge) - return resp, trace.Wrap(err) } + resp, err := codec.decodeResponse([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge) + return resp, trace.Wrap(err) } // readChallenge reads and decodes the challenge from the // websocket in the correct format. func (t *WSStream) readChallenge(codec mfaCodec) (*authproto.MFAAuthenticateChallenge, error) { - select { - case <-t.completedC: + envelope, ok := <-t.challengeC + if !ok { return nil, io.EOF - case envelope := <-t.challengeC: - challenge, err := codec.decodeChallenge([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge) - return challenge, trace.Wrap(err) } + challenge, err := codec.decodeChallenge([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge) + return challenge, trace.Wrap(err) } // writeAuditEvent encodes and writes the audit event to the @@ -1263,9 +1264,9 @@ func (t *WSStream) Write(data []byte) (n int, err error) { } // Read provides data received from [defaults.WebsocketRaw] envelopes. If -// the previous envelope was not consumed in the last read any remaining data +// the previous envelope was not consumed in the last read, any remaining data // is returned prior to processing the next envelope. -func (t *WSStream) Read(out []byte) (n int, err error) { +func (t *WSStream) Read(out []byte) (int, error) { if len(t.buffer) > 0 { n := copy(out, t.buffer) if n == len(t.buffer) { @@ -1276,53 +1277,53 @@ func (t *WSStream) Read(out []byte) (n int, err error) { return n, nil } - select { - case <-t.completedC: + envelope, ok := <-t.rawC + if !ok { return 0, io.EOF - case envelope := <-t.rawC: - data, err := t.decoder.Bytes([]byte(envelope.Payload)) - if err != nil { - return 0, trace.Wrap(err) - } + } - n := copy(out, data) - // if payload size is greater than [out], store the remaining - // part in the buffer to be processed on the next Read call - if len(data) > n { - t.buffer = data[n:] - } - return n, nil + data, err := t.decoder.Bytes([]byte(envelope.Payload)) + if err != nil { + return 0, trace.Wrap(err) } + + n := copy(out, data) + // if the payload size is greater than [out], store the remaining + // part in the buffer to be processed on the next Read call + if len(data) > n { + t.buffer = data[n:] + } + return n, nil } -// Close sends a close message on the web socket and closes the web socket. -func (t *WSStream) Close() error { - var closeErr error +// SendCloseMessage sends a close message on the web socket. +func (t *WSStream) SendCloseMessage() error { + envelope := &Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketClose, + } + envelopeBytes, err := proto.Marshal(envelope) + if err != nil { + return trace.Wrap(err) + } + + t.mu.Lock() + defer t.mu.Unlock() + return trace.Wrap(t.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes)) +} + +func (t *WSStream) close() { t.once.Do(func() { defer func() { - <-t.completedC - close(t.rawC) close(t.challengeC) }() - - // Send close envelope to web terminal upon exit without an error. - envelope := &Envelope{ - Version: defaults.WebsocketVersion, - Type: defaults.WebsocketClose, - } - envelopeBytes, err := proto.Marshal(envelope) - if err != nil { - closeErr = trace.NewAggregate(err, t.ws.Close()) - return - } - - t.mu.Lock() - defer t.mu.Unlock() - closeErr = trace.NewAggregate(t.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes), t.ws.Close()) }) +} - return trace.Wrap(closeErr) +// Close sends a close message on the web socket and closes the web socket. +func (t *WSStream) Close() error { + return trace.Wrap(t.ws.Close()) } // deadlineForInterval returns a suitable network read deadline for a given ping interval.