diff --git a/lib/ai/chat.go b/lib/ai/chat.go index 6d8424e49be2b..196b980bdf2f6 100644 --- a/lib/ai/chat.go +++ b/lib/ai/chat.go @@ -18,17 +18,13 @@ package ai import ( "context" - "encoding/json" - "errors" - "io" - "strings" "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" "github.com/tiktoken-go/tokenizer" -) -const maxResponseTokens = 2000 + "github.com/gravitational/teleport/lib/ai/model" +) // Chat represents a conversation between a user and an assistant with context memory. type Chat struct { @@ -37,19 +33,14 @@ type Chat struct { tokenizer tokenizer.Codec } -// Insert inserts a message into the conversation. This is commonly in the -// form of a user's input but may also take the form of a system messages used for instructions. -func (chat *Chat) Insert(role string, content string) Message { +// Insert inserts a message into the conversation. Returns the index of the message. +func (chat *Chat) Insert(role string, content string) int { chat.messages = append(chat.messages, openai.ChatCompletionMessage{ Role: role, Content: content, }) - return Message{ - Role: role, - Content: content, - Idx: len(chat.messages) - 1, - } + return len(chat.messages) - 1 } // GetMessages returns the messages in the conversation. @@ -57,157 +48,31 @@ func (chat *Chat) GetMessages() []openai.ChatCompletionMessage { return chat.messages } -// PromptTokens uses the chat's tokenizer to calculate -// the total number of tokens in the prompt -// -// Ref: https://github.com/openai/openai-cookbook/blob/594fc6c952425810e9ea5bd1a275c8ca5f32e8f9/examples/How_to_count_tokens_with_tiktoken.ipynb -func (chat *Chat) PromptTokens() (int, error) { - // perRequest is the number of tokens used up for each completion request - const perRequest = 3 - // perRole is the number of tokens used to encode a message's role - const perRole = 1 - // perMessage is the token "overhead" for each message - const perMessage = 3 - - sum := perRequest - for _, m := range chat.messages { - tokens, _, err := chat.tokenizer.Encode(m.Content) - if err != nil { - return 0, trace.Wrap(err) - } - sum += len(tokens) - sum += perRole - sum += perMessage - } - - return sum, nil -} - -// Complete completes the conversation with a message from the assistant based on the current context. +// Complete completes the conversation with a message from the assistant based on the current context and user input. // On success, it returns the message. // Returned types: -// - Message: the message from the assistant +// - message: one of the message types below // - error: an error if one occurred // Message types: // - CompletionCommand: a command from the assistant -// - StreamingMessage: a message that is streamed from the assistant -func (chat *Chat) Complete(ctx context.Context) (any, error) { - var numTokens int - +// - Message: a text message from the assistant +func (chat *Chat) Complete(ctx context.Context, userInput string) (any, error) { // if the chat is empty, return the initial response we predefine instead of querying GPT-4 if len(chat.messages) == 1 { - return &Message{ - Role: openai.ChatMessageRoleAssistant, - Content: initialAIResponse, - Idx: len(chat.messages) - 1, + return &model.Message{ + Content: model.InitialAIResponse, }, nil } - // if not, copy the current chat log to a new slice and append the suffix instruction - messages := make([]openai.ChatCompletionMessage, len(chat.messages)+1) - copy(messages, chat.messages) - messages[len(messages)-1] = openai.ChatCompletionMessage{ + userMessage := openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, - Content: promptExtractInstruction, + Content: userInput, } - // create a streaming completion request, we do this to optimistically stream the response when - // we don't believe it's a payload - stream, err := chat.client.svc.CreateChatCompletionStream( - ctx, - openai.ChatCompletionRequest{ - Model: openai.GPT4, - Messages: messages, - MaxTokens: maxResponseTokens, - Stream: true, - }, - ) + response, err := model.AssistAgent.PlanAndExecute(ctx, chat.client.svc, chat.messages, userMessage) if err != nil { return nil, trace.Wrap(err) } - var ( - response openai.ChatCompletionStreamResponse - trimmed string - ) - for trimmed == "" { - // fetch the first delta to check for a possible JSON payload - response, err = stream.Recv() - if err != nil { - return nil, trace.Wrap(err) - } - numTokens++ - - trimmed = strings.TrimSpace(response.Choices[0].Delta.Content) - } - - // if it looks like a JSON payload, let's wait for the entire response and try to parse it - if strings.HasPrefix(trimmed, "{") { - payload := strings.Builder{} - payload.WriteString(response.Choices[0].Delta.Content) - - for { - response, err := stream.Recv() - if errors.Is(err, io.EOF) { - break - } - if err != nil { - return nil, trace.Wrap(err) - } - numTokens++ - - payload.WriteString(response.Choices[0].Delta.Content) - } - - // if we can parse it, return the parsed payload, otherwise return a non-streaming message - var c CompletionCommand - err = json.Unmarshal([]byte(payload.String()), &c) - switch err { - case nil: - c.NumTokens = numTokens - return &c, nil - default: - return &Message{ - Role: openai.ChatMessageRoleAssistant, - Content: payload.String(), - Idx: len(chat.messages) - 1, - NumTokens: numTokens, - }, nil - } - } - - // if it doesn't look like a JSON payload, return a streaming message to the caller - chunks := make(chan string, 1) - errCh := make(chan error) - chunks <- response.Choices[0].Delta.Content - go func() { - defer close(chunks) - - for { - response, err := stream.Recv() - switch { - case errors.Is(err, io.EOF): - return - case err != nil: - select { - case <-ctx.Done(): - case errCh <- trace.Wrap(err): - } - return - } - - select { - case chunks <- response.Choices[0].Delta.Content: - case <-ctx.Done(): - return - } - } - }() - - return &StreamingMessage{ - Role: openai.ChatMessageRoleAssistant, - Idx: len(chat.messages) - 1, - Chunks: chunks, - Error: errCh, - }, nil + return response, nil } diff --git a/lib/ai/chat_test.go b/lib/ai/chat_test.go deleted file mode 100644 index bb9aaec20dd4f..0000000000000 --- a/lib/ai/chat_test.go +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Copyright 2023 Gravitational, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ai - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/sashabaranov/go-openai" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tiktoken-go/tokenizer/codec" -) - -func TestChat_PromptTokens(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - messages []openai.ChatCompletionMessage - want int - wantErr bool - }{ - { - name: "empty", - messages: []openai.ChatCompletionMessage{}, - want: 3, - }, - { - name: "only system message", - messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - Content: "Hello", - }, - }, - want: 8, - }, - { - name: "system and user messages", - messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - Content: "Hello", - }, - { - Role: openai.ChatMessageRoleUser, - Content: "Hi LLM.", - }, - }, - want: 16, - }, - { - name: "tokenize our prompt", - messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - Content: promptCharacter("Bob"), - }, - { - Role: openai.ChatMessageRoleUser, - Content: "Show me free disk space on localhost node.", - }, - }, - want: 187, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - chat := &Chat{ - messages: tt.messages, - tokenizer: codec.NewCl100kBase(), - } - usedTokens, err := chat.PromptTokens() - require.NoError(t, err) - require.Equal(t, tt.want, usedTokens) - }) - } -} - -func TestChat_Complete(t *testing.T) { - t.Parallel() - - responses := [][]byte{ - generateTextResponse(), - generateCommandResponse(), - } - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - - // Use assert as require doesn't work when called from a goroutine - assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") - dataBytes := responses[0] - - _, err := w.Write(dataBytes) - assert.NoError(t, err, "Write error") - - responses = responses[1:] - })) - defer server.Close() - - cfg := openai.DefaultConfig("secret-test-token") - cfg.BaseURL = server.URL + "/v1" - client := NewClientFromConfig(cfg) - - chat := client.NewChat("Bob") - - t.Run("initial message", func(t *testing.T) { - msg, err := chat.Complete(context.Background()) - require.NoError(t, err) - - expectedResp := &Message{Role: "assistant", - Content: "Hey, I'm Teleport - a powerful tool that can assist you in managing your Teleport cluster via OpenAI GPT-4.", - Idx: 0, - } - require.Equal(t, expectedResp, msg) - }) - - t.Run("text completion", func(t *testing.T) { - chat.Insert(openai.ChatMessageRoleUser, "Show me free disk space") - - msg, err := chat.Complete(context.Background()) - require.NoError(t, err) - - require.IsType(t, &StreamingMessage{}, msg) - streamingMessage := msg.(*StreamingMessage) - require.Equal(t, openai.ChatMessageRoleAssistant, streamingMessage.Role) - - require.Equal(t, "Which ", <-streamingMessage.Chunks) - require.Equal(t, "node do ", <-streamingMessage.Chunks) - require.Equal(t, "you want ", <-streamingMessage.Chunks) - require.Equal(t, "use?", <-streamingMessage.Chunks) - }) - - t.Run("command completion", func(t *testing.T) { - chat.Insert(openai.ChatMessageRoleUser, "localhost") - - msg, err := chat.Complete(context.Background()) - require.NoError(t, err) - - require.IsType(t, &CompletionCommand{}, msg) - command := msg.(*CompletionCommand) - require.Equal(t, "df -h", command.Command) - require.Len(t, command.Nodes, 1) - require.Equal(t, "localhost", command.Nodes[0]) - }) -} - -// generateTextResponse generates a response for a text completion -func generateTextResponse() []byte { - dataBytes := []byte{} - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "Which ", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data = `{"id":"2","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "node do ", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data = `{"id":"3","object":"completion","created":1598069255,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "you want ", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data = `{"id":"4","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "use?", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: done\n")...) - - dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) - - return dataBytes -} - -// generateCommandResponse generates a response for the command "df -h" on the node "localhost" -func generateCommandResponse() []byte { - dataBytes := []byte{} - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "{\"command\": \"df -h\",", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data = `{"id":"2","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "\"nodes\": [\"localhost\"]}", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - dataBytes = append(dataBytes, []byte("event: done\n")...) - dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) - - return dataBytes -} diff --git a/lib/ai/client.go b/lib/ai/client.go index 39b2b0f050d00..b0f5c6ff768aa 100644 --- a/lib/ai/client.go +++ b/lib/ai/client.go @@ -22,6 +22,8 @@ import ( "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" "github.com/tiktoken-go/tokenizer/codec" + + "github.com/gravitational/teleport/lib/ai/model" ) // Client is a client for OpenAI API. @@ -47,7 +49,7 @@ func (client *Client) NewChat(username string) *Chat { messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, - Content: promptCharacter(username), + Content: model.PromptCharacter(username), }, }, // Initialize a tokenizer for prompt token accounting. @@ -63,7 +65,7 @@ func (client *Client) Summary(ctx context.Context, message string) (string, erro openai.ChatCompletionRequest{ Model: openai.GPT4, Messages: []openai.ChatCompletionMessage{ - {Role: openai.ChatMessageRoleSystem, Content: promptSummarizeTitle}, + {Role: openai.ChatMessageRoleSystem, Content: model.PromptSummarizeTitle}, {Role: openai.ChatMessageRoleUser, Content: message}, }, }, diff --git a/lib/ai/messages.go b/lib/ai/messages.go deleted file mode 100644 index 8d5bd9c73df97..0000000000000 --- a/lib/ai/messages.go +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2023 Gravitational, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ai - -// Message represents a message within a live conversation. -// Indexed by ID for frontend ordering and future partial message streaming. -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - Idx int `json:"idx"` - // NumTokens is the number of completion tokens for the (non-streaming) message - NumTokens int `json:"-"` -} - -// Label represents a label returned by OpenAI's completion API. -type Label struct { - Key string `json:"key"` - Value string `json:"value"` -} - -// CompletionCommand represents a command returned by OpenAI's completion API. -type CompletionCommand struct { - Command string `json:"command,omitempty"` - Nodes []string `json:"nodes,omitempty"` - Labels []Label `json:"labels,omitempty"` - // NumTokens is the number of completion tokens for the (non-streaming) message - NumTokens int `json:"-"` -} - -// StreamingMessage represents a message that is streamed from the assistant and will later be stored as a normal message in the conversation store. -type StreamingMessage struct { - // Role describes the OpenAI role of the message, i.e its sender. - Role string - - // Idx is a semi-unique ID assigned when loading a conversation so that the UI can group partial messages together. - Idx int - - // Chunks is a channel of message chunks that are streamed from the assistant. - Chunks <-chan string - - // Error is a channel which may receive one error if the assistant encounters an error while streaming. - // Consumers should stop reading from all channels if they receive an error and abort. - Error <-chan error -} diff --git a/lib/ai/model/agent.go b/lib/ai/model/agent.go new file mode 100644 index 0000000000000..1c5c2c4c24055 --- /dev/null +++ b/lib/ai/model/agent.go @@ -0,0 +1,340 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package model + +import ( + "context" + "encoding/json" + "strings" + "time" + + "github.com/gravitational/trace" + "github.com/sashabaranov/go-openai" + log "github.com/sirupsen/logrus" +) + +const ( + actionFinalAnswer = "Final Answer" + actionException = "_Exception" + maxIterations = 15 + maxElapsedTime = 5 * time.Minute +) + +// AssistAgent is a global instance of the Assist agent which defines the model responsible for the Assist feature. +var AssistAgent = &Agent{ + tools: []Tool{ + &commandExecutionTool{}, + }, +} + +// Agent is a model storing static state which defines some properties of the chat model. +type Agent struct { + tools []Tool +} + +// agentAction is an event type represetning the decision to take a single action, typically a tool invocation. +type agentAction struct { + // The action to take, typically a tool name. + action string + + // The input to the action, varies depending on the action. + input string + + // The log is either a direct tool response or a thought prompt correlated to the input. + log string +} + +// agentFinish is an event type representing the decision to finish a thought +// loop and return a final text answer to the user. +type agentFinish struct { + // output must be Message or CompletionCommand + output any +} + +type executionState struct { + llm *openai.Client + chatHistory []openai.ChatCompletionMessage + humanMessage openai.ChatCompletionMessage + intermediateSteps []agentAction + observations []string + tokensUsed *TokensUsed +} + +// PlanAndExecute runs the agent with a given input until it arrives at a text answer it is satisfied +// with or until it times out. +func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHistory []openai.ChatCompletionMessage, humanMessage openai.ChatCompletionMessage) (any, error) { + log.Trace("entering agent think loop") + iterations := 0 + start := time.Now() + tookTooLong := func() bool { return iterations > maxIterations || time.Since(start) > maxElapsedTime } + tokensUsed := newTokensUsed_Cl100kBase() + state := &executionState{ + llm: llm, + chatHistory: chatHistory, + humanMessage: humanMessage, + intermediateSteps: make([]agentAction, 0), + observations: make([]string, 0), + tokensUsed: tokensUsed, + } + + for { + log.Tracef("performing iteration %v of loop, %v seconds elapsed", iterations, int(time.Since(start).Seconds())) + + // This is intentionally not context-based, as we want to finish the current step before exiting + // and the concern is not that we're stuck but that we're taking too long over multiple iterations. + if tookTooLong() { + return nil, trace.Errorf("timeout: agent took too long to finish") + } + + output, err := a.takeNextStep(ctx, state) + if err != nil { + return nil, trace.Wrap(err) + } + + if output.finish != nil { + log.Tracef("agent finished with output: %v", output.finish.output) + switch v := output.finish.output.(type) { + case *Message: + v.TokensUsed = tokensUsed + return v, nil + case *CompletionCommand: + v.TokensUsed = tokensUsed + return v, nil + default: + return nil, trace.Errorf("invalid output type %T", v) + } + } + + if output.action != nil { + state.intermediateSteps = append(state.intermediateSteps, *output.action) + state.observations = append(state.observations, output.observation) + } + + iterations++ + } +} + +// stepOutput represents the inputs and outputs of a single thought step. +type stepOutput struct { + // if the agent is done, finish is set. + finish *agentFinish + + // if the agent is not done, action is set together with observation. + action *agentAction + observation string +} + +func (a *Agent) takeNextStep(ctx context.Context, state *executionState) (stepOutput, error) { + log.Trace("agent entering takeNextStep") + defer log.Trace("agent exiting takeNextStep") + + action, finish, err := a.plan(ctx, state) + if err, ok := trace.Unwrap(err).(*invalidOutputError); ok { + log.Tracef("agent encountered an invalid output error: %v, attempting to recover", err) + action := &agentAction{ + action: actionException, + input: observationPrefix + "Invalid or incomplete response", + log: thoughtPrefix + err.Error(), + } + + // The exception tool is currently a bit special, the observation is always equal to the input. + // We can expand on this in the future to make it handle errors better. + log.Tracef("agent decided on action %v and received observation %v", action.action, action.input) + return stepOutput{action: action, observation: action.input}, nil + } + if err != nil { + log.Tracef("agent encountered an error: %v", err) + return stepOutput{}, trace.Wrap(err) + } + + // If finish is set, the agent is done and did not call upon any tool. + if finish != nil { + log.Trace("agent picked finish, returning") + return stepOutput{finish: finish}, nil + } + + var tool Tool + for _, candidate := range a.tools { + if candidate.Name() == action.action { + tool = candidate + break + } + } + + if tool == nil { + log.Tracef("agent picked an unknown tool %v", action.action) + action := &agentAction{ + action: actionException, + input: observationPrefix + "Unknown tool", + log: thoughtPrefix + "No tool with name " + action.action + " exists.", + } + + return stepOutput{action: action, observation: action.input}, nil + } + + if tool, ok := tool.(*commandExecutionTool); ok { + input, err := tool.parseInput(action.input) + if err != nil { + action := &agentAction{ + action: actionException, + input: observationPrefix + "Invalid or incomplete response", + log: thoughtPrefix + err.Error(), + } + + return stepOutput{action: action, observation: action.input}, nil + } + + completion := &CompletionCommand{ + Command: input.Command, + Nodes: input.Nodes, + Labels: input.Labels, + } + + log.Tracef("agent decided on command execution, let's translate to an agentFinish") + return stepOutput{finish: &agentFinish{output: completion}}, nil + } + + return stepOutput{}, trace.NotImplemented("assist does not support non command execution tools yet") +} + +func (a *Agent) plan(ctx context.Context, state *executionState) (*agentAction, *agentFinish, error) { + scratchpad := a.constructScratchpad(state.intermediateSteps, state.observations) + prompt := a.createPrompt(state.chatHistory, scratchpad, state.humanMessage) + resp, err := state.llm.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4, + Messages: prompt, + }, + ) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + llmOut := resp.Choices[0].Message.Content + err = state.tokensUsed.AddTokens(prompt, llmOut) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + action, finish, err := parsePlanningOutput(llmOut) + return action, finish, trace.Wrap(err) +} + +func (a *Agent) createPrompt(chatHistory, agentScratchpad []openai.ChatCompletionMessage, humanMessage openai.ChatCompletionMessage) []openai.ChatCompletionMessage { + prompt := make([]openai.ChatCompletionMessage, 0) + prompt = append(prompt, chatHistory...) + toolList := strings.Builder{} + toolNames := make([]string, 0, len(a.tools)) + for _, tool := range a.tools { + toolNames = append(toolNames, tool.Name()) + toolList.WriteString("> ") + toolList.WriteString(tool.Name()) + toolList.WriteString(": ") + toolList.WriteString(tool.Description()) + toolList.WriteString("\n") + } + + if len(a.tools) == 0 { + toolList.WriteString("No tools available.") + } + + formatInstructions := conversationParserFormatInstructionsPrompt(toolNames) + newHumanMessage := conversationToolUsePrompt(toolList.String(), formatInstructions, humanMessage.Content) + prompt = append(prompt, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: newHumanMessage, + }) + + prompt = append(prompt, agentScratchpad...) + return prompt +} + +func (a *Agent) constructScratchpad(intermediateSteps []agentAction, observations []string) []openai.ChatCompletionMessage { + var thoughts []openai.ChatCompletionMessage + for i, action := range intermediateSteps { + thoughts = append(thoughts, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: action.log, + }, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: conversationToolResponse(observations[i]), + }) + } + + return thoughts +} + +// parseJSONFromModel parses a JSON object from the model output and attempts to sanitize contaminant text +// to avoid triggering self-correction due to some natural language being bundled with the JSON. +// The output type is generic and thus the structure of the expected JSON varies depending on T. +func parseJSONFromModel[T any](text string) (T, *invalidOutputError) { + cleaned := strings.TrimSpace(text) + if strings.Contains(cleaned, "```json") { + cleaned = strings.Split(cleaned, "```json")[1] + } + if strings.Contains(cleaned, "```") { + cleaned = strings.Split(cleaned, "```")[0] + } + cleaned = strings.TrimPrefix(cleaned, "```json") + cleaned = strings.TrimPrefix(cleaned, "```") + cleaned = strings.TrimSuffix(cleaned, "```") + cleaned = strings.TrimSpace(cleaned) + var output T + err := json.Unmarshal([]byte(cleaned), &output) + if err != nil { + return output, newInvalidOutputErrorWithParseError(err) + } + + return output, nil +} + +// planOutput describes the expected JSON output after asking it to plan it's next action. +type planOutput struct { + Action string `json:"action"` + Action_input any `json:"action_input"` +} + +// parsePlanningOutput parses the output of the model after asking it to plan it's next action +// and returns the appropriate event type or an error. +func parsePlanningOutput(text string) (*agentAction, *agentFinish, error) { + log.Tracef("received planning output: \"%v\"", text) + response, err := parseJSONFromModel[planOutput](text) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + if response.Action == actionFinalAnswer { + outputString, ok := response.Action_input.(string) + if !ok { + return nil, nil, trace.Errorf("invalid final answer type %T", response.Action_input) + } + + return nil, &agentFinish{output: &Message{Content: outputString}}, nil + } + + if v, ok := response.Action_input.(string); ok { + return &agentAction{action: response.Action, input: v}, nil, nil + } else { + input, err := json.Marshal(response.Action_input) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return &agentAction{action: response.Action, input: string(input)}, nil, nil + } +} diff --git a/lib/ai/model/error.go b/lib/ai/model/error.go new file mode 100644 index 0000000000000..c3a3c8fd8d697 --- /dev/null +++ b/lib/ai/model/error.go @@ -0,0 +1,41 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package model + +import ( + "fmt" +) + +// invalidOutputError represents an error caused by the output of an LLM. +// These may be used automatically by the agent loop to attempt to correct an output until it is valid. +type invalidOutputError struct { + coarse string + detail string +} + +// newInvalidOutputErrorWithParseError creates a new invalidOutputError assuming a JSON parse error. +func newInvalidOutputErrorWithParseError(err error) *invalidOutputError { + return &invalidOutputError{ + coarse: "json parse error", + detail: err.Error(), + } +} + +// Error returns a string representation of the error. This is used to satisfy the error interface. +func (o *invalidOutputError) Error() string { + return fmt.Sprintf("%v: %v", o.coarse, o.detail) +} diff --git a/lib/ai/model/messages.go b/lib/ai/model/messages.go new file mode 100644 index 0000000000000..58468972d0a1b --- /dev/null +++ b/lib/ai/model/messages.go @@ -0,0 +1,97 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "github.com/gravitational/trace" + "github.com/sashabaranov/go-openai" + "github.com/tiktoken-go/tokenizer" + "github.com/tiktoken-go/tokenizer/codec" +) + +// Ref: https://github.com/openai/openai-cookbook/blob/594fc6c952425810e9ea5bd1a275c8ca5f32e8f9/examples/How_to_count_tokens_with_tiktoken.ipynb +const ( + // perMessage is the token "overhead" for each message + perMessage = 3 + + // perRequest is the number of tokens used up for each completion request + perRequest = 3 + + // perRole is the number of tokens used to encode a message's role + perRole = 1 +) + +// Message represents a new message within a live conversation. +type Message struct { + *TokensUsed + Content string +} + +// Label represents a label returned by OpenAI's completion API. +type Label struct { + Key string `json:"key"` + Value string `json:"value"` +} + +// CompletionCommand represents a command returned by OpenAI's completion API. +type CompletionCommand struct { + *TokensUsed + Command string `json:"command,omitempty"` + Nodes []string `json:"nodes,omitempty"` + Labels []Label `json:"labels,omitempty"` +} + +// TokensUsed is used to track the number of tokens used during a single invocation of the agent. +type TokensUsed struct { + tokenizer tokenizer.Codec + + // Prompt is the number of prompt-class tokens used. + Prompt int + + // Completion is the number of completion-class tokens used. + Completion int +} + +// newTokensUsed_Cl100kBase creates a new TokensUsed instance with a Cl100kBase tokenizer. +// This tokenizer is used by GPT-3 and GPT-4. +func newTokensUsed_Cl100kBase() *TokensUsed { + return &TokensUsed{ + tokenizer: codec.NewCl100kBase(), + Prompt: 0, + Completion: 0, + } +} + +// AddTokens updates TokensUsed with the tokens used for a single call to an LLM. +func (t *TokensUsed) AddTokens(prompt []openai.ChatCompletionMessage, completion string) error { + for _, message := range prompt { + promptTokens, _, err := t.tokenizer.Encode(message.Content) + if err != nil { + return trace.Wrap(err) + } + + t.Prompt = t.Prompt + perMessage + perRole + len(promptTokens) + } + + completionTokens, _, err := t.tokenizer.Encode(completion) + if err != nil { + return trace.Wrap(err) + } + + t.Completion = t.Completion + perRequest + len(completionTokens) + return err +} diff --git a/lib/ai/model/prompt.go b/lib/ai/model/prompt.go new file mode 100644 index 0000000000000..aaa0650dc4339 --- /dev/null +++ b/lib/ai/model/prompt.go @@ -0,0 +1,102 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package model + +import "fmt" + +var observationPrefix = "Observation: " +var thoughtPrefix = "Thought: " + +const PromptSummarizeTitle = `You will be given a message. Create a short summary of that message. +Respond only with summary, nothing else.` + +const InitialAIResponse = `Hey, I'm Teleport - a powerful tool that can assist you in managing your Teleport cluster via OpenAI GPT-4.` + +func PromptCharacter(username string) string { + return fmt.Sprintf(`You are Teleport, a tool that users can use to connect to Linux servers and run relevant commands, as well as have a conversation. +A Teleport cluster is a connectivity layer that allows access to a set of servers. Servers may also be referred to as nodes. +Nodes sometimes have labels such as "production" and "staging" assigned to them. Labels are used to group nodes together. +You will engage in professional conversation with the user and help accomplish tasks such as executing tasks +within the cluster or answering relevant questions about Teleport, Linux or the cluster itself. + +You possess advanced capabilities to think and reason in multiple steps and use the available tools to accomplish the task at hand in a way a human would expect you to. + +You are not permitted to engage in conversation that is not related to Teleport, Linux or the cluster itself. +If this user asks such an unrelated question, you must concisely respond that it is beyond your scope of knowledge. + +You are talking to %v.`, username) +} + +func conversationParserFormatInstructionsPrompt(toolnames []string) string { + return fmt.Sprintf(`RESPONSE FORMAT INSTRUCTIONS +---------------------------- + +When responding to me, please output a response in one of two formats: + +**Option 1:** +Use this if you want the human to use a tool. +Markdown code snippet formatted in the following schema: + +%vjson +{ + "action": string \\ The action to take. Must be one of %v + "action_input": string \\ The input to the action +} +%v + +**Option #2:** +Use this if you want to respond directly to the human or you want to ask the human a question to gather more information. +You should avoid asking too many questions when you have other options available to you as it may be perceived as annoying. +But asking is far better than guessing or making assumptions. +Markdown code snippet formatted in the following schema: + +%vjson +{ + "action": "Final Answer", + "action_input": string \\ You should put what you want to return to use here +} +%v`, "```", toolnames, "```", "```", "```", + ) +} + +func conversationToolUsePrompt(tools string, formatInstructions string, userInput string) string { + return fmt.Sprintf(`TOOLS +------ +Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are: + +%v + +%v + +USER'S INPUT +-------------------- +Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else): + +%v`, tools, formatInstructions, userInput) +} + +func conversationToolResponse(toolResponse string) string { + return fmt.Sprintf(`TOOL RESPONSE: +--------------------- + +%v + +USER'S INPUT +-------------------- + +Okay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else.`, toolResponse) +} diff --git a/lib/ai/model/tool.go b/lib/ai/model/tool.go new file mode 100644 index 0000000000000..ab3c444cd2edb --- /dev/null +++ b/lib/ai/model/tool.go @@ -0,0 +1,96 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package model + +import ( + "context" + "fmt" + + "github.com/gravitational/trace" +) + +// Tool is an interface that allows the agent to interact with the outside world. +// It is used to implement things such as vector document retrieval and command execution. +type Tool interface { + Name() string + Description() string + Run(ctx context.Context, input string) (string, error) +} +type commandExecutionTool struct{} + +type commandExecutionToolInput struct { + // Command is a unix command to execute. + Command string `json:"command"` + + // Nodes is a list of hostnames to execute the command on. + Nodes []string `json:"nodes"` + + // Labels is a list of labels specifying node groups to execute the command on. + Labels []Label `json:"labels"` +} + +func (c *commandExecutionTool) Name() string { + return "Command Execution" +} + +func (c *commandExecutionTool) Description() string { + return fmt.Sprintf(`Execute a command on a set of remote hosts based on a set of hostnames or/and a set of labels. +The input must be a JSON object with the following schema: + +%vjson +{ + "command": string, \\ The command to execute + "nodes": []string, \\ Execute a command on all nodes that have the given hostnames + "labels": []{"key": string, "value": string} \\ Execute a command on all nodes that has at least one of the labels +} +%v +`, "```", "```") +} + +func (c *commandExecutionTool) Run(ctx context.Context, input string) (string, error) { + // This is stubbed because commandExecutionTool is handled specially. + // This is because execution of this tool breaks the loop and returns a command suggestion to the user. + // It is still handled as a tool because testing has shown that the LLM behaves better when it is treated as a tool. + // + // In addition, treating it as a Tool interface item simplifies the display and prompt assembly logic significantly. + return "", trace.NotImplemented("not implemented") +} + +// parseInput is called in a special case if the planned tool is commandExecutionTool. +// This is because commandExecutionTool is handled differently from most other tools and forcibly terminates the thought loop. +func (*commandExecutionTool) parseInput(input string) (*commandExecutionToolInput, *invalidOutputError) { + output, err := parseJSONFromModel[commandExecutionToolInput](input) + if err != nil { + return nil, err + } + + if output.Command == "" { + return nil, &invalidOutputError{ + coarse: "command execution: missing command", + detail: "command must be non-empty", + } + } + + if len(output.Nodes) == 0 && len(output.Labels) == 0 { + return nil, &invalidOutputError{ + coarse: "command execution: missing nodes or labels", + detail: "at least one node or label must be specified", + } + } + + return &output, nil +} diff --git a/lib/ai/prompt.go b/lib/ai/prompt.go deleted file mode 100644 index df3b46dffd5d5..0000000000000 --- a/lib/ai/prompt.go +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2023 Gravitational, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ai - -import "fmt" - -const promptSummarizeTitle = `You will be given a message. Create a short summary of that message. -Respond only with summary, nothing else.` - -const initialAIResponse = `Hey, I'm Teleport - a powerful tool that can assist you in managing your Teleport cluster via OpenAI GPT-4.` - -const promptExtractInstruction = `If the input is a request to complete a task on a server, try to extract the following information: -- A Linux shell command -- One or more target servers -- One or more target labels - -If there is a lack of details, provide most logical solution. -Ensure the output is a valid shell command. -There must be at least one target server or label, otherwise we do not have enough information to complete the task. -Provide the output in the following format with no other text: - -{ - "command": "", - "nodes": ["", ""], - "labels": [ - { - "key": "", - "value": "", - }, - { - "key": "", - "value": "", - } - ] -} - -If the user is not asking to complete a task on a server directly but is asking a question related to Teleport or Linux - disregard this entire message and help them with their Teleport or Linux related request.` - -// promptCharacter is a prompt that sets the context for the conversation. -// Username is the name of the user that the AI is talking to. -func promptCharacter(username string) string { - return fmt.Sprintf(`You are Teleport, a tool that users can use to connect to Linux servers and run relevant commands, as well as have a conversation. -A Teleport cluster is a connectivity layer that allows access to a set of servers. Servers may also be referred to as nodes. -Nodes sometimes have labels such as "production" and "staging" assigned to them. Labels are used to group nodes together. -You will engage in professional conversation with the user and help accomplish tasks such as executing tasks -within the cluster or answering relevant questions about Teleport, Linux or the cluster itself. - -You are not permitted to engage in conversation that is not related to Teleport, Linux or the cluster itself. -If this user asks such an unrelated question, you must concisely respond that it is beyond your scope of knowledge. - -You are talking to %v.`, username) -} diff --git a/lib/assist/assist.go b/lib/assist/assist.go index 0fcbe9f50685c..fa6ee949c1901 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -21,7 +21,6 @@ package assist import ( "context" "encoding/json" - "regexp" "time" "github.com/gravitational/trace" @@ -35,6 +34,7 @@ import ( assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" "github.com/gravitational/teleport/lib/ai" + "github.com/gravitational/teleport/lib/ai/model" "github.com/gravitational/teleport/lib/auth" ) @@ -144,8 +144,7 @@ func (c *Chat) loadMessages(ctx context.Context) error { // IsNewConversation returns true if the conversation has no messages yet. func (c *Chat) IsNewConversation() bool { - // The first message is always the system message generated by us. - return len(c.chat.GetMessages()) <= 1 + return len(c.chat.GetMessages()) == 1 } // getAssistantClient returns the OpenAI client created base on Teleport Plugin information @@ -183,157 +182,38 @@ func getAssistantClient(ctx context.Context, proxyClient auth.ClientI, return ai.NewClient(apiKey), nil } -// InsertAssistantMessage inserts a message from the user into the conversation. -// Message is saved in the DB. -func (c *Chat) InsertAssistantMessage(ctx context.Context, msgType MessageType, msgPayload string, -) error { - // write a user message to both an in-memory chain and persistent storage - c.chat.Insert(kindToRole(msgType), msgPayload) +// ProcessComplete processes the completion request and returns the number of tokens used. +func (c *Chat) ProcessComplete(ctx context.Context, + onMessage func(kind MessageType, payload []byte, createdTime time.Time) error, userInput string, +) (*model.TokensUsed, error) { + var tokensUsed *model.TokensUsed + + // query the assistant and fetch an answer + message, err := c.chat.Complete(ctx, userInput) + if err != nil { + return nil, trace.Wrap(err) + } + // write the user message to persistent storage and the chat structure + c.chat.Insert(openai.ChatMessageRoleUser, userInput) if err := c.authClient.CreateAssistantMessage(ctx, &assistpb.CreateAssistantMessageRequest{ Message: &assistpb.AssistantMessage{ - Type: string(msgType), //string(MessageKindUserMessage), - Payload: msgPayload, // TODO(jakule): Sanitize the payload + Type: string(MessageKindUserMessage), + Payload: userInput, // TODO(jakule): Sanitize the payload CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()), }, ConversationId: c.ConversationID, Username: c.Username, }); err != nil { - return trace.Wrap(err) - } - - return nil -} - -// TokensUsed is a number of tokens used in the last completion. -type TokensUsed struct { - // Prompt is a number of tokens used in the prompt. - Prompt int - // Completion is a number of tokens used in the completion. - Completion int -} - -// ProcessComplete processes the completion request and returns the number of tokens used. -func (c *Chat) ProcessComplete(ctx context.Context, - onMessage func(kind MessageType, payload []byte, createdTime time.Time) error, -) (*TokensUsed, error) { - var numTokens int - - promptTokens, err := c.chat.PromptTokens() - if err != nil { - log.Warnf("Failed to calculate prompt tokens: %v", err) - } - - // query the assistant and fetch an answer - message, err := c.chat.Complete(ctx) - if err != nil { return nil, trace.Wrap(err) } switch message := message.(type) { - case *ai.StreamingMessage: - // collection of the entire message, used for writing to conversation log - content := "" - - // stream all chunks to the client - outer: - for { - select { - case chunk, ok := <-message.Chunks: - if !ok { - break outer - } - - if len(chunk) == 0 { - continue outer - } - - numTokens++ - content += chunk - payload := partialMessagePayload{ - Content: chunk, - Idx: message.Idx, - } - - payloadJSON, err := json.Marshal(payload) - if err != nil { - return nil, trace.Wrap(err) - } - - if err := onMessage(MessageKindAssistantPartialMessage, payloadJSON, c.assist.clock.Now().UTC()); err != nil { - return nil, trace.Wrap(err) - } - - case err = <-message.Error: - return nil, trace.Wrap(err) - } - } - - // tell the client that the message is complete - finalizePayload := partialFinalizePayload{ - Idx: message.Idx, - } - - finalizePayloadJSON, err := json.Marshal(finalizePayload) - if err != nil { - return nil, trace.Wrap(err) - } - - if err := onMessage(MessageKindAssistantPartialFinalize, finalizePayloadJSON, c.assist.clock.Now().UTC()); err != nil { - return nil, trace.Wrap(err) - } - - // write the entire message to both an in-memory chain and persistent storage - c.chat.Insert(message.Role, content) - protoMsg := &assist.CreateAssistantMessageRequest{ - ConversationId: c.ConversationID, - Username: c.Username, - Message: &assist.AssistantMessage{ - Type: string(MessageKindAssistantMessage), - Payload: content, - CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()), - }, - } + case *model.Message: + tokensUsed = message.TokensUsed + c.chat.Insert(openai.ChatMessageRoleAssistant, message.Content) - if err := c.authClient.CreateAssistantMessage(ctx, protoMsg); err != nil { - return nil, trace.Wrap(err) - } - - // check if there's any embedded command in the response, if so, send a suggestion with it - if command := tryFindEmbeddedCommand(content); command != nil { - payload := commandPayload{ - Command: command.Command, - Nodes: command.Nodes, - Labels: command.Labels, - } - - payloadJson, err := json.Marshal(payload) - if err != nil { - return nil, trace.Wrap(err) - } - - msg := &assist.CreateAssistantMessageRequest{ - ConversationId: c.ConversationID, - Username: c.Username, - Message: &assist.AssistantMessage{ - Type: string(MessageKindCommand), - Payload: string(payloadJson), - CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()), - }, - } - - if err := c.authClient.CreateAssistantMessage(ctx, msg); err != nil { - return nil, trace.Wrap(err) - } - - if err := onMessage(MessageKindCommand, payloadJson, c.assist.clock.Now().UTC()); err != nil { - return nil, trace.Wrap(err) - } - } - case *ai.Message: - numTokens = message.NumTokens - // write an assistant message to both an in-memory chain and persistent storage - c.chat.Insert(message.Role, message.Content) + // write an assistant message to persistent storage protoMsg := &assist.CreateAssistantMessageRequest{ ConversationId: c.ConversationID, Username: c.Username, @@ -351,8 +231,8 @@ func (c *Chat) ProcessComplete(ctx context.Context, if err := onMessage(MessageKindAssistantMessage, []byte(message.Content), c.assist.clock.Now().UTC()); err != nil { return nil, trace.Wrap(err) } - case *ai.CompletionCommand: - numTokens = message.NumTokens + case *model.CompletionCommand: + tokensUsed = message.TokensUsed payload := commandPayload{ Command: message.Command, Nodes: message.Nodes, @@ -385,26 +265,7 @@ func (c *Chat) ProcessComplete(ctx context.Context, return nil, trace.Errorf("unknown message type") } - return &TokensUsed{ - Prompt: promptTokens, - Completion: numTokens, - }, nil -} - -var jsonBlockPattern = regexp.MustCompile(`(?s){.+}`) - -// tryFindEmbeddedCommand tries to find an embedded command in the message. -func tryFindEmbeddedCommand(message string) *ai.CompletionCommand { - candidates := jsonBlockPattern.FindAllString(message, -1) - - for _, candidate := range candidates { - var c ai.CompletionCommand - if err := json.Unmarshal([]byte(candidate), &c); err == nil { - return &c - } - } - - return nil + return tokensUsed, nil } func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient auth.ClientI) (string, error) { diff --git a/lib/assist/messages.go b/lib/assist/messages.go index 2c12ebfc15e2d..f36274b8d89e2 100644 --- a/lib/assist/messages.go +++ b/lib/assist/messages.go @@ -19,22 +19,11 @@ package assist -import "github.com/gravitational/teleport/lib/ai" +import "github.com/gravitational/teleport/lib/ai/model" // commandPayload is a payload for a command message. type commandPayload struct { - Command string `json:"command,omitempty"` - Nodes []string `json:"nodes,omitempty"` - Labels []ai.Label `json:"labels,omitempty"` -} - -// partialMessagePayload is a payload for a partial message. -type partialMessagePayload struct { - Content string `json:"content,omitempty"` - Idx int `json:"idx,omitempty"` -} - -// partialFinalizePayload is a payload for a partial finalize message. -type partialFinalizePayload struct { - Idx int `json:"idx,omitempty"` + Command string `json:"command,omitempty"` + Nodes []string `json:"nodes,omitempty"` + Labels []model.Label `json:"labels,omitempty"` } diff --git a/lib/web/assistant.go b/lib/web/assistant.go index d7f41d9ca53b2..32dc798e892ed 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -433,7 +433,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, if chat.IsNewConversation() { // new conversation, generate a hello message - if _, err := chat.ProcessComplete(ctx, onMessageFn); err != nil { + if _, err := chat.ProcessComplete(ctx, onMessageFn, ""); err != nil { return trace.Wrap(err) } } @@ -465,11 +465,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, } //TODO(jakule): Should we sanitize the payload? - if err := chat.InsertAssistantMessage(ctx, assist.MessageKindUserMessage, wsIncoming.Payload); err != nil { - return trace.Wrap(err) - } - - usedTokens, err := chat.ProcessComplete(ctx, onMessageFn) + usedTokens, err := chat.ProcessComplete(ctx, onMessageFn, wsIncoming.Payload) if err != nil { return trace.Wrap(err) } diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go deleted file mode 100644 index 292517c7234e1..0000000000000 --- a/lib/web/assistant_test.go +++ /dev/null @@ -1,363 +0,0 @@ -/* - * Copyright 2023 Gravitational, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package web - -import ( - "context" - "crypto/tls" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/gorilla/websocket" - "github.com/gravitational/roundtrip" - "github.com/gravitational/trace" - "github.com/sashabaranov/go-openai" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/time/rate" - - authproto "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/lib/assist" - "github.com/gravitational/teleport/lib/client" -) - -func Test_runAssistant(t *testing.T) { - t.Parallel() - - readPartialMessage := func(t *testing.T, ws *websocket.Conn) string { - var msg assistantMessage - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - require.Equal(t, assist.MessageKindAssistantPartialMessage, msg.Type) - return msg.Payload - } - - readStreamEnd := func(t *testing.T, ws *websocket.Conn) { - var msg assistantMessage - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - require.Equal(t, assist.MessageKindAssistantPartialFinalize, msg.Type) - } - - readRateLimitedMessage := func(t *testing.T, ws *websocket.Conn) { - var msg assistantMessage - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - require.Equal(t, assist.MessageKindError, msg.Type) - require.Equal(t, msg.Payload, "You have reached the rate limit. Please try again later.") - } - - testCases := []struct { - name string - responses [][]byte - cfg webSuiteConfig - setup func(*testing.T, *WebSuite) - act func(*testing.T, *websocket.Conn) - }{ - { - name: "normal", - responses: [][]byte{ - generateTextResponse(), - }, - act: func(t *testing.T, ws *websocket.Conn) { - err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) - require.NoError(t, err) - - require.Contains(t, readPartialMessage(t, ws), "Which") - require.Contains(t, readPartialMessage(t, ws), "node do") - require.Contains(t, readPartialMessage(t, ws), "you want") - require.Contains(t, readPartialMessage(t, ws), "use?") - - readStreamEnd(t, ws) - }, - }, - { - name: "rate limited", - responses: [][]byte{ - generateTextResponse(), - generateTextResponse(), - }, - cfg: webSuiteConfig{ - ClusterFeatures: &authproto.Features{ - Cloud: true, - }, - }, - setup: func(t *testing.T, s *WebSuite) { - // Assert that rate limiter is set up when Cloud feature is active, - // before replacing with a lower capacity rate-limiter for test purposes - require.Equal(t, assistantLimiterRate, s.webHandler.handler.assistantLimiter.Limit()) - - // 101 token capacity (lookaheadTokens+1) and a slow replenish rate - // to let the first completion request succeed, but not the second one - s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101) - - }, - act: func(t *testing.T, ws *websocket.Conn) { - err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) - require.NoError(t, err) - - require.Contains(t, readPartialMessage(t, ws), "Which") - require.Contains(t, readPartialMessage(t, ws), "node do") - require.Contains(t, readPartialMessage(t, ws), "you want") - require.Contains(t, readPartialMessage(t, ws), "use?") - - readStreamEnd(t, ws) - - err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "all nodes, please"}`)) - require.NoError(t, err) - - readRateLimitedMessage(t, ws) - }, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - responses := tc.responses - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - - // Use assert as require doesn't work when called from a goroutine - assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") - dataBytes := responses[0] - - _, err := w.Write(dataBytes) - assert.NoError(t, err, "Write error") - - responses = responses[1:] - })) - t.Cleanup(server.Close) - - openaiCfg := openai.DefaultConfig("test-token") - openaiCfg.BaseURL = server.URL - tc.cfg.OpenAIConfig = &openaiCfg - s := newWebSuiteWithConfig(t, tc.cfg) - - if tc.setup != nil { - tc.setup(t, s) - } - - ctx := context.Background() - authPack := s.authPack(t, "foo") - // Create the conversation - conversationID := s.makeAssistConversation(t, ctx, authPack) - - // Make WS client and start the conversation - ws, err := s.makeAssistant(t, authPack, conversationID) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, ws.Close()) }) - - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - var msg assistantMessage - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - // Expect "hello" message - require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) - require.Contains(t, msg.Payload, "Hey, I'm Teleport") - - tc.act(t, ws) - }) - } -} - -// Test_runAssistError tests that the assistant returns an error message -// when the OpenAI API returns an error. -func Test_runAssistError(t *testing.T) { - t.Parallel() - - readHelloMsg := func(ws *websocket.Conn) { - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - var msg assistantMessage - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - // Expect "hello" message - require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) - require.Contains(t, msg.Payload, "Hey, I'm Teleport") - } - - readErrorMsg := func(ws *websocket.Conn) { - err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) - require.NoError(t, err) - - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - var msg assistantMessage - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - // Expect OpenAI error message - require.Equal(t, assist.MessageKindError, msg.Type) - require.Contains(t, msg.Payload, "You are sending requests too quickly") - } - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - // Simulate rate limit error - w.WriteHeader(429) - - errMsg := openai.ErrorResponse{ - Error: &openai.APIError{ - Code: "rate_limit_reached", - Message: "You are sending requests too quickly.", - Param: nil, - Type: "rate_limit_reached", - HTTPStatusCode: 429, - }, - } - - dataBytes, err := json.Marshal(errMsg) - // Use assert as require doesn't work when called from a goroutine - assert.NoError(t, err, "Marshal error") - - _, err = w.Write(dataBytes) - assert.NoError(t, err, "Write error") - })) - t.Cleanup(server.Close) - - openaiCfg := openai.DefaultConfig("test-token") - openaiCfg.BaseURL = server.URL - s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) - - ctx := context.Background() - authPack := s.authPack(t, "foo") - // Create the conversation - conversationID := s.makeAssistConversation(t, ctx, authPack) - - // Make WS client and start the conversation - ws, err := s.makeAssistant(t, authPack, conversationID) - require.NoError(t, err) - t.Cleanup(func() { - // Close should yield an error as the server closes the connection - require.Error(t, ws.Close()) - }) - - // verify responses - readHelloMsg(ws) - readErrorMsg(ws) - - // Check for close message - _, _, err = ws.ReadMessage() - closeErr, ok := err.(*websocket.CloseError) - require.True(t, ok, "Expected close error") - require.Equal(t, websocket.CloseInternalServerErr, closeErr.Code, "Expected abnormal closure") -} - -// makeAssistConversation creates a new assist conversation and returns its ID -func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string { - clt := authPack.clt - - resp, err := clt.PostJSON(ctx, clt.Endpoint("webapi", "assistant", "conversations"), nil) - require.NoError(t, err) - - convResp := struct { - ConversationID string `json:"id"` - }{} - err = json.Unmarshal(resp.Bytes(), &convResp) - require.NoError(t, err) - - return convResp.ConversationID -} - -// makeAssistant creates a new assistant websocket connection. -func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack, conversationID string) (*websocket.Conn, error) { - u := url.URL{ - Host: s.url().Host, - Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%s/assistant", currentSiteShortcut), - } - - q := u.Query() - q.Set("conversation_id", conversationID) - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) - u.RawQuery = q.Encode() - - dialer := websocket.Dialer{} - dialer.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, - } - - header := http.Header{} - header.Add("Origin", "http://localhost") - for _, cookie := range pack.cookies { - header.Add("Cookie", cookie.String()) - } - - ws, resp, err := dialer.Dial(u.String(), header) - if err != nil { - res, err2 := io.ReadAll(resp.Body) - t.Log("response body:", string(res), err2) - return nil, trace.Wrap(err) - } - - err = resp.Body.Close() - if err != nil { - return nil, trace.Wrap(err) - } - - return ws, nil -} - -func generateTextResponse() []byte { - dataBytes := []byte{} - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "Which ", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data = `{"id":"2","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "node do ", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data = `{"id":"3","object":"completion","created":1598069255,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "you want ", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: message\n")...) - - data = `{"id":"4","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "use?", "role": "assistant"}}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - dataBytes = append(dataBytes, []byte("event: done\n")...) - - dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) - - return dataBytes -}