Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2259,22 +2259,23 @@ func TestCloseConnectionsOnLogout(t *testing.T) {
_, err = io.WriteString(stream, "expr 137 + 39\r\n")
require.NoError(t, err)

// make sure server has replied
// make sure the server has replied
out := make([]byte, 100)
_, err = stream.Read(out)
require.NoError(t, err)

_, err = pack.clt.Delete(s.ctx, pack.clt.Endpoint("webapi", "sessions", "web"))
require.NoError(t, err)

// wait until we timeout or detect that connection has been closed
// wait until timeout or detect that the connection has been closed.
after := time.After(5 * time.Second)
errC := make(chan error)
go func() {
for {
_, err := stream.Read(out)
if err != nil {
errC <- err
return
}
}
}()
Expand Down
2 changes: 1 addition & 1 deletion lib/web/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl
return
}

if err := t.stream.Close(); err != nil {
if err := t.stream.SendCloseMessage(); err != nil {
t.log.WithError(err).Error("Unable to send close event to web client.")
return
}
Expand Down
50 changes: 31 additions & 19 deletions lib/web/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"

Expand Down Expand Up @@ -173,43 +174,28 @@ func TestExecuteCommandSummary(t *testing.T) {
ws, _, err := s.makeCommand(t, authPack, conversationID)
require.NoError(t, err)

// The current Assist execution relies on a hack that multiplexes multiple
// streams over a single one. This causes issues when a stream close as the
// single stream receiver will consider it should close. We work around by
// using a non-closing websocket and initiating a proper stream close.
// Then we reopen a new stream on the same websocket and continue reading
// the summary.
nonClosableWebsocket := &noopCloserWS{ws}
stream := NewWStream(ctx, nonClosableWebsocket, utils.NewLoggerForTests(), nil)
// For simplicity, use simple WS to io.Reader adapter
stream := &wsReader{conn: ws}

// Wait for command execution to complete
require.NoError(t, waitForCommandOutput(stream, "teleport"))

// Stop the stream consumption
err = stream.Close()
require.NoError(t, err)

// Start a new stream and process the summary event
var env Envelope
stream = NewWStream(ctx, ws, utils.NewLoggerForTests(), nil)
dec := json.NewDecoder(stream)
err = dec.Decode(&env)
require.NoError(t, err)
require.Equal(t, envelopeTypeSummary, env.GetType())
require.NotEmpty(t, env.GetPayload())

// Close the stream, and the underlying websocket
_ = stream.Close()

// Wait for the command execution history to be saved
var messages *assist.GetAssistantMessagesResponse
// Command execution history is saved in asynchronusly, so we need to wait for it.
// Command execution history is saved in asynchronously, so we need to wait for it.
require.Eventually(t, func() bool {
messages, err = clt.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
ConversationId: conversationID.String(),
Username: testUser,
})
require.NoError(t, err)
assert.NoError(t, err)

return len(messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResultSummary]) == 1
}, 5*time.Second, 100*time.Millisecond)
Expand Down Expand Up @@ -364,3 +350,29 @@ func messagesByType(messages []*assist.AssistantMessage) map[assistlib.MessageTy
}
return byType
}

// wsReader implements io.Reader interface over websocket connection
type wsReader struct {
conn *websocket.Conn
}

// Read reads data from websocket connection.
// The message should be in web.Envelope format and only the payload will be returned.
// It expects that the passed buffer is big enough to fit the whole message.
func (r *wsReader) Read(p []byte) (int, error) {
_, data, err := r.conn.ReadMessage()
if err != nil {
return 0, trace.Wrap(err)
}

var envelope Envelope
if err := proto.Unmarshal(data, &envelope); err != nil {
return 0, trace.Errorf("Unable to parse message payload %v", err)
}

if len(envelope.Payload) > len(p) {
return 0, trace.BadParameter("buffer too small")
}

return copy(p, envelope.Payload), nil
}
101 changes: 51 additions & 50 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,11 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor
return
}

// Send close envelope to web terminal upon exit without an error.
if err := t.stream.SendCloseMessage(); err != nil {
t.log.WithError(err).Error("Unable to send close event to web client.")
}

