diff --git a/lib/web/command.go b/lib/web/command.go index ad2814ccfc644..fc8a3fc7b9087 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -86,6 +86,12 @@ type commandExecResult struct { SessionID string `json:"session_id"` } +// sessionEndEvent is an event that is sent when a session ends. +type sessionEndEvent struct { + // NodeID is the ID of the server where the session was created. + NodeID string `json:"node_id"` +} + // Check checks if the request is valid. func (c *CommandRequest) Check() error { if c.Command == "" { @@ -336,7 +342,7 @@ func (h *Handler) computeAndSendSummary( return trace.Wrap(err) } - // Add the summary message to the backend so it is persisted on chat + // Add the summary message to the backend, so it is persisted on chat // reload. messagePayload, err := json.Marshal(&assistlib.CommandExecSummary{ ExecutionID: req.executionID, @@ -668,7 +674,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl return } - if err := t.stream.SendCloseMessage(); err != nil { + if err := t.stream.SendCloseMessage(sessionEndEvent{NodeID: t.sessionData.ServerID}); err != nil { t.log.WithError(err).Error("Unable to send close event to web client.") return } diff --git a/lib/web/command_test.go b/lib/web/command_test.go index 7a73d715099ba..37a242f550b4b 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -19,7 +19,6 @@ package web import ( "context" "crypto/tls" - "encoding/base64" "encoding/json" "fmt" "io" @@ -180,12 +179,20 @@ func TestExecuteCommandSummary(t *testing.T) { // Wait for command execution to complete require.NoError(t, waitForCommandOutput(stream, "teleport")) - var env Envelope dec := json.NewDecoder(stream) + + // Consume the close message + var sessionMetadata sessionEndEvent + err = dec.Decode(&sessionMetadata) + require.NoError(t, err) + require.Equal(t, "node", sessionMetadata.NodeID) + + // Consume the summary message + var env outEnvelope err = dec.Decode(&env) require.NoError(t, err) - require.Equal(t, envelopeTypeSummary, env.GetType()) - require.NotEmpty(t, env.GetPayload()) + require.Equal(t, envelopeTypeSummary, env.Type) + require.NotEmpty(t, env.Payload) // Wait for the command execution history to be saved var messages *assist.GetAssistantMessagesResponse @@ -292,18 +299,13 @@ func waitForCommandOutput(stream io.Reader, substr string) error { default: } - var env Envelope + var env outEnvelope dec := json.NewDecoder(stream) if err := dec.Decode(&env); err != nil { return trace.Wrap(err, "decoding envelope JSON from stream") } - d, err := base64.StdEncoding.DecodeString(env.Payload) - if err != nil { - return trace.Wrap(err, "decoding b64 payload") - } - - data := removeSpace(string(d)) + data := removeSpace(string(env.Payload)) if strings.Contains(data, substr) { return nil } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index b2235271b2ad0..befa5fa4b7fb9 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -745,7 +745,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor } // Send close envelope to web terminal upon exit without an error. - if err := t.stream.SendCloseMessage(); err != nil { + if err := t.stream.SendCloseMessage(sessionEndEvent{NodeID: t.sessionData.ServerID}); err != nil { t.log.WithError(err).Error("Unable to send close event to web client.") } @@ -1297,10 +1297,16 @@ func (t *WSStream) Read(out []byte) (int, error) { } // SendCloseMessage sends a close message on the web socket. -func (t *WSStream) SendCloseMessage() error { +func (t *WSStream) SendCloseMessage(event sessionEndEvent) error { + sessionMetadataPayload, err := json.Marshal(&event) + if err != nil { + return trace.Wrap(err) + } + envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketClose, + Payload: string(sessionMetadataPayload), } envelopeBytes, err := proto.Marshal(envelope) if err != nil { diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index 0e4f889641f71..784464cdf3549 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -33,8 +33,10 @@ import { getAccessToken, getHostName } from 'teleport/services/api'; import { ExecutionEnvelopeType, + ExecutionTeleportErrorType, RawPayload, ServerMessageType, + SessionEndData, } from 'teleport/Assist/types'; import { MessageTypeEnum, Protobuf } from 'teleport/lib/term/protobuf'; @@ -429,39 +431,51 @@ export function AssistContextProvider(props: PropsWithChildren) { ); const proto = new Protobuf(); - executeCommandWebSocket.current = new WebSocket(url); executeCommandWebSocket.current.binaryType = 'arraybuffer'; - let sessionsEnded = 0; - executeCommandWebSocket.current.onmessage = event => { const uintArray = new Uint8Array(event.data); const msg = proto.decode(uintArray); switch (msg.type) { - case MessageTypeEnum.RAW: + case MessageTypeEnum.RAW: { const data = JSON.parse(msg.payload) as RawPayload; const payload = atob(data.payload); - if (data.type === ExecutionEnvelopeType) { - dispatch({ - type: AssistStateActionType.AddCommandResultSummary, - conversationId: state.conversations.selectedId, - summary: payload, - executionId: execParams.execution_id, - command: execParams.command, - }); - } else { - dispatch({ - type: AssistStateActionType.UpdateCommandResult, - conversationId: state.conversations.selectedId, - commandResultId: nodeIdToResultId.get(data.node_id), - output: payload, - }); + switch (data.type) { + case ExecutionTeleportErrorType: + dispatch({ + type: AssistStateActionType.FinishCommandResult, + conversationId: state.conversations.selectedId, + commandResultId: nodeIdToResultId.get(data.node_id), + }); + + nodeIdToResultId.delete(data.node_id); + break; + + case ExecutionEnvelopeType: + dispatch({ + type: AssistStateActionType.AddCommandResultSummary, + conversationId: state.conversations.selectedId, + summary: payload, + executionId: execParams.execution_id, + command: execParams.command, + }); + break; + + default: + dispatch({ + type: AssistStateActionType.UpdateCommandResult, + conversationId: state.conversations.selectedId, + commandResultId: nodeIdToResultId.get(data.node_id), + output: payload, + }); } + break; + } case MessageTypeEnum.WEBAUTHN_CHALLENGE: const challenge = JSON.parse(msg.payload); @@ -481,30 +495,19 @@ export function AssistContextProvider(props: PropsWithChildren) { break; - case MessageTypeEnum.SESSION_END: - // we don't know the nodeId of the session that ended, so we have to - // count the finished sessions and then mark them all as done once - // they've all finished - sessionsEnded += 1; - - if (sessionsEnded === nodeIdToResultId.size) { - const message = proto.encodeCloseMessage(); - const bytearray = new Uint8Array(message); + case MessageTypeEnum.SESSION_END: { + const data = JSON.parse(msg.payload) as SessionEndData; - for (const nodeId of nodeIdToResultId.keys()) { - dispatch({ - type: AssistStateActionType.FinishCommandResult, - conversationId: state.conversations.selectedId, - commandResultId: nodeIdToResultId.get(nodeId), - }); - - executeCommandWebSocket.current.send(bytearray.buffer); - } + dispatch({ + type: AssistStateActionType.FinishCommandResult, + conversationId: state.conversations.selectedId, + commandResultId: nodeIdToResultId.get(data.node_id), + }); - nodeIdToResultId.clear(); - } + nodeIdToResultId.delete(data.node_id); break; + } } }; diff --git a/web/packages/teleport/src/Assist/types.ts b/web/packages/teleport/src/Assist/types.ts index 5e205d3a0c5e7..ade6da6cd08b2 100644 --- a/web/packages/teleport/src/Assist/types.ts +++ b/web/packages/teleport/src/Assist/types.ts @@ -28,8 +28,14 @@ export enum ServerMessageType { AssistThought = 'CHAT_MESSAGE_PROGRESS_UPDATE', } +// ExecutionEnvelopeType is the type of message that is returned when +// the command summary is returned. export const ExecutionEnvelopeType = 'summary'; +// ExecutionTeleportErrorType is the type of error that is returned when +// Teleport returns an error (failed to execute command, failed to connect, etc.) +export const ExecutionTeleportErrorType = 'teleport-error'; + export interface Conversation { id: string; title?: string; @@ -192,6 +198,10 @@ export interface SessionData { session: { server_id: string }; } +export interface SessionEndData { + node_id: string; +} + export interface ExecuteRemoteCommandPayload { command: string; login?: string;