diff --git a/lib/ai/client.go b/lib/ai/client.go
index b0f5c6ff768aa..6ec831f1b2cf4 100644
--- a/lib/ai/client.go
+++ b/lib/ai/client.go
@@ -77,3 +77,25 @@ func (client *Client) Summary(ctx context.Context, message string) (string, erro
return resp.Choices[0].Message.Content, nil
}
+
+// CommandSummary creates a command summary based on the command output.
+// The message history is also passed to the model in order to keep context
+// and extract relevant information from the output.
+func (client *Client) CommandSummary(ctx context.Context, messages []openai.ChatCompletionMessage, output map[string][]byte) (string, error) {
+ messages = append(messages, openai.ChatCompletionMessage{
+ Role: openai.ChatMessageRoleUser, Content: model.ConversationCommandResult(output)})
+
+ resp, err := client.svc.CreateChatCompletion(
+ ctx,
+ openai.ChatCompletionRequest{
+ Model: openai.GPT4,
+ Messages: messages,
+ },
+ )
+
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ return resp.Choices[0].Message.Content, nil
+}
diff --git a/lib/ai/model/agent.go b/lib/ai/model/agent.go
index 1c5c2c4c24055..f142f7f70c14e 100644
--- a/lib/ai/model/agent.go
+++ b/lib/ai/model/agent.go
@@ -315,6 +315,7 @@ func parsePlanningOutput(text string) (*agentAction, *agentFinish, error) {
log.Tracef("received planning output: \"%v\"", text)
response, err := parseJSONFromModel[planOutput](text)
if err != nil {
+ log.WithError(err).Trace("failed to parse planning output")
return nil, nil, trace.Wrap(err)
}
diff --git a/lib/ai/model/prompt.go b/lib/ai/model/prompt.go
index aaa0650dc4339..3b728d6bfb949 100644
--- a/lib/ai/model/prompt.go
+++ b/lib/ai/model/prompt.go
@@ -16,7 +16,10 @@ limitations under the License.
package model
-import "fmt"
+import (
+ "fmt"
+ "strings"
+)
var observationPrefix = "Observation: "
var thoughtPrefix = "Thought: "
@@ -24,6 +27,9 @@ var thoughtPrefix = "Thought: "
const PromptSummarizeTitle = `You will be given a message. Create a short summary of that message.
Respond only with summary, nothing else.`
+const PromptSummarizeCommand = `You will be given a chat history and a command output. Based on the history context, extract relevant information from the command output and write a short summary of the command output.
+Respond only with summary, nothing else.`
+
const InitialAIResponse = `Hey, I'm Teleport - a powerful tool that can assist you in managing your Teleport cluster via OpenAI GPT-4.`
func PromptCharacter(username string) string {
@@ -100,3 +106,14 @@ USER'S INPUT
Okay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else.`, toolResponse)
}
+
+func ConversationCommandResult(result map[string][]byte) string {
+ var message strings.Builder
+ for node, output := range result {
+ message.WriteString(fmt.Sprintf(`Command ran on node "%s" and produced the following output:\n`, node))
+ message.WriteString(string(output))
+ message.WriteString("\n")
+ }
+ message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary.")
+ return message.String()
+}
diff --git a/lib/assist/assist.go b/lib/assist/assist.go
index 61ba0e6feff76..5783cddbc2fe7 100644
--- a/lib/assist/assist.go
+++ b/lib/assist/assist.go
@@ -44,6 +44,11 @@ const (
MessageKindCommand MessageType = "COMMAND"
// MessageKindCommandResult is the type of Assist message that contains the command execution result.
MessageKindCommandResult MessageType = "COMMAND_RESULT"
+ // MessageKindCommandResultSummary is the type of message that is optionally
+ // emitted after a command and contains a summary of the command output.
+ // This message is both sent after the command execution to the web UI,
+ // and persisted in the conversation history.
+ MessageKindCommandResultSummary MessageType = "COMMAND_RESULT_SUMMARY"
// MessageKindUserMessage is the type of Assist message that contains the user message.
MessageKindUserMessage MessageType = "CHAT_MESSAGE_USER"
// MessageKindAssistantMessage is the type of Assist message that contains the assistant message.
@@ -104,6 +109,10 @@ type Chat struct {
ConversationID string
// Username is the username of the user who started the chat.
Username string
+ // potentiallyStaleHistory indicates messages might have been inserted into
+ // the chat history and the messages should be re-fetched before attempting
+ // the next completion.
+ potentiallyStaleHistory bool
}
// NewChat creates a new Assist chat.
@@ -113,11 +122,12 @@ func (a *Assist) NewChat(ctx context.Context, assistService MessageService,
aichat := a.client.NewChat(username)
chat := &Chat{
- assist: a,
- chat: aichat,
- assistService: assistService,
- ConversationID: conversationID,
- Username: username,
+ assist: a,
+ chat: aichat,
+ assistService: assistService,
+ ConversationID: conversationID,
+ Username: username,
+ potentiallyStaleHistory: false,
}
if err := chat.loadMessages(ctx); err != nil {
@@ -132,6 +142,29 @@ func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, e
return a.client.Summary(ctx, message)
}
+// GenerateCommandSummary summarizes the output of a command executed on one or
+// many nodes. The conversation history is also sent into the prompt in order
+// to gather context and know what information is relevant in the command output.
+func (a *Assist) GenerateCommandSummary(ctx context.Context, messages []*assist.AssistantMessage, output map[string][]byte) (string, error) {
+ // Create system prompt
+ modelMessages := []openai.ChatCompletionMessage{
+ {Role: openai.ChatMessageRoleSystem, Content: model.PromptSummarizeCommand},
+ }
+
+ // Load context back into prompt
+ for _, message := range messages {
+ role := kindToRole(MessageType(message.Type))
+ if role != "" && role != openai.ChatMessageRoleSystem {
+ payload, err := formatMessagePayload(message)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ modelMessages = append(modelMessages, openai.ChatCompletionMessage{Role: role, Content: payload})
+ }
+ }
+ return a.client.CommandSummary(ctx, modelMessages, output)
+}
+
// loadMessages loads the messages from the database.
func (c *Chat) loadMessages(ctx context.Context) error {
// existing conversation, retrieve old messages
@@ -147,7 +180,11 @@ func (c *Chat) loadMessages(ctx context.Context) error {
for _, msg := range messages.GetMessages() {
role := kindToRole(MessageType(msg.Type))
if role != "" {
- c.chat.Insert(role, msg.Payload)
+ payload, err := formatMessagePayload(msg)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ c.chat.Insert(role, payload)
}
}
@@ -202,6 +239,16 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use
) (*model.TokensUsed, error) {
var tokensUsed *model.TokensUsed
+ // If data might have been inserted into the chat history, we want to
+ // refresh and get the latest data before querying the model.
+ if c.potentiallyStaleHistory {
+ c.chat = c.assist.client.NewChat(c.Username)
+ err := c.loadMessages(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ }
+
// query the assistant and fetch an answer
message, err := c.chat.Complete(ctx, userInput)
if err != nil {
@@ -279,6 +326,11 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use
if err := onMessage(MessageKindCommand, payloadJson, c.assist.clock.Now().UTC()); nil != err {
return nil, trace.Wrap(err)
}
+ // As we emitted a command suggestion, the user might have run it. If
+ // the command ran, a summary could have been inserted in the backend.
+ // To take this command summary into account we note the history might
+ // be stale.
+ c.potentiallyStaleHistory = true
default:
return nil, trace.Errorf("unknown message type")
}
@@ -313,7 +365,27 @@ func kindToRole(kind MessageType) string {
return openai.ChatMessageRoleAssistant
case MessageKindSystemMessage:
return openai.ChatMessageRoleSystem
+ case MessageKindCommandResultSummary:
+ return openai.ChatMessageRoleUser
default:
return ""
}
}
+
+// formatMessagePayload generates the OpemAI message payload corresponding to
+// an Assist message. Most Assist message payloads can be converted directly,
+// but some payloads are JSON-formatted and must be processed before being
+// passed to the model.
+func formatMessagePayload(message *assist.AssistantMessage) (string, error) {
+ switch MessageType(message.GetType()) {
+ case MessageKindCommandResultSummary:
+ var summary CommandExecSummary
+ err := json.Unmarshal([]byte(message.GetPayload()), &summary)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ return summary.String(), nil
+ default:
+ return message.GetPayload(), nil
+ }
+}
diff --git a/lib/assist/messages.go b/lib/assist/messages.go
index f36274b8d89e2..f5197348ca295 100644
--- a/lib/assist/messages.go
+++ b/lib/assist/messages.go
@@ -19,7 +19,11 @@
package assist
-import "github.com/gravitational/teleport/lib/ai/model"
+import (
+ "fmt"
+
+ "github.com/gravitational/teleport/lib/ai/model"
+)
// commandPayload is a payload for a command message.
type commandPayload struct {
@@ -27,3 +31,16 @@ type commandPayload struct {
Nodes []string `json:"nodes,omitempty"`
Labels []model.Label `json:"labels,omitempty"`
}
+
+// CommandExecSummary is a payload for the COMMAND_RESULT_SUMMARY message.
+type CommandExecSummary struct {
+ ExecutionID string `json:"execution_id"`
+ Summary string `json:"summary"`
+ Command string `json:"command"`
+}
+
+// String implements the Stringer interface and formats the message for AI
+// model consumption.
+func (s CommandExecSummary) String() string {
+ return fmt.Sprintf("Command: `%s` executed. The command output summary is: %s", s.Command, s.Summary)
+}
diff --git a/lib/web/command.go b/lib/web/command.go
index 7a92ca483513d..2bcdd223f8382 100644
--- a/lib/web/command.go
+++ b/lib/web/command.go
@@ -43,6 +43,7 @@ import (
"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/lib/agentless"
assistlib "github.com/gravitational/teleport/lib/assist"
+ "github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/httplib"
@@ -50,9 +51,16 @@ import (
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
+ "github.com/gravitational/teleport/lib/srv"
"github.com/gravitational/teleport/lib/teleagent"
)
+// summaryBufferCapacity is the summary buffer size in bytes. The summary buffer
+// is shared across all nodes the command is running on and stores the command
+// output. If the command output exceeds the buffer capacity, the summary won't
+// be computed.
+const summaryBufferCapacity = 2000
+
// CommandRequest is a request to execute a command on all nodes that match the query.
type CommandRequest struct {
// Command is the command to be executed on all nodes.
@@ -213,6 +221,8 @@ func (h *Handler) executeCommand(
mfaCacheFn := getMFACacheFn()
interactiveCommand := strings.Split(req.Command, " ")
+ buffer := newSummaryBuffer(summaryBufferCapacity)
+
runCmd := func(host *hostInfo) error {
sessionData, err := h.generateCommandSession(host, req.Login, clusterName, sessionCtx.cfg.User)
if err != nil {
@@ -234,6 +244,7 @@ func (h *Handler) executeCommand(
TracerProvider: h.cfg.TracerProvider,
LocalAuthProvider: h.auth.accessPoint,
mfaFuncCache: mfaCacheFn,
+ buffer: buffer,
}
handler, err := newCommandHandler(ctx, commandHandlerConfig)
@@ -275,9 +286,114 @@ func (h *Handler) executeCommand(
runCommands(hosts, runCmd, h.log)
+ // Optionally, try to compute the command summary.
+ if output, overflow := buffer.Export(); !overflow || len(output) != 0 {
+ summaryReq := summaryRequest{
+ hosts: hosts,
+ output: output,
+ authClient: clt,
+ identity: identity,
+ executionID: req.ExecutionID,
+ conversationID: req.ConversationID,
+ command: req.Command,
+ }
+ err := h.computeAndSendSummary(ctx, &summaryReq, ws)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ }
+
return nil, nil
}
+type summaryRequest struct {
+ hosts []hostInfo
+ output map[string][]byte
+ authClient auth.ClientI
+ identity srv.IdentityContext
+ executionID string
+ conversationID string
+ command string
+}
+
+func (h *Handler) computeAndSendSummary(
+ ctx context.Context,
+ req *summaryRequest,
+ ws WSConn,
+) error {
+ // Convert the map nodeId->output into a map nodeName->output
+ namedOutput := outputByName(req.hosts, req.output)
+
+ history, err := req.authClient.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
+ ConversationId: req.conversationID,
+ Username: req.identity.TeleportUser,
+ })
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ assistClient, err := assistlib.NewClient(ctx, req.authClient, h.cfg.ProxySettings, h.cfg.OpenAIConfig)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ summary, err := assistClient.GenerateCommandSummary(ctx, history.GetMessages(), namedOutput)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Add the summary message to the backend so it is persisted on chat
+ // reload.
+ messagePayload, err := json.Marshal(&assistlib.CommandExecSummary{
+ ExecutionID: req.executionID,
+ Command: req.command,
+ Summary: summary,
+ })
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ summaryMessage := &assist.CreateAssistantMessageRequest{
+ ConversationId: req.conversationID,
+ Username: req.identity.TeleportUser,
+ Message: &assist.AssistantMessage{
+ Type: string(assistlib.MessageKindCommandResultSummary),
+ CreatedTime: timestamppb.New(time.Now().UTC()),
+ Payload: string(messagePayload),
+ },
+ }
+
+ err = req.authClient.CreateAssistantMessage(ctx, summaryMessage)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Send the summary over the execution websocket to provide instant
+ // feedback to the user.
+ out := &outEnvelope{
+ Type: envelopeTypeSummary,
+ Payload: []byte(summary),
+ }
+ data, err := json.Marshal(out)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ stream := NewWStream(ctx, ws, log, nil)
+ _, err = stream.Write(data)
+ return trace.Wrap(err)
+}
+
+func outputByName(hosts []hostInfo, output map[string][]byte) map[string][]byte {
+ hostIDToName := make(map[string]string, len(hosts))
+ for _, host := range hosts {
+ hostIDToName[host.id] = host.hostName
+ }
+ namedOutput := make(map[string][]byte, len(output))
+ for id, data := range output {
+ namedOutput[hostIDToName[id]] = data
+ }
+ return namedOutput
+}
+
// runCommands runs the given command on the given hosts.
func runCommands(hosts []hostInfo, runCmd func(host *hostInfo) error, log logrus.FieldLogger) {
// Create a synchronization channel to limit the number of concurrent commands.
@@ -356,6 +472,7 @@ func newCommandHandler(ctx context.Context, cfg CommandHandlerConfig) (*commandH
tracer: cfg.tracer,
},
mfaAuthCache: cfg.mfaFuncCache,
+ buffer: cfg.buffer,
}, nil
}
@@ -386,6 +503,9 @@ type CommandHandlerConfig struct {
tracer oteltrace.Tracer
// mfaFuncCache is used to cache the MFA auth method
mfaFuncCache mfaFuncCache
+ // buffer shared across multiple commandHandlers that saves the command
+ // output in order to generate a summary of the executed commands.
+ buffer *summaryBuffer
}
// CheckAndSetDefaults checks and sets default values.
@@ -451,6 +571,10 @@ type commandHandler struct {
// returns a list of ssh.AuthMethods. It is used to cache the result of
// the MFA challenge.
mfaAuthCache mfaFuncCache
+
+ // buffer shared across multiple commandHandlers that saves the command
+ // output in order to generate a summary of the executed commands.
+ buffer *summaryBuffer
}
// sendError sends an error message to the client using the provided websocket.
@@ -597,8 +721,8 @@ func (t *commandHandler) makeClient(ctx context.Context, ws WSConn) (*client.Tel
clientConfig.HostLogin = t.sessionData.Login
clientConfig.ForwardAgent = client.ForwardAgentLocal
clientConfig.Namespace = apidefaults.Namespace
- clientConfig.Stdout = newPayloadWriter(t.sessionData.ServerID, "stdout", t.stream)
- clientConfig.Stderr = newPayloadWriter(t.sessionData.ServerID, "stderr", t.stream)
+ clientConfig.Stdout = newBufferedPayloadWriter(newPayloadWriter(t.sessionData.ServerID, envelopeTypeStdout, t.stream), t.buffer)
+ clientConfig.Stderr = newBufferedPayloadWriter(newPayloadWriter(t.sessionData.ServerID, envelopeTypeStderr, t.stream), t.buffer)
clientConfig.Stdin = &bytes.Buffer{} // set stdin to a dummy buffer
clientConfig.SiteName = t.sessionData.ClusterName
if err := clientConfig.ParseProxyHost(t.proxyHostPort); err != nil {
@@ -622,7 +746,7 @@ func (t *commandHandler) makeClient(ctx context.Context, ws WSConn) (*client.Tel
func (t *commandHandler) writeError(err error) {
out := &outEnvelope{
NodeID: t.sessionData.ServerID,
- Type: "teleport-error",
+ Type: envelopeTypeError,
Payload: []byte(err.Error()),
}
data, err := json.Marshal(out)
diff --git a/lib/web/command_test.go b/lib/web/command_test.go
index c7794a0964aca..ecff846c09c6e 100644
--- a/lib/web/command_test.go
+++ b/lib/web/command_test.go
@@ -24,6 +24,7 @@ import (
"fmt"
"io"
"net/http"
+ "net/http/httptest"
"net/url"
"strings"
"sync/atomic"
@@ -35,11 +36,13 @@ import (
"github.com/gorilla/websocket"
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
+ "github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
+ "github.com/gravitational/teleport/lib/ai/testutils"
assistlib "github.com/gravitational/teleport/lib/assist"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/client"
@@ -47,11 +50,22 @@ import (
"github.com/gravitational/teleport/lib/utils"
)
+const (
+ testCommand = "echo txlxport | sed 's/x/e/g'"
+ testUser = "foo"
+)
+
func TestExecuteCommand(t *testing.T) {
t.Parallel()
- s := newWebSuiteWithConfig(t, webSuiteConfig{disableDiskBasedRecording: true})
+ openAIMock := mockOpenAISummary(t)
+ openAIConfig := openai.DefaultConfig("test-token")
+ openAIConfig.BaseURL = openAIMock.URL + "/v1"
+ s := newWebSuiteWithConfig(t, webSuiteConfig{
+ disableDiskBasedRecording: true,
+ OpenAIConfig: &openAIConfig,
+ })
- ws, _, err := s.makeCommand(t, s.authPack(t, "foo"), uuid.New())
+ ws, _, err := s.makeCommand(t, s.authPack(t, testUser), uuid.New())
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, ws.Close()) })
@@ -63,17 +77,22 @@ func TestExecuteCommand(t *testing.T) {
func TestExecuteCommandHistory(t *testing.T) {
t.Parallel()
- // Given
- s := newWebSuiteWithConfig(t, webSuiteConfig{disableDiskBasedRecording: true})
- authPack := s.authPack(t, "foo")
+ openAIMock := mockOpenAISummary(t)
+ openAIConfig := openai.DefaultConfig("test-token")
+ openAIConfig.BaseURL = openAIMock.URL + "/v1"
+ s := newWebSuiteWithConfig(t, webSuiteConfig{
+ disableDiskBasedRecording: true,
+ OpenAIConfig: &openAIConfig,
+ })
+ authPack := s.authPack(t, testUser)
ctx := context.Background()
- clt, err := s.server.NewClient(auth.TestUser("foo"))
+ clt, err := s.server.NewClient(auth.TestUser(testUser))
require.NoError(t, err)
// Create conversation, otherwise the command execution will not be saved
conversation, err := clt.CreateAssistantConversation(context.Background(), &assist.CreateAssistantConversationRequest{
- Username: "foo",
+ Username: testUser,
CreatedTime: timestamppb.Now(),
})
require.NoError(t, err)
@@ -91,9 +110,8 @@ func TestExecuteCommandHistory(t *testing.T) {
// When command executes
require.NoError(t, waitForCommandOutput(stream, "teleport"))
- // Explecitly close the stream
- err = stream.Close()
- require.NoError(t, err)
+ // Close the stream if not already closed
+ _ = stream.Close()
// Then command execution history is saved
var messages *assist.GetAssistantMessagesResponse
@@ -101,16 +119,17 @@ func TestExecuteCommandHistory(t *testing.T) {
require.Eventually(t, func() bool {
messages, err = clt.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
ConversationId: conversationID.String(),
- Username: "foo",
+ Username: testUser,
})
require.NoError(t, err)
- return len(messages.GetMessages()) == 1
+ return len(messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResult]) == 1
}, 5*time.Second, 100*time.Millisecond)
// Assert the returned message
- msg := messages.GetMessages()[0]
- require.Equal(t, string(assistlib.MessageKindCommandResult), msg.Type)
+ resultMessages, ok := messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResult]
+ require.True(t, ok, "Message must be of type COMMAND_RESULT")
+ msg := resultMessages[0]
require.NotZero(t, msg.CreatedTime)
var result commandExecResult
@@ -123,13 +142,100 @@ func TestExecuteCommandHistory(t *testing.T) {
require.Equal(t, "node", result.NodeID)
}
+func TestExecuteCommandSummary(t *testing.T) {
+ t.Parallel()
+
+ openAIMock := mockOpenAISummary(t)
+ openAIConfig := openai.DefaultConfig("test-token")
+ openAIConfig.BaseURL = openAIMock.URL + "/v1"
+ s := newWebSuiteWithConfig(t, webSuiteConfig{
+ disableDiskBasedRecording: true,
+ OpenAIConfig: &openAIConfig,
+ })
+ authPack := s.authPack(t, testUser)
+
+ ctx := context.Background()
+ clt, err := s.server.NewClient(auth.TestUser(testUser))
+ require.NoError(t, err)
+
+ // Create conversation, otherwise the command execution will not be saved
+ conversation, err := clt.CreateAssistantConversation(context.Background(), &assist.CreateAssistantConversationRequest{
+ Username: testUser,
+ CreatedTime: timestamppb.Now(),
+ })
+ require.NoError(t, err)
+
+ require.NotEmpty(t, conversation.GetId())
+
+ conversationID, err := uuid.Parse(conversation.GetId())
+ require.NoError(t, err)
+
+ ws, _, err := s.makeCommand(t, authPack, conversationID)
+ require.NoError(t, err)
+
+ // The current Assist execution relies on a hack that multiplexes multiple
+ // streams over a single one. This causes issues when a stream close as the
+ // single stream receiver will consider it should close. We work around by
+ // using a non-closing websocket and initiating a proper stream close.
+ // Then we reopen a new stream on the same websocket and continue reading
+ // the summary.
+ nonClosableWebsocket := &noopCloserWS{ws}
+ stream := NewWStream(ctx, nonClosableWebsocket, utils.NewLoggerForTests(), nil)
+
+ // Wait for command execution to complete
+ require.NoError(t, waitForCommandOutput(stream, "teleport"))
+
+ // Stop the stream consumption
+ err = stream.Close()
+ require.NoError(t, err)
+
+ // Start a new stream and process the summary event
+ var env Envelope
+ stream = NewWStream(ctx, ws, utils.NewLoggerForTests(), nil)
+ dec := json.NewDecoder(stream)
+ err = dec.Decode(&env)
+ require.NoError(t, err)
+ require.Equal(t, envelopeTypeSummary, env.GetType())
+ require.NotEmpty(t, env.GetPayload())
+
+ // Close the stream, and the underlying websocket
+ _ = stream.Close()
+
+ // Wait for the command execution history to be saved
+ var messages *assist.GetAssistantMessagesResponse
+ // Command execution history is saved in asynchronusly, so we need to wait for it.
+ require.Eventually(t, func() bool {
+ messages, err = clt.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{
+ ConversationId: conversationID.String(),
+ Username: testUser,
+ })
+ require.NoError(t, err)
+
+ return len(messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResultSummary]) == 1
+ }, 5*time.Second, 100*time.Millisecond)
+
+ // Check the returned summary message
+ summaryMessages, ok := messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResultSummary]
+ require.True(t, ok, "At least one summary message is expected")
+ msg := summaryMessages[0]
+ require.NotZero(t, msg.CreatedTime)
+
+ var result assistlib.CommandExecSummary
+ err = json.Unmarshal([]byte(msg.GetPayload()), &result)
+ require.NoError(t, err)
+
+ require.NotEmpty(t, result.ExecutionID)
+ require.Equal(t, testCommand, result.Command)
+ require.NotEmpty(t, result.Summary)
+}
+
func (s *WebSuite) makeCommand(t *testing.T, pack *authPack, conversationID uuid.UUID) (*websocket.Conn, *session.Session, error) {
req := CommandRequest{
Query: fmt.Sprintf("name == \"%s\"", s.srvID),
Login: pack.login,
ConversationID: conversationID.String(),
ExecutionID: uuid.New().String(),
- Command: "echo txlxport | sed 's/x/e/g'",
+ Command: testCommand,
}
u := url.URL{
@@ -243,3 +349,18 @@ func Test_runCommands(t *testing.T) {
require.Equal(t, int32(100), counter.Load())
}
+
+func mockOpenAISummary(t *testing.T) *httptest.Server {
+ responses := []string{"This is the summary of the command."}
+ server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses))
+ t.Cleanup(server.Close)
+ return server
+}
+
+func messagesByType(messages []*assist.AssistantMessage) map[assistlib.MessageType][]*assist.AssistantMessage {
+ byType := make(map[assistlib.MessageType][]*assist.AssistantMessage)
+ for _, message := range messages {
+ byType[assistlib.MessageType(message.GetType())] = append(byType[assistlib.MessageType(message.GetType())], message)
+ }
+ return byType
+}
diff --git a/lib/web/command_utils.go b/lib/web/command_utils.go
index db8ebdbc42a5a..48b4712fb74eb 100644
--- a/lib/web/command_utils.go
+++ b/lib/web/command_utils.go
@@ -45,6 +45,13 @@ type WSConn interface {
SetPongHandler(h func(appData string) error)
}
+const (
+ envelopeTypeStdout = "stdout"
+ envelopeTypeStderr = "stderr"
+ envelopeTypeError = "teleport-error"
+ envelopeTypeSummary = "summary"
+)
+
// outEnvelope is an envelope used to wrap messages send back to the client connected over WS.
type outEnvelope struct {
NodeID string `json:"node_id"`
@@ -56,7 +63,7 @@ type outEnvelope struct {
// outEnvelope and writes it to the underlying stream.
type payloadWriter struct {
nodeID string
- // output name, can be stdout, stderr or teleport-error.
+ // output name, can be stdout, stderr, teleport-error or summary.
outputName string
// stream is the underlying stream.
stream io.Writer
@@ -130,3 +137,65 @@ func (s *syncRWWSConn) ReadMessage() (messageType int, p []byte, err error) {
defer s.rmtx.Unlock()
return s.WSConn.ReadMessage()
}
+
+func newBufferedPayloadWriter(pw *payloadWriter, buffer *summaryBuffer) *bufferedPayloadWriter {
+ return &bufferedPayloadWriter{
+ payloadWriter: pw,
+ buffer: buffer,
+ }
+}
+
+type bufferedPayloadWriter struct {
+ *payloadWriter
+ buffer *summaryBuffer
+}
+
+func (bp *bufferedPayloadWriter) Write(data []byte) (int, error) {
+ bp.buffer.Write(bp.nodeID, data)
+ return bp.payloadWriter.Write(data)
+}
+
+func newSummaryBuffer(capacity int) *summaryBuffer {
+ return &summaryBuffer{
+ buffer: make(map[string][]byte),
+ remainingCapacity: capacity,
+ invalid: false,
+ mutex: sync.Mutex{},
+ }
+}
+
+type summaryBuffer struct {
+ buffer map[string][]byte
+ remainingCapacity int
+ invalid bool
+ // mutex protects all members of the struct and must be acquired before
+ // performing any read or write operation
+ mutex sync.Mutex
+}
+
+func (b *summaryBuffer) Write(node string, data []byte) {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ if b.invalid {
+ return
+ }
+ if len(data) > b.remainingCapacity {
+ // We're out of capacity, not all content will be written to the buffer
+ // it should not be used anymore
+ b.invalid = true
+ return
+ }
+ b.buffer[node] = append(b.buffer[node], data...)
+ b.remainingCapacity -= len(data)
+}
+
+// Export returns the buffer content and a whether the buffer overflowed.
+func (b *summaryBuffer) Export() (map[string][]byte, bool) {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ if b.invalid {
+ return nil, true
+ }
+ b.invalid = true
+ return b.buffer, false
+}
diff --git a/lib/web/command_utils_test.go b/lib/web/command_utils_test.go
new file mode 100644
index 0000000000000..01528bff29c1b
--- /dev/null
+++ b/lib/web/command_utils_test.go
@@ -0,0 +1,142 @@
+/*
+Copyright 2023 Gravitational, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+package web
+
+import (
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestSummaryBuffer(t *testing.T) {
+ tests := []struct {
+ name string
+ outputs map[string][][]byte
+ capacity int
+ expectedOutput map[string][]byte
+ expectedOverflow bool
+ }{
+ {
+ name: "Single node",
+ outputs: map[string][][]byte{
+ "node": {
+ []byte("foo"),
+ []byte("bar"),
+ []byte("baz"),
+ },
+ },
+ capacity: 9,
+ expectedOutput: map[string][]byte{
+ "node": []byte("foobarbaz"),
+ },
+ expectedOverflow: false,
+ },
+ {
+ name: "Single node overflow",
+ outputs: map[string][][]byte{
+ "node": {
+ []byte("foo"),
+ []byte("bar"),
+ []byte("baz"),
+ },
+ },
+ capacity: 8,
+ expectedOutput: nil,
+ expectedOverflow: true,
+ },
+ {
+ name: "Multiple nodes",
+ outputs: map[string][][]byte{
+ "node1": {
+ []byte("foo"),
+ []byte("bar"),
+ []byte("baz"),
+ },
+ "node2": {
+ []byte("baz"),
+ []byte("bar"),
+ []byte("foo"),
+ },
+ "node3": {
+ []byte("baz"),
+ []byte("baz"),
+ []byte("baz"),
+ },
+ },
+ capacity: 30,
+ expectedOutput: map[string][]byte{
+ "node1": []byte("foobarbaz"),
+ "node2": []byte("bazbarfoo"),
+ "node3": []byte("bazbazbaz"),
+ },
+ expectedOverflow: false,
+ },
+ {
+ name: "Multiple nodes overflow",
+ outputs: map[string][][]byte{
+ "node1": {
+ []byte("foo"),
+ []byte("bar"),
+ []byte("baz"),
+ },
+ "node2": {
+ []byte("baz"),
+ []byte("bar"),
+ []byte("foo"),
+ },
+ "node3": {
+ []byte("baz"),
+ []byte("baz"),
+ []byte("baz"),
+ },
+ },
+ capacity: 25,
+ expectedOutput: nil,
+ expectedOverflow: true,
+ },
+ {
+ name: "No output",
+ outputs: nil,
+ capacity: 10,
+ expectedOutput: map[string][]byte{},
+ expectedOverflow: false,
+ },
+ }
+ for _, tc := range tests {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ buffer := newSummaryBuffer(tc.capacity)
+ var wg sync.WaitGroup
+ for node, output := range tc.outputs {
+ node := node
+ output := output
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for _, chunk := range output {
+ buffer.Write(node, chunk)
+ }
+ }()
+ }
+ wg.Wait()
+ output, overflow := buffer.Export()
+ require.Equal(t, tc.expectedOutput, output)
+ require.Equal(t, tc.expectedOverflow, overflow)
+
+ })
+ }
+}
diff --git a/web/packages/teleport/src/Assist/Conversation/CommandResultSummaryEntry.tsx b/web/packages/teleport/src/Assist/Conversation/CommandResultSummaryEntry.tsx
new file mode 100644
index 0000000000000..5fb7626a8563b
--- /dev/null
+++ b/web/packages/teleport/src/Assist/Conversation/CommandResultSummaryEntry.tsx
@@ -0,0 +1,71 @@
+/**
+ * Copyright 2023 Gravitational, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import React from 'react';
+import styled from 'styled-components';
+
+import ReactMarkdown from 'react-markdown';
+import remarkGfm from 'remark-gfm';
+
+import { markdownCSS } from 'teleport/Assist/markdown';
+
+interface CommandResultSummaryEntryProps {
+ command: string;
+ summary: string;
+}
+
+const Container = styled.div`
+ border-radius: 10px;
+ position: relative;
+`;
+
+const Title = styled.div`
+ font-size: 15px;
+ font-weight: 600;
+ padding: 10px 15px;
+`;
+
+const Summary = styled.div`
+ padding: 10px 15px 0 17px;
+
+ ${markdownCSS}
+`;
+
+const Header = styled.div`
+ display: flex;
+ justify-content: space-between;
+ padding-right: 20px;
+`;
+
+export function CommandResultSummaryEntry(
+ props: CommandResultSummaryEntryProps
+) {
+ return (
+
+
+
+ Summary of command execution {props.command}
+
+
+
+
+
+ {props.summary}
+
+
+
+ );
+}
diff --git a/web/packages/teleport/src/Assist/Conversation/Message.tsx b/web/packages/teleport/src/Assist/Conversation/Message.tsx
index 4cb5f2fad5e18..effcc450df6e7 100644
--- a/web/packages/teleport/src/Assist/Conversation/Message.tsx
+++ b/web/packages/teleport/src/Assist/Conversation/Message.tsx
@@ -21,6 +21,7 @@ import { CheckIcon } from 'design/SVGIcon';
import {
Author,
+ ConversationMessage,
ResolvedServerMessage,
ServerMessageType,
} from 'teleport/Assist/types';
@@ -39,8 +40,7 @@ import { MessageEntry } from 'teleport/Assist/Conversation/MessageEntry';
import { useAssist } from 'teleport/Assist/context/AssistContext';
import { ExecuteRemoteCommandEntry } from 'teleport/Assist/Conversation/ExecuteRemoteCommandEntry';
import { CommandResultEntry } from 'teleport/Assist/Conversation/CommandResultEntry';
-
-import type { ConversationMessage } from 'teleport/Assist/types';
+import { CommandResultSummaryEntry } from 'teleport/Assist/Conversation/CommandResultSummaryEntry';
interface MessageProps {
message: ConversationMessage;
@@ -139,6 +139,13 @@ function createComponentForEntry(
errorMessage={entry.errorMessage}
/>
);
+ case ServerMessageType.CommandResultSummary:
+ return (
+
+ );
}
}
diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx
index 942fe036c7c34..b61bb8de0535c 100644
--- a/web/packages/teleport/src/Assist/context/AssistContext.tsx
+++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx
@@ -31,7 +31,11 @@ import useStickyClusterId from 'teleport/useStickyClusterId';
import cfg from 'teleport/config';
import { getAccessToken, getHostName } from 'teleport/services/api';
-import { RawPayload, ServerMessageType } from 'teleport/Assist/types';
+import {
+ ExecutionEnvelopeType,
+ RawPayload,
+ ServerMessageType,
+} from 'teleport/Assist/types';
import { MessageTypeEnum, Protobuf } from 'teleport/lib/term/protobuf';
@@ -41,6 +45,7 @@ import {
} from 'teleport/services/auth';
import * as service from '../service';
+
import { resolveServerCommandMessage, resolveServerMessage } from '../service';
import type {
@@ -48,6 +53,7 @@ import type {
ResolvedServerMessage,
ServerMessage,
} from 'teleport/Assist/types';
+
import type { AssistState } from 'teleport/Assist/context/state';
interface AssistContextValue {
@@ -421,13 +427,22 @@ export function AssistContextProvider(props: PropsWithChildren) {
const data = JSON.parse(msg.payload) as RawPayload;
const payload = atob(data.payload);
- dispatch({
- type: AssistStateActionType.UpdateCommandResult,
- conversationId: state.conversations.selectedId,
- commandResultId: nodeIdToResultId.get(data.node_id),
- output: payload,
- });
-
+ if (data.type === ExecutionEnvelopeType) {
+ dispatch({
+ type: AssistStateActionType.AddCommandResultSummary,
+ conversationId: state.conversations.selectedId,
+ summary: payload,
+ executionId: execParams.execution_id,
+ command: execParams.command,
+ });
+ } else {
+ dispatch({
+ type: AssistStateActionType.UpdateCommandResult,
+ conversationId: state.conversations.selectedId,
+ commandResultId: nodeIdToResultId.get(data.node_id),
+ output: payload,
+ });
+ }
break;
case MessageTypeEnum.WEBAUTHN_CHALLENGE:
@@ -455,18 +470,20 @@ export function AssistContextProvider(props: PropsWithChildren) {
sessionsEnded += 1;
if (sessionsEnded === nodeIdToResultId.size) {
+ const message = proto.encodeCloseMessage();
+ const bytearray = new Uint8Array(message);
+
for (const nodeId of nodeIdToResultId.keys()) {
dispatch({
type: AssistStateActionType.FinishCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(nodeId),
});
+
+ executeCommandWebSocket.current.send(bytearray.buffer);
}
nodeIdToResultId.clear();
-
- // TODO(ryan): move this to after the summary is sent once it's implemented
- executeCommandWebSocket.current.close();
}
break;
diff --git a/web/packages/teleport/src/Assist/context/state.ts b/web/packages/teleport/src/Assist/context/state.ts
index 41dc5312ff367..aa7ace7a5b6c0 100644
--- a/web/packages/teleport/src/Assist/context/state.ts
+++ b/web/packages/teleport/src/Assist/context/state.ts
@@ -60,6 +60,7 @@ export enum AssistStateActionType {
PromptMfa,
DeleteConversation,
UpdateConversationTitle,
+ AddCommandResultSummary,
}
export interface ReplaceConversationsAction {
@@ -164,6 +165,14 @@ export interface UpdateConversationTitleAction {
title: string;
}
+export interface AddCommandResultSummaryAction {
+ type: AssistStateActionType.AddCommandResultSummary;
+ summary: string;
+ conversationId: string;
+ command: string;
+ executionId: string;
+}
+
export type AssistContextAction =
| SetConversationsLoadingAction
| ReplaceConversationsAction
@@ -181,7 +190,8 @@ export type AssistContextAction =
| FinishCommandResultAction
| PromptMfaAction
| DeleteConversationAction
- | UpdateConversationTitleAction;
+ | UpdateConversationTitleAction
+ | AddCommandResultSummaryAction;
export function reducer(
state: AssistState,
@@ -239,6 +249,9 @@ export function reducer(
case AssistStateActionType.UpdateConversationTitle:
return updateConversationTitle(state, action);
+ case AssistStateActionType.AddCommandResultSummary:
+ return addCommandResultSummary(state, action);
+
default:
return state;
}
@@ -589,6 +602,35 @@ export function finishCommandResult(
};
}
+export function addCommandResultSummary(
+ state: AssistState,
+ action: AddCommandResultSummaryAction
+): AssistState {
+ const messages = new Map(state.messages.data);
+
+ let conversationMessages = messages.get(action.conversationId);
+
+ conversationMessages = [
+ ...conversationMessages,
+ {
+ type: ServerMessageType.CommandResultSummary,
+ created: new Date(),
+ executionId: action.executionId,
+ command: action.command,
+ summary: action.summary,
+ },
+ ];
+
+ messages.set(action.conversationId, conversationMessages);
+
+ return {
+ ...state,
+ messages: {
+ ...state.messages,
+ data: messages,
+ },
+ };
+}
export function promptMfa(
state: AssistState,
action: PromptMfaAction
diff --git a/web/packages/teleport/src/Assist/context/utils.ts b/web/packages/teleport/src/Assist/context/utils.ts
index e7473f7003ea7..795dcda0c7e23 100644
--- a/web/packages/teleport/src/Assist/context/utils.ts
+++ b/web/packages/teleport/src/Assist/context/utils.ts
@@ -32,6 +32,7 @@ function getMessageTypeAuthor(type: string) {
case ServerMessageType.Command:
case ServerMessageType.CommandResult:
case ServerMessageType.CommandResultStream:
+ case ServerMessageType.CommandResultSummary:
case ServerMessageType.Error:
return Author.Teleport;
}
diff --git a/web/packages/teleport/src/Assist/service.ts b/web/packages/teleport/src/Assist/service.ts
index 49d263a2da07e..b7294f85f1f28 100644
--- a/web/packages/teleport/src/Assist/service.ts
+++ b/web/packages/teleport/src/Assist/service.ts
@@ -31,6 +31,7 @@ import { ServerMessageType } from './types';
import type {
CommandResultPayload,
+ CommandResultSummaryPayload,
Conversation,
CreateConversationResponse,
ExecEvent,
@@ -39,6 +40,7 @@ import type {
GetConversationMessagesResponse,
GetConversationsResponse,
ResolvedCommandResultServerMessage,
+ ResolvedCommandResultSummaryServerMessage,
ResolvedCommandServerMessage,
ResolvedServerMessage,
ServerMessage,
@@ -68,6 +70,9 @@ export async function resolveServerMessage(
case ServerMessageType.CommandResult:
return resolveServerCommandResultMessage(message, clusterId);
+ case ServerMessageType.CommandResultSummary:
+ return resolveServerCommandResultSummaryMessage(message);
+
case ServerMessageType.Assist:
case ServerMessageType.User:
return {
@@ -173,6 +178,20 @@ export async function resolveServerCommandResultMessage(
}
}
+export function resolveServerCommandResultSummaryMessage(
+ message: ServerMessage
+): ResolvedCommandResultSummaryServerMessage {
+ const payload = JSON.parse(message.payload) as CommandResultSummaryPayload;
+
+ return {
+ type: ServerMessageType.CommandResultSummary,
+ executionId: payload.execution_id,
+ command: payload.command,
+ summary: payload.summary,
+ created: new Date(message.created_time),
+ };
+}
+
export function resolveServerCommandMessage(
message: ServerMessage
): ResolvedCommandServerMessage {
diff --git a/web/packages/teleport/src/Assist/types.ts b/web/packages/teleport/src/Assist/types.ts
index ba73068e0c80e..a46c9fb94b371 100644
--- a/web/packages/teleport/src/Assist/types.ts
+++ b/web/packages/teleport/src/Assist/types.ts
@@ -21,12 +21,15 @@ export enum ServerMessageType {
Error = 'CHAT_MESSAGE_ERROR',
Command = 'COMMAND',
CommandResult = 'COMMAND_RESULT',
+ CommandResultSummary = 'COMMAND_RESULT_SUMMARY',
CommandResultStream = 'COMMAND_RESULT_STREAM',
AssistPartialMessage = 'CHAT_PARTIAL_MESSAGE_ASSISTANT',
AssistPartialMessageEnd = 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE',
AssistThought = 'CHAT_THOUGHT_ASSISTANT',
}
+export const ExecutionEnvelopeType = 'summary';
+
export interface Conversation {
id: string;
title?: string;
@@ -68,6 +71,14 @@ export interface ResolvedCommandResultServerMessage {
created: Date;
}
+export interface ResolvedCommandResultSummaryServerMessage {
+ type: ServerMessageType.CommandResultSummary;
+ executionId: string;
+ summary: string;
+ command: string;
+ created: Date;
+}
+
export interface ResolvedAssistThoughtServerMessage {
type: ServerMessageType.AssistThought;
message: string;
@@ -108,6 +119,7 @@ export type ResolvedServerMessage =
| ResolvedUserServerMessage
| ResolvedErrorServerMessage
| ResolvedCommandResultServerMessage
+ | ResolvedCommandResultSummaryServerMessage
| ResolvedAssistThoughtServerMessage
| ResolvedCommandResultStreamServerMessage;
@@ -142,6 +154,12 @@ export interface CommandResultPayload {
execution_id: string;
}
+export interface CommandResultSummaryPayload {
+ execution_id: string;
+ command: string;
+ summary: string;
+}
+
export interface ExecEvent {
event: EventType.EXEC;
exitError?: string;
@@ -162,6 +180,7 @@ export interface NodeState {
export interface RawPayload {
node_id: string;
+ type: string;
payload: string;
}
diff --git a/web/packages/teleport/src/lib/term/protobuf.js b/web/packages/teleport/src/lib/term/protobuf.js
index a259070ae7804..c654dacbef29c 100644
--- a/web/packages/teleport/src/lib/term/protobuf.js
+++ b/web/packages/teleport/src/lib/term/protobuf.js
@@ -90,6 +90,11 @@ export class Protobuf {
return this.encode(messageFields.type.values.data, message);
}
+ encodeCloseMessage() {
+ // Close message has no payload
+ return this.encode(messageFields.type.values.close, '');
+ }
+
encodePayload(buffer, text) {
// set type
buffer.push(messageFields.payload.code);