if err := t.stream.Close(); err != nil {
t.log.WithError(err).Error("Unable to send close event to web client.")
return
Expand Down Expand Up @@ -910,7 +915,6 @@ func NewWStream(ctx context.Context, ws WSConn, log logrus.FieldLogger, handlers
ws: ws,
encoder: unicode.UTF8.NewEncoder(),
decoder: unicode.UTF8.NewDecoder(),
completedC: make(chan struct{}),
rawC: make(chan Envelope, 100),
challengeC: make(chan Envelope, 1),
handlers: handlers,
Expand Down Expand Up @@ -956,7 +960,6 @@ type WSStream struct {
once sync.Once
challengeC chan Envelope
rawC chan Envelope
completedC chan struct{}

// buffer is a buffer used to store the remaining payload data if it did not
// fit into the buffer provided by the callee to Read method
Expand Down Expand Up @@ -993,7 +996,7 @@ func (t *WSStream) writeError(msg string) {

func (t *WSStream) processMessages(ctx context.Context) {
defer func() {
close(t.completedC)
t.close()
}()
t.ws.SetReadLimit(teleport.MaxHTTPRequestSize)

Expand Down Expand Up @@ -1187,25 +1190,23 @@ func (t *WSStream) writeChallenge(challenge *client.MFAAuthenticateChallenge, co
// readChallengeResponse reads and decodes the challenge response from the
// websocket in the correct format.
func (t *WSStream) readChallengeResponse(codec mfaCodec) (*authproto.MFAAuthenticateResponse, error) {
select {
case <-t.completedC:
envelope, ok := <-t.challengeC
if !ok {
return nil, io.EOF
case envelope := <-t.challengeC:
resp, err := codec.decodeResponse([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge)
return resp, trace.Wrap(err)
}
resp, err := codec.decodeResponse([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge)
return resp, trace.Wrap(err)
}

// readChallenge reads and decodes the challenge from the
// websocket in the correct format.
func (t *WSStream) readChallenge(codec mfaCodec) (*authproto.MFAAuthenticateChallenge, error) {
select {
case <-t.completedC:
envelope, ok := <-t.challengeC
if !ok {
return nil, io.EOF
case envelope := <-t.challengeC:
challenge, err := codec.decodeChallenge([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge)
return challenge, trace.Wrap(err)
}
challenge, err := codec.decodeChallenge([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge)
return challenge, trace.Wrap(err)
}

// writeAuditEvent encodes and writes the audit event to the
Expand Down Expand Up @@ -1263,9 +1264,9 @@ func (t *WSStream) Write(data []byte) (n int, err error) {
}

// Read provides data received from [defaults.WebsocketRaw] envelopes. If
// the previous envelope was not consumed in the last read any remaining data
// the previous envelope was not consumed in the last read, any remaining data
// is returned prior to processing the next envelope.
func (t *WSStream) Read(out []byte) (n int, err error) {
func (t *WSStream) Read(out []byte) (int, error) {
if len(t.buffer) > 0 {
n := copy(out, t.buffer)
if n == len(t.buffer) {
Expand All @@ -1276,53 +1277,53 @@ func (t *WSStream) Read(out []byte) (n int, err error) {
return n, nil
}

select {
case <-t.completedC:
envelope, ok := <-t.rawC
if !ok {
return 0, io.EOF
case envelope := <-t.rawC:
data, err := t.decoder.Bytes([]byte(envelope.Payload))
if err != nil {
return 0, trace.Wrap(err)
}
}

n := copy(out, data)
// if payload size is greater than [out], store the remaining
// part in the buffer to be processed on the next Read call
if len(data) > n {
t.buffer = data[n:]
}
return n, nil
data, err := t.decoder.Bytes([]byte(envelope.Payload))
if err != nil {
return 0, trace.Wrap(err)
}

n := copy(out, data)
// if the payload size is greater than [out], store the remaining
// part in the buffer to be processed on the next Read call
if len(data) > n {
t.buffer = data[n:]
}
return n, nil
}

// Close sends a close message on the web socket and closes the web socket.
func (t *WSStream) Close() error {
var closeErr error
// SendCloseMessage sends a close message on the web socket.
func (t *WSStream) SendCloseMessage() error {
envelope := &Envelope{
Version: defaults.WebsocketVersion,
Type: defaults.WebsocketClose,
}
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
return trace.Wrap(err)
}

t.mu.Lock()
defer t.mu.Unlock()
return trace.Wrap(t.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes))
}

func (t *WSStream) close() {
t.once.Do(func() {
defer func() {
<-t.completedC

close(t.rawC)
close(t.challengeC)
}()

// Send close envelope to web terminal upon exit without an error.
envelope := &Envelope{
Version: defaults.WebsocketVersion,
Type: defaults.WebsocketClose,
}
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
closeErr = trace.NewAggregate(err, t.ws.Close())
return
}

t.mu.Lock()
defer t.mu.Unlock()
closeErr = trace.NewAggregate(t.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes), t.ws.Close())
})
}

return trace.Wrap(closeErr)
// Close sends a close message on the web socket and closes the web socket.
func (t *WSStream) Close() error {
return trace.Wrap(t.ws.Close())
}

// deadlineForInterval returns a suitable network read deadline for a given ping interval.
Expand Down