diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 32dc798e892ed..b5d66d2c432b6 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -375,9 +375,10 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, closureReason := websocket.CloseNormalClosure closureMsg := "" if err != nil { + h.log.WithError(err).Error("Error in the Assistant loop") _ = ws.WriteJSON(&assistantMessage{ Type: assist.MessageKindError, - Payload: err.Error(), + Payload: "An error has occurred. Please try again later.", CreatedTime: h.clock.Now().UTC().Format(time.RFC3339), }) // Set server error code and message: https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 diff --git a/web/packages/teleport/src/Assist/Conversation/Message.tsx b/web/packages/teleport/src/Assist/Conversation/Message.tsx index 31f60bef5714e..4cb5f2fad5e18 100644 --- a/web/packages/teleport/src/Assist/Conversation/Message.tsx +++ b/web/packages/teleport/src/Assist/Conversation/Message.tsx @@ -107,6 +107,7 @@ function createComponentForEntry( switch (entry.type) { case ServerMessageType.Assist: case ServerMessageType.User: + case ServerMessageType.Error: return ; case ServerMessageType.Command: diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index 7ce6c0ed6136b..e76bd5f4d6a30 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -105,7 +105,7 @@ export function AssistContextProvider(props: PropsWithChildren) { }); } - function setupWebSocket(conversationId: string) { + function setupWebSocket(conversationId: string, initialMessage?: string) { activeWebSocket.current = new WebSocket( cfg.getAssistConversationWebSocketUrl( getHostName(), @@ -123,6 +123,16 @@ export function AssistContextProvider(props: PropsWithChildren) { TEN_MINUTES * 0.8 ); + activeWebSocket.current.onopen = () => { + if (initialMessage) { + activeWebSocket.current.send(initialMessage); + } + }; + + activeWebSocket.current.onclose = () => { + dispatch({ type: AssistStateActionType.SetStreaming, streaming: false }); + }; + activeWebSocket.current.onmessage = async event => { const data = JSON.parse(event.data) as ServerMessage; @@ -178,6 +188,21 @@ export function AssistContextProvider(props: PropsWithChildren) { break; } + + case ServerMessageType.Error: + dispatch({ + type: AssistStateActionType.AddMessage, + messageType: ServerMessageType.Error, + message: data.payload, + conversationId, + }); + + dispatch({ + type: AssistStateActionType.SetStreaming, + streaming: false, + }); + + break; } }; } @@ -273,7 +298,16 @@ export function AssistContextProvider(props: PropsWithChildren) { dispatch({ type: AssistStateActionType.SetStreaming, streaming: true }); - activeWebSocket.current.send(JSON.stringify({ payload: message })); + const data = JSON.stringify({ payload: message }); + + if ( + !activeWebSocket.current || + activeWebSocket.current.readyState === WebSocket.CLOSED + ) { + setupWebSocket(state.conversations.selectedId, data); + } else { + activeWebSocket.current.send(data); + } dispatch({ type: AssistStateActionType.AddMessage, diff --git a/web/packages/teleport/src/Assist/context/state.ts b/web/packages/teleport/src/Assist/context/state.ts index c72222db77878..41dc5312ff367 100644 --- a/web/packages/teleport/src/Assist/context/state.ts +++ b/web/packages/teleport/src/Assist/context/state.ts @@ -90,7 +90,10 @@ export interface SetConversationMessagesAction { export interface AddMessageAction { type: AssistStateActionType.AddMessage; - messageType: ServerMessageType.User | ServerMessageType.Assist; + messageType: + | ServerMessageType.User + | ServerMessageType.Assist + | ServerMessageType.Error; message: string; conversationId: string; } diff --git a/web/packages/teleport/src/Assist/context/utils.ts b/web/packages/teleport/src/Assist/context/utils.ts index 08f41ebdd4fd4..e7473f7003ea7 100644 --- a/web/packages/teleport/src/Assist/context/utils.ts +++ b/web/packages/teleport/src/Assist/context/utils.ts @@ -32,6 +32,7 @@ function getMessageTypeAuthor(type: string) { case ServerMessageType.Command: case ServerMessageType.CommandResult: case ServerMessageType.CommandResultStream: + case ServerMessageType.Error: return Author.Teleport; } } diff --git a/web/packages/teleport/src/Assist/types.ts b/web/packages/teleport/src/Assist/types.ts index d9b2579c116b4..ba73068e0c80e 100644 --- a/web/packages/teleport/src/Assist/types.ts +++ b/web/packages/teleport/src/Assist/types.ts @@ -18,6 +18,7 @@ import { EventType } from 'teleport/lib/term/enums'; export enum ServerMessageType { Assist = 'CHAT_MESSAGE_ASSISTANT', User = 'CHAT_MESSAGE_USER', + Error = 'CHAT_MESSAGE_ERROR', Command = 'COMMAND', CommandResult = 'COMMAND_RESULT', CommandResultStream = 'COMMAND_RESULT_STREAM', @@ -85,6 +86,12 @@ export interface ResolvedUserServerMessage { created: Date; } +export interface ResolvedErrorServerMessage { + type: ServerMessageType.Error; + message: string; + created: Date; +} + export interface ResolvedCommandResultStreamServerMessage { type: ServerMessageType.CommandResultStream; id: number; @@ -99,6 +106,7 @@ export type ResolvedServerMessage = | ResolvedCommandServerMessage | ResolvedAssistServerMessage | ResolvedUserServerMessage + | ResolvedErrorServerMessage | ResolvedCommandResultServerMessage | ResolvedAssistThoughtServerMessage | ResolvedCommandResultStreamServerMessage;