Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions lib/web/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ type commandExecResult struct {
SessionID string `json:"session_id"`
}

// sessionEndEvent is an event that is sent when a session ends.
type sessionEndEvent struct {
// NodeID is the ID of the server where the session was created.
NodeID string `json:"node_id"`
}

// Check checks if the request is valid.
func (c *CommandRequest) Check() error {
if c.Command == "" {
Expand Down Expand Up @@ -342,7 +348,7 @@ func (h *Handler) computeAndSendSummary(
return trace.Wrap(err)
}

// Add the summary message to the backend so it is persisted on chat
// Add the summary message to the backend, so it is persisted on chat
// reload.
messagePayload, err := json.Marshal(&assistlib.CommandExecSummary{
ExecutionID: req.executionID,
Expand Down Expand Up @@ -674,7 +680,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl
return
}

if err := t.stream.SendCloseMessage(); err != nil {
if err := t.stream.SendCloseMessage(sessionEndEvent{NodeID: t.sessionData.ServerID}); err != nil {
t.log.WithError(err).Error("Unable to send close event to web client.")
return
}
Expand Down
40 changes: 17 additions & 23 deletions lib/web/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package web
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -180,12 +179,20 @@ func TestExecuteCommandSummary(t *testing.T) {
// Wait for command execution to complete
require.NoError(t, waitForCommandOutput(stream, "teleport"))

var env Envelope
dec := json.NewDecoder(stream)

// Consume the close message
var sessionMetadata sessionEndEvent
err = dec.Decode(&sessionMetadata)
require.NoError(t, err)
require.Equal(t, "node", sessionMetadata.NodeID)

// Consume the summary message
var env outEnvelope
err = dec.Decode(&env)
require.NoError(t, err)
require.Equal(t, envelopeTypeSummary, env.GetType())
require.NotEmpty(t, env.GetPayload())
require.Equal(t, envelopeTypeSummary, env.Type)
require.NotEmpty(t, env.Payload)

// Wait for the command execution history to be saved
var messages *assist.GetAssistantMessagesResponse
Expand Down Expand Up @@ -292,29 +299,16 @@ func waitForCommandOutput(stream io.Reader, substr string) error {
default:
}

out := make([]byte, 100)
n, err := stream.Read(out)
if err != nil {
return trace.Wrap(err)
var env outEnvelope
dec := json.NewDecoder(stream)
if err := dec.Decode(&env); err != nil {
return trace.Wrap(err, "decoding envelope JSON from stream")
}

var env Envelope
err = json.Unmarshal(out[:n], &env)
if err != nil {
return trace.Wrap(err)
}

d, err := base64.StdEncoding.DecodeString(env.Payload)
if err != nil {
return trace.Wrap(err)
}
data := removeSpace(string(d))
if n > 0 && strings.Contains(data, substr) {
data := removeSpace(string(env.Payload))
if strings.Contains(data, substr) {
return nil
}
if err != nil {
return trace.Wrap(err)
}
}
}

Expand Down
10 changes: 8 additions & 2 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor
}

// Send close envelope to web terminal upon exit without an error.
if err := t.stream.SendCloseMessage(); err != nil {
if err := t.stream.SendCloseMessage(sessionEndEvent{NodeID: t.sessionData.ServerID}); err != nil {
t.log.WithError(err).Error("Unable to send close event to web client.")
}

Expand Down Expand Up @@ -1297,10 +1297,16 @@ func (t *WSStream) Read(out []byte) (int, error) {
}

// SendCloseMessage sends a close message on the web socket.
func (t *WSStream) SendCloseMessage() error {
func (t *WSStream) SendCloseMessage(event sessionEndEvent) error {
sessionMetadataPayload, err := json.Marshal(&event)
if err != nil {
return trace.Wrap(err)
}

envelope := &Envelope{
Version: defaults.WebsocketVersion,
Type: defaults.WebsocketClose,
Payload: string(sessionMetadataPayload),
}
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
Expand Down
86 changes: 47 additions & 39 deletions web/packages/teleport/src/Assist/context/AssistContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ import { getAccessToken, getHostName } from 'teleport/services/api';

import {
ExecutionEnvelopeType,
ExecutionTeleportErrorType,
RawPayload,
ServerMessageType,
SessionEndData,
} from 'teleport/Assist/types';

import { MessageTypeEnum, Protobuf } from 'teleport/lib/term/protobuf';
Expand Down Expand Up @@ -78,6 +80,7 @@ const TEN_MINUTES = 10 * 60 * 1000;

export function AssistContextProvider(props: PropsWithChildren<unknown>) {
const activeWebSocket = useRef<WebSocket>(null);
// TODO(ryan): this should be removed once https://github.com/gravitational/teleport.e/pull/1609 is implemented
const executeCommandWebSocket = useRef<WebSocket>(null);
const refreshWebSocketTimeout = useRef<number | null>(null);

Expand Down Expand Up @@ -428,39 +431,51 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
);

const proto = new Protobuf();

executeCommandWebSocket.current = new WebSocket(url);
executeCommandWebSocket.current.binaryType = 'arraybuffer';

let sessionsEnded = 0;

executeCommandWebSocket.current.onmessage = event => {
const uintArray = new Uint8Array(event.data);

const msg = proto.decode(uintArray);

switch (msg.type) {
case MessageTypeEnum.RAW:
case MessageTypeEnum.RAW: {
const data = JSON.parse(msg.payload) as RawPayload;
const payload = atob(data.payload);

if (data.type === ExecutionEnvelopeType) {
dispatch({
type: AssistStateActionType.AddCommandResultSummary,
conversationId: state.conversations.selectedId,
summary: payload,
executionId: execParams.execution_id,
command: execParams.command,
});
} else {
dispatch({
type: AssistStateActionType.UpdateCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(data.node_id),
output: payload,
});
switch (data.type) {
case ExecutionTeleportErrorType:
dispatch({
type: AssistStateActionType.FinishCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(data.node_id),
});

nodeIdToResultId.delete(data.node_id);
break;

case ExecutionEnvelopeType:
dispatch({
type: AssistStateActionType.AddCommandResultSummary,
conversationId: state.conversations.selectedId,
summary: payload,
executionId: execParams.execution_id,
command: execParams.command,
});
break;

default:
dispatch({
type: AssistStateActionType.UpdateCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(data.node_id),
output: payload,
});
}

break;
}

case MessageTypeEnum.WEBAUTHN_CHALLENGE:
const challenge = JSON.parse(msg.payload);
Expand All @@ -480,30 +495,19 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

break;

case MessageTypeEnum.SESSION_END:
// we don't know the nodeId of the session that ended, so we have to
// count the finished sessions and then mark them all as done once
// they've all finished
sessionsEnded += 1;

if (sessionsEnded === nodeIdToResultId.size) {
const message = proto.encodeCloseMessage();
const bytearray = new Uint8Array(message);

for (const nodeId of nodeIdToResultId.keys()) {
dispatch({
type: AssistStateActionType.FinishCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(nodeId),
});
case MessageTypeEnum.SESSION_END: {
const data = JSON.parse(msg.payload) as SessionEndData;

executeCommandWebSocket.current.send(bytearray.buffer);
}
dispatch({
type: AssistStateActionType.FinishCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(data.node_id),
});

nodeIdToResultId.clear();
}
nodeIdToResultId.delete(data.node_id);

break;
}
}
};

Expand Down Expand Up @@ -549,6 +553,10 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

useEffect(() => {
loadConversations();

return () => {
window.clearTimeout(refreshWebSocketTimeout.current);
};
}, []);

const selectedConversationMessages = useMemo(
Expand Down
10 changes: 10 additions & 0 deletions web/packages/teleport/src/Assist/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@ export enum ServerMessageType {
AssistThought = 'CHAT_MESSAGE_PROGRESS_UPDATE',
}

// ExecutionEnvelopeType is the type of message that is returned when
// the command summary is returned.
export const ExecutionEnvelopeType = 'summary';

// ExecutionTeleportErrorType is the type of error that is returned when
// Teleport returns an error (failed to execute command, failed to connect, etc.)
export const ExecutionTeleportErrorType = 'teleport-error';

export interface Conversation {
id: string;
title?: string;
Expand Down Expand Up @@ -192,6 +198,10 @@ export interface SessionData {
session: { server_id: string };
}

export interface SessionEndData {
node_id: string;
}

export interface ExecuteRemoteCommandPayload {
command: string;
login?: string;
Expand Down