diff --git a/integration/helpers/helpers.go b/integration/helpers/helpers.go index 7087877c80db3..d2c14bdbfb28e 100644 --- a/integration/helpers/helpers.go +++ b/integration/helpers/helpers.go @@ -205,11 +205,13 @@ func MustCreateUserIdentityFile(t *testing.T, tc *TeleInstance, username string, require.NoError(t, err) key.ClusterName = tc.Secrets.SiteName - sshCert, tlsCert, err := tc.Process.GetAuthServer().GenerateUserTestCerts( - key.MarshalSSHPublicKey(), username, ttl, - constants.CertificateFormatStandard, - tc.Secrets.SiteName, "", - ) + sshCert, tlsCert, err := tc.Process.GetAuthServer().GenerateUserTestCerts(auth.GenerateUserTestCertsRequest{ + Key: key.MarshalSSHPublicKey(), + Username: username, + TTL: ttl, + Compatibility: constants.CertificateFormatStandard, + RouteToCluster: tc.Secrets.SiteName, + }) require.NoError(t, err) key.Cert = sshCert diff --git a/integration/helpers/usercreds.go b/integration/helpers/usercreds.go index 84b7264958a11..b78677e030a0b 100644 --- a/integration/helpers/usercreds.go +++ b/integration/helpers/usercreds.go @@ -23,6 +23,7 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/service" @@ -115,8 +116,14 @@ func GenerateUserCreds(req UserCredsRequest) (*UserCreds, error) { } a := req.Process.GetAuthServer() sshPub := ssh.MarshalAuthorizedKey(priv.SSHPublicKey()) - sshCert, x509Cert, err := a.GenerateUserTestCerts( - sshPub, req.Username, ttl, constants.CertificateFormatStandard, req.RouteToCluster, req.SourceIP) + sshCert, x509Cert, err := a.GenerateUserTestCerts(auth.GenerateUserTestCertsRequest{ + Key: sshPub, + Username: req.Username, + TTL: ttl, + Compatibility: constants.CertificateFormatStandard, + RouteToCluster: req.RouteToCluster, + PinnedIP: req.SourceIP, + }) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/auth.go b/lib/auth/auth.go index dacb1b104937e..23483f82a99ed 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -1445,9 +1445,20 @@ func (a *Server) GenerateOpenSSHCert(ctx context.Context, req *proto.OpenSSHCert }, nil } +// GenerateUserTestCertsRequest is a request to generate test certificates. +type GenerateUserTestCertsRequest struct { + Key []byte + Username string + TTL time.Duration + Compatibility string + RouteToCluster string + PinnedIP string + MFAVerified string +} + // GenerateUserTestCerts is used to generate user certificate, used internally for tests -func (a *Server) GenerateUserTestCerts(key []byte, username string, ttl time.Duration, compatibility, routeToCluster, pinnedIP string) ([]byte, []byte, error) { - user, err := a.GetUser(username, false) +func (a *Server) GenerateUserTestCerts(req GenerateUserTestCertsRequest) ([]byte, []byte, error) { + user, err := a.GetUser(req.Username, false) if err != nil { return nil, nil, trace.Wrap(err) } @@ -1462,14 +1473,15 @@ func (a *Server) GenerateUserTestCerts(key []byte, username string, ttl time.Dur } certs, err := a.generateUserCert(certRequest{ user: user, - ttl: ttl, - compatibility: compatibility, - publicKey: key, - routeToCluster: routeToCluster, + ttl: req.TTL, + compatibility: req.Compatibility, + publicKey: req.Key, + routeToCluster: req.RouteToCluster, checker: checker, traits: user.GetTraits(), - loginIP: pinnedIP, - pinIP: pinnedIP != "", + loginIP: req.PinnedIP, + pinIP: req.PinnedIP != "", + mfaVerified: req.MFAVerified, }) if err != nil { return nil, nil, trace.Wrap(err) diff --git a/lib/client/api.go b/lib/client/api.go index fc09c5e73abbf..870ebbe2e6595 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -368,6 +368,9 @@ type Config struct { // MockSSOLogin is used in tests for mocking the SSO login response. MockSSOLogin SSOLoginFunc + // MockHeadlessLogin is used in tests for mocking the Headless login response. + MockHeadlessLogin SSHLoginFunc + // HomePath is where tsh stores profiles HomePath string @@ -3673,6 +3676,10 @@ func (tc *TeleportClient) mfaLocalLogin(ctx context.Context, priv *keys.PrivateK } func (tc *TeleportClient) headlessLogin(ctx context.Context, priv *keys.PrivateKey) (*auth.SSHLoginResponse, error) { + if tc.MockHeadlessLogin != nil { + return tc.MockHeadlessLogin(ctx, priv) + } + headlessAuthenticationID := services.NewHeadlessAuthenticationID(priv.MarshalSSHPublicKey()) webUILink, err := url.JoinPath("https://"+tc.WebProxyAddr, "web", "headless", headlessAuthenticationID) diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 248798c20e065..e34add3fb9b37 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -331,6 +331,9 @@ type CLIConf struct { // mockSSOLogin used in tests to override sso login handler in teleport client. mockSSOLogin client.SSOLoginFunc + // mockHeadlessLogin used in tests to override Headless login handler in teleport client. + mockHeadlessLogin client.SSHLoginFunc + // HomePath is where tsh stores profiles HomePath string @@ -3342,6 +3345,9 @@ func makeClientForProxy(cf *CLIConf, proxy string, useProfileLogin bool) (*clien return nil, trace.BadParameter("either --headless or --auth can be specified, not both") } cf.AuthConnector = constants.HeadlessConnector + if !cf.ExplicitUsername { + return nil, trace.BadParameter("user must be set explicitly for headless login with the --user flag or $TELEPORT_USER env variable") + } } if err := tryLockMemory(cf); err != nil { @@ -3492,6 +3498,7 @@ func makeClientForProxy(cf *CLIConf, proxy string, useProfileLogin bool) (*clien // pass along mock sso login if provided (only used in tests) c.MockSSOLogin = cf.mockSSOLogin + c.MockHeadlessLogin = cf.mockHeadlessLogin // Set tsh home directory c.HomePath = cf.HomePath diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index 5e2465c86cdfa..25b07f1989ceb 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -1745,6 +1745,114 @@ func tryCreateTrustedCluster(t *testing.T, authServer *auth.Server, trustedClust require.FailNow(t, "Timeout creating trusted cluster") } +func TestSSHHeadless(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + user, err := user.Current() + require.NoError(t, err) + + // Headless ssh should pass session mfa requirements + sshLoginRole, err := types.NewRole("ssh-login", types.RoleSpecV6{ + Options: types.RoleOptions{ + RequireMFAType: types.RequireMFAType_SESSION, + }, + Allow: types.RoleConditions{ + Logins: []string{user.Username}, + NodeLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + }, + }) + require.NoError(t, err) + + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetRoles([]string{"ssh-login"}) + + rootAuth, rootProxy := makeTestServers(t, withBootstrap(sshLoginRole, alice)) + + authAddr, err := rootAuth.AuthAddr() + require.NoError(t, err) + + proxyAddr, err := rootProxy.ProxyWebAddr() + require.NoError(t, err) + + require.NoError(t, rootAuth.GetAuthServer().SetAuthPreference(ctx, &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOptional, + Webauthn: &types.Webauthn{ + RPID: "127.0.0.1", + }, + }, + })) + + sshHostname := "test-ssh-server" + node := makeTestSSHNode(t, authAddr, withHostname(sshHostname), withSSHLabel("access", "true")) + sshHostID := node.Config.HostUUID + + hasNodes := func(hostIDs ...string) func() bool { + return func() bool { + nodes, err := rootAuth.GetAuthServer().GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) + foundCount := 0 + for _, node := range nodes { + if slices.Contains(hostIDs, node.GetName()) { + foundCount++ + } + } + return foundCount == len(hostIDs) + } + } + + // wait for auth to see nodes + require.Eventually(t, hasNodes(sshHostID), 10*time.Second, 100*time.Millisecond, "nodes never showed up") + + // perform "tsh --headless ssh" + err = Run(ctx, []string{ + "ssh", + "--insecure", + "--headless", + "--proxy", proxyAddr.String(), + "--user", "alice", + fmt.Sprintf("%s@%s", user.Username, sshHostname), + "echo", "test", + }, cliOption(func(cf *CLIConf) error { + cf.mockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) + return nil + })) + require.NoError(t, err) + + // "tsh --auth headless ssh" should also perform headless ssh + err = Run(ctx, []string{ + "ssh", + "--insecure", + "--auth", constants.HeadlessConnector, + "--proxy", proxyAddr.String(), + "--user", "alice", + fmt.Sprintf("%s@%s", user.Username, sshHostname), + "echo", "test", + }, cliOption(func(cf *CLIConf) error { + cf.mockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) + return nil + })) + require.NoError(t, err) + + // headless ssh should fail if user is not set. + err = Run(ctx, []string{ + "ssh", + "--insecure", + "--headless", + "--proxy", proxyAddr.String(), + fmt.Sprintf("%s@%s", user.Username, sshHostname), + "echo", "test", + }, cliOption(func(cf *CLIConf) error { + cf.mockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) + return nil + })) + require.Error(t, err) + require.ErrorIs(t, err, trace.BadParameter("user must be set explicitly for headless login with the --user flag or $TELEPORT_USER env variable")) +} + func TestFormatConnectCommand(t *testing.T) { t.Parallel() @@ -2601,11 +2709,45 @@ func mockSSOLogin(t *testing.T, authServer *auth.Server, user types.User) client // generate certificates for our user clusterName, err := authServer.GetClusterName() require.NoError(t, err) - sshCert, tlsCert, err := authServer.GenerateUserTestCerts( - priv.MarshalSSHPublicKey(), user.GetName(), time.Hour, - constants.CertificateFormatStandard, - clusterName.GetClusterName(), "", - ) + sshCert, tlsCert, err := authServer.GenerateUserTestCerts(auth.GenerateUserTestCertsRequest{ + Key: priv.MarshalSSHPublicKey(), + Username: user.GetName(), + TTL: time.Hour, + Compatibility: constants.CertificateFormatStandard, + RouteToCluster: clusterName.GetClusterName(), + }) + require.NoError(t, err) + + // load CA cert + authority, err := authServer.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.HostCA, + DomainName: clusterName.GetClusterName(), + }, false) + require.NoError(t, err) + + // build login response + return &auth.SSHLoginResponse{ + Username: user.GetName(), + Cert: sshCert, + TLSCert: tlsCert, + HostSigners: auth.AuthoritiesToTrustedCerts([]types.CertAuthority{authority}), + }, nil + } +} + +func mockHeadlessLogin(t *testing.T, authServer *auth.Server, user types.User) client.SSHLoginFunc { + return func(ctx context.Context, priv *keys.PrivateKey) (*auth.SSHLoginResponse, error) { + // generate certificates for our user + clusterName, err := authServer.GetClusterName() + require.NoError(t, err) + sshCert, tlsCert, err := authServer.GenerateUserTestCerts(auth.GenerateUserTestCertsRequest{ + Key: priv.MarshalSSHPublicKey(), + Username: user.GetName(), + TTL: time.Hour, + Compatibility: constants.CertificateFormatStandard, + RouteToCluster: clusterName.GetClusterName(), + MFAVerified: "mfa-verified", + }) require.NoError(t, err) // load CA cert