diff --git a/lib/ai/chat.go b/lib/ai/chat.go index 055e4f406d8cf..dd2691f914f95 100644 --- a/lib/ai/chat.go +++ b/lib/ai/chat.go @@ -57,7 +57,7 @@ func (chat *Chat) GetMessages() []openai.ChatCompletionMessage { // Message types: // - CompletionCommand: a command from the assistant // - Message: a text message from the assistant -func (chat *Chat) Complete(ctx context.Context, userInput string) (any, error) { +func (chat *Chat) Complete(ctx context.Context, userInput string, progressUpdates func(*model.AgentAction)) (any, error) { // if the chat is empty, return the initial response we predefine instead of querying GPT-4 if len(chat.messages) == 1 { return &model.Message{ @@ -71,7 +71,7 @@ func (chat *Chat) Complete(ctx context.Context, userInput string) (any, error) { Content: userInput, } - response, err := chat.agent.PlanAndExecute(ctx, chat.client.svc, chat.messages, userMessage) + response, err := chat.agent.PlanAndExecute(ctx, chat.client.svc, chat.messages, userMessage, progressUpdates) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/ai/chat_test.go b/lib/ai/chat_test.go index 1b25591450125..a016574d7ba5c 100644 --- a/lib/ai/chat_test.go +++ b/lib/ai/chat_test.go @@ -18,6 +18,9 @@ package ai import ( "context" + "encoding/json" + "fmt" + "net/http" "net/http/httptest" "testing" @@ -25,7 +28,6 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/ai/model" - aitest "github.com/gravitational/teleport/lib/ai/testutils" ) func TestChat_PromptTokens(t *testing.T) { @@ -49,7 +51,7 @@ func TestChat_PromptTokens(t *testing.T) { Content: "Hello", }, }, - want: 743, + want: 697, }, { name: "system and user messages", @@ -63,7 +65,7 @@ func TestChat_PromptTokens(t *testing.T) { Content: "Hi LLM.", }, }, - want: 751, + want: 705, }, { name: "tokenize our prompt", @@ -77,7 +79,7 @@ func TestChat_PromptTokens(t *testing.T) { Content: "Show me free disk space on localhost node.", }, }, - want: 954, + want: 908, }, } @@ -89,7 +91,17 @@ func TestChat_PromptTokens(t *testing.T) { responses := []string{ generateCommandResponse(), } - server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + require.GreaterOrEqual(t, len(responses), 1, "Unexpected request") + dataBytes := responses[0] + _, err := w.Write([]byte(dataBytes)) + require.NoError(t, err, "Write error") + + responses = responses[1:] + })) + t.Cleanup(server.Close) cfg := openai.DefaultConfig("secret-test-token") @@ -103,7 +115,7 @@ func TestChat_PromptTokens(t *testing.T) { } ctx := context.Background() - message, err := chat.Complete(ctx, "") + message, err := chat.Complete(ctx, "", func(aa *model.AgentAction) {}) require.NoError(t, err) msg, ok := message.(interface{ UsedTokens() *model.TokensUsed }) require.True(t, ok) @@ -117,12 +129,22 @@ func TestChat_PromptTokens(t *testing.T) { func TestChat_Complete(t *testing.T) { t.Parallel() - responses := []string{ - generateTextResponse(), - generateCommandResponse(), + responses := [][]byte{ + []byte(generateTextResponse()), + []byte(generateCommandResponse()), } - server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) - t.Cleanup(server.Close) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + require.GreaterOrEqual(t, len(responses), 1, "Unexpected request") + dataBytes := responses[0] + + _, err := w.Write(dataBytes) + require.NoError(t, err, "Write error") + + responses = responses[1:] + })) + defer server.Close() cfg := openai.DefaultConfig("secret-test-token") cfg.BaseURL = server.URL + "/v1" @@ -130,38 +152,26 @@ func TestChat_Complete(t *testing.T) { chat := client.NewChat(nil, "Bob") - t.Run("initial message", func(t *testing.T) { - msgAny, err := chat.Complete(context.Background(), "Hello") - require.NoError(t, err) - - msg, ok := msgAny.(*model.Message) - require.True(t, ok) + ctx := context.Background() + _, err := chat.Complete(ctx, "Hello", func(aa *model.AgentAction) {}) + require.NoError(t, err) - expectedResp := &model.Message{ - Content: "Hey, I'm Teleport - a powerful tool that can assist you in managing your Teleport cluster via OpenAI GPT-4.", - } - require.Equal(t, expectedResp.Content, msg.Content) - require.NotNil(t, msg.TokensUsed) - }) + chat.Insert(openai.ChatMessageRoleUser, "Show me free disk space on localhost node.") t.Run("text completion", func(t *testing.T) { - chat.Insert(openai.ChatMessageRoleUser, "Show me free disk space") - - msg, err := chat.Complete(context.Background(), "") + msg, err := chat.Complete(ctx, "Show me free disk space", func(aa *model.AgentAction) {}) require.NoError(t, err) - require.IsType(t, &model.Message{}, msg) - streamingMessage := msg.(*model.Message) - - const expectedResponse = "Which node do you want use?" - - require.Equal(t, expectedResponse, streamingMessage.Content) + require.IsType(t, &model.StreamingMessage{}, msg) + streamingMessage := msg.(*model.StreamingMessage) + require.Equal(t, "Which ", <-streamingMessage.Parts) + require.Equal(t, "node do ", <-streamingMessage.Parts) + require.Equal(t, "you want ", <-streamingMessage.Parts) + require.Equal(t, "use?", <-streamingMessage.Parts) }) t.Run("command completion", func(t *testing.T) { - chat.Insert(openai.ChatMessageRoleUser, "localhost") - - msg, err := chat.Complete(context.Background(), "") + msg, err := chat.Complete(ctx, "localhost", func(aa *model.AgentAction) {}) require.NoError(t, err) require.IsType(t, &model.CompletionCommand{}, msg) @@ -174,20 +184,64 @@ func TestChat_Complete(t *testing.T) { // generateTextResponse generates a response for a text completion func generateTextResponse() string { - return "```" + `json - { - "action": "Final Answer", - "action_input": "Which node do you want use?" - } - ` + "```" + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "Which ", "role": "assistant"}}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + dataBytes = append(dataBytes, []byte("event: message\n")...) + + data = `{"id":"2","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "node do ", "role": "assistant"}}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + dataBytes = append(dataBytes, []byte("event: message\n")...) + + data = `{"id":"3","object":"completion","created":1598069255,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "you want ", "role": "assistant"}}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + dataBytes = append(dataBytes, []byte("event: message\n")...) + + data = `{"id":"4","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "use?", "role": "assistant"}}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + dataBytes = append(dataBytes, []byte("event: done\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + return string(dataBytes) } // generateCommandResponse generates a response for the command "df -h" on the node "localhost" func generateCommandResponse() string { - return "```" + `json - { - "action": "Command Execution", - "action_input": "{\"command\":\"df -h\",\"nodes\":[\"localhost\"],\"labels\":[]}" + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + + actionObj := model.PlanOutput{ + Action: "Command Execution", + ActionInput: struct { + Command string `json:"command"` + Nodes []string `json:"nodes"` + }{"df -h", []string{"localhost"}}, + } + actionJson, err := json.Marshal(actionObj) + if err != nil { + panic(err) } - ` + "```" + + obj := struct { + Content string `json:"content"` + Role string `json:"role"` + }{ + Content: string(actionJson), + Role: "assistant", + } + json, err := json.Marshal(obj) + if err != nil { + panic(err) + } + + data := fmt.Sprintf(`{"id":"1","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":%v}]}`, string(json)) + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + return string(dataBytes) } diff --git a/lib/ai/embeddings.go b/lib/ai/embeddings.go index a215b2b8813dd..3e8f0a21298db 100644 --- a/lib/ai/embeddings.go +++ b/lib/ai/embeddings.go @@ -229,7 +229,7 @@ func (e *EmbeddingProcessor) Run(ctx context.Context, initialDelay, period time. } func (e *EmbeddingProcessor) process(ctx context.Context) { - batch := NewBatchReducer[*nodeStringPair, []*Embedding](e.mapProcessFn, + batch := NewBatchReducer(e.mapProcessFn, maxEmbeddingAPISize, // Max batch size allowed by OpenAI API, ) diff --git a/lib/ai/model/agent.go b/lib/ai/model/agent.go index 37761ff940a6a..a0cade5fda335 100644 --- a/lib/ai/model/agent.go +++ b/lib/ai/model/agent.go @@ -19,6 +19,9 @@ package model import ( "context" "encoding/json" + "errors" + "fmt" + "io" "strings" "time" @@ -30,10 +33,17 @@ import ( ) const ( - actionFinalAnswer = "Final Answer" - actionException = "_Exception" - maxIterations = 15 - maxElapsedTime = 5 * time.Minute + // The internal name used to create actions when the agent encounters an error, such as when parsing output. + actionException = "_Exception" + + // The maximum amount of thought<-> observation iterations the agent is allowed to perform. + maxIterations = 15 + + // The maximum amount of time the agent is allowed to spend before yielding a final answer. + maxElapsedTime = 5 * time.Minute + + // The special header the LLM has to respond with to indicate it's done. + finalResponseHeader = "" ) // NewAgent creates a new agent. The Assist agent which defines the model responsible for the Assist feature. @@ -54,22 +64,25 @@ type Agent struct { tools []Tool } -// agentAction is an event type representing the decision to take a single action, typically a tool invocation. -type agentAction struct { +// AgentAction is an event type representing the decision to take a single action, typically a tool invocation. +type AgentAction struct { // The action to take, typically a tool name. - action string + Action string `json:"action"` // The input to the action, varies depending on the action. - input string + Input string `json:"input"` // The log is either a direct tool response or a thought prompt correlated to the input. - log string + Log string `json:"log"` + + // The reasoning is a string describing the reasoning behind the action. + Reasoning string `json:"reasoning"` } // agentFinish is an event type representing the decision to finish a thought // loop and return a final text answer to the user. type agentFinish struct { - // output must be Message or CompletionCommand + // output must be Message, StreamingMessage or CompletionCommand output any } @@ -77,14 +90,14 @@ type executionState struct { llm *openai.Client chatHistory []openai.ChatCompletionMessage humanMessage openai.ChatCompletionMessage - intermediateSteps []agentAction + intermediateSteps []AgentAction observations []string tokensUsed *TokensUsed } // PlanAndExecute runs the agent with a given input until it arrives at a text answer it is satisfied // with or until it times out. -func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHistory []openai.ChatCompletionMessage, humanMessage openai.ChatCompletionMessage) (any, error) { +func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHistory []openai.ChatCompletionMessage, humanMessage openai.ChatCompletionMessage, progressUpdates func(*AgentAction)) (any, error) { log.Trace("entering agent think loop") iterations := 0 start := time.Now() @@ -94,7 +107,7 @@ func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHist llm: llm, chatHistory: chatHistory, humanMessage: humanMessage, - intermediateSteps: make([]agentAction, 0), + intermediateSteps: make([]AgentAction, 0), observations: make([]string, 0), tokensUsed: tokensUsed, } @@ -108,23 +121,21 @@ func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHist return nil, trace.Errorf("timeout: agent took too long to finish") } - output, err := a.takeNextStep(ctx, state) + output, err := a.takeNextStep(ctx, state, progressUpdates) if err != nil { return nil, trace.Wrap(err) } if output.finish != nil { - log.Tracef("agent finished with output: %v", output.finish.output) - switch v := output.finish.output.(type) { - case *Message: - v.TokensUsed = tokensUsed - return v, nil - case *CompletionCommand: - v.TokensUsed = tokensUsed - return v, nil - default: - return nil, trace.Errorf("invalid output type %T", v) + log.Tracef("agent finished with output: %#v", output.finish.output) + item, ok := output.finish.output.(interface{ SetUsed(data *TokensUsed) }) + if !ok { + return nil, trace.Errorf("invalid output type %T", output.finish.output) } + + item.SetUsed(tokensUsed) + + return item, nil } if output.action != nil { @@ -142,27 +153,27 @@ type stepOutput struct { finish *agentFinish // if the agent is not done, action is set together with observation. - action *agentAction + action *AgentAction observation string } -func (a *Agent) takeNextStep(ctx context.Context, state *executionState) (stepOutput, error) { +func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progressUpdates func(*AgentAction)) (stepOutput, error) { log.Trace("agent entering takeNextStep") defer log.Trace("agent exiting takeNextStep") action, finish, err := a.plan(ctx, state) if err, ok := trace.Unwrap(err).(*invalidOutputError); ok { log.Tracef("agent encountered an invalid output error: %v, attempting to recover", err) - action := &agentAction{ - action: actionException, - input: observationPrefix + "Invalid or incomplete response", - log: thoughtPrefix + err.Error(), + action := &AgentAction{ + Action: actionException, + Input: observationPrefix + "Invalid or incomplete response", + Log: thoughtPrefix + err.Error(), } // The exception tool is currently a bit special, the observation is always equal to the input. // We can expand on this in the future to make it handle errors better. - log.Tracef("agent decided on action %v and received observation %v", action.action, action.input) - return stepOutput{action: action, observation: action.input}, nil + log.Tracef("agent decided on action %v and received observation %v", action.Action, action.Input) + return stepOutput{action: action, observation: action.Input}, nil } if err != nil { log.Tracef("agent encountered an error: %v", err) @@ -175,75 +186,96 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState) (stepOu return stepOutput{finish: finish}, nil } + // If action is set, the agent is not done and called upon a tool. + progressUpdates(action) + var tool Tool for _, candidate := range a.tools { - if candidate.Name() == action.action { + if candidate.Name() == action.Action { tool = candidate break } } if tool == nil { - log.Tracef("agent picked an unknown tool %v", action.action) - action := &agentAction{ - action: actionException, - input: observationPrefix + "Unknown tool", - log: thoughtPrefix + "No tool with name " + action.action + " exists.", + log.Tracef("agent picked an unknown tool %v", action.Action) + action := &AgentAction{ + Action: actionException, + Input: observationPrefix + "Unknown tool", + Log: fmt.Sprintf("%s No tool with name %s exists.", thoughtPrefix, action.Action), } - return stepOutput{action: action, observation: action.input}, nil + return stepOutput{action: action, observation: action.Input}, nil } if tool, ok := tool.(*commandExecutionTool); ok { - input, err := tool.parseInput(action.input) + input, err := tool.parseInput(action.Input) if err != nil { - action := &agentAction{ - action: actionException, - input: observationPrefix + "Invalid or incomplete response", - log: thoughtPrefix + err.Error(), + action := &AgentAction{ + Action: actionException, + Input: observationPrefix + "Invalid or incomplete response", + Log: thoughtPrefix + err.Error(), } - return stepOutput{action: action, observation: action.input}, nil + return stepOutput{action: action, observation: action.Input}, nil } completion := &CompletionCommand{ - Command: input.Command, - Nodes: input.Nodes, - Labels: input.Labels, + TokensUsed: newTokensUsed_Cl100kBase(), + Command: input.Command, + Nodes: input.Nodes, + Labels: input.Labels, } log.Tracef("agent decided on command execution, let's translate to an agentFinish") return stepOutput{finish: &agentFinish{output: completion}}, nil } - runOut, err := tool.Run(ctx, action.input) + runOut, err := tool.Run(ctx, action.Input) if err != nil { return stepOutput{}, trace.Wrap(err) } return stepOutput{action: action, observation: runOut}, nil } -func (a *Agent) plan(ctx context.Context, state *executionState) (*agentAction, *agentFinish, error) { +func (a *Agent) plan(ctx context.Context, state *executionState) (*AgentAction, *agentFinish, error) { scratchpad := a.constructScratchpad(state.intermediateSteps, state.observations) prompt := a.createPrompt(state.chatHistory, scratchpad, state.humanMessage) - resp, err := state.llm.CreateChatCompletion( + stream, err := state.llm.CreateChatCompletionStream( ctx, openai.ChatCompletionRequest{ Model: openai.GPT4, Messages: prompt, + Stream: true, }, ) if err != nil { return nil, nil, trace.Wrap(err) } - llmOut := resp.Choices[0].Message.Content - err = state.tokensUsed.AddTokens(prompt, llmOut) - if err != nil { - return nil, nil, trace.Wrap(err) - } + deltas := make(chan string) + completion := strings.Builder{} + go func() { + defer close(deltas) + + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + return + } else if err != nil { + log.Tracef("agent encountered an error while streaming: %v", err) + return + } - action, finish, err := parsePlanningOutput(llmOut) + delta := response.Choices[0].Delta.Content + deltas <- delta + // TODO(jakule): Fix token counting. Uncommenting the line below causes a race condition. + //completion.WriteString(delta) + } + }() + + action, finish, err := parsePlanningOutput(deltas) + state.tokensUsed.AddTokens(prompt, completion.String()) return action, finish, trace.Wrap(err) } @@ -276,12 +308,12 @@ func (a *Agent) createPrompt(chatHistory, agentScratchpad []openai.ChatCompletio return prompt } -func (a *Agent) constructScratchpad(intermediateSteps []agentAction, observations []string) []openai.ChatCompletionMessage { +func (a *Agent) constructScratchpad(intermediateSteps []AgentAction, observations []string) []openai.ChatCompletionMessage { var thoughts []openai.ChatCompletionMessage for i, action := range intermediateSteps { thoughts = append(thoughts, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleAssistant, - Content: action.log, + Content: action.Log, }, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: conversationToolResponse(observations[i]), @@ -293,7 +325,7 @@ func (a *Agent) constructScratchpad(intermediateSteps []agentAction, observation // parseJSONFromModel parses a JSON object from the model output and attempts to sanitize contaminant text // to avoid triggering self-correction due to some natural language being bundled with the JSON. -// The output type is generic and thus the structure of the expected JSON varies depending on T. +// The output type is generic, and thus the structure of the expected JSON varies depending on T. func parseJSONFromModel[T any](text string) (T, *invalidOutputError) { cleaned := strings.TrimSpace(text) if strings.Contains(cleaned, "```json") { @@ -315,39 +347,54 @@ func parseJSONFromModel[T any](text string) (T, *invalidOutputError) { return output, nil } -// planOutput describes the expected JSON output after asking it to plan it's next action. -type planOutput struct { - Action string `json:"action"` - Action_input any `json:"action_input"` +// PlanOutput describes the expected JSON output after asking it to plan its next action. +type PlanOutput struct { + Action string `json:"action"` + ActionInput any `json:"action_input"` + Reasoning string `json:"reasoning"` } -// parsePlanningOutput parses the output of the model after asking it to plan it's next action +// parsePlanningOutput parses the output of the model after asking it to plan its next action // and returns the appropriate event type or an error. -func parsePlanningOutput(text string) (*agentAction, *agentFinish, error) { +func parsePlanningOutput(deltas <-chan string) (*AgentAction, *agentFinish, error) { + var text string + for delta := range deltas { + text += delta + + if strings.HasPrefix(text, finalResponseHeader) { + parts := make(chan string) + go func() { + defer close(parts) + + parts <- strings.TrimPrefix(text, finalResponseHeader) + for delta := range deltas { + parts <- delta + } + }() + + return nil, &agentFinish{output: &StreamingMessage{Parts: parts, TokensUsed: newTokensUsed_Cl100kBase()}}, nil + } + } + log.Tracef("received planning output: \"%v\"", text) - response, err := parseJSONFromModel[planOutput](text) + if outputString, found := strings.CutPrefix(text, finalResponseHeader); found { + return nil, &agentFinish{output: &Message{Content: outputString, TokensUsed: newTokensUsed_Cl100kBase()}}, nil + } + + response, err := parseJSONFromModel[PlanOutput](text) if err != nil { log.WithError(err).Trace("failed to parse planning output") return nil, nil, trace.Wrap(err) } - if response.Action == actionFinalAnswer { - outputString, ok := response.Action_input.(string) - if !ok { - return nil, nil, trace.Errorf("invalid final answer type %T", response.Action_input) - } - - return nil, &agentFinish{output: &Message{Content: outputString}}, nil - } - - if v, ok := response.Action_input.(string); ok { - return &agentAction{action: response.Action, input: v}, nil, nil + if v, ok := response.ActionInput.(string); ok { + return &AgentAction{Action: response.Action, Input: v}, nil, nil } else { - input, err := json.Marshal(response.Action_input) + input, err := json.Marshal(response.ActionInput) if err != nil { return nil, nil, trace.Wrap(err) } - return &agentAction{action: response.Action, input: string(input)}, nil, nil + return &AgentAction{Action: response.Action, Input: string(input), Reasoning: response.Reasoning}, nil, nil } } diff --git a/lib/ai/model/messages.go b/lib/ai/model/messages.go index 26762c1bc9a84..0c087740e238c 100644 --- a/lib/ai/model/messages.go +++ b/lib/ai/model/messages.go @@ -41,6 +41,12 @@ type Message struct { Content string } +// StreamingMessage represents a new message that is being streamed from the LLM. +type StreamingMessage struct { + *TokensUsed + Parts <-chan string +} + // Label represents a label returned by OpenAI's completion API. type Label struct { Key string `json:"key"` @@ -101,3 +107,8 @@ func (t *TokensUsed) AddTokens(prompt []openai.ChatCompletionMessage, completion t.Completion = t.Completion + perRequest + len(completionTokens) return err } + +// SetUsed sets the TokensUsed instance to the given data. +func (t *TokensUsed) SetUsed(data *TokensUsed) { + *t = *data +} diff --git a/lib/ai/model/prompt.go b/lib/ai/model/prompt.go index 94b29e91ced1d..fd940c6c72773 100644 --- a/lib/ai/model/prompt.go +++ b/lib/ai/model/prompt.go @@ -61,6 +61,7 @@ Markdown code snippet formatted in the following schema: { "action": string \\ The action to take. Must be one of %v "action_input": string \\ The input to the action + "reasoning": string \\ Your reasoning for taking this action } %v @@ -68,14 +69,10 @@ Markdown code snippet formatted in the following schema: Use this if you want to respond directly to the human or you want to ask the human a question to gather more information. You should avoid asking too many questions when you have other options available to you as it may be perceived as annoying. But asking is far better than guessing or making assumptions. -Markdown code snippet formatted in the following schema: +Text with the hardcoded header %v followed by your response as below: -%vjson -{ - "action": "Final Answer", - "action_input": string \\ You should put what you want to return to use here -} -%v`, "```", toolnames, "```", "```", "```", +%v +YOUR RESPONSE HERE`, "```", toolnames, "```", finalResponseHeader, finalResponseHeader, ) } diff --git a/lib/ai/testutils/http.go b/lib/ai/testutils/http.go index c98de30435a2f..1e8504ec54294 100644 --- a/lib/ai/testutils/http.go +++ b/lib/ai/testutils/http.go @@ -31,43 +31,99 @@ import ( // the chat API. It takes a list of responses that will be returned in order. func GetTestHandlerFn(t *testing.T, responses []string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - req := &openai.ChatCompletionRequest{} - err := json.NewDecoder(r.Body).Decode(req) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + if r.Method != http.MethodPost || r.URL.Path != "/chat/completions" { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return } - // Use assert as require doesn't work when called from a goroutine - if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") { + switch r.Header.Get("Accept") { + case "application/json; charset=utf-8", "application/json": + responses = messageResponse(w, r, t, responses) + case "text/event-stream": + responses = streamResponse(w, t, responses) + default: http.Error(w, "Unexpected request", http.StatusBadRequest) - return } + } +} + +func streamResponse(w http.ResponseWriter, t *testing.T, responses []string) []string { + w.Header().Set("Content-Type", "text/event-stream") - dataBytes := responses[0] - - resp := openai.ChatCompletionResponse{ - ID: strconv.Itoa(int(time.Now().Unix())), - Object: "test-object", - Created: time.Now().Unix(), - Model: req.Model, - Choices: []openai.ChatCompletionChoice{ - { - Message: openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, - Content: dataBytes, - Name: "", - }, + if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return responses + } + + resp := &openai.ChatCompletionStreamResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "completion", + Created: time.Now().Unix(), + Model: openai.GPT4, + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: responses[0], + Role: openai.ChatMessageRoleAssistant, }, + FinishReason: "", }, - Usage: openai.Usage{}, - } + }, + } + + respBytes, err := json.Marshal(resp) + assert.NoError(t, err, "Marshal error") - respBytes, err := json.Marshal(resp) - assert.NoError(t, err, "Marshal error") + _, err = w.Write([]byte("data: ")) + assert.NoError(t, err, "Write error") + _, err = w.Write(respBytes) + assert.NoError(t, err, "Write error") + _, err = w.Write([]byte("\n\nevent: done\ndata: [DONE]\n\n")) + assert.NoError(t, err, "Write error") - _, err = w.Write(respBytes) - assert.NoError(t, err, "Write error") + return responses[1:] +} - responses = responses[1:] +func messageResponse(w http.ResponseWriter, r *http.Request, t *testing.T, responses []string) []string { + w.Header().Set("Content-Type", "application/json") + + req := &openai.ChatCompletionRequest{} + err := json.NewDecoder(r.Body).Decode(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + + // Use assert as require doesn't work when called from a goroutine + if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return responses } + + dataBytes := responses[0] + + resp := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + Model: req.Model, + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: dataBytes, + Name: "", + }, + }, + }, + Usage: openai.Usage{}, + } + + respBytes, err := json.Marshal(resp) + assert.NoError(t, err, "Marshal error") + + _, err = w.Write(respBytes) + assert.NoError(t, err, "Write error") + + return responses[1:] } diff --git a/lib/assist/assist.go b/lib/assist/assist.go index c5a29a9d5c6d8..adfb73b708211 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -62,6 +62,9 @@ const ( MessageKindSystemMessage MessageType = "CHAT_MESSAGE_SYSTEM" // MessageKindError is the type of Assist message that is presented to user as information, but not stored persistently in the conversation. This can include backend error messages and the like. MessageKindError MessageType = "CHAT_MESSAGE_ERROR" + // MessageKindProgressUpdate is the type of Assist message that contains a progress update. + // A progress update starts a new "stage" and ends a previous stage if there was one. + MessageKindProgressUpdate MessageType = "CHAT_MESSAGE_PROGRESS_UPDATE" ) // PluginGetter is the minimal interface used by the chat to interact with the plugin service in the backend. @@ -213,6 +216,9 @@ func (c *Chat) loadMessages(ctx context.Context) error { } } + // Mark the history as fresh. + c.potentiallyStaleHistory = false + return nil } @@ -263,6 +269,18 @@ type onMessageFunc func(kind MessageType, payload []byte, createdTime time.Time) func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, userInput string, ) (*model.TokensUsed, error) { var tokensUsed *model.TokensUsed + progressUpdates := func(update *model.AgentAction) { + payload, err := json.Marshal(update) + if err != nil { + log.WithError(err).Debugf("Failed to marshal progress update: %v", update) + return + } + + if err := onMessage(MessageKindProgressUpdate, payload, c.assist.clock.Now().UTC()); err != nil { + log.WithError(err).Debugf("Failed to send progress update: %v", update) + return + } + } // If data might have been inserted into the chat history, we want to // refresh and get the latest data before querying the model. @@ -273,7 +291,7 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use } // query the assistant and fetch an answer - message, err := c.chat.Complete(ctx, userInput) + message, err := c.chat.Complete(ctx, userInput, progressUpdates) if err != nil { return nil, trace.Wrap(err) } @@ -319,6 +337,34 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use if err := onMessage(MessageKindAssistantMessage, []byte(message.Content), c.assist.clock.Now().UTC()); err != nil { return nil, trace.Wrap(err) } + case *model.StreamingMessage: + tokensUsed = message.TokensUsed + var text strings.Builder + defer onMessage(MessageKindAssistantPartialFinalize, nil, c.assist.clock.Now().UTC()) + for part := range message.Parts { + text.WriteString(part) + + if err := onMessage(MessageKindAssistantPartialMessage, []byte(part), c.assist.clock.Now().UTC()); err != nil { + return nil, trace.Wrap(err) + } + } + + // write an assistant message to memory and persistent storage + textS := text.String() + c.chat.Insert(openai.ChatMessageRoleAssistant, textS) + protoMsg := &assist.CreateAssistantMessageRequest{ + ConversationId: c.ConversationID, + Username: c.Username, + Message: &assist.AssistantMessage{ + Type: string(MessageKindAssistantMessage), + Payload: textS, + CreatedTime: timestamppb.New(c.assist.clock.Now().UTC()), + }, + } + + if err := c.assistService.CreateAssistantMessage(ctx, protoMsg); err != nil { + return nil, trace.Wrap(err) + } case *model.CompletionCommand: tokensUsed = message.TokensUsed payload := commandPayload{ @@ -351,11 +397,11 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use } // As we emitted a command suggestion, the user might have run it. If // the command ran, a summary could have been inserted in the backend. - // To take this command summary into account we note the history might + // To take this command summary into account, we note the history might // be stale. c.potentiallyStaleHistory = true default: - return nil, trace.Errorf("unknown message type") + return nil, trace.Errorf("unknown message type: %T", message) } return tokensUsed, nil diff --git a/lib/assist/assist_test.go b/lib/assist/assist_test.go index 30547794fbaee..495383fb2ef38 100644 --- a/lib/assist/assist_test.go +++ b/lib/assist/assist_test.go @@ -48,7 +48,7 @@ func TestChatComplete(t *testing.T) { t.Cleanup(server.Close) cfg := openai.DefaultConfig("secret-test-token") - cfg.BaseURL = server.URL + "/v1" + cfg.BaseURL = server.URL // And a chat client. ctx := context.Background() @@ -80,23 +80,33 @@ func TestChatComplete(t *testing.T) { }) t.Run("the first message is the hey message", func(t *testing.T) { + // Use called to make sure that the callback is called. + called := false // The first message is the welcome message. _, err = chat.ProcessComplete(ctx, func(kind MessageType, payload []byte, createdTime time.Time) error { require.Equal(t, MessageKindAssistantMessage, kind) require.Contains(t, string(payload), "Hey, I'm Teleport") + called = true return nil }, "") require.NoError(t, err) + require.True(t, called) }) t.Run("command should be returned in the response", func(t *testing.T) { + called := false // The second message is the command response. _, err = chat.ProcessComplete(ctx, func(kind MessageType, payload []byte, createdTime time.Time) error { + if kind == MessageKindProgressUpdate { + return nil + } require.Equal(t, MessageKindCommand, kind) require.Equal(t, string(payload), `{"command":"df -h","nodes":["localhost"]}`) + called = true return nil }, "Show free disk space on localhost") require.NoError(t, err) + require.True(t, called) }) t.Run("check what messages are stored in the backend", func(t *testing.T) { @@ -127,7 +137,7 @@ func TestClassifyMessage(t *testing.T) { t.Cleanup(server.Close) cfg := openai.DefaultConfig("secret-test-token") - cfg.BaseURL = server.URL + "/v1" + cfg.BaseURL = server.URL // And a chat client. ctx := context.Background() diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 9ea879838e186..de3d15ac144fe 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -507,7 +507,6 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, } } - h.log.Debugf("end assistant conversation loop") - + h.log.Debug("end assistant conversation loop") return nil } diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index c163b255e9f50..7f78c9768df0c 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -25,6 +25,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/gorilla/websocket" @@ -44,16 +45,25 @@ import ( func Test_runAssistant(t *testing.T) { t.Parallel() - readMessage := func(t *testing.T, ws *websocket.Conn) string { - var msg assistantMessage - _, payload, err := ws.ReadMessage() - require.NoError(t, err) + readStreamResponse := func(t *testing.T, ws *websocket.Conn) string { + var sb strings.Builder + for { + var msg assistantMessage + _, payload, err := ws.ReadMessage() + require.NoError(t, err) - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) - require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) - return msg.Payload + if msg.Type == assist.MessageKindAssistantPartialFinalize { + break + } + + require.Equal(t, assist.MessageKindAssistantPartialMessage, msg.Type) + sb.WriteString(msg.Payload) + } + + return sb.String() } readRateLimitedMessage := func(t *testing.T, ws *websocket.Conn) { @@ -85,14 +95,13 @@ func Test_runAssistant(t *testing.T) { require.NoError(t, err) const expectedMsg = "Which node do you want to use?" - require.Contains(t, expectedMsg, readMessage(t, ws)) + require.Contains(t, readStreamResponse(t, ws), expectedMsg) }, }, { name: "rate limited", responses: []string{ generateTextResponse(), - generateTextResponse(), }, cfg: webSuiteConfig{ ClusterFeatures: &authproto.Features{ @@ -114,7 +123,7 @@ func Test_runAssistant(t *testing.T) { require.NoError(t, err) const expectedMsg = "Which node do you want to use?" - require.Contains(t, expectedMsg, readMessage(t, ws)) + require.Contains(t, readStreamResponse(t, ws), expectedMsg) err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "all nodes, please"}`)) require.NoError(t, err) @@ -246,7 +255,7 @@ func Test_runAssistError(t *testing.T) { readHelloMsg(ws) readErrorMsg(ws) - // Check for close message + // Check for the close message _, _, err = ws.ReadMessage() closeErr, ok := err.(*websocket.CloseError) require.True(t, ok, "Expected close error") @@ -346,10 +355,5 @@ func Test_generateAssistantTitle(t *testing.T) { // generateTextResponse generates a response for a text completion func generateTextResponse() string { - return "```" + `json - { - "action": "Final Answer", - "action_input": "Which node do you want to use?" - } - ` + "```" + return "\nWhich node do you want to use?" } diff --git a/lib/web/command_test.go b/lib/web/command_test.go index 5ae2d888d25c4..7a73d715099ba 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -148,7 +148,7 @@ func TestExecuteCommandSummary(t *testing.T) { openAIMock := mockOpenAISummary(t) openAIConfig := openai.DefaultConfig("test-token") - openAIConfig.BaseURL = openAIMock.URL + "/v1" + openAIConfig.BaseURL = openAIMock.URL s := newWebSuiteWithConfig(t, webSuiteConfig{ disableDiskBasedRecording: true, OpenAIConfig: &openAIConfig, diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index a519868754b45..4c04569ddfedf 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -45,15 +45,17 @@ import { } from 'teleport/services/auth'; import * as service from '../service'; - -import { resolveServerCommandMessage, resolveServerMessage } from '../service'; +import { + resolveServerAssistThoughtMessage, + resolveServerCommandMessage, + resolveServerMessage, +} from '../service'; import type { ConversationMessage, ResolvedServerMessage, ServerMessage, } from 'teleport/Assist/types'; - import type { AssistState } from 'teleport/Assist/context/state'; interface AssistContextValue { @@ -138,11 +140,15 @@ export function AssistContextProvider(props: PropsWithChildren) { }; activeWebSocket.current.onclose = () => { - dispatch({ type: AssistStateActionType.SetStreaming, streaming: false }); + dispatch({ + type: AssistStateActionType.SetStreaming, + streaming: false, + }); }; activeWebSocket.current.onmessage = async event => { const data = JSON.parse(event.data) as ServerMessage; + console.log('onmessage', data); switch (data.type) { case ServerMessageType.Assist: @@ -169,19 +175,26 @@ export function AssistContextProvider(props: PropsWithChildren) { break; case ServerMessageType.AssistPartialMessage: { - const payload = JSON.parse(data.payload); - dispatch({ type: AssistStateActionType.AddPartialMessage, - message: payload.content, + message: data.payload, conversationId, }); break; } + case ServerMessageType.AssistThought: + const message = resolveServerAssistThoughtMessage(data); + dispatch({ + type: AssistStateActionType.AddThought, + message: message.message, + conversationId, + }); + + break; case ServerMessageType.Command: { - const message = await resolveServerCommandMessage(data); + const message = resolveServerCommandMessage(data); dispatch({ type: AssistStateActionType.AddExecuteRemoteCommand, @@ -304,7 +317,10 @@ export function AssistContextProvider(props: PropsWithChildren) { const messages = state.messages.data.get(state.conversations.selectedId); - dispatch({ type: AssistStateActionType.SetStreaming, streaming: true }); + dispatch({ + type: AssistStateActionType.SetStreaming, + streaming: true, + }); const data = JSON.stringify({ payload: message }); diff --git a/web/packages/teleport/src/Assist/service.ts b/web/packages/teleport/src/Assist/service.ts index b7294f85f1f28..9de574d06da5a 100644 --- a/web/packages/teleport/src/Assist/service.ts +++ b/web/packages/teleport/src/Assist/service.ts @@ -27,7 +27,11 @@ import { EventType } from 'teleport/lib/term/enums'; import NodeService from 'teleport/services/nodes'; -import { ServerMessageType } from './types'; +import { + ResolvedAssistThoughtServerMessage, + ServerMessageType, + ThoughtMessagePayload, +} from './types'; import type { CommandResultPayload, @@ -72,7 +76,8 @@ export async function resolveServerMessage( case ServerMessageType.CommandResultSummary: return resolveServerCommandResultSummaryMessage(message); - + case ServerMessageType.AssistThought: + return resolveServerAssistThoughtMessage(message); case ServerMessageType.Assist: case ServerMessageType.User: return { @@ -192,6 +197,18 @@ export function resolveServerCommandResultSummaryMessage( }; } +export function resolveServerAssistThoughtMessage( + message: ServerMessage +): ResolvedAssistThoughtServerMessage { + const payload = JSON.parse(message.payload) as ThoughtMessagePayload; + + return { + type: ServerMessageType.AssistThought, + message: payload.action, + created: new Date(message.created_time), + }; +} + export function resolveServerCommandMessage( message: ServerMessage ): ResolvedCommandServerMessage { diff --git a/web/packages/teleport/src/Assist/types.ts b/web/packages/teleport/src/Assist/types.ts index 8eff137eb84c8..5e205d3a0c5e7 100644 --- a/web/packages/teleport/src/Assist/types.ts +++ b/web/packages/teleport/src/Assist/types.ts @@ -25,7 +25,7 @@ export enum ServerMessageType { CommandResultStream = 'COMMAND_RESULT_STREAM', AssistPartialMessage = 'CHAT_PARTIAL_MESSAGE_ASSISTANT', AssistPartialMessageEnd = 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE', - AssistThought = 'CHAT_THOUGHT_ASSISTANT', + AssistThought = 'CHAT_MESSAGE_PROGRESS_UPDATE', } export const ExecutionEnvelopeType = 'summary'; @@ -160,6 +160,10 @@ export interface CommandResultSummaryPayload { summary: string; } +export interface ThoughtMessagePayload { + action: string; +} + export interface ExecEvent { event: EventType.EXEC; exitError?: string;