diff --git a/lib/web/command.go b/lib/web/command.go index ba47b4dac654f..5e758f4f97f0e 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -188,80 +188,108 @@ func (h *Handler) executeCommand( h.log.Debugf("Found %d hosts to run Assist command %q on.", len(hosts), req.Command) mfaCacheFn := getMFACacheFn() + interactiveCommand := strings.Split(req.Command, " ") - for _, host := range hosts { - err := func() error { - sessionData, err := h.generateCommandSession(&host, req.Login, clusterName, sessionCtx.cfg.User) - if err != nil { - h.log.WithError(err).Debug("Unable to generate new ssh session.") - return trace.Wrap(err) - } + runCmd := func(host *hostInfo) error { + sessionData, err := h.generateCommandSession(host, req.Login, clusterName, sessionCtx.cfg.User) + if err != nil { + h.log.WithError(err).Debug("Unable to generate new ssh session.") + return trace.Wrap(err) + } - h.log.Debugf("New command request for server=%s, id=%v, login=%s, sid=%s, websid=%s.", - host.hostName, host.id, req.Login, sessionData.ID, sessionCtx.GetSessionID()) - - commandHandlerConfig := CommandHandlerConfig{ - SessionCtx: sessionCtx, - AuthProvider: clt, - SessionData: sessionData, - KeepAliveInterval: netConfig.GetKeepAliveInterval(), - ProxyHostPort: h.ProxyHostPort(), - InteractiveCommand: strings.Split(req.Command, " "), - Router: h.cfg.Router, - TracerProvider: h.cfg.TracerProvider, - LocalAuthProvider: h.auth.accessPoint, - mfaFuncCache: mfaCacheFn, - } + h.log.Debugf("New command request for server=%s, id=%v, login=%s, sid=%s, websid=%s.", + host.hostName, host.id, req.Login, sessionData.ID, sessionCtx.GetSessionID()) + + commandHandlerConfig := CommandHandlerConfig{ + SessionCtx: sessionCtx, + AuthProvider: clt, + SessionData: sessionData, + KeepAliveInterval: keepAliveInterval, + ProxyHostPort: h.ProxyHostPort(), + InteractiveCommand: interactiveCommand, + Router: h.cfg.Router, + TracerProvider: h.cfg.TracerProvider, + LocalAuthProvider: h.auth.accessPoint, + mfaFuncCache: mfaCacheFn, + } - handler, err := newCommandHandler(ctx, commandHandlerConfig) - if err != nil { - h.log.WithError(err).Error("Unable to create terminal.") - return trace.Wrap(err) - } - handler.ws = &noopCloserWS{ws} - - h.userConns.Add(1) - defer h.userConns.Add(-1) - - h.log.Infof("Executing command: %#v.", req) - httplib.MakeTracingHandler(handler, teleport.ComponentProxy).ServeHTTP(w, r) - - msgPayload, err := json.Marshal(struct { - NodeID string `json:"node_id"` - ExecutionID string `json:"execution_id"` - SessionID string `json:"session_id"` - }{ - NodeID: host.id, - ExecutionID: req.ExecutionID, - SessionID: string(sessionData.ID), - }) - - if err != nil { - return trace.Wrap(err) - } + handler, err := newCommandHandler(ctx, commandHandlerConfig) + if err != nil { + h.log.WithError(err).Error("Unable to create terminal.") + return trace.Wrap(err) + } + handler.ws = &noopCloserWS{ws} - err = clt.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{ - ConversationId: req.ConversationID, - Username: identity.TeleportUser, - Message: &assist.AssistantMessage{ - Type: string(assistlib.MessageKindCommandResult), - CreatedTime: timestamppb.New(time.Now().UTC()), - Payload: string(msgPayload), - }, - }) + h.userConns.Add(1) + defer h.userConns.Add(-1) - return trace.Wrap(err) - }() + h.log.Infof("Executing command: %#v.", req) + httplib.MakeTracingHandler(handler, teleport.ComponentProxy).ServeHTTP(w, r) + + msgPayload, err := json.Marshal(struct { + NodeID string `json:"node_id"` + ExecutionID string `json:"execution_id"` + SessionID string `json:"session_id"` + }{ + NodeID: host.id, + ExecutionID: req.ExecutionID, + SessionID: string(sessionData.ID), + }) if err != nil { - h.log.WithError(err).Warnf("Failed to start session: %v", host.hostName) - continue + return trace.Wrap(err) } + + err = clt.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{ + ConversationId: req.ConversationID, + Username: identity.TeleportUser, + Message: &assist.AssistantMessage{ + Type: string(assistlib.MessageKindCommandResult), + CreatedTime: timestamppb.New(time.Now().UTC()), + Payload: string(msgPayload), + }, + }) + + return trace.Wrap(err) } + runCommands(hosts, runCmd, h.log) + return nil, nil } +// runCommands runs the given command on the given hosts. +func runCommands(hosts []hostInfo, runCmd func(host *hostInfo) error, log logrus.FieldLogger) { + // Create a synchronization channel to limit the number of concurrent commands. + // The maximum number of concurrent commands is 30 - it is arbitrary. + syncChan := make(chan struct{}, 30) + // WaiteGroup to wait for all commands to finish. + wg := sync.WaitGroup{} + + for _, host := range hosts { + host := host + wg.Add(1) + + go func() { + defer wg.Done() + + // Limit the number of concurrent commands. + syncChan <- struct{}{} + defer func() { + // Release the command slot. + <-syncChan + }() + + if err := runCmd(&host); err != nil { + log.WithError(err).Warnf("Failed to start session: %v", host.hostName) + } + }() + } + + // Wait for all commands to finish. + wg.Wait() +} + // getMFACacheFn returns a function that caches the result of the given // get function. The cache is protected by a mutex, so it is safe to call // the returned function from multiple goroutines. diff --git a/lib/web/command_test.go b/lib/web/command_test.go index 4cb3d7b0a904d..7eb5860de1edf 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -25,6 +25,7 @@ import ( "net/http" "net/url" "strings" + "sync/atomic" "testing" "time" @@ -33,6 +34,7 @@ import ( "github.com/gorilla/websocket" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/client" @@ -154,3 +156,29 @@ func waitForCommandOutput(stream io.Reader, substr string) error { } } } + +// Test_runCommands tests that runCommands runs the given command on all hosts. +// The commands should run in parallel, but we don't have a deterministic way to +// test that (sleep with checking the execution time in not deterministic). +func Test_runCommands(t *testing.T) { + counter := atomic.Int32{} + + runCmd := func(host *hostInfo) error { + counter.Add(1) + return nil + } + + hosts := make([]hostInfo, 0, 100) + for i := 0; i < 100; i++ { + hosts = append(hosts, hostInfo{ + hostName: fmt.Sprintf("localhost%d", i), + }) + } + + logger := logrus.New() + logger.Out = io.Discard + + runCommands(hosts, runCmd, logger) + + require.Equal(t, int32(100), counter.Load()) +}