diff --git a/lib/ai/client.go b/lib/ai/client.go index b0f5c6ff768aa..6ec831f1b2cf4 100644 --- a/lib/ai/client.go +++ b/lib/ai/client.go @@ -77,3 +77,25 @@ func (client *Client) Summary(ctx context.Context, message string) (string, erro return resp.Choices[0].Message.Content, nil } + +// CommandSummary creates a command summary based on the command output. +// The message history is also passed to the model in order to keep context +// and extract relevant information from the output. +func (client *Client) CommandSummary(ctx context.Context, messages []openai.ChatCompletionMessage, output map[string][]byte) (string, error) { + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, Content: model.ConversationCommandResult(output)}) + + resp, err := client.svc.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4, + Messages: messages, + }, + ) + + if err != nil { + return "", trace.Wrap(err) + } + + return resp.Choices[0].Message.Content, nil +} diff --git a/lib/ai/model/agent.go b/lib/ai/model/agent.go index 1c5c2c4c24055..f142f7f70c14e 100644 --- a/lib/ai/model/agent.go +++ b/lib/ai/model/agent.go @@ -315,6 +315,7 @@ func parsePlanningOutput(text string) (*agentAction, *agentFinish, error) { log.Tracef("received planning output: \"%v\"", text) response, err := parseJSONFromModel[planOutput](text) if err != nil { + log.WithError(err).Trace("failed to parse planning output") return nil, nil, trace.Wrap(err) } diff --git a/lib/ai/model/prompt.go b/lib/ai/model/prompt.go index aaa0650dc4339..3b728d6bfb949 100644 --- a/lib/ai/model/prompt.go +++ b/lib/ai/model/prompt.go @@ -16,7 +16,10 @@ limitations under the License. package model -import "fmt" +import ( + "fmt" + "strings" +) var observationPrefix = "Observation: " var thoughtPrefix = "Thought: " @@ -24,6 +27,9 @@ var thoughtPrefix = "Thought: " const PromptSummarizeTitle = `You will be given a message. Create a short summary of that message. Respond only with summary, nothing else.` +const PromptSummarizeCommand = `You will be given a chat history and a command output. Based on the history context, extract relevant information from the command output and write a short summary of the command output. +Respond only with summary, nothing else.` + const InitialAIResponse = `Hey, I'm Teleport - a powerful tool that can assist you in managing your Teleport cluster via OpenAI GPT-4.` func PromptCharacter(username string) string { @@ -100,3 +106,14 @@ USER'S INPUT Okay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else.`, toolResponse) } + +func ConversationCommandResult(result map[string][]byte) string { + var message strings.Builder + for node, output := range result { + message.WriteString(fmt.Sprintf(`Command ran on node "%s" and produced the following output:\n`, node)) + message.WriteString(string(output)) + message.WriteString("\n") + } + message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary.") + return message.String() +} diff --git a/lib/assist/assist.go b/lib/assist/assist.go index 61ba0e6feff76..5783cddbc2fe7 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -44,6 +44,11 @@ const ( MessageKindCommand MessageType = "COMMAND" // MessageKindCommandResult is the type of Assist message that contains the command execution result. MessageKindCommandResult MessageType = "COMMAND_RESULT" + // MessageKindCommandResultSummary is the type of message that is optionally + // emitted after a command and contains a summary of the command output. + // This message is both sent after the command execution to the web UI, + // and persisted in the conversation history. + MessageKindCommandResultSummary MessageType = "COMMAND_RESULT_SUMMARY" // MessageKindUserMessage is the type of Assist message that contains the user message. MessageKindUserMessage MessageType = "CHAT_MESSAGE_USER" // MessageKindAssistantMessage is the type of Assist message that contains the assistant message. @@ -104,6 +109,10 @@ type Chat struct { ConversationID string // Username is the username of the user who started the chat. Username string + // potentiallyStaleHistory indicates messages might have been inserted into + // the chat history and the messages should be re-fetched before attempting + // the next completion. + potentiallyStaleHistory bool } // NewChat creates a new Assist chat. @@ -113,11 +122,12 @@ func (a *Assist) NewChat(ctx context.Context, assistService MessageService, aichat := a.client.NewChat(username) chat := &Chat{ - assist: a, - chat: aichat, - assistService: assistService, - ConversationID: conversationID, - Username: username, + assist: a, + chat: aichat, + assistService: assistService, + ConversationID: conversationID, + Username: username, + potentiallyStaleHistory: false, } if err := chat.loadMessages(ctx); err != nil { @@ -132,6 +142,29 @@ func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, e return a.client.Summary(ctx, message) } +// GenerateCommandSummary summarizes the output of a command executed on one or +// many nodes. The conversation history is also sent into the prompt in order +// to gather context and know what information is relevant in the command output. +func (a *Assist) GenerateCommandSummary(ctx context.Context, messages []*assist.AssistantMessage, output map[string][]byte) (string, error) { + // Create system prompt + modelMessages := []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleSystem, Content: model.PromptSummarizeCommand}, + } + + // Load context back into prompt + for _, message := range messages { + role := kindToRole(MessageType(message.Type)) + if role != "" && role != openai.ChatMessageRoleSystem { + payload, err := formatMessagePayload(message) + if err != nil { + return "", trace.Wrap(err) + } + modelMessages = append(modelMessages, openai.ChatCompletionMessage{Role: role, Content: payload}) + } + } + return a.client.CommandSummary(ctx, modelMessages, output) +} + // loadMessages loads the messages from the database. func (c *Chat) loadMessages(ctx context.Context) error { // existing conversation, retrieve old messages @@ -147,7 +180,11 @@ func (c *Chat) loadMessages(ctx context.Context) error { for _, msg := range messages.GetMessages() { role := kindToRole(MessageType(msg.Type)) if role != "" { - c.chat.Insert(role, msg.Payload) + payload, err := formatMessagePayload(msg) + if err != nil { + return trace.Wrap(err) + } + c.chat.Insert(role, payload) } } @@ -202,6 +239,16 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use ) (*model.TokensUsed, error) { var tokensUsed *model.TokensUsed + // If data might have been inserted into the chat history, we want to + // refresh and get the latest data before querying the model. + if c.potentiallyStaleHistory { + c.chat = c.assist.client.NewChat(c.Username) + err := c.loadMessages(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + } + // query the assistant and fetch an answer message, err := c.chat.Complete(ctx, userInput) if err != nil { @@ -279,6 +326,11 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use if err := onMessage(MessageKindCommand, payloadJson, c.assist.clock.Now().UTC()); nil != err { return nil, trace.Wrap(err) } + // As we emitted a command suggestion, the user might have run it. If + // the command ran, a summary could have been inserted in the backend. + // To take this command summary into account we note the history might + // be stale. + c.potentiallyStaleHistory = true default: return nil, trace.Errorf("unknown message type") } @@ -313,7 +365,27 @@ func kindToRole(kind MessageType) string { return openai.ChatMessageRoleAssistant case MessageKindSystemMessage: return openai.ChatMessageRoleSystem + case MessageKindCommandResultSummary: + return openai.ChatMessageRoleUser default: return "" } } + +// formatMessagePayload generates the OpemAI message payload corresponding to +// an Assist message. Most Assist message payloads can be converted directly, +// but some payloads are JSON-formatted and must be processed before being +// passed to the model. +func formatMessagePayload(message *assist.AssistantMessage) (string, error) { + switch MessageType(message.GetType()) { + case MessageKindCommandResultSummary: + var summary CommandExecSummary + err := json.Unmarshal([]byte(message.GetPayload()), &summary) + if err != nil { + return "", trace.Wrap(err) + } + return summary.String(), nil + default: + return message.GetPayload(), nil + } +} diff --git a/lib/assist/messages.go b/lib/assist/messages.go index f36274b8d89e2..f5197348ca295 100644 --- a/lib/assist/messages.go +++ b/lib/assist/messages.go @@ -19,7 +19,11 @@ package assist -import "github.com/gravitational/teleport/lib/ai/model" +import ( + "fmt" + + "github.com/gravitational/teleport/lib/ai/model" +) // commandPayload is a payload for a command message. type commandPayload struct { @@ -27,3 +31,16 @@ type commandPayload struct { Nodes []string `json:"nodes,omitempty"` Labels []model.Label `json:"labels,omitempty"` } + +// CommandExecSummary is a payload for the COMMAND_RESULT_SUMMARY message. +type CommandExecSummary struct { + ExecutionID string `json:"execution_id"` + Summary string `json:"summary"` + Command string `json:"command"` +} + +// String implements the Stringer interface and formats the message for AI +// model consumption. +func (s CommandExecSummary) String() string { + return fmt.Sprintf("Command: `%s` executed. The command output summary is: %s", s.Command, s.Summary) +} diff --git a/lib/web/command.go b/lib/web/command.go index 7a92ca483513d..2bcdd223f8382 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -43,6 +43,7 @@ import ( "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/lib/agentless" assistlib "github.com/gravitational/teleport/lib/assist" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" @@ -50,9 +51,16 @@ import ( "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/teleagent" ) +// summaryBufferCapacity is the summary buffer size in bytes. The summary buffer +// is shared across all nodes the command is running on and stores the command +// output. If the command output exceeds the buffer capacity, the summary won't +// be computed. +const summaryBufferCapacity = 2000 + // CommandRequest is a request to execute a command on all nodes that match the query. type CommandRequest struct { // Command is the command to be executed on all nodes. @@ -213,6 +221,8 @@ func (h *Handler) executeCommand( mfaCacheFn := getMFACacheFn() interactiveCommand := strings.Split(req.Command, " ") + buffer := newSummaryBuffer(summaryBufferCapacity) + runCmd := func(host *hostInfo) error { sessionData, err := h.generateCommandSession(host, req.Login, clusterName, sessionCtx.cfg.User) if err != nil { @@ -234,6 +244,7 @@ func (h *Handler) executeCommand( TracerProvider: h.cfg.TracerProvider, LocalAuthProvider: h.auth.accessPoint, mfaFuncCache: mfaCacheFn, + buffer: buffer, } handler, err := newCommandHandler(ctx, commandHandlerConfig) @@ -275,9 +286,114 @@ func (h *Handler) executeCommand( runCommands(hosts, runCmd, h.log) + // Optionally, try to compute the command summary. + if output, overflow := buffer.Export(); !overflow || len(output) != 0 { + summaryReq := summaryRequest{ + hosts: hosts, + output: output, + authClient: clt, + identity: identity, + executionID: req.ExecutionID, + conversationID: req.ConversationID, + command: req.Command, + } + err := h.computeAndSendSummary(ctx, &summaryReq, ws) + if err != nil { + return nil, trace.Wrap(err) + } + } + return nil, nil } +type summaryRequest struct { + hosts []hostInfo + output map[string][]byte + authClient auth.ClientI + identity srv.IdentityContext + executionID string + conversationID string + command string +} + +func (h *Handler) computeAndSendSummary( + ctx context.Context, + req *summaryRequest, + ws WSConn, +) error { + // Convert the map nodeId->output into a map nodeName->output + namedOutput := outputByName(req.hosts, req.output) + + history, err := req.authClient.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ + ConversationId: req.conversationID, + Username: req.identity.TeleportUser, + }) + if err != nil { + return trace.Wrap(err) + } + + assistClient, err := assistlib.NewClient(ctx, req.authClient, h.cfg.ProxySettings, h.cfg.OpenAIConfig) + if err != nil { + return trace.Wrap(err) + } + + summary, err := assistClient.GenerateCommandSummary(ctx, history.GetMessages(), namedOutput) + if err != nil { + return trace.Wrap(err) + } + + // Add the summary message to the backend so it is persisted on chat + // reload. + messagePayload, err := json.Marshal(&assistlib.CommandExecSummary{ + ExecutionID: req.executionID, + Command: req.command, + Summary: summary, + }) + if err != nil { + return trace.Wrap(err) + } + summaryMessage := &assist.CreateAssistantMessageRequest{ + ConversationId: req.conversationID, + Username: req.identity.TeleportUser, + Message: &assist.AssistantMessage{ + Type: string(assistlib.MessageKindCommandResultSummary), + CreatedTime: timestamppb.New(time.Now().UTC()), + Payload: string(messagePayload), + }, + } + + err = req.authClient.CreateAssistantMessage(ctx, summaryMessage) + if err != nil { + return trace.Wrap(err) + } + + // Send the summary over the execution websocket to provide instant + // feedback to the user. + out := &outEnvelope{ + Type: envelopeTypeSummary, + Payload: []byte(summary), + } + data, err := json.Marshal(out) + if err != nil { + return trace.Wrap(err) + } + stream := NewWStream(ctx, ws, log, nil) + _, err = stream.Write(data) + return trace.Wrap(err) +} + +func outputByName(hosts []hostInfo, output map[string][]byte) map[string][]byte { + hostIDToName := make(map[string]string, len(hosts)) + for _, host := range hosts { + hostIDToName[host.id] = host.hostName + } + namedOutput := make(map[string][]byte, len(output)) + for id, data := range output { + namedOutput[hostIDToName[id]] = data + } + return namedOutput +} + // runCommands runs the given command on the given hosts. func runCommands(hosts []hostInfo, runCmd func(host *hostInfo) error, log logrus.FieldLogger) { // Create a synchronization channel to limit the number of concurrent commands. @@ -356,6 +472,7 @@ func newCommandHandler(ctx context.Context, cfg CommandHandlerConfig) (*commandH tracer: cfg.tracer, }, mfaAuthCache: cfg.mfaFuncCache, + buffer: cfg.buffer, }, nil } @@ -386,6 +503,9 @@ type CommandHandlerConfig struct { tracer oteltrace.Tracer // mfaFuncCache is used to cache the MFA auth method mfaFuncCache mfaFuncCache + // buffer shared across multiple commandHandlers that saves the command + // output in order to generate a summary of the executed commands. + buffer *summaryBuffer } // CheckAndSetDefaults checks and sets default values. @@ -451,6 +571,10 @@ type commandHandler struct { // returns a list of ssh.AuthMethods. It is used to cache the result of // the MFA challenge. mfaAuthCache mfaFuncCache + + // buffer shared across multiple commandHandlers that saves the command + // output in order to generate a summary of the executed commands. + buffer *summaryBuffer } // sendError sends an error message to the client using the provided websocket. @@ -597,8 +721,8 @@ func (t *commandHandler) makeClient(ctx context.Context, ws WSConn) (*client.Tel clientConfig.HostLogin = t.sessionData.Login clientConfig.ForwardAgent = client.ForwardAgentLocal clientConfig.Namespace = apidefaults.Namespace - clientConfig.Stdout = newPayloadWriter(t.sessionData.ServerID, "stdout", t.stream) - clientConfig.Stderr = newPayloadWriter(t.sessionData.ServerID, "stderr", t.stream) + clientConfig.Stdout = newBufferedPayloadWriter(newPayloadWriter(t.sessionData.ServerID, envelopeTypeStdout, t.stream), t.buffer) + clientConfig.Stderr = newBufferedPayloadWriter(newPayloadWriter(t.sessionData.ServerID, envelopeTypeStderr, t.stream), t.buffer) clientConfig.Stdin = &bytes.Buffer{} // set stdin to a dummy buffer clientConfig.SiteName = t.sessionData.ClusterName if err := clientConfig.ParseProxyHost(t.proxyHostPort); err != nil { @@ -622,7 +746,7 @@ func (t *commandHandler) makeClient(ctx context.Context, ws WSConn) (*client.Tel func (t *commandHandler) writeError(err error) { out := &outEnvelope{ NodeID: t.sessionData.ServerID, - Type: "teleport-error", + Type: envelopeTypeError, Payload: []byte(err.Error()), } data, err := json.Marshal(out) diff --git a/lib/web/command_test.go b/lib/web/command_test.go index c7794a0964aca..ecff846c09c6e 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -24,6 +24,7 @@ import ( "fmt" "io" "net/http" + "net/http/httptest" "net/url" "strings" "sync/atomic" @@ -35,11 +36,13 @@ import ( "github.com/gorilla/websocket" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" + "github.com/sashabaranov/go-openai" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/timestamppb" "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" + "github.com/gravitational/teleport/lib/ai/testutils" assistlib "github.com/gravitational/teleport/lib/assist" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" @@ -47,11 +50,22 @@ import ( "github.com/gravitational/teleport/lib/utils" ) +const ( + testCommand = "echo txlxport | sed 's/x/e/g'" + testUser = "foo" +) + func TestExecuteCommand(t *testing.T) { t.Parallel() - s := newWebSuiteWithConfig(t, webSuiteConfig{disableDiskBasedRecording: true}) + openAIMock := mockOpenAISummary(t) + openAIConfig := openai.DefaultConfig("test-token") + openAIConfig.BaseURL = openAIMock.URL + "/v1" + s := newWebSuiteWithConfig(t, webSuiteConfig{ + disableDiskBasedRecording: true, + OpenAIConfig: &openAIConfig, + }) - ws, _, err := s.makeCommand(t, s.authPack(t, "foo"), uuid.New()) + ws, _, err := s.makeCommand(t, s.authPack(t, testUser), uuid.New()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -63,17 +77,22 @@ func TestExecuteCommand(t *testing.T) { func TestExecuteCommandHistory(t *testing.T) { t.Parallel() - // Given - s := newWebSuiteWithConfig(t, webSuiteConfig{disableDiskBasedRecording: true}) - authPack := s.authPack(t, "foo") + openAIMock := mockOpenAISummary(t) + openAIConfig := openai.DefaultConfig("test-token") + openAIConfig.BaseURL = openAIMock.URL + "/v1" + s := newWebSuiteWithConfig(t, webSuiteConfig{ + disableDiskBasedRecording: true, + OpenAIConfig: &openAIConfig, + }) + authPack := s.authPack(t, testUser) ctx := context.Background() - clt, err := s.server.NewClient(auth.TestUser("foo")) + clt, err := s.server.NewClient(auth.TestUser(testUser)) require.NoError(t, err) // Create conversation, otherwise the command execution will not be saved conversation, err := clt.CreateAssistantConversation(context.Background(), &assist.CreateAssistantConversationRequest{ - Username: "foo", + Username: testUser, CreatedTime: timestamppb.Now(), }) require.NoError(t, err) @@ -91,9 +110,8 @@ func TestExecuteCommandHistory(t *testing.T) { // When command executes require.NoError(t, waitForCommandOutput(stream, "teleport")) - // Explecitly close the stream - err = stream.Close() - require.NoError(t, err) + // Close the stream if not already closed + _ = stream.Close() // Then command execution history is saved var messages *assist.GetAssistantMessagesResponse @@ -101,16 +119,17 @@ func TestExecuteCommandHistory(t *testing.T) { require.Eventually(t, func() bool { messages, err = clt.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ ConversationId: conversationID.String(), - Username: "foo", + Username: testUser, }) require.NoError(t, err) - return len(messages.GetMessages()) == 1 + return len(messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResult]) == 1 }, 5*time.Second, 100*time.Millisecond) // Assert the returned message - msg := messages.GetMessages()[0] - require.Equal(t, string(assistlib.MessageKindCommandResult), msg.Type) + resultMessages, ok := messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResult] + require.True(t, ok, "Message must be of type COMMAND_RESULT") + msg := resultMessages[0] require.NotZero(t, msg.CreatedTime) var result commandExecResult @@ -123,13 +142,100 @@ func TestExecuteCommandHistory(t *testing.T) { require.Equal(t, "node", result.NodeID) } +func TestExecuteCommandSummary(t *testing.T) { + t.Parallel() + + openAIMock := mockOpenAISummary(t) + openAIConfig := openai.DefaultConfig("test-token") + openAIConfig.BaseURL = openAIMock.URL + "/v1" + s := newWebSuiteWithConfig(t, webSuiteConfig{ + disableDiskBasedRecording: true, + OpenAIConfig: &openAIConfig, + }) + authPack := s.authPack(t, testUser) + + ctx := context.Background() + clt, err := s.server.NewClient(auth.TestUser(testUser)) + require.NoError(t, err) + + // Create conversation, otherwise the command execution will not be saved + conversation, err := clt.CreateAssistantConversation(context.Background(), &assist.CreateAssistantConversationRequest{ + Username: testUser, + CreatedTime: timestamppb.Now(), + }) + require.NoError(t, err) + + require.NotEmpty(t, conversation.GetId()) + + conversationID, err := uuid.Parse(conversation.GetId()) + require.NoError(t, err) + + ws, _, err := s.makeCommand(t, authPack, conversationID) + require.NoError(t, err) + + // The current Assist execution relies on a hack that multiplexes multiple + // streams over a single one. This causes issues when a stream close as the + // single stream receiver will consider it should close. We work around by + // using a non-closing websocket and initiating a proper stream close. + // Then we reopen a new stream on the same websocket and continue reading + // the summary. + nonClosableWebsocket := &noopCloserWS{ws} + stream := NewWStream(ctx, nonClosableWebsocket, utils.NewLoggerForTests(), nil) + + // Wait for command execution to complete + require.NoError(t, waitForCommandOutput(stream, "teleport")) + + // Stop the stream consumption + err = stream.Close() + require.NoError(t, err) + + // Start a new stream and process the summary event + var env Envelope + stream = NewWStream(ctx, ws, utils.NewLoggerForTests(), nil) + dec := json.NewDecoder(stream) + err = dec.Decode(&env) + require.NoError(t, err) + require.Equal(t, envelopeTypeSummary, env.GetType()) + require.NotEmpty(t, env.GetPayload()) + + // Close the stream, and the underlying websocket + _ = stream.Close() + + // Wait for the command execution history to be saved + var messages *assist.GetAssistantMessagesResponse + // Command execution history is saved in asynchronusly, so we need to wait for it. + require.Eventually(t, func() bool { + messages, err = clt.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ + ConversationId: conversationID.String(), + Username: testUser, + }) + require.NoError(t, err) + + return len(messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResultSummary]) == 1 + }, 5*time.Second, 100*time.Millisecond) + + // Check the returned summary message + summaryMessages, ok := messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResultSummary] + require.True(t, ok, "At least one summary message is expected") + msg := summaryMessages[0] + require.NotZero(t, msg.CreatedTime) + + var result assistlib.CommandExecSummary + err = json.Unmarshal([]byte(msg.GetPayload()), &result) + require.NoError(t, err) + + require.NotEmpty(t, result.ExecutionID) + require.Equal(t, testCommand, result.Command) + require.NotEmpty(t, result.Summary) +} + func (s *WebSuite) makeCommand(t *testing.T, pack *authPack, conversationID uuid.UUID) (*websocket.Conn, *session.Session, error) { req := CommandRequest{ Query: fmt.Sprintf("name == \"%s\"", s.srvID), Login: pack.login, ConversationID: conversationID.String(), ExecutionID: uuid.New().String(), - Command: "echo txlxport | sed 's/x/e/g'", + Command: testCommand, } u := url.URL{ @@ -243,3 +349,18 @@ func Test_runCommands(t *testing.T) { require.Equal(t, int32(100), counter.Load()) } + +func mockOpenAISummary(t *testing.T) *httptest.Server { + responses := []string{"This is the summary of the command."} + server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) + return server +} + +func messagesByType(messages []*assist.AssistantMessage) map[assistlib.MessageType][]*assist.AssistantMessage { + byType := make(map[assistlib.MessageType][]*assist.AssistantMessage) + for _, message := range messages { + byType[assistlib.MessageType(message.GetType())] = append(byType[assistlib.MessageType(message.GetType())], message) + } + return byType +} diff --git a/lib/web/command_utils.go b/lib/web/command_utils.go index db8ebdbc42a5a..48b4712fb74eb 100644 --- a/lib/web/command_utils.go +++ b/lib/web/command_utils.go @@ -45,6 +45,13 @@ type WSConn interface { SetPongHandler(h func(appData string) error) } +const ( + envelopeTypeStdout = "stdout" + envelopeTypeStderr = "stderr" + envelopeTypeError = "teleport-error" + envelopeTypeSummary = "summary" +) + // outEnvelope is an envelope used to wrap messages send back to the client connected over WS. type outEnvelope struct { NodeID string `json:"node_id"` @@ -56,7 +63,7 @@ type outEnvelope struct { // outEnvelope and writes it to the underlying stream. type payloadWriter struct { nodeID string - // output name, can be stdout, stderr or teleport-error. + // output name, can be stdout, stderr, teleport-error or summary. outputName string // stream is the underlying stream. stream io.Writer @@ -130,3 +137,65 @@ func (s *syncRWWSConn) ReadMessage() (messageType int, p []byte, err error) { defer s.rmtx.Unlock() return s.WSConn.ReadMessage() } + +func newBufferedPayloadWriter(pw *payloadWriter, buffer *summaryBuffer) *bufferedPayloadWriter { + return &bufferedPayloadWriter{ + payloadWriter: pw, + buffer: buffer, + } +} + +type bufferedPayloadWriter struct { + *payloadWriter + buffer *summaryBuffer +} + +func (bp *bufferedPayloadWriter) Write(data []byte) (int, error) { + bp.buffer.Write(bp.nodeID, data) + return bp.payloadWriter.Write(data) +} + +func newSummaryBuffer(capacity int) *summaryBuffer { + return &summaryBuffer{ + buffer: make(map[string][]byte), + remainingCapacity: capacity, + invalid: false, + mutex: sync.Mutex{}, + } +} + +type summaryBuffer struct { + buffer map[string][]byte + remainingCapacity int + invalid bool + // mutex protects all members of the struct and must be acquired before + // performing any read or write operation + mutex sync.Mutex +} + +func (b *summaryBuffer) Write(node string, data []byte) { + b.mutex.Lock() + defer b.mutex.Unlock() + if b.invalid { + return + } + if len(data) > b.remainingCapacity { + // We're out of capacity, not all content will be written to the buffer + // it should not be used anymore + b.invalid = true + return + } + b.buffer[node] = append(b.buffer[node], data...) + b.remainingCapacity -= len(data) +} + +// Export returns the buffer content and a whether the buffer overflowed. +func (b *summaryBuffer) Export() (map[string][]byte, bool) { + b.mutex.Lock() + defer b.mutex.Unlock() + if b.invalid { + return nil, true + } + b.invalid = true + return b.buffer, false +} diff --git a/lib/web/command_utils_test.go b/lib/web/command_utils_test.go new file mode 100644 index 0000000000000..01528bff29c1b --- /dev/null +++ b/lib/web/command_utils_test.go @@ -0,0 +1,142 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package web + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSummaryBuffer(t *testing.T) { + tests := []struct { + name string + outputs map[string][][]byte + capacity int + expectedOutput map[string][]byte + expectedOverflow bool + }{ + { + name: "Single node", + outputs: map[string][][]byte{ + "node": { + []byte("foo"), + []byte("bar"), + []byte("baz"), + }, + }, + capacity: 9, + expectedOutput: map[string][]byte{ + "node": []byte("foobarbaz"), + }, + expectedOverflow: false, + }, + { + name: "Single node overflow", + outputs: map[string][][]byte{ + "node": { + []byte("foo"), + []byte("bar"), + []byte("baz"), + }, + }, + capacity: 8, + expectedOutput: nil, + expectedOverflow: true, + }, + { + name: "Multiple nodes", + outputs: map[string][][]byte{ + "node1": { + []byte("foo"), + []byte("bar"), + []byte("baz"), + }, + "node2": { + []byte("baz"), + []byte("bar"), + []byte("foo"), + }, + "node3": { + []byte("baz"), + []byte("baz"), + []byte("baz"), + }, + }, + capacity: 30, + expectedOutput: map[string][]byte{ + "node1": []byte("foobarbaz"), + "node2": []byte("bazbarfoo"), + "node3": []byte("bazbazbaz"), + }, + expectedOverflow: false, + }, + { + name: "Multiple nodes overflow", + outputs: map[string][][]byte{ + "node1": { + []byte("foo"), + []byte("bar"), + []byte("baz"), + }, + "node2": { + []byte("baz"), + []byte("bar"), + []byte("foo"), + }, + "node3": { + []byte("baz"), + []byte("baz"), + []byte("baz"), + }, + }, + capacity: 25, + expectedOutput: nil, + expectedOverflow: true, + }, + { + name: "No output", + outputs: nil, + capacity: 10, + expectedOutput: map[string][]byte{}, + expectedOverflow: false, + }, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + buffer := newSummaryBuffer(tc.capacity) + var wg sync.WaitGroup + for node, output := range tc.outputs { + node := node + output := output + wg.Add(1) + go func() { + defer wg.Done() + for _, chunk := range output { + buffer.Write(node, chunk) + } + }() + } + wg.Wait() + output, overflow := buffer.Export() + require.Equal(t, tc.expectedOutput, output) + require.Equal(t, tc.expectedOverflow, overflow) + + }) + } +} diff --git a/web/packages/teleport/src/Assist/Conversation/CommandResultSummaryEntry.tsx b/web/packages/teleport/src/Assist/Conversation/CommandResultSummaryEntry.tsx new file mode 100644 index 0000000000000..5fb7626a8563b --- /dev/null +++ b/web/packages/teleport/src/Assist/Conversation/CommandResultSummaryEntry.tsx @@ -0,0 +1,71 @@ +/** + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import React from 'react'; +import styled from 'styled-components'; + +import ReactMarkdown from 'react-markdown'; +import remarkGfm from 'remark-gfm'; + +import { markdownCSS } from 'teleport/Assist/markdown'; + +interface CommandResultSummaryEntryProps { + command: string; + summary: string; +} + +const Container = styled.div` + border-radius: 10px; + position: relative; +`; + +const Title = styled.div` + font-size: 15px; + font-weight: 600; + padding: 10px 15px; +`; + +const Summary = styled.div` + padding: 10px 15px 0 17px; + + ${markdownCSS} +`; + +const Header = styled.div` + display: flex; + justify-content: space-between; + padding-right: 20px; +`; + +export function CommandResultSummaryEntry( + props: CommandResultSummaryEntryProps +) { + return ( + +
+ + Summary of command execution <pre>{props.command}</pre> + +
+ + + + {props.summary} + + +
+ ); +} diff --git a/web/packages/teleport/src/Assist/Conversation/Message.tsx b/web/packages/teleport/src/Assist/Conversation/Message.tsx index 4cb5f2fad5e18..effcc450df6e7 100644 --- a/web/packages/teleport/src/Assist/Conversation/Message.tsx +++ b/web/packages/teleport/src/Assist/Conversation/Message.tsx @@ -21,6 +21,7 @@ import { CheckIcon } from 'design/SVGIcon'; import { Author, + ConversationMessage, ResolvedServerMessage, ServerMessageType, } from 'teleport/Assist/types'; @@ -39,8 +40,7 @@ import { MessageEntry } from 'teleport/Assist/Conversation/MessageEntry'; import { useAssist } from 'teleport/Assist/context/AssistContext'; import { ExecuteRemoteCommandEntry } from 'teleport/Assist/Conversation/ExecuteRemoteCommandEntry'; import { CommandResultEntry } from 'teleport/Assist/Conversation/CommandResultEntry'; - -import type { ConversationMessage } from 'teleport/Assist/types'; +import { CommandResultSummaryEntry } from 'teleport/Assist/Conversation/CommandResultSummaryEntry'; interface MessageProps { message: ConversationMessage; @@ -139,6 +139,13 @@ function createComponentForEntry( errorMessage={entry.errorMessage} /> ); + case ServerMessageType.CommandResultSummary: + return ( + + ); } } diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index 942fe036c7c34..b61bb8de0535c 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -31,7 +31,11 @@ import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; import { getAccessToken, getHostName } from 'teleport/services/api'; -import { RawPayload, ServerMessageType } from 'teleport/Assist/types'; +import { + ExecutionEnvelopeType, + RawPayload, + ServerMessageType, +} from 'teleport/Assist/types'; import { MessageTypeEnum, Protobuf } from 'teleport/lib/term/protobuf'; @@ -41,6 +45,7 @@ import { } from 'teleport/services/auth'; import * as service from '../service'; + import { resolveServerCommandMessage, resolveServerMessage } from '../service'; import type { @@ -48,6 +53,7 @@ import type { ResolvedServerMessage, ServerMessage, } from 'teleport/Assist/types'; + import type { AssistState } from 'teleport/Assist/context/state'; interface AssistContextValue { @@ -421,13 +427,22 @@ export function AssistContextProvider(props: PropsWithChildren) { const data = JSON.parse(msg.payload) as RawPayload; const payload = atob(data.payload); - dispatch({ - type: AssistStateActionType.UpdateCommandResult, - conversationId: state.conversations.selectedId, - commandResultId: nodeIdToResultId.get(data.node_id), - output: payload, - }); - + if (data.type === ExecutionEnvelopeType) { + dispatch({ + type: AssistStateActionType.AddCommandResultSummary, + conversationId: state.conversations.selectedId, + summary: payload, + executionId: execParams.execution_id, + command: execParams.command, + }); + } else { + dispatch({ + type: AssistStateActionType.UpdateCommandResult, + conversationId: state.conversations.selectedId, + commandResultId: nodeIdToResultId.get(data.node_id), + output: payload, + }); + } break; case MessageTypeEnum.WEBAUTHN_CHALLENGE: @@ -455,18 +470,20 @@ export function AssistContextProvider(props: PropsWithChildren) { sessionsEnded += 1; if (sessionsEnded === nodeIdToResultId.size) { + const message = proto.encodeCloseMessage(); + const bytearray = new Uint8Array(message); + for (const nodeId of nodeIdToResultId.keys()) { dispatch({ type: AssistStateActionType.FinishCommandResult, conversationId: state.conversations.selectedId, commandResultId: nodeIdToResultId.get(nodeId), }); + + executeCommandWebSocket.current.send(bytearray.buffer); } nodeIdToResultId.clear(); - - // TODO(ryan): move this to after the summary is sent once it's implemented - executeCommandWebSocket.current.close(); } break; diff --git a/web/packages/teleport/src/Assist/context/state.ts b/web/packages/teleport/src/Assist/context/state.ts index 41dc5312ff367..aa7ace7a5b6c0 100644 --- a/web/packages/teleport/src/Assist/context/state.ts +++ b/web/packages/teleport/src/Assist/context/state.ts @@ -60,6 +60,7 @@ export enum AssistStateActionType { PromptMfa, DeleteConversation, UpdateConversationTitle, + AddCommandResultSummary, } export interface ReplaceConversationsAction { @@ -164,6 +165,14 @@ export interface UpdateConversationTitleAction { title: string; } +export interface AddCommandResultSummaryAction { + type: AssistStateActionType.AddCommandResultSummary; + summary: string; + conversationId: string; + command: string; + executionId: string; +} + export type AssistContextAction = | SetConversationsLoadingAction | ReplaceConversationsAction @@ -181,7 +190,8 @@ export type AssistContextAction = | FinishCommandResultAction | PromptMfaAction | DeleteConversationAction - | UpdateConversationTitleAction; + | UpdateConversationTitleAction + | AddCommandResultSummaryAction; export function reducer( state: AssistState, @@ -239,6 +249,9 @@ export function reducer( case AssistStateActionType.UpdateConversationTitle: return updateConversationTitle(state, action); + case AssistStateActionType.AddCommandResultSummary: + return addCommandResultSummary(state, action); + default: return state; } @@ -589,6 +602,35 @@ export function finishCommandResult( }; } +export function addCommandResultSummary( + state: AssistState, + action: AddCommandResultSummaryAction +): AssistState { + const messages = new Map(state.messages.data); + + let conversationMessages = messages.get(action.conversationId); + + conversationMessages = [ + ...conversationMessages, + { + type: ServerMessageType.CommandResultSummary, + created: new Date(), + executionId: action.executionId, + command: action.command, + summary: action.summary, + }, + ]; + + messages.set(action.conversationId, conversationMessages); + + return { + ...state, + messages: { + ...state.messages, + data: messages, + }, + }; +} export function promptMfa( state: AssistState, action: PromptMfaAction diff --git a/web/packages/teleport/src/Assist/context/utils.ts b/web/packages/teleport/src/Assist/context/utils.ts index e7473f7003ea7..795dcda0c7e23 100644 --- a/web/packages/teleport/src/Assist/context/utils.ts +++ b/web/packages/teleport/src/Assist/context/utils.ts @@ -32,6 +32,7 @@ function getMessageTypeAuthor(type: string) { case ServerMessageType.Command: case ServerMessageType.CommandResult: case ServerMessageType.CommandResultStream: + case ServerMessageType.CommandResultSummary: case ServerMessageType.Error: return Author.Teleport; } diff --git a/web/packages/teleport/src/Assist/service.ts b/web/packages/teleport/src/Assist/service.ts index 49d263a2da07e..b7294f85f1f28 100644 --- a/web/packages/teleport/src/Assist/service.ts +++ b/web/packages/teleport/src/Assist/service.ts @@ -31,6 +31,7 @@ import { ServerMessageType } from './types'; import type { CommandResultPayload, + CommandResultSummaryPayload, Conversation, CreateConversationResponse, ExecEvent, @@ -39,6 +40,7 @@ import type { GetConversationMessagesResponse, GetConversationsResponse, ResolvedCommandResultServerMessage, + ResolvedCommandResultSummaryServerMessage, ResolvedCommandServerMessage, ResolvedServerMessage, ServerMessage, @@ -68,6 +70,9 @@ export async function resolveServerMessage( case ServerMessageType.CommandResult: return resolveServerCommandResultMessage(message, clusterId); + case ServerMessageType.CommandResultSummary: + return resolveServerCommandResultSummaryMessage(message); + case ServerMessageType.Assist: case ServerMessageType.User: return { @@ -173,6 +178,20 @@ export async function resolveServerCommandResultMessage( } } +export function resolveServerCommandResultSummaryMessage( + message: ServerMessage +): ResolvedCommandResultSummaryServerMessage { + const payload = JSON.parse(message.payload) as CommandResultSummaryPayload; + + return { + type: ServerMessageType.CommandResultSummary, + executionId: payload.execution_id, + command: payload.command, + summary: payload.summary, + created: new Date(message.created_time), + }; +} + export function resolveServerCommandMessage( message: ServerMessage ): ResolvedCommandServerMessage { diff --git a/web/packages/teleport/src/Assist/types.ts b/web/packages/teleport/src/Assist/types.ts index ba73068e0c80e..a46c9fb94b371 100644 --- a/web/packages/teleport/src/Assist/types.ts +++ b/web/packages/teleport/src/Assist/types.ts @@ -21,12 +21,15 @@ export enum ServerMessageType { Error = 'CHAT_MESSAGE_ERROR', Command = 'COMMAND', CommandResult = 'COMMAND_RESULT', + CommandResultSummary = 'COMMAND_RESULT_SUMMARY', CommandResultStream = 'COMMAND_RESULT_STREAM', AssistPartialMessage = 'CHAT_PARTIAL_MESSAGE_ASSISTANT', AssistPartialMessageEnd = 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE', AssistThought = 'CHAT_THOUGHT_ASSISTANT', } +export const ExecutionEnvelopeType = 'summary'; + export interface Conversation { id: string; title?: string; @@ -68,6 +71,14 @@ export interface ResolvedCommandResultServerMessage { created: Date; } +export interface ResolvedCommandResultSummaryServerMessage { + type: ServerMessageType.CommandResultSummary; + executionId: string; + summary: string; + command: string; + created: Date; +} + export interface ResolvedAssistThoughtServerMessage { type: ServerMessageType.AssistThought; message: string; @@ -108,6 +119,7 @@ export type ResolvedServerMessage = | ResolvedUserServerMessage | ResolvedErrorServerMessage | ResolvedCommandResultServerMessage + | ResolvedCommandResultSummaryServerMessage | ResolvedAssistThoughtServerMessage | ResolvedCommandResultStreamServerMessage; @@ -142,6 +154,12 @@ export interface CommandResultPayload { execution_id: string; } +export interface CommandResultSummaryPayload { + execution_id: string; + command: string; + summary: string; +} + export interface ExecEvent { event: EventType.EXEC; exitError?: string; @@ -162,6 +180,7 @@ export interface NodeState { export interface RawPayload { node_id: string; + type: string; payload: string; } diff --git a/web/packages/teleport/src/lib/term/protobuf.js b/web/packages/teleport/src/lib/term/protobuf.js index a259070ae7804..c654dacbef29c 100644 --- a/web/packages/teleport/src/lib/term/protobuf.js +++ b/web/packages/teleport/src/lib/term/protobuf.js @@ -90,6 +90,11 @@ export class Protobuf { return this.encode(messageFields.type.values.data, message); } + encodeCloseMessage() { + // Close message has no payload + return this.encode(messageFields.type.values.close, ''); + } + encodePayload(buffer, text) { // set type buffer.push(messageFields.payload.code);