diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index df3554f1debb5..4fbafbee75389 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -3374,28 +3374,29 @@ func retryWithAccessRequest( } func promptUserForAccessRequestDetails(cf *CLIConf, req types.AccessRequest) error { - if cf.RequestMode != accessRequestModeRole { - return nil - } - // If this is a role access request, ensure that it only has one role. - switch len(req.GetRoles()) { - case 0: - return trace.AccessDenied("no roles to request that would grant access") - case 1: - return nil - default: - selectedRole, err := prompt.PickOne( - cf.Context, os.Stdout, prompt.NewContextReader(os.Stdin), - "Choose role to request", - req.GetRoles()) - if err != nil { - return trace.Wrap(err) + if cf.RequestMode == accessRequestModeRole { + // If this is a role access request, ensure that it only has one role. + switch len(req.GetRoles()) { + case 0: + return trace.AccessDenied("no roles to request that would grant access") + case 1: + // No need to choose a role, just set request reason. + default: + selectedRole, err := prompt.PickOne( + cf.Context, os.Stdout, prompt.NewContextReader(os.Stdin), + "Choose role to request", + req.GetRoles()) + if err != nil { + return trace.Wrap(err) + } + req.SetRoles([]string{selectedRole}) } - req.SetRoles([]string{selectedRole}) } + if err := setAccessRequestReason(cf, req); err != nil { return trace.Wrap(err) } + return nil } diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index 5cce530be144e..50bf98a7b19a8 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -48,6 +48,7 @@ import ( "github.com/ghodss/yaml" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1975,6 +1976,20 @@ func TestSSHAccessRequest(t *testing.T) { _, err = rootAuth.GetAuthServer().UpsertUser(ctx, alice) require.NoError(t, err) + err = Run(ctx, []string{ + "logout", + }, setHomePath(tmpHomePath)) + require.NoError(t, err) + + err = Run(ctx, []string{ + "login", + "--insecure", + "--proxy", proxyAddr.String(), + "--user", "alice", + }, setHomePath(tmpHomePath), setMockSSOLogin(rootAuth.GetAuthServer(), alice, connector.GetName())) + require.NoError(t, err) + + requestReason := uuid.New().String() // the first ssh request can fail if the proxy node watcher doesn't know // about the nodes yet, retry a few times until it works require.Eventually(t, func() bool { @@ -1984,7 +1999,7 @@ func TestSSHAccessRequest(t *testing.T) { "--debug", "--insecure", "--request-mode", tc.requestMode, - "--request-reason", "reason here to bypass prompt", + "--request-reason", requestReason, fmt.Sprintf("%s@%s", user.Username, sshHostname), "echo", "test", }, setHomePath(tmpHomePath)) @@ -1994,6 +2009,12 @@ func TestSSHAccessRequest(t *testing.T) { return err == nil }, 10*time.Second, 100*time.Millisecond, "failed to ssh with retries") + requests, err := rootAuth.GetAuthServer().GetAccessRequests(ctx, types.AccessRequestFilter{}) + require.NoError(t, err) + require.True(t, slices.ContainsFunc(requests, func(request types.AccessRequest) bool { + return request.GetRequestReason() == requestReason + }), "access request with the specified reason was not found") + // now that we have an approved access request, it should work without // prompting for a request reason err = Run(ctx, []string{