diff --git a/lib/ai/chat.go b/lib/ai/chat.go index 196b980bdf2f6..617151d40bc75 100644 --- a/lib/ai/chat.go +++ b/lib/ai/chat.go @@ -60,7 +60,8 @@ func (chat *Chat) Complete(ctx context.Context, userInput string) (any, error) { // if the chat is empty, return the initial response we predefine instead of querying GPT-4 if len(chat.messages) == 1 { return &model.Message{ - Content: model.InitialAIResponse, + Content: model.InitialAIResponse, + TokensUsed: &model.TokensUsed{}, }, nil } diff --git a/lib/ai/chat_test.go b/lib/ai/chat_test.go new file mode 100644 index 0000000000000..f71b20b76d5e6 --- /dev/null +++ b/lib/ai/chat_test.go @@ -0,0 +1,193 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/ai/model" + aitest "github.com/gravitational/teleport/lib/ai/testutils" +) + +func TestChat_PromptTokens(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + messages []openai.ChatCompletionMessage + want int + }{ + { + name: "empty", + messages: []openai.ChatCompletionMessage{}, + want: 0, + }, + { + name: "only system message", + messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Hello", + }, + }, + want: 632, + }, + { + name: "system and user messages", + messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Hello", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hi LLM.", + }, + }, + want: 640, + }, + { + name: "tokenize our prompt", + messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: model.PromptCharacter("Bob"), + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Show me free disk space on localhost node.", + }, + }, + want: 843, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + responses := []string{ + generateCommandResponse(), + } + server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) + + cfg := openai.DefaultConfig("secret-test-token") + cfg.BaseURL = server.URL + "/v1" + + client := NewClientFromConfig(cfg) + chat := client.NewChat("Bob") + + for _, message := range tt.messages { + chat.Insert(message.Role, message.Content) + } + + ctx := context.Background() + message, err := chat.Complete(ctx, "") + require.NoError(t, err) + msg, ok := message.(interface{ UsedTokens() *model.TokensUsed }) + require.True(t, ok) + + usedTokens := msg.UsedTokens().Completion + msg.UsedTokens().Prompt + require.Equal(t, tt.want, usedTokens) + }) + } +} + +func TestChat_Complete(t *testing.T) { + t.Parallel() + + responses := []string{ + generateTextResponse(), + generateCommandResponse(), + } + server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) + + cfg := openai.DefaultConfig("secret-test-token") + cfg.BaseURL = server.URL + "/v1" + client := NewClientFromConfig(cfg) + + chat := client.NewChat("Bob") + + t.Run("initial message", func(t *testing.T) { + msgAny, err := chat.Complete(context.Background(), "Hello") + require.NoError(t, err) + + msg, ok := msgAny.(*model.Message) + require.True(t, ok) + + expectedResp := &model.Message{ + Content: "Hey, I'm Teleport - a powerful tool that can assist you in managing your Teleport cluster via OpenAI GPT-4.", + } + require.Equal(t, expectedResp.Content, msg.Content) + require.NotNil(t, msg.TokensUsed) + }) + + t.Run("text completion", func(t *testing.T) { + chat.Insert(openai.ChatMessageRoleUser, "Show me free disk space") + + msg, err := chat.Complete(context.Background(), "") + require.NoError(t, err) + + require.IsType(t, &model.Message{}, msg) + streamingMessage := msg.(*model.Message) + + const expectedResponse = "Which node do you want use?" + + require.Equal(t, expectedResponse, streamingMessage.Content) + }) + + t.Run("command completion", func(t *testing.T) { + chat.Insert(openai.ChatMessageRoleUser, "localhost") + + msg, err := chat.Complete(context.Background(), "") + require.NoError(t, err) + + require.IsType(t, &model.CompletionCommand{}, msg) + command := msg.(*model.CompletionCommand) + require.Equal(t, "df -h", command.Command) + require.Len(t, command.Nodes, 1) + require.Equal(t, "localhost", command.Nodes[0]) + }) +} + +// generateTextResponse generates a response for a text completion +func generateTextResponse() string { + return "```" + `json + { + "action": "Final Answer", + "action_input": "Which node do you want use?" + } + ` + "```" +} + +// generateCommandResponse generates a response for the command "df -h" on the node "localhost" +func generateCommandResponse() string { + return "```" + `json + { + "action": "Command Execution", + "action_input": "{\"command\":\"df -h\",\"nodes\":[\"localhost\"],\"labels\":[]}" + } + ` + "```" +} diff --git a/lib/ai/model/messages.go b/lib/ai/model/messages.go index 58468972d0a1b..4a4a1487adf62 100644 --- a/lib/ai/model/messages.go +++ b/lib/ai/model/messages.go @@ -66,6 +66,12 @@ type TokensUsed struct { Completion int } +// UsedTokens returns the number of tokens used during a single invocation of the agent. +// This method creates a convinient way to get TokensUsed from embedded structs. +func (t *TokensUsed) UsedTokens() *TokensUsed { + return t +} + // newTokensUsed_Cl100kBase creates a new TokensUsed instance with a Cl100kBase tokenizer. // This tokenizer is used by GPT-3 and GPT-4. func newTokensUsed_Cl100kBase() *TokensUsed { diff --git a/lib/ai/testutils/http.go b/lib/ai/testutils/http.go new file mode 100644 index 0000000000000..c98de30435a2f --- /dev/null +++ b/lib/ai/testutils/http.go @@ -0,0 +1,73 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package testutils + +import ( + "encoding/json" + "net/http" + "strconv" + "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" +) + +// GetTestHandlerFn returns a handler function that can be used to OpenAI API used by +// the chat API. It takes a list of responses that will be returned in order. +func GetTestHandlerFn(t *testing.T, responses []string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + req := &openai.ChatCompletionRequest{} + err := json.NewDecoder(r.Body).Decode(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + + // Use assert as require doesn't work when called from a goroutine + if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return + } + + dataBytes := responses[0] + + resp := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + Model: req.Model, + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: dataBytes, + Name: "", + }, + }, + }, + Usage: openai.Usage{}, + } + + respBytes, err := json.Marshal(resp) + assert.NoError(t, err, "Marshal error") + + _, err = w.Write(respBytes) + assert.NoError(t, err, "Write error") + + responses = responses[1:] + } +} diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go new file mode 100644 index 0000000000000..60b45d871336a --- /dev/null +++ b/lib/web/assistant_test.go @@ -0,0 +1,319 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package web + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/websocket" + "github.com/gravitational/roundtrip" + "github.com/gravitational/trace" + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" + + authproto "github.com/gravitational/teleport/api/client/proto" + aitest "github.com/gravitational/teleport/lib/ai/testutils" + "github.com/gravitational/teleport/lib/assist" + "github.com/gravitational/teleport/lib/client" +) + +func Test_runAssistant(t *testing.T) { + t.Parallel() + + readMessage := func(t *testing.T, ws *websocket.Conn) string { + var msg assistantMessage + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) + return msg.Payload + } + + readRateLimitedMessage := func(t *testing.T, ws *websocket.Conn) { + var msg assistantMessage + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + require.Equal(t, assist.MessageKindError, msg.Type) + require.Equal(t, msg.Payload, "You have reached the rate limit. Please try again later.") + } + + testCases := []struct { + name string + responses []string + cfg webSuiteConfig + setup func(*testing.T, *WebSuite) + act func(*testing.T, *websocket.Conn) + }{ + { + name: "normal", + responses: []string{ + generateTextResponse(), + }, + act: func(t *testing.T, ws *websocket.Conn) { + err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) + require.NoError(t, err) + + const expectedMsg = "Which node do you want to use?" + require.Contains(t, expectedMsg, readMessage(t, ws)) + }, + }, + { + name: "rate limited", + responses: []string{ + generateTextResponse(), + generateTextResponse(), + }, + cfg: webSuiteConfig{ + ClusterFeatures: &authproto.Features{ + Cloud: true, + }, + }, + setup: func(t *testing.T, s *WebSuite) { + // Assert that rate limiter is set up when Cloud feature is active, + // before replacing with a lower capacity rate-limiter for test purposes + require.Equal(t, assistantLimiterRate, s.webHandler.handler.assistantLimiter.Limit()) + + // 101 token capacity (lookaheadTokens+1) and a slow replenish rate + // to let the first completion request succeed, but not the second one + s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101) + + }, + act: func(t *testing.T, ws *websocket.Conn) { + err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) + require.NoError(t, err) + + const expectedMsg = "Which node do you want to use?" + require.Contains(t, expectedMsg, readMessage(t, ws)) + + err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "all nodes, please"}`)) + require.NoError(t, err) + + readRateLimitedMessage(t, ws) + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + responses := tc.responses + server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) + + openaiCfg := openai.DefaultConfig("test-token") + openaiCfg.BaseURL = server.URL + tc.cfg.OpenAIConfig = &openaiCfg + s := newWebSuiteWithConfig(t, tc.cfg) + + if tc.setup != nil { + tc.setup(t, s) + } + + ctx := context.Background() + authPack := s.authPack(t, "foo") + // Create the conversation + conversationID := s.makeAssistConversation(t, ctx, authPack) + + // Make WS client and start the conversation + ws, err := s.makeAssistant(t, authPack, conversationID) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) + + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + var msg assistantMessage + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + // Expect "hello" message + require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) + require.Contains(t, msg.Payload, "Hey, I'm Teleport") + + tc.act(t, ws) + }) + } +} + +// Test_runAssistError tests that the assistant returns an error message +// when the OpenAI API returns an error. +func Test_runAssistError(t *testing.T) { + t.Parallel() + + readHelloMsg := func(ws *websocket.Conn) { + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + var msg assistantMessage + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + // Expect "hello" message + require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) + require.Contains(t, msg.Payload, "Hey, I'm Teleport") + } + + readErrorMsg := func(ws *websocket.Conn) { + err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) + require.NoError(t, err) + + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + var msg assistantMessage + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + // Expect OpenAI error message + require.Equal(t, assist.MessageKindError, msg.Type) + require.Contains(t, msg.Payload, "An error has occurred. Please try again later.") + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // Simulate rate limit error + w.WriteHeader(429) + + errMsg := openai.ErrorResponse{ + Error: &openai.APIError{ + Code: "rate_limit_reached", + Message: "You are sending requests too quickly.", + Param: nil, + Type: "rate_limit_reached", + HTTPStatusCode: 429, + }, + } + + dataBytes, err := json.Marshal(errMsg) + // Use assert as require doesn't work when called from a goroutine + assert.NoError(t, err, "Marshal error") + + _, err = w.Write(dataBytes) + assert.NoError(t, err, "Write error") + })) + t.Cleanup(server.Close) + + openaiCfg := openai.DefaultConfig("test-token") + openaiCfg.BaseURL = server.URL + s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) + + ctx := context.Background() + authPack := s.authPack(t, "foo") + // Create the conversation + conversationID := s.makeAssistConversation(t, ctx, authPack) + + // Make WS client and start the conversation + ws, err := s.makeAssistant(t, authPack, conversationID) + require.NoError(t, err) + t.Cleanup(func() { + // Close should yield an error as the server closes the connection + require.Error(t, ws.Close()) + }) + + // verify responses + readHelloMsg(ws) + readErrorMsg(ws) + + // Check for close message + _, _, err = ws.ReadMessage() + closeErr, ok := err.(*websocket.CloseError) + require.True(t, ok, "Expected close error") + require.Equal(t, websocket.CloseInternalServerErr, closeErr.Code, "Expected abnormal closure") +} + +// makeAssistConversation creates a new assist conversation and returns its ID +func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string { + clt := authPack.clt + + resp, err := clt.PostJSON(ctx, clt.Endpoint("webapi", "assistant", "conversations"), nil) + require.NoError(t, err) + + convResp := struct { + ConversationID string `json:"id"` + }{} + err = json.Unmarshal(resp.Bytes(), &convResp) + require.NoError(t, err) + + return convResp.ConversationID +} + +// makeAssistant creates a new assistant websocket connection. +func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack, conversationID string) (*websocket.Conn, error) { + u := url.URL{ + Host: s.url().Host, + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%s/assistant", currentSiteShortcut), + } + + q := u.Query() + q.Set("conversation_id", conversationID) + q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) + u.RawQuery = q.Encode() + + dialer := websocket.Dialer{} + dialer.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + header := http.Header{} + header.Add("Origin", "http://localhost") + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + ws, resp, err := dialer.Dial(u.String(), header) + if err != nil { + res, err2 := io.ReadAll(resp.Body) + t.Log("response body:", string(res), err2) + return nil, trace.Wrap(err) + } + + err = resp.Body.Close() + if err != nil { + return nil, trace.Wrap(err) + } + + return ws, nil +} + +// generateTextResponse generates a response for a text completion +func generateTextResponse() string { + return "```" + `json + { + "action": "Final Answer", + "action_input": "Which node do you want to use?" + } + ` + "```" +}