diff --git a/lib/web/command.go b/lib/web/command.go index fd38c44ea2930..2117bf51be768 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -67,6 +67,18 @@ type CommandRequest struct { ExecutionID string `json:"execution_id"` } +// commandExecResult is a result of a command execution. +type commandExecResult struct { + // NodeID is the ID of the node where the command was executed. + NodeID string `json:"node_id"` + // NodeName is the name of the node where the command was executed. + NodeName string `json:"node_name"` + // ExecutionID is a unique ID used to identify the command execution. + ExecutionID string `json:"execution_id"` + // SessionID is the ID of the session where the command was executed. + SessionID string `json:"session_id"` +} + // Check checks if the request is valid. func (c *CommandRequest) Check() error { if c.Command == "" { @@ -226,12 +238,9 @@ func (h *Handler) executeCommand( h.log.Infof("Executing command: %#v.", req) httplib.MakeTracingHandler(handler, teleport.ComponentProxy).ServeHTTP(w, r) - msgPayload, err := json.Marshal(struct { - NodeID string `json:"node_id"` - ExecutionID string `json:"execution_id"` - SessionID string `json:"session_id"` - }{ + msgPayload, err := json.Marshal(&commandExecResult{ NodeID: host.id, + NodeName: host.hostName, ExecutionID: req.ExecutionID, SessionID: string(sessionData.ID), }) diff --git a/lib/web/command_test.go b/lib/web/command_test.go index a21e41a82a3fb..2134720918d84 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -37,7 +37,11 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" + assistlib "github.com/gravitational/teleport/lib/assist" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" @@ -47,7 +51,7 @@ func TestExecuteCommand(t *testing.T) { t.Parallel() s := newWebSuite(t) - ws, _, err := s.makeCommand(t, s.authPack(t, "foo")) + ws, _, err := s.makeCommand(t, s.authPack(t, "foo"), uuid.New()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -56,11 +60,74 @@ func TestExecuteCommand(t *testing.T) { require.NoError(t, waitForCommandOutput(stream, "teleport")) } -func (s *WebSuite) makeCommand(t *testing.T, pack *authPack) (*websocket.Conn, *session.Session, error) { +func TestExecuteCommandHistory(t *testing.T) { + t.Parallel() + + // Given + s := newWebSuite(t) + authPack := s.authPack(t, "foo") + + ctx := context.Background() + clt, err := s.server.NewClient(auth.TestUser("foo")) + require.NoError(t, err) + + // Create conversation, otherwise the command execution will not be saved + conversation, err := clt.CreateAssistantConversation(context.Background(), &assist.CreateAssistantConversationRequest{ + Username: "foo", + CreatedTime: timestamppb.Now(), + }) + require.NoError(t, err) + + require.NotEmpty(t, conversation.GetId()) + + conversationID, err := uuid.Parse(conversation.GetId()) + require.NoError(t, err) + + ws, _, err := s.makeCommand(t, authPack, conversationID) + require.NoError(t, err) + + stream := NewWStream(ctx, ws, utils.NewLoggerForTests(), nil) + + // When command executes + require.NoError(t, waitForCommandOutput(stream, "teleport")) + + // Explecitly close the stream + err = stream.Close() + require.NoError(t, err) + + // Then command execution history is saved + var messages *assist.GetAssistantMessagesResponse + // Command execution history is saved in asynchronusly, so we need to wait for it. + require.Eventually(t, func() bool { + messages, err = clt.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ + ConversationId: conversationID.String(), + Username: "foo", + }) + require.NoError(t, err) + + return len(messages.GetMessages()) == 1 + }, 5*time.Second, 100*time.Millisecond) + + // Assert the returned message + msg := messages.GetMessages()[0] + require.Equal(t, string(assistlib.MessageKindCommandResult), msg.Type) + require.NotZero(t, msg.CreatedTime) + + var result commandExecResult + err = json.Unmarshal([]byte(msg.GetPayload()), &result) + require.NoError(t, err) + + require.NotEmpty(t, result.ExecutionID) + require.NotEmpty(t, result.SessionID) + require.Equal(t, "node", result.NodeName) + require.Equal(t, "node", result.NodeID) +} + +func (s *WebSuite) makeCommand(t *testing.T, pack *authPack, conversationID uuid.UUID) (*websocket.Conn, *session.Session, error) { req := CommandRequest{ Query: fmt.Sprintf("name == \"%s\"", s.srvID), Login: pack.login, - ConversationID: uuid.New().String(), + ConversationID: conversationID.String(), ExecutionID: uuid.New().String(), Command: "echo txlxport | sed 's/x/e/g'", }