Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion e
Submodule e updated from 181e48 to f0d46e
6 changes: 4 additions & 2 deletions lib/ai/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tiktoken-go/tokenizer/codec"
)
Expand Down Expand Up @@ -107,11 +108,12 @@ func TestChat_Complete(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
// Use assert as require doesn't work when called from a goroutine
assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]

_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")
assert.NoError(t, err, "Write error")

responses = responses[1:]
}))
Expand Down
28 changes: 26 additions & 2 deletions lib/web/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error {
// runAssistant upgrades the HTTP connection to a websocket and starts a chat loop.
func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
sctx *SessionContext, site reversetunnel.RemoteSite,
) error {
) (err error) {
q := r.URL.Query()
conversationID := q.Get("conversation_id")
if conversationID == "" {
Expand Down Expand Up @@ -371,7 +371,31 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
h.log.WithError(err).Error("Error setting websocket readline")
return nil
}
defer ws.Close()
defer func() {
closureReason := websocket.CloseNormalClosure
closureMsg := ""
if err != nil {
h.log.WithError(err).Error("Error in the Assistant loop")
_ = ws.WriteJSON(&assistantMessage{
Type: assist.MessageKindError,
Payload: "An error has occurred. Please try again later.",
CreatedTime: h.clock.Now().UTC().Format(time.RFC3339),
})
// Set server error code and message: https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1
closureReason = websocket.CloseInternalServerErr
closureMsg = err.Error()
}
// Send the close message to the client and close the connection
if err := ws.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(closureReason, closureMsg),
time.Now().Add(time.Second),
); err != nil {
h.log.Warnf("Failed to write close message: %v", err)
}
if err := ws.Close(); err != nil {
h.log.Warnf("Failed to close websocket: %v", err)
}
}()

// Update the read deadline upon receiving a pong message.
ws.SetPongHandler(func(_ string) error {
Expand Down
91 changes: 89 additions & 2 deletions lib/web/assistant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"

Expand Down Expand Up @@ -149,11 +150,12 @@ func Test_runAssistant(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
// Use assert as require doesn't work when called from a goroutine
assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]

_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")
assert.NoError(t, err, "Write error")

responses = responses[1:]
}))
Expand Down Expand Up @@ -194,6 +196,91 @@ func Test_runAssistant(t *testing.T) {
}
}

// Test_runAssistError tests that the assistant returns an error message
// when the OpenAI API returns an error.
func Test_runAssistError(t *testing.T) {
t.Parallel()

readHelloMsg := func(ws *websocket.Conn) {
_, payload, err := ws.ReadMessage()
require.NoError(t, err)

var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

// Expect "hello" message
require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
require.Contains(t, msg.Payload, "Hey, I'm Teleport")
}

readErrorMsg := func(ws *websocket.Conn) {
err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
require.NoError(t, err)

_, payload, err := ws.ReadMessage()
require.NoError(t, err)

var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

// Expect a generic error message
require.Equal(t, assist.MessageKindError, msg.Type)
require.Contains(t, msg.Payload, "An error has occurred.")
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Simulate rate limit error
w.WriteHeader(429)

errMsg := openai.ErrorResponse{
Error: &openai.APIError{
Code: "rate_limit_reached",
Message: "You are sending requests too quickly.",
Param: nil,
Type: "rate_limit_reached",
HTTPStatusCode: 429,
},
}

dataBytes, err := json.Marshal(errMsg)
// Use assert as require doesn't work when called from a goroutine
assert.NoError(t, err, "Marshal error")

_, err = w.Write(dataBytes)
assert.NoError(t, err, "Write error")
}))
t.Cleanup(server.Close)

openaiCfg := openai.DefaultConfig("test-token")
openaiCfg.BaseURL = server.URL
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})

ctx := context.Background()
authPack := s.authPack(t, "foo")
// Create the conversation
conversationID := s.makeAssistConversation(t, ctx, authPack)

// Make WS client and start the conversation
ws, err := s.makeAssistant(t, authPack, conversationID)
require.NoError(t, err)
t.Cleanup(func() {
ws.Close()
})

// verify responses
readHelloMsg(ws)
readErrorMsg(ws)

// Check for close message
_, _, err = ws.ReadMessage()
closeErr, ok := err.(*websocket.CloseError)
require.True(t, ok, "Expected close error")
require.Equal(t, websocket.CloseInternalServerErr, closeErr.Code, "Expected abnormal closure")
}

// makeAssistConversation creates a new assist conversation and returns its ID
func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string {
clt := authPack.clt
Expand Down
1 change: 1 addition & 0 deletions web/packages/teleport/src/Assist/Conversation/Message.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ function createComponentForEntry(
switch (entry.type) {
case ServerMessageType.Assist:
case ServerMessageType.User:
case ServerMessageType.Error:
return <MessageEntry content={entry.message} />;

case ServerMessageType.Command:
Expand Down
38 changes: 36 additions & 2 deletions web/packages/teleport/src/Assist/context/AssistContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
});
}

function setupWebSocket(conversationId: string) {
function setupWebSocket(conversationId: string, initialMessage?: string) {
activeWebSocket.current = new WebSocket(
cfg.getAssistConversationWebSocketUrl(
getHostName(),
Expand All @@ -123,6 +123,16 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
TEN_MINUTES * 0.8
);

activeWebSocket.current.onopen = () => {
if (initialMessage) {
activeWebSocket.current.send(initialMessage);
}
};

activeWebSocket.current.onclose = () => {
dispatch({ type: AssistStateActionType.SetStreaming, streaming: false });
};

activeWebSocket.current.onmessage = async event => {
const data = JSON.parse(event.data) as ServerMessage;

Expand Down Expand Up @@ -178,6 +188,21 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

break;
}

case ServerMessageType.Error:
dispatch({
type: AssistStateActionType.AddMessage,
messageType: ServerMessageType.Error,
message: data.payload,
conversationId,
});

dispatch({
type: AssistStateActionType.SetStreaming,
streaming: false,
});

break;
}
};
}
Expand Down Expand Up @@ -273,7 +298,16 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {

dispatch({ type: AssistStateActionType.SetStreaming, streaming: true });

activeWebSocket.current.send(JSON.stringify({ payload: message }));
const data = JSON.stringify({ payload: message });

if (
!activeWebSocket.current ||
activeWebSocket.current.readyState === WebSocket.CLOSED
) {
setupWebSocket(state.conversations.selectedId, data);
} else {
activeWebSocket.current.send(data);
}

dispatch({
type: AssistStateActionType.AddMessage,
Expand Down
5 changes: 4 additions & 1 deletion web/packages/teleport/src/Assist/context/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ export interface SetConversationMessagesAction {

export interface AddMessageAction {
type: AssistStateActionType.AddMessage;
messageType: ServerMessageType.User | ServerMessageType.Assist;
messageType:
| ServerMessageType.User
| ServerMessageType.Assist
| ServerMessageType.Error;
message: string;
conversationId: string;
}
Expand Down
1 change: 1 addition & 0 deletions web/packages/teleport/src/Assist/context/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function getMessageTypeAuthor(type: string) {
case ServerMessageType.Command:
case ServerMessageType.CommandResult:
case ServerMessageType.CommandResultStream:
case ServerMessageType.Error:
return Author.Teleport;
}
}
Expand Down
8 changes: 8 additions & 0 deletions web/packages/teleport/src/Assist/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { EventType } from 'teleport/lib/term/enums';
export enum ServerMessageType {
Assist = 'CHAT_MESSAGE_ASSISTANT',
User = 'CHAT_MESSAGE_USER',
Error = 'CHAT_MESSAGE_ERROR',
Command = 'COMMAND',
CommandResult = 'COMMAND_RESULT',
CommandResultStream = 'COMMAND_RESULT_STREAM',
Expand Down Expand Up @@ -85,6 +86,12 @@ export interface ResolvedUserServerMessage {
created: Date;
}

export interface ResolvedErrorServerMessage {
type: ServerMessageType.Error;
message: string;
created: Date;
}

export interface ResolvedCommandResultStreamServerMessage {
type: ServerMessageType.CommandResultStream;
id: number;
Expand All @@ -99,6 +106,7 @@ export type ResolvedServerMessage =
| ResolvedCommandServerMessage
| ResolvedAssistServerMessage
| ResolvedUserServerMessage
| ResolvedErrorServerMessage
| ResolvedCommandResultServerMessage
| ResolvedAssistThoughtServerMessage
| ResolvedCommandResultStreamServerMessage;
Expand Down