Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions web/packages/teleport/src/Assist/context/AssistContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ import { AssistStateActionType, reducer } from 'teleport/Assist/context/state';
import { convertServerMessages } from 'teleport/Assist/context/utils';
import useStickyClusterId from 'teleport/useStickyClusterId';
import cfg from 'teleport/config';
import { getAccessToken, getHostName } from 'teleport/services/api';

import { WebsocketStatus } from 'teleport/types';
import { getHostName } from 'teleport/services/api';

import {
AccessRequestClientMessage,
Expand All @@ -50,6 +48,7 @@ import {
makeMfaAuthenticateChallenge,
WebauthnAssertionResponse,
} from 'teleport/services/auth';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket';

import * as service from '../service';
import {
Expand All @@ -65,7 +64,6 @@ import type {
ServerMessage,
} from 'teleport/Assist/types';
import type { AssistState } from 'teleport/Assist/context/state';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket';

interface AssistContextValue {
cancelMfaChallenge: () => void;
Expand Down Expand Up @@ -127,7 +125,13 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
}

function setupWebSocket(conversationId: string, initialMessage?: string) {

activeWebSocket.current = new AuthenticatedWebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
clusterId,
conversationId
)
);

window.clearTimeout(refreshWebSocketTimeout.current);

Expand All @@ -137,21 +141,22 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
TEN_MINUTES * 0.8
);

const onopen = () => {
activeWebSocket.current.onopen = () => {
if (initialMessage) {
activeWebSocket.current.send(initialMessage);
}
}
};

const onclose = () => {
activeWebSocket.current.onclose = () => {
dispatch({
type: AssistStateActionType.SetStreaming,
streaming: false,
});
};

const onmessage = event => {
activeWebSocket.current.onmessage = async event => {
const data = JSON.parse(event.data) as ServerMessage;

switch (data.type) {
case ServerMessageType.Assist:
dispatch({
Expand Down Expand Up @@ -245,14 +250,6 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
break;
}
};

activeWebSocket.current = new AuthenticatedWebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
clusterId,
conversationId
), onopen, onmessage, null, onclose
);
}

async function createConversation() {
Expand Down Expand Up @@ -353,7 +350,7 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

if (
!activeWebSocket.current ||
activeWebSocket.current.readyState === WebSocket.CLOSED
activeWebSocket.current.readyState === AuthenticatedWebSocket.CLOSED
) {
setupWebSocket(state.conversations.selectedId, data);
} else {
Expand Down Expand Up @@ -383,7 +380,8 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
function sendMfaChallenge(data: WebauthnAssertionResponse) {
if (
!executeCommandWebSocket.current ||
executeCommandWebSocket.current.readyState !== WebSocket.OPEN ||
executeCommandWebSocket.current.readyState !==
AuthenticatedWebSocket.OPEN ||
!data
) {
console.warn(
Expand Down Expand Up @@ -455,8 +453,10 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
);

const proto = new Protobuf();
const onmessage = (event: MessageEvent) => {
executeCommandWebSocket.current.binaryType = 'arraybuffer';
executeCommandWebSocket.current = new AuthenticatedWebSocket(url);
executeCommandWebSocket.current.binaryType = 'arraybuffer';

executeCommandWebSocket.current.onmessage = event => {
const uintArray = new Uint8Array(event.data);

const msg = proto.decode(uintArray);
Expand Down Expand Up @@ -533,8 +533,9 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
}
};

const onclose = () => {
executeCommandWebSocket.current.onclose = () => {
executeCommandWebSocket.current = null;

// If the execution failed, we won't get a SESSION_END message, so we
// need to mark all the results as finished here.
for (const nodeId of nodeIdToResultId.keys()) {
Expand All @@ -546,8 +547,6 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
}
nodeIdToResultId.clear();
};

executeCommandWebSocket.current = new AuthenticatedWebSocket(url, null, onmessage, null, onclose);
}

async function deleteConversation(conversationId: string) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import React, {
} from 'react';

import { Author, ServerMessage } from 'teleport/Assist/types';
import { getAccessToken, getHostName } from 'teleport/services/api';
import { getHostName } from 'teleport/services/api';
import useStickyClusterId from 'teleport/useStickyClusterId';
import cfg from 'teleport/config';
import {
Expand All @@ -36,7 +36,7 @@ import {
SuggestedCommandMessage,
UserMessage,
} from 'teleport/Console/DocumentSsh/TerminalAssist/types';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket';
import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket';

interface TerminalAssistContextValue {
close: () => void;
Expand Down Expand Up @@ -72,7 +72,9 @@ export function TerminalAssistContextProvider(
const [messages, setMessages] = useState<Message[]>([]);

useEffect(() => {
let onmessage = (e: MessageEvent) => {
socketRef.current = new AuthenticatedWebSocket(socketUrl);

socketRef.current.onmessage = e => {
const data = JSON.parse(e.data) as ServerMessage;
const payload = JSON.parse(data.payload) as {
action: string;
Expand All @@ -93,8 +95,6 @@ export function TerminalAssistContextProvider(
setLoading(false);
setMessages(m => [message, ...m]);
};

socketRef.current = new AuthenticatedWebSocket(socketUrl, null, onmessage);
}, []);

function close() {
Expand All @@ -120,14 +120,15 @@ export function TerminalAssistContextProvider(
'ssh-explain'
);

const ws = new AuthenticatedWebSocket(socketUrl);


let onopen = () => {
ws.send(encodedOutput);
ws.onopen = () => {
ws.send(encodedOutput);
};

let onmessage = (event: MessageEvent) => {
const msg = JSON.parse(event.data) as ServerMessage;
ws.onmessage = event => {
const message = event.data;
const msg = JSON.parse(message) as ServerMessage;

const explanation: ExplanationMessage = {
author: Author.Teleport,
Expand All @@ -140,7 +141,6 @@ export function TerminalAssistContextProvider(

ws.close();
};
const ws = new AuthenticatedWebSocket(socketUrl, onopen, onmessage);
}

function send(message: string) {
Expand Down
Loading