diff --git a/lib/client/api.go b/lib/client/api.go index 15d6b8d929c40..0c7f53e2d60af 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -582,25 +582,35 @@ func IsErrorResolvableWithRelogin(err error) bool { trace.IsBadParameter(err) || trace.IsTrustError(err) || keys.IsPrivateKeyPolicyError(err) || trace.IsNotFound(err) } -// LoadProfile populates Config with the values stored in the given -// profiles directory. If profileDir is an empty string, the default profile -// directory ~/.tsh is used. -func (c *Config) LoadProfile(ps ProfileStore, proxyAddr string) error { +// GetProfile gets the profile for the specified proxy address, or +// the current profile if no proxy is specified. +func (c *Config) GetProfile(ps ProfileStore, proxyAddr string) (*profile.Profile, error) { var proxyHost string var err error if proxyAddr == "" { proxyHost, err = ps.CurrentProfile() if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } } else { proxyHost, err = utils.Host(proxyAddr) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } } profile, err := ps.GetProfile(proxyHost) + if err != nil { + return nil, trace.Wrap(err) + } + return profile, nil +} + +// LoadProfile populates Config with the values stored in the given +// profiles directory. If profileDir is an empty string, the default profile +// directory ~/.tsh is used. +func (c *Config) LoadProfile(ps ProfileStore, proxyAddr string) error { + profile, err := c.GetProfile(ps, proxyAddr) if err != nil { return trace.Wrap(err) } diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 1d2f834dc592f..9fc0cd97e5296 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -3191,6 +3191,61 @@ func makeClient(cf *CLIConf) (*client.TeleportClient, error) { // makeClient takes the command-line configuration and a proxy address and constructs & returns // a fully configured TeleportClient object func makeClientForProxy(cf *CLIConf, proxy string) (*client.TeleportClient, error) { + c, err := loadClientConfigFromCLIConf(cf, proxy) + if err != nil { + return nil, trace.Wrap(err) + } + + ctx, span := c.Tracer.Start(cf.Context, "makeClientForProxy/init") + defer span.End() + + tc, err := client.NewClient(c) + if err != nil { + return nil, trace.Wrap(err) + } + + // Load SSH key for the cluster indicated in the profile. + // Handle gracefully if the profile is empty, the key cannot + // be found, or the key isn't supported as an agent key. + profile, profileError := c.GetProfile(c.ClientStore, proxy) + if profileError == nil { + if err := tc.LoadKeyForCluster(ctx, profile.SiteName); err != nil { + if !trace.IsNotFound(err) && !trace.IsConnectionProblem(err) { + return nil, trace.Wrap(err) + } + log.WithError(err).Infof("Could not load key for %s into the local agent.", cf.SiteName) + } + } + + // If we are missing client profile information, ping the webproxy + // for proxy info and load it into the client config. + if profileError != nil || cf.IdentityFileIn != "" { + log.Debug("Pinging the proxy to fetch listening addresses for non-web ports.") + _, err := tc.Ping(cf.Context) + if err != nil { + return nil, trace.Wrap(err) + } + + // Identityfile uses a placeholder profile. Save missing profile info. + if cf.IdentityFileIn != "" { + if err := tc.SaveProfile(true); err != nil { + return nil, trace.Wrap(err) + } + } + } + + return tc, nil +} + +func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, error) { + if cf.TracingProvider == nil { + cf.TracingProvider = tracing.NoopProvider() + } + + cf.tracer = cf.TracingProvider.Tracer(teleport.ComponentTSH) + ctx, span := cf.tracer.Start(cf.Context, "loadClientConfigFromCLIConf/init") + defer span.End() + // Parse OpenSSH style options. options, err := parseOptions(cf.Options) if err != nil { @@ -3245,13 +3300,8 @@ func makeClientForProxy(cf *CLIConf, proxy string) (*client.TeleportClient, erro // 1: start with the defaults c := client.MakeDefaultConfig() - if cf.TracingProvider == nil { - cf.TracingProvider = tracing.NoopProvider() - } - c.Tracer = cf.TracingProvider.Tracer(teleport.ComponentTSH) - cf.tracer = c.Tracer - ctx, span := c.Tracer.Start(cf.Context, "makeClientForProxy/init") - defer span.End() + + c.Tracer = cf.tracer // Force the use of proxy template below. useProxyTemplate := strings.Contains(cf.ProxyJump, "{{proxy}}") @@ -3316,9 +3366,10 @@ func makeClientForProxy(cf *CLIConf, proxy string) (*client.TeleportClient, erro return nil, trace.BadParameter("either --headless or --auth can be specified, not both") } cf.AuthConnector = constants.HeadlessConnector - if !cf.ExplicitUsername { - return nil, trace.BadParameter("user must be set explicitly for headless login with the --user flag or $TELEPORT_USER env variable") - } + } + + if cf.AuthConnector == constants.HeadlessConnector && !cf.ExplicitUsername { + return nil, trace.BadParameter("user must be set explicitly for headless login with the --user flag or $TELEPORT_USER env variable") } if err := tryLockMemory(cf); err != nil { @@ -3373,7 +3424,6 @@ func makeClientForProxy(cf *CLIConf, proxy string) (*client.TeleportClient, erro if len(dPorts) > 0 { c.DynamicForwardedPorts = dPorts } - profileSiteName := c.SiteName if cf.SiteName != "" { c.SiteName = cf.SiteName } @@ -3477,47 +3527,13 @@ func makeClientForProxy(cf *CLIConf, proxy string) (*client.TeleportClient, erro c.NonInteractive = true } - tc, err := client.NewClient(c) - if err != nil { - return nil, trace.Wrap(err) - } + c.Stderr = cf.Stderr() + c.Stdout = cf.Stdout() - // Load SSH key for the cluster indicated in the profile. - // Handle gracefully if the profile is empty, the key cannot - // be found, or the key isn't supported as an agent key. - if profileSiteName != "" { - if err := tc.LoadKeyForCluster(ctx, profileSiteName); err != nil { - if !trace.IsNotFound(err) && !trace.IsConnectionProblem(err) { - return nil, trace.Wrap(err) - } - log.WithError(err).Infof("Could not load key for %s into the local agent.", profileSiteName) - } - } - - // If we are missing client profile information, ping the webproxy - // for proxy info and load it into the client config. - if profileErr != nil || cf.IdentityFileIn != "" { - log.Debug("Pinging the proxy to fetch listening addresses for non-web ports.") - _, err := tc.Ping(cf.Context) - if err != nil { - return nil, trace.Wrap(err) - } - - // Identityfile uses a placeholder profile. Save missing profile info. - if cf.IdentityFileIn != "" { - if err := tc.SaveProfile(true); err != nil { - return nil, trace.Wrap(err) - } - } - } - - tc.Config.Stderr = cf.Stderr() - tc.Config.Stdout = cf.Stdout() - - tc.Config.Reason = cf.Reason - tc.Config.Invited = cf.Invited - tc.Config.DisplayParticipantRequirements = cf.displayParticipantRequirements - return tc, nil + c.Reason = cf.Reason + c.Invited = cf.Invited + c.DisplayParticipantRequirements = cf.displayParticipantRequirements + return c, nil } func initClientStore(cf *CLIConf, proxy string) (*client.Store, error) { @@ -4239,9 +4255,6 @@ func reissueWithRequests(cf *CLIConf, tc *client.TeleportClient, newRequests []s if err != nil { return trace.Wrap(err) } - if profile.IsVirtual { - return trace.BadParameter("cannot reissue certificates while using an identity file (-i)") - } params := client.ReissueParams{ AccessRequests: newRequests, DropAccessRequests: dropRequests, diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index 954a068675741..aa7ffffe826fa 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -38,6 +38,7 @@ import ( "github.com/ghodss/yaml" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" otlp "go.opentelemetry.io/proto/otlp/trace/v1" "golang.org/x/exp/slices" @@ -1924,15 +1925,95 @@ func tryCreateTrustedCluster(t *testing.T, authServer *auth.Server, trustedClust require.FailNow(t, "Timeout creating trusted cluster") } +func TestSSHHeadlessCLIFlags(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + modifyCLIConf func(c *CLIConf) + assertErr require.ErrorAssertionFunc + assertConfig func(t require.TestingT, c *client.Config) + }{ + { + name: "OK --auth headless", + modifyCLIConf: func(c *CLIConf) { + c.AuthConnector = constants.HeadlessConnector + c.ExplicitUsername = true + }, + assertErr: require.NoError, + assertConfig: func(t require.TestingT, c *client.Config) { + require.Equal(t, constants.HeadlessConnector, c.AuthConnector) + }, + }, { + name: "OK --headless", + modifyCLIConf: func(c *CLIConf) { + c.Headless = true + c.ExplicitUsername = true + }, + assertErr: require.NoError, + assertConfig: func(t require.TestingT, c *client.Config) { + require.Equal(t, constants.HeadlessConnector, c.AuthConnector) + }, + }, { + name: "NOK --headless with mismatched auth connector", + modifyCLIConf: func(c *CLIConf) { + c.Headless = true + c.AuthConnector = constants.LocalConnector + c.ExplicitUsername = true + }, + assertErr: func(t require.TestingT, err error, msgAndArgs ...interface{}) { + require.True(t, trace.IsBadParameter(err), "expected trace.BadParameter error but got %v", err) + }, + }, { + name: "NOK --auth headless without explicit user", + modifyCLIConf: func(c *CLIConf) { + c.AuthConnector = constants.HeadlessConnector + }, + assertErr: func(t require.TestingT, err error, msgAndArgs ...interface{}) { + require.True(t, trace.IsBadParameter(err), "expected trace.BadParameter error but got %v", err) + }, + }, { + name: "NOK --headless without explicit user", + modifyCLIConf: func(c *CLIConf) { + c.Headless = true + }, + assertErr: func(t require.TestingT, err error, msgAndArgs ...interface{}) { + require.True(t, trace.IsBadParameter(err), "expected trace.BadParameter error but got %v", err) + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + // minimal configuration (with defaults) + conf := &CLIConf{ + Proxy: "proxy:3080", + UserHost: "localhost", + HomePath: t.TempDir(), + } + + tc.modifyCLIConf(conf) + + c, err := loadClientConfigFromCLIConf(conf, "proxy:3080") + tc.assertErr(t, err) + if tc.assertConfig != nil { + tc.assertConfig(t, c) + } + }) + } +} + func TestSSHHeadless(t *testing.T) { + modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise}) + ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + t.Cleanup(cancel) user, err := user.Current() require.NoError(t, err) // Headless ssh should pass session mfa requirements - sshLoginRole, err := types.NewRole("ssh-login", types.RoleSpecV6{ + nodeAccess, err := types.NewRole("node-access", types.RoleSpecV6{ Options: types.RoleOptions{ RequireMFAType: types.RequireMFAType_SESSION, }, @@ -1945,12 +2026,27 @@ func TestSSHHeadless(t *testing.T) { alice, err := types.NewUser("alice@example.com") require.NoError(t, err) - alice.SetRoles([]string{"ssh-login"}) + alice.SetRoles([]string{"node-access"}) - rootAuth, rootProxy := makeTestServers(t, withBootstrap(sshLoginRole, alice)) + requester, err := types.NewRole("requester", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Request: &types.AccessRequestConditions{ + SearchAsRoles: []string{"node-access"}, + }, + }, + }) + require.NoError(t, err) - authAddr, err := rootAuth.AuthAddr() + bob, err := types.NewUser("bob@example.com") require.NoError(t, err) + bob.SetRoles([]string{"requester"}) + + sshHostName := "test-ssh-host" + rootAuth, rootProxy := makeTestServers(t, withBootstrap(nodeAccess, alice, requester, bob), withConfig(func(cfg *service.Config) { + cfg.Hostname = sshHostName + cfg.SSH.Enabled = true + cfg.SSH.Addr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())} + })) proxyAddr, err := rootProxy.ProxyWebAddr() require.NoError(t, err) @@ -1965,71 +2061,50 @@ func TestSSHHeadless(t *testing.T) { }, })) - sshHostname := "test-ssh-server" - node := makeTestSSHNode(t, authAddr, withHostname(sshHostname), withSSHLabel("access", "true")) - sshHostID := node.Config.HostUUID - - hasNodes := func(hostIDs ...string) func() bool { - return func() bool { - nodes, err := rootAuth.GetAuthServer().GetNodes(ctx, apidefaults.Namespace) - require.NoError(t, err) - foundCount := 0 - for _, node := range nodes { - if slices.Contains(hostIDs, node.GetName()) { - foundCount++ - } - } - return foundCount == len(hostIDs) + go func() { + if err := approveAllAccessRequests(ctx, rootAuth.GetAuthServer()); err != nil { + assert.ErrorIs(t, err, context.Canceled, "unexpected error from approveAllAccessRequests") } - } - - // wait for auth to see nodes - require.Eventually(t, hasNodes(sshHostID), 10*time.Second, 100*time.Millisecond, "nodes never showed up") - - // perform "tsh --headless ssh" - err = Run(ctx, []string{ - "ssh", - "--insecure", - "--headless", - "--proxy", proxyAddr.String(), - "--user", "alice", - fmt.Sprintf("%s@%s", user.Username, sshHostname), - "echo", "test", - }, cliOption(func(cf *CLIConf) error { - cf.mockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) - return nil - })) - require.NoError(t, err) + // Cancel the context, so Run calls don't block + cancel() + }() - // "tsh --auth headless ssh" should also perform headless ssh - err = Run(ctx, []string{ - "ssh", - "--insecure", - "--auth", constants.HeadlessConnector, - "--proxy", proxyAddr.String(), - "--user", "alice", - fmt.Sprintf("%s@%s", user.Username, sshHostname), - "echo", "test", - }, cliOption(func(cf *CLIConf) error { - cf.mockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) - return nil - })) - require.NoError(t, err) + for _, tc := range []struct { + name string + args []string + assertErr require.ErrorAssertionFunc + }{ + { + name: "node access", + args: []string{"--headless", "--user", "alice"}, + assertErr: require.NoError, + }, { + name: "resource request", + args: []string{"--headless", "--user", "bob", "--request-reason", "reason here to bypass prompt"}, + assertErr: require.NoError, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + args := append([]string{ + "ssh", + "-d", + "--insecure", + "--proxy", proxyAddr.String(), + }, tc.args...) + args = append(args, + fmt.Sprintf("%s@%s", user.Username, sshHostName), + "echo", "test", + ) - // headless ssh should fail if user is not set. - err = Run(ctx, []string{ - "ssh", - "--insecure", - "--headless", - "--proxy", proxyAddr.String(), - fmt.Sprintf("%s@%s", user.Username, sshHostname), - "echo", "test", - }, cliOption(func(cf *CLIConf) error { - cf.mockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) - return nil - })) - require.Error(t, err) - require.ErrorIs(t, err, trace.BadParameter("user must be set explicitly for headless login with the --user flag or $TELEPORT_USER env variable")) + err := Run(ctx, args, cliOption(func(cf *CLIConf) error { + cf.mockHeadlessLogin = mockHeadlessLogin(t, rootAuth.GetAuthServer(), alice) + return nil + })) + tc.assertErr(t, err) + }) + } } func TestFormatConnectCommand(t *testing.T) {