From 385793f921c6b94887402ce70a7abcb0d15d34e0 Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Wed, 10 May 2023 20:31:48 +0300 Subject: [PATCH 1/9] Add rate limiting to Assist --- lib/web/apiserver.go | 11 +++++++++++ lib/web/assistant.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index bbc83f4b2043c..784f863af756a 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -51,6 +51,7 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "golang.org/x/mod/semver" + "golang.org/x/time/rate" "google.golang.org/protobuf/encoding/protojson" "github.com/gravitational/teleport" @@ -92,6 +93,14 @@ import ( const ( // SSOLoginFailureMessage is a generic error message to avoid disclosing sensitive SSO failure messages. SSOLoginFailureMessage = "Failed to login. Please check Teleport's log for more details." + + assistantTokensPerHour = 140 + // assistantLimiterRate is the rate (in tokens per second) + // at which tokens for the assistant rate limiter are replenished + assistantLimiterRate = rate.Limit(assistantTokensPerHour / float64(time.Hour/time.Second)) + // assistantLimiterCapacity is the total capacity of the token bucket for the assistant rate limiter. + // The bucket starts full, prefilled for a week. + assistantLimiterCapacity = assistantTokensPerHour * 24 * 7 ) // healthCheckAppServerFunc defines a function used to perform a health check @@ -111,6 +120,7 @@ type Handler struct { clock clockwork.Clock limiter *limiter.RateLimiter highLimiter *limiter.RateLimiter + assistantLimiter *rate.Limiter healthCheckAppServer healthCheckAppServerFunc // sshPort specifies the SSH proxy port extracted // from configuration @@ -299,6 +309,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { clock: clockwork.NewRealClock(), ClusterFeatures: cfg.ClusterFeatures, healthCheckAppServer: cfg.HealthCheckAppServer, + assistantLimiter: rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity), } // for properly handling url-encoded parameter values. diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 171efb6d98d3e..70622590e3e75 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -405,9 +405,23 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } + // We can not know how many tokens we will consume in advance. + // Try to consume a small amount of tokens first. + const lookaheadTokens = 100 + if !h.assistantLimiter.AllowN(time.Now(), lookaheadTokens) { + if err := sendRateLimitedMessage(h, conversationID, ws); err != nil { + return trace.Wrap(err) + } + continue + } + //TODO(jakule): Should we sanitize the payload? if err := chat.InsertAssistantMessage(ctx, assist.MessageKindUserMessage, wsIncoming.Payload); err != nil { return trace.Wrap(err) + + promptTokens, err := chat.PromptTokens() + if err != nil { + log.Warnf("Failed to calculate prompt tokens: %v", err) } usedTokens, err := chat.ProcessComplete(ctx, onMessageFn) @@ -415,6 +429,14 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } + // Once we know how many tokens were consumed for prompt+completion, + // consume the remaining tokens from the rate limiter bucket. + extraTokens := promptTokens + completionTokens - lookaheadTokens + if extraTokens < 0 { + extraTokens = 0 + } + h.assistantLimiter.ReserveN(time.Now(), extraTokens) + usageEventReq := &proto.SubmitUsageEventRequest{ Event: &usageeventsv1.UsageEventOneOf{ Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{ @@ -436,3 +458,14 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return nil } + +// Sends a "rate-limited" message to the user without persisting it in the conversation. +func sendRateLimitedMessage(h *Handler, conversationID string, ws *websocket.Conn) error { + protoMsg := &assistantMessage{ + Type: messageKindAssistantMessage, + Payload: "You have reached the rate limit. Please try again later.", + CreatedTime: h.clock.Now().UTC().Format(time.RFC3339), + } + err := ws.WriteJSON(protoMsg) + return trace.Wrap(err) +} From bcd3b38af9ddbfff556f55842f4925b6430a5089 Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Thu, 11 May 2023 15:54:53 +0300 Subject: [PATCH 2/9] Only rate limit Assist in Cloud --- lib/web/apiserver.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 784f863af756a..7ff2ed11cc61d 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -309,7 +309,15 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { clock: clockwork.NewRealClock(), ClusterFeatures: cfg.ClusterFeatures, healthCheckAppServer: cfg.HealthCheckAppServer, - assistantLimiter: rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity), + } + + // Check for self-hosted vs Cloud. + // TODO(justinas): this needs to be modified when we allow user-supplied API keys in Cloud + if modules.GetModules().Features().Cloud { + h.assistantLimiter = rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity) + } else { + // Set up a limiter with "infinite limit", the "burst" parameter is ignored + h.assistantLimiter = rate.NewLimiter(rate.Inf, 0) } // for properly handling url-encoded parameter values. From 41b1c4c6b088fbdc30174f12ee2a79fe494b1bd2 Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Mon, 15 May 2023 14:17:38 +0300 Subject: [PATCH 3/9] Add a comment to assistantLimiter --- lib/web/apiserver.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 7ff2ed11cc61d..a89515e10b471 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -120,8 +120,13 @@ type Handler struct { clock clockwork.Clock limiter *limiter.RateLimiter highLimiter *limiter.RateLimiter - assistantLimiter *rate.Limiter - healthCheckAppServer healthCheckAppServerFunc + // assistantLimiter limits the amount of tokens that can be consumed + // by OpenAI API calls when using a shared key. + // golang.org/x/time/rate is used, as the oxy ratelimiter + // is quite tightly tied to individual http.Requests, + // and instead we want to consume arbitrary amounts of tokens. + assistantLimiter *rate.Limiter + healthCheckAppServer healthCheckAppServerFunc // sshPort specifies the SSH proxy port extracted // from configuration sshPort string From 54a4ebfe4bb9dc37fa5e12b56d16ba85ea93f77b Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Mon, 15 May 2023 14:37:57 +0300 Subject: [PATCH 4/9] Fixes after rebase --- lib/assist/assist.go | 8 +++++--- lib/web/assistant.go | 24 +++++------------------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/lib/assist/assist.go b/lib/assist/assist.go index 1a003cbce7f53..fa19ce47ca367 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -58,6 +58,8 @@ const ( MessageKindAssistantPartialFinalize MessageType = "CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE" // MessageKindSystemMessage is the type of Assist message that contains the system message. MessageKindSystemMessage MessageType = "CHAT_MESSAGE_SYSTEM" + // MessageKindUIMessage is the type of Assist message that is presented to user as information, but not stored persistently in the conversation. This can include backend error messages and the like. + MessageKindUIMessage MessageType = "CHAT_MESSAGE_UI" ) // Assist is the Teleport Assist client. @@ -210,7 +212,7 @@ type TokensUsed struct { // Prompt is a number of tokens used in the prompt. Prompt int // Completion is a number of tokens used in the completion. - Competition int + Completion int } // ProcessComplete processes the completion request and returns the number of tokens used. @@ -386,8 +388,8 @@ func (c *Chat) ProcessComplete(ctx context.Context, } return &TokensUsed{ - Prompt: promptTokens, - Competition: numTokens, + Prompt: promptTokens, + Completion: numTokens, }, nil } diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 70622590e3e75..9dbcc60a6f2ca 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -409,7 +409,8 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, // Try to consume a small amount of tokens first. const lookaheadTokens = 100 if !h.assistantLimiter.AllowN(time.Now(), lookaheadTokens) { - if err := sendRateLimitedMessage(h, conversationID, ws); err != nil { + err := onMessageFn(assist.MessageKindUIMessage, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC()) + if err != nil { return trace.Wrap(err) } continue @@ -418,10 +419,6 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, //TODO(jakule): Should we sanitize the payload? if err := chat.InsertAssistantMessage(ctx, assist.MessageKindUserMessage, wsIncoming.Payload); err != nil { return trace.Wrap(err) - - promptTokens, err := chat.PromptTokens() - if err != nil { - log.Warnf("Failed to calculate prompt tokens: %v", err) } usedTokens, err := chat.ProcessComplete(ctx, onMessageFn) @@ -431,7 +428,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, // Once we know how many tokens were consumed for prompt+completion, // consume the remaining tokens from the rate limiter bucket. - extraTokens := promptTokens + completionTokens - lookaheadTokens + extraTokens := usedTokens.Prompt + usedTokens.Completion - lookaheadTokens if extraTokens < 0 { extraTokens = 0 } @@ -442,9 +439,9 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{ AssistCompletion: &usageeventsv1.AssistCompletionEvent{ ConversationId: conversationID, - TotalTokens: int64(usedTokens.Prompt + usedTokens.Competition), + TotalTokens: int64(usedTokens.Prompt + usedTokens.Completion), PromptTokens: int64(usedTokens.Prompt), - CompletionTokens: int64(usedTokens.Competition), + CompletionTokens: int64(usedTokens.Completion), }, }, }, @@ -458,14 +455,3 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return nil } - -// Sends a "rate-limited" message to the user without persisting it in the conversation. -func sendRateLimitedMessage(h *Handler, conversationID string, ws *websocket.Conn) error { - protoMsg := &assistantMessage{ - Type: messageKindAssistantMessage, - Payload: "You have reached the rate limit. Please try again later.", - CreatedTime: h.clock.Now().UTC().Format(time.RFC3339), - } - err := ws.WriteJSON(protoMsg) - return trace.Wrap(err) -} From b598ee05875d5dd3d11da0ad7ff42565b099151a Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Mon, 15 May 2023 15:16:29 +0300 Subject: [PATCH 5/9] Add 'rate-limited' test case to assistant_test --- lib/web/assistant_test.go | 158 +++++++++++++++++++++++++++----------- 1 file changed, 113 insertions(+), 45 deletions(-) diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 551dc1d273c4d..7822f955720a3 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/require" + "golang.org/x/time/rate" "github.com/gravitational/teleport/lib/assist" "github.com/gravitational/teleport/lib/client" @@ -40,71 +41,138 @@ import ( func Test_runAssistant(t *testing.T) { t.Parallel() - responses := [][]byte{ - generateTextResponse(), + readPartialMessage := func(t *testing.T, ws *websocket.Conn) string { + var msg assistantMessage + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + require.Equal(t, assist.MessageKindAssistantPartialMessage, msg.Type) + return msg.Payload } - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") + readStreamEnd := func(t *testing.T, ws *websocket.Conn) { + var msg assistantMessage + _, payload, err := ws.ReadMessage() + require.NoError(t, err) - require.GreaterOrEqual(t, len(responses), 1, "Unexpected request") - dataBytes := responses[0] + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) - _, err := w.Write(dataBytes) - require.NoError(t, err, "Write error") + require.Equal(t, assist.MessageKindAssistantPartialFinalize, msg.Type) + } + + readRateLimitedMessage := func(t *testing.T, ws *websocket.Conn) { + var msg assistantMessage + _, payload, err := ws.ReadMessage() + require.NoError(t, err) - responses = responses[1:] - })) - defer server.Close() + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) - openaiCfg := openai.DefaultConfig("test-token") - openaiCfg.BaseURL = server.URL - s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) + require.Equal(t, assist.MessageKindUIMessage, msg.Type) + require.Equal(t, msg.Payload, "You have reached the rate limit. Please try again later.") + } - ws, err := s.makeAssistant(t, s.authPack(t, "foo")) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, ws.Close()) }) + testCases := []struct { + name string + responses [][]byte + setup func(*WebSuite) + act func(*testing.T, *websocket.Conn) + }{ + { + name: "normal", + responses: [][]byte{ + generateTextResponse(), + }, + act: func(t *testing.T, ws *websocket.Conn) { + err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) + require.NoError(t, err) + + require.Contains(t, readPartialMessage(t, ws), "Which") + require.Contains(t, readPartialMessage(t, ws), "node do") + require.Contains(t, readPartialMessage(t, ws), "you want") + require.Contains(t, readPartialMessage(t, ws), "use?") + + readStreamEnd(t, ws) + }, + }, + { + name: "rate limited", + responses: [][]byte{ + generateTextResponse(), + generateTextResponse(), + }, + setup: func(s *WebSuite) { + // 101 token capacity (lookaheadTokens+1) and a slow replenish rate + // to let the first completion request succeed, but not the second one + s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101) + + }, + act: func(t *testing.T, ws *websocket.Conn) { + err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) + require.NoError(t, err) + + require.Contains(t, readPartialMessage(t, ws), "Which") + require.Contains(t, readPartialMessage(t, ws), "node do") + require.Contains(t, readPartialMessage(t, ws), "you want") + require.Contains(t, readPartialMessage(t, ws), "use?") + + readStreamEnd(t, ws) + + err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "all nodes, please"}`)) + require.NoError(t, err) + + readRateLimitedMessage(t, ws) + }, + }, + } - _, payload, err := ws.ReadMessage() - require.NoError(t, err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + responses := tc.responses + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") - var msg assistantMessage - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) + require.GreaterOrEqual(t, len(responses), 1, "Unexpected request") + dataBytes := responses[0] - require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) - require.Contains(t, msg.Payload, "Hey, I'm Teleport") + _, err := w.Write(dataBytes) + require.NoError(t, err, "Write error") - err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) - require.NoError(t, err) + responses = responses[1:] + })) + t.Cleanup(server.Close) - readPartialMessage := func() string { - _, payload, err = ws.ReadMessage() - require.NoError(t, err) + openaiCfg := openai.DefaultConfig("test-token") + openaiCfg.BaseURL = server.URL + s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) + if tc.setup != nil { + tc.setup(s) + } - require.Equal(t, assist.MessageKindAssistantPartialMessage, msg.Type) - return msg.Payload - } + ws, err := s.makeAssistant(t, s.authPack(t, "foo")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) - require.Contains(t, readPartialMessage(), "Which") - require.Contains(t, readPartialMessage(), "node do") - require.Contains(t, readPartialMessage(), "you want") - require.Contains(t, readPartialMessage(), "use?") + _, payload, err := ws.ReadMessage() + require.NoError(t, err) - readStraemEnd := func() { - _, payload, err = ws.ReadMessage() - require.NoError(t, err) + var msg assistantMessage + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) + // Expect "hello" message + require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) + require.Contains(t, msg.Payload, "Hey, I'm Teleport") - require.Equal(t, assist.MessageKindAssistantPartialFinalize, msg.Type) + tc.act(t, ws) + }) } - readStraemEnd() } func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) { From c52f3491984283328c2ab4654a3a0dd508842094 Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Mon, 15 May 2023 15:30:18 +0300 Subject: [PATCH 6/9] Handle CHAT_MESSAGE_UI in Assist web UI --- .../teleport/src/Assist/contexts/messages.tsx | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/web/packages/teleport/src/Assist/contexts/messages.tsx b/web/packages/teleport/src/Assist/contexts/messages.tsx index c1912675adaa8..f6a3e39b7f126 100644 --- a/web/packages/teleport/src/Assist/contexts/messages.tsx +++ b/web/packages/teleport/src/Assist/contexts/messages.tsx @@ -152,7 +152,10 @@ async function convertServerMessage( message: ServerMessage, clusterId: string ): Promise { - if (message.type === 'CHAT_MESSAGE_ASSISTANT') { + if ( + message.type === 'CHAT_MESSAGE_ASSISTANT' || + message.type === 'CHAT_MESSAGE_UI' + ) { const newMessage: Message = { author: Author.Teleport, timestamp: message.created_time, @@ -263,6 +266,8 @@ async function convertServerMessage( return (messages: Message[]) => messages.push(newMessage); } + + throw new Error('unrecognized message type'); } function findIntersection(elems: T[][]): T[] { @@ -364,9 +369,12 @@ export function MessagesContextProvider( if (lastMessage !== null) { const value = JSON.parse(lastMessage.data) as ServerMessage; + // When a streaming message ends, or a non-streaming message arrives if ( value.type === 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE' || - value.type === 'COMMAND' + value.type === 'COMMAND' || + value.type === 'CHAT_MESSAGE_ASSISTANT' || + value.type === 'CHAT_MESSAGE_UI' ) { setResponding(false); } From 984c1b8314b17063485bcab7b117911bf629a1e3 Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Mon, 15 May 2023 16:36:04 +0300 Subject: [PATCH 7/9] Add godoc --- lib/web/apiserver.go | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index a89515e10b471..2a4f907457be9 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -94,6 +94,7 @@ const ( // SSOLoginFailureMessage is a generic error message to avoid disclosing sensitive SSO failure messages. SSOLoginFailureMessage = "Failed to login. Please check Teleport's log for more details." + // assistantTokensPerHour defines how many assistant rate limiter tokens are replenished every hour. assistantTokensPerHour = 140 // assistantLimiterRate is the rate (in tokens per second) // at which tokens for the assistant rate limiter are replenished From 9b5f09ef2ff93122d16749cb76c466136601fb2b Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Mon, 15 May 2023 17:43:31 +0300 Subject: [PATCH 8/9] CHAT_MESSAGE_UI -> CHAT_MESSAGE_ERROR --- lib/assist/assist.go | 4 ++-- lib/web/assistant.go | 2 +- lib/web/assistant_test.go | 2 +- web/packages/teleport/src/Assist/contexts/messages.tsx | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/assist/assist.go b/lib/assist/assist.go index fa19ce47ca367..0e24731d08d9e 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -58,8 +58,8 @@ const ( MessageKindAssistantPartialFinalize MessageType = "CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE" // MessageKindSystemMessage is the type of Assist message that contains the system message. MessageKindSystemMessage MessageType = "CHAT_MESSAGE_SYSTEM" - // MessageKindUIMessage is the type of Assist message that is presented to user as information, but not stored persistently in the conversation. This can include backend error messages and the like. - MessageKindUIMessage MessageType = "CHAT_MESSAGE_UI" + // MessageKindError is the type of Assist message that is presented to user as information, but not stored persistently in the conversation. This can include backend error messages and the like. + MessageKindError MessageType = "CHAT_MESSAGE_ERROR" ) // Assist is the Teleport Assist client. diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 9dbcc60a6f2ca..db7197bc4b2a2 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -409,7 +409,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, // Try to consume a small amount of tokens first. const lookaheadTokens = 100 if !h.assistantLimiter.AllowN(time.Now(), lookaheadTokens) { - err := onMessageFn(assist.MessageKindUIMessage, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC()) + err := onMessageFn(assist.MessageKindError, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC()) if err != nil { return trace.Wrap(err) } diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 7822f955720a3..0661ccfae01a5 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -72,7 +72,7 @@ func Test_runAssistant(t *testing.T) { err = json.Unmarshal(payload, &msg) require.NoError(t, err) - require.Equal(t, assist.MessageKindUIMessage, msg.Type) + require.Equal(t, assist.MessageKindError, msg.Type) require.Equal(t, msg.Payload, "You have reached the rate limit. Please try again later.") } diff --git a/web/packages/teleport/src/Assist/contexts/messages.tsx b/web/packages/teleport/src/Assist/contexts/messages.tsx index f6a3e39b7f126..e970a91fbe891 100644 --- a/web/packages/teleport/src/Assist/contexts/messages.tsx +++ b/web/packages/teleport/src/Assist/contexts/messages.tsx @@ -154,7 +154,7 @@ async function convertServerMessage( ): Promise { if ( message.type === 'CHAT_MESSAGE_ASSISTANT' || - message.type === 'CHAT_MESSAGE_UI' + message.type === 'CHAT_MESSAGE_ERROR' ) { const newMessage: Message = { author: Author.Teleport, @@ -374,7 +374,7 @@ export function MessagesContextProvider( value.type === 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE' || value.type === 'COMMAND' || value.type === 'CHAT_MESSAGE_ASSISTANT' || - value.type === 'CHAT_MESSAGE_UI' + value.type === 'CHAT_MESSAGE_ERROR' ) { setResponding(false); } From 5eb9fb59c9fd3326ab695df1c8a3f21b115a0031 Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Mon, 15 May 2023 18:31:50 +0300 Subject: [PATCH 9/9] Run assistant test cases in parallel --- lib/web/assistant_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 0661ccfae01a5..25cff80c560cd 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -131,7 +131,9 @@ func Test_runAssistant(t *testing.T) { } for _, tc := range testCases { + tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() responses := tc.responses server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream")