Skip to content
10 changes: 8 additions & 2 deletions lib/web/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ type commandExecResult struct {
SessionID string `json:"session_id"`
}

// sessionEndEvent is an event that is sent when a session ends.
type sessionEndEvent struct {
// NodeID is the ID of the server where the session was created.
NodeID string `json:"node_id"`
}

// Check checks if the request is valid.
func (c *CommandRequest) Check() error {
if c.Command == "" {
Expand Down Expand Up @@ -336,7 +342,7 @@ func (h *Handler) computeAndSendSummary(
return trace.Wrap(err)
}

// Add the summary message to the backend so it is persisted on chat
// Add the summary message to the backend, so it is persisted on chat
// reload.
messagePayload, err := json.Marshal(&assistlib.CommandExecSummary{
ExecutionID: req.executionID,
Expand Down Expand Up @@ -668,7 +674,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl
return
}

if err := t.stream.SendCloseMessage(); err != nil {
if err := t.stream.SendCloseMessage(sessionEndEvent{NodeID: t.sessionData.ServerID}); err != nil {
t.log.WithError(err).Error("Unable to send close event to web client.")
return
}
Expand Down
24 changes: 13 additions & 11 deletions lib/web/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package web
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -180,12 +179,20 @@ func TestExecuteCommandSummary(t *testing.T) {
// Wait for command execution to complete
require.NoError(t, waitForCommandOutput(stream, "teleport"))

var env Envelope
dec := json.NewDecoder(stream)

// Consume the close message
var sessionMetadata sessionEndEvent
err = dec.Decode(&sessionMetadata)
require.NoError(t, err)
require.Equal(t, "node", sessionMetadata.NodeID)

// Consume the summary message
var env outEnvelope
err = dec.Decode(&env)
require.NoError(t, err)
require.Equal(t, envelopeTypeSummary, env.GetType())
require.NotEmpty(t, env.GetPayload())
require.Equal(t, envelopeTypeSummary, env.Type)
require.NotEmpty(t, env.Payload)

// Wait for the command execution history to be saved
var messages *assist.GetAssistantMessagesResponse
Expand Down Expand Up @@ -292,18 +299,13 @@ func waitForCommandOutput(stream io.Reader, substr string) error {
default:
}

var env Envelope
var env outEnvelope
dec := json.NewDecoder(stream)
if err := dec.Decode(&env); err != nil {
return trace.Wrap(err, "decoding envelope JSON from stream")
}

d, err := base64.StdEncoding.DecodeString(env.Payload)
if err != nil {
return trace.Wrap(err, "decoding b64 payload")
}

data := removeSpace(string(d))
data := removeSpace(string(env.Payload))
if strings.Contains(data, substr) {
return nil
}
Expand Down
10 changes: 8 additions & 2 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor
}

// Send close envelope to web terminal upon exit without an error.
if err := t.stream.SendCloseMessage(); err != nil {
if err := t.stream.SendCloseMessage(sessionEndEvent{NodeID: t.sessionData.ServerID}); err != nil {
t.log.WithError(err).Error("Unable to send close event to web client.")
}

Expand Down Expand Up @@ -1297,10 +1297,16 @@ func (t *WSStream) Read(out []byte) (int, error) {
}

// SendCloseMessage sends a close message on the web socket.
func (t *WSStream) SendCloseMessage() error {
func (t *WSStream) SendCloseMessage(event sessionEndEvent) error {
sessionMetadataPayload, err := json.Marshal(&event)
if err != nil {
return trace.Wrap(err)
}

envelope := &Envelope{
Version: defaults.WebsocketVersion,
Type: defaults.WebsocketClose,
Payload: string(sessionMetadataPayload),
}
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
Expand Down
81 changes: 42 additions & 39 deletions web/packages/teleport/src/Assist/context/AssistContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ import { getAccessToken, getHostName } from 'teleport/services/api';

import {
ExecutionEnvelopeType,
ExecutionTeleportErrorType,
RawPayload,
ServerMessageType,
SessionEndData,
} from 'teleport/Assist/types';

import { MessageTypeEnum, Protobuf } from 'teleport/lib/term/protobuf';
Expand Down Expand Up @@ -429,39 +431,51 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
);

const proto = new Protobuf();

executeCommandWebSocket.current = new WebSocket(url);
executeCommandWebSocket.current.binaryType = 'arraybuffer';

let sessionsEnded = 0;

executeCommandWebSocket.current.onmessage = event => {
const uintArray = new Uint8Array(event.data);

const msg = proto.decode(uintArray);

switch (msg.type) {
case MessageTypeEnum.RAW:
case MessageTypeEnum.RAW: {
const data = JSON.parse(msg.payload) as RawPayload;
const payload = atob(data.payload);

if (data.type === ExecutionEnvelopeType) {
dispatch({
type: AssistStateActionType.AddCommandResultSummary,
conversationId: state.conversations.selectedId,
summary: payload,
executionId: execParams.execution_id,
command: execParams.command,
});
} else {
dispatch({
type: AssistStateActionType.UpdateCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(data.node_id),
output: payload,
});
switch (data.type) {
case ExecutionTeleportErrorType:
dispatch({
type: AssistStateActionType.FinishCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(data.node_id),
});

nodeIdToResultId.delete(data.node_id);
break;

case ExecutionEnvelopeType:
dispatch({
type: AssistStateActionType.AddCommandResultSummary,
conversationId: state.conversations.selectedId,
summary: payload,
executionId: execParams.execution_id,
command: execParams.command,
});
break;

default:
dispatch({
type: AssistStateActionType.UpdateCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(data.node_id),
output: payload,
});
}

break;
}

case MessageTypeEnum.WEBAUTHN_CHALLENGE:
const challenge = JSON.parse(msg.payload);
Expand All @@ -481,30 +495,19 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

break;

case MessageTypeEnum.SESSION_END:
// we don't know the nodeId of the session that ended, so we have to
// count the finished sessions and then mark them all as done once
// they've all finished
sessionsEnded += 1;

if (sessionsEnded === nodeIdToResultId.size) {
const message = proto.encodeCloseMessage();
const bytearray = new Uint8Array(message);
case MessageTypeEnum.SESSION_END: {
const data = JSON.parse(msg.payload) as SessionEndData;

for (const nodeId of nodeIdToResultId.keys()) {
dispatch({
type: AssistStateActionType.FinishCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(nodeId),
});

executeCommandWebSocket.current.send(bytearray.buffer);
}
dispatch({
type: AssistStateActionType.FinishCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(data.node_id),
});

nodeIdToResultId.clear();
}
nodeIdToResultId.delete(data.node_id);

break;
}
}
};

Expand Down
10 changes: 10 additions & 0 deletions web/packages/teleport/src/Assist/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@ export enum ServerMessageType {
AssistThought = 'CHAT_MESSAGE_PROGRESS_UPDATE',
}

// ExecutionEnvelopeType is the type of message that is returned when
// the command summary is returned.
export const ExecutionEnvelopeType = 'summary';

// ExecutionTeleportErrorType is the type of error that is returned when
// Teleport returns an error (failed to execute command, failed to connect, etc.)
export const ExecutionTeleportErrorType = 'teleport-error';

export interface Conversation {
id: string;
title?: string;
Expand Down Expand Up @@ -192,6 +198,10 @@ export interface SessionData {
session: { server_id: string };
}

export interface SessionEndData {
node_id: string;
}

export interface ExecuteRemoteCommandPayload {
command: string;
login?: string;
Expand Down