diff --git a/api/observability/tracing/tracing.go b/api/observability/tracing/tracing.go index 6e6f848404135..8eaab3915647e 100644 --- a/api/observability/tracing/tracing.go +++ b/api/observability/tracing/tracing.go @@ -47,3 +47,9 @@ func WithPropagationContext(ctx context.Context, pc PropagationContext, opts ... func DefaultProvider() oteltrace.TracerProvider { return otel.GetTracerProvider() } + +// NewTracer creates a new [oteltrace.Tracer] from the global default +// [oteltrace.TracerProvider] with the provided name +func NewTracer(name string) oteltrace.Tracer { + return DefaultProvider().Tracer(name) +} diff --git a/integration/integration_test.go b/integration/integration_test.go index b8d32d1911ba6..a2cc50fc93fc7 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -2276,7 +2276,7 @@ func twoClustersTunnel(t *testing.T, suite *integrationTestSuite, now time.Time, require.Equal(t, "hello world\n", outputA.String()) // Update trusted CAs. - err = tc.UpdateTrustedCA(ctx, a.Secrets.SiteName) + err = tc.UpdateTrustedCA(ctx, a.GetSiteAPI(a.Secrets.SiteName)) require.NoError(t, err) // The known_hosts file should have two certificates, the way bytes.Split diff --git a/lib/client/api.go b/lib/client/api.go index 07f895624b9bc..081d69ba2208e 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -565,12 +565,29 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error) } return trace.Wrap(err) } - if err := tc.ActivateKey(ctx, key); err != nil { + + if err := tc.ActivateKeyWithoutTrustedCerts(ctx, key); err != nil { + return trace.Wrap(err) + } + + clusterClient, err := tc.ConnectToCluster(ctx) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + rootAuth, err := clusterClient.RootClient(ctx) + if err != nil { + return trace.Wrap(err) + } + defer rootAuth.Close() + + if err := tc.UpdateTrustedCA(ctx, rootAuth); err != nil { return trace.Wrap(err) } // Attempt device login. This activates a fresh key if successful. - if err := tc.AttemptDeviceLogin(ctx, key); err != nil { + if err := tc.AttemptDeviceLogin(ctx, key, rootAuth); err != nil { return trace.Wrap(err) } @@ -3332,7 +3349,14 @@ func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) { // successful, as skipping the ceremony is valid for various reasons (Teleport // cluster doesn't support device authn, device wasn't enrolled, etc). // Use [TeleportClient.DeviceLogin] if you want more control over process. -func (tc *TeleportClient) AttemptDeviceLogin(ctx context.Context, key *Key) error { +func (tc *TeleportClient) AttemptDeviceLogin(ctx context.Context, key *Key, rootAuth auth.ClientI) error { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/AttemptDeviceLogin", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + pingResp, err := tc.Ping(ctx) if err != nil { return trace.Wrap(err) @@ -3343,11 +3367,15 @@ func (tc *TeleportClient) AttemptDeviceLogin(ctx context.Context, key *Key) erro return nil } - newCerts, err := tc.DeviceLogin(ctx, &devicepb.UserCertificates{ - // Augment the SSH certificate. - // The TLS certificate is already part of the connection. - SshAuthorizedKey: key.Cert, - }) + newCerts, err := tc.DeviceLogin( + ctx, + &devicepb.UserCertificates{ + // Augment the SSH certificate. + // The TLS certificate is already part of the connection. + SshAuthorizedKey: key.Cert, + }, + rootAuth, + ) switch { case errors.Is(err, devicetrust.ErrDeviceKeyNotFound): log.Debug("Device Trust: Skipping device authentication, device key not found") @@ -3370,7 +3398,7 @@ func (tc *TeleportClient) AttemptDeviceLogin(ctx context.Context, key *Key) erro Type: "CERTIFICATE", Bytes: newCerts.X509Der, }) - return trace.Wrap(tc.ActivateKey(ctx, &cp)) + return trace.Wrap(tc.ActivateKey(ctx, &cp, rootAuth)) } // DeviceLogin attempts to authenticate the current device with Teleport. @@ -3386,18 +3414,13 @@ func (tc *TeleportClient) AttemptDeviceLogin(ctx context.Context, key *Key) erro // `tsh login`). // // Device Trust is a Teleport Enterprise feature. -func (tc *TeleportClient) DeviceLogin(ctx context.Context, certs *devicepb.UserCertificates) (*devicepb.UserCertificates, error) { - proxyClient, err := tc.ConnectToProxy(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err := proxyClient.ConnectToRootCluster(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - defer authClient.Close() +func (tc *TeleportClient) DeviceLogin(ctx context.Context, certs *devicepb.UserCertificates, rootAuth auth.ClientI) (*devicepb.UserCertificates, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/DeviceLogin", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() // Allow tests to override the default authn function. runCeremony := tc.dtAuthnRunCeremony @@ -3406,7 +3429,7 @@ func (tc *TeleportClient) DeviceLogin(ctx context.Context, certs *devicepb.UserC } // Login without a previous auto-enroll attempt. - devicesClient := authClient.DevicesClient() + devicesClient := rootAuth.DevicesClient() newCerts, loginErr := runCeremony(ctx, devicesClient, certs) // Success or auto-enroll impossible. if loginErr == nil || errors.Is(loginErr, devicetrust.ErrPlatformNotSupported) || trace.IsNotImplemented(loginErr) { @@ -3530,6 +3553,13 @@ type SSHLoginFunc func(context.Context, *keys.PrivateKey) (*auth.SSHLoginRespons // SSHLogin uses the given login function to login the client. This function handles // private key logic and parsing the resulting auth response. func (tc *TeleportClient) SSHLogin(ctx context.Context, sshLoginFunc SSHLoginFunc) (*Key, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/SSHLogin", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + priv, err := tc.GetNewLoginKey(ctx) if err != nil { return nil, trace.Wrap(err) @@ -3591,6 +3621,13 @@ func (tc *TeleportClient) SSHLogin(ctx context.Context, sshLoginFunc SSHLoginFun // GetNewLoginKey gets a new private key for login. func (tc *TeleportClient) GetNewLoginKey(ctx context.Context) (priv *keys.PrivateKey, err error) { + _, span := tc.Tracer.Start( + ctx, + "teleportClient/GetNewLoginKey", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + switch tc.PrivateKeyPolicy { case keys.PrivateKeyPolicyHardwareKey: log.Debugf("Attempting to login with YubiKey private key.") @@ -3603,10 +3640,7 @@ func (tc *TeleportClient) GetNewLoginKey(ctx context.Context) (priv *keys.Privat priv, err = native.GeneratePrivateKey() } - if err != nil { - return nil, trace.Wrap(err) - } - return priv, nil + return priv, trace.Wrap(err) } // new SSHLogin generates a new SSHLogin using the given login key. @@ -3631,6 +3665,13 @@ func (tc *TeleportClient) newSSHLogin(priv *keys.PrivateKey) (SSHLogin, error) { } func (tc *TeleportClient) pwdlessLogin(ctx context.Context, priv *keys.PrivateKey) (*auth.SSHLoginResponse, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/pwdlessLogin", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + // Only pass on the user if explicitly set, otherwise let the credential // picker kick in. user := "" @@ -3682,6 +3723,13 @@ func (tc *TeleportClient) localLogin(ctx context.Context, priv *keys.PrivateKey, // directLogin asks for a password + OTP token, makes a request to CA via proxy func (tc *TeleportClient) directLogin(ctx context.Context, secondFactorType constants.SecondFactorType, priv *keys.PrivateKey) (*auth.SSHLoginResponse, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/directLogin", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + password, err := tc.AskPassword(ctx) if err != nil { return nil, trace.Wrap(err) @@ -3714,6 +3762,13 @@ func (tc *TeleportClient) directLogin(ctx context.Context, secondFactorType cons // mfaLocalLogin asks for a password and performs the challenge-response authentication func (tc *TeleportClient) mfaLocalLogin(ctx context.Context, priv *keys.PrivateKey) (*auth.SSHLoginResponse, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/mfaLocalLogin", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + password, err := tc.AskPassword(ctx) if err != nil { return nil, trace.Wrap(err) @@ -3799,7 +3854,7 @@ func (tc *TeleportClient) ssoLogin(ctx context.Context, priv *keys.PrivateKey, c // ActivateKey saves the target session cert into the local // keystore (and into the ssh-agent) for future use. -func (tc *TeleportClient) ActivateKey(ctx context.Context, key *Key) error { +func (tc *TeleportClient) ActivateKey(ctx context.Context, key *Key, getter services.AuthorityGetter) error { ctx, span := tc.Tracer.Start( ctx, "teleportClient/ActivateKey", @@ -3817,21 +3872,27 @@ func (tc *TeleportClient) ActivateKey(ctx context.Context, key *Key) error { return trace.Wrap(err) } - // Connect to the Auth Server of the root cluster and fetch the known hosts. - rootClusterName := key.TrustedCerts[0].ClusterName - if err := tc.UpdateTrustedCA(ctx, rootClusterName); err != nil { - if len(tc.JumpHosts) == 0 { - return trace.Wrap(err) - } - errViaJumphost := err - // If JumpHosts was pointing at the leaf cluster (e.g. during 'tsh ssh - // -J leaf.example.com'), this could've caused the above error. Try to - // fetch CAs without JumpHosts to force it to use the root cluster. - if err := tc.WithoutJumpHosts(func(tc *TeleportClient) error { - return tc.UpdateTrustedCA(ctx, rootClusterName) - }); err != nil { - return trace.NewAggregate(errViaJumphost, err) - } + return trace.Wrap(tc.UpdateTrustedCA(ctx, getter)) +} + +// ActivateKeyWithoutTrustedCerts saves the target session cert into the local +// keystore (and into the ssh-agent) for future use. +func (tc *TeleportClient) ActivateKeyWithoutTrustedCerts(ctx context.Context, key *Key) error { + _, span := tc.Tracer.Start( + ctx, + "teleportClient/ActivateKey", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + + if tc.localAgent == nil { + // skip activation if no local agent is present + return nil + } + + // save the cert to the local storage (~/.tsh usually): + if err := tc.localAgent.AddKey(key); err != nil { + return trace.Wrap(err) } return nil @@ -3935,57 +3996,25 @@ func (tc *TeleportClient) ShowMOTD(ctx context.Context) error { return nil } -// GetTrustedCA returns a list of host certificate authorities -// trusted by the cluster client is authenticated with. -func (tc *TeleportClient) GetTrustedCA(ctx context.Context, clusterName string) ([]types.CertAuthority, error) { - ctx, span := tc.Tracer.Start( - ctx, - "teleportClient/GetTrustedCA", - oteltrace.WithSpanKind(oteltrace.SpanKindClient), - oteltrace.WithAttributes(attribute.String("cluster", clusterName)), - ) - defer span.End() - - // Connect to the proxy. - if !tc.Config.ProxySpecified() { - return nil, trace.BadParameter("proxy server is not specified") - } - proxyClient, err := tc.ConnectToProxy(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - defer proxyClient.Close() - - // Get a client to the Auth Server. - clt, err := proxyClient.ConnectToCluster(ctx, clusterName) - if err != nil { - return nil, trace.Wrap(err) - } - defer clt.Close() - - // Get the list of host certificates that this cluster knows about. - return clt.GetCertAuthorities(ctx, types.HostCA, false) -} - // UpdateTrustedCA connects to the Auth Server and fetches all host certificates // and updates ~/.tsh/keys/proxy/certs.pem and ~/.tsh/known_hosts. -func (tc *TeleportClient) UpdateTrustedCA(ctx context.Context, clusterName string) error { +func (tc *TeleportClient) UpdateTrustedCA(ctx context.Context, getter services.AuthorityGetter) error { ctx, span := tc.Tracer.Start( ctx, "teleportClient/UpdateTrustedCA", oteltrace.WithSpanKind(oteltrace.SpanKindClient), - oteltrace.WithAttributes(attribute.String("cluster", clusterName)), ) defer span.End() if tc.localAgent == nil { return trace.BadParameter("TeleportClient.UpdateTrustedCA called on a client without localAgent") } - // Get the list of host certificates that this cluster knows about. - hostCerts, err := tc.GetTrustedCA(ctx, clusterName) + + hostCerts, err := getter.GetCertAuthorities(ctx, types.HostCA, false) if err != nil { return trace.Wrap(err) } + trustedCerts := auth.AuthoritiesToTrustedCerts(hostCerts) // Update the CA pool and known hosts for all CAs the cluster knows about. @@ -4286,6 +4315,13 @@ func Username() (string, error) { // AskOTP prompts the user to enter the OTP token. func (tc *TeleportClient) AskOTP(ctx context.Context) (token string, err error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/AskOTP", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + stdin := prompt.Stdin() if !stdin.IsTerminal() { return "", trace.Wrap(prompt.ErrNotTerminal, "cannot perform OTP login without a terminal") @@ -4295,6 +4331,13 @@ func (tc *TeleportClient) AskOTP(ctx context.Context) (token string, err error) // AskPassword prompts the user to enter the password func (tc *TeleportClient) AskPassword(ctx context.Context) (pwd string, err error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/AskPassword", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + stdin := prompt.Stdin() if !stdin.IsTerminal() { return "", trace.Wrap(prompt.ErrNotTerminal, "cannot perform password login without a terminal") diff --git a/lib/client/api_login_test.go b/lib/client/api_login_test.go index 326ee11bdb5ea..da80f18f2d074 100644 --- a/lib/client/api_login_test.go +++ b/lib/client/api_login_test.go @@ -444,9 +444,19 @@ func TestTeleportClient_DeviceLogin(t *testing.T) { key, err := teleportClient.Login(ctx) require.NoError(t, err, "Login failed") require.NoError(t, - teleportClient.ActivateKey(ctx, key), + teleportClient.ActivateKeyWithoutTrustedCerts(ctx, key), "ActivateKey failed") + clusterClient, err := teleportClient.ConnectToCluster(ctx) + require.NoError(t, err) + defer clusterClient.Close() + + rootAuthClient, err := clusterClient.RootClient(ctx) + require.NoError(t, err) + defer rootAuthClient.Close() + + require.NoError(t, teleportClient.UpdateTrustedCA(ctx, rootAuthClient)) + // Prepare "device aware" certificates from key. // In a real scenario these would be augmented certs. block, _ := pem.Decode(key.TLSCert) @@ -487,16 +497,20 @@ func TestTeleportClient_DeviceLogin(t *testing.T) { require.NoError(t, authenticatedAction(), "Authenticated action failed *before* AttemptDeviceLogin") // Test! Exercise DeviceLogin. - got, err := teleportClient.DeviceLogin(ctx, &devicepb.UserCertificates{ - SshAuthorizedKey: key.Cert, - }) + got, err := teleportClient.DeviceLogin( + ctx, + &devicepb.UserCertificates{ + SshAuthorizedKey: key.Cert, + }, + rootAuthClient, + ) require.NoError(t, err, "DeviceLogin failed") require.Equal(t, validCerts, got, "DeviceLogin mismatch") assert.Equal(t, 1, runCeremonyCalls, `DeviceLogin didn't call dtAuthnRunCeremony()`) // Test! Exercise AttemptDeviceLogin. require.NoError(t, - teleportClient.AttemptDeviceLogin(ctx, key), + teleportClient.AttemptDeviceLogin(ctx, key, rootAuthClient), "AttemptDeviceLogin failed") assert.Equal(t, 2, runCeremonyCalls, `AttemptDeviceLogin didn't call dtAuthnRunCeremony()`) @@ -521,7 +535,7 @@ func TestTeleportClient_DeviceLogin(t *testing.T) { // Test! // AttemptDeviceLogin should obey Ping and not attempt the ceremony. require.NoError(t, - teleportClient.AttemptDeviceLogin(ctx, key), + teleportClient.AttemptDeviceLogin(ctx, key, rootAuthClient), "AttemptDeviceLogin failed") assert.False(t, runCeremonyCalled, "AttemptDeviceLogin called DeviceLogin/dtAuthnRunCeremony, despite the Ping response") }) @@ -551,9 +565,13 @@ func TestTeleportClient_DeviceLogin(t *testing.T) { }) // Test! - got, err := teleportClient.DeviceLogin(ctx, &devicepb.UserCertificates{ - SshAuthorizedKey: key.Cert, - }) + got, err := teleportClient.DeviceLogin( + ctx, + &devicepb.UserCertificates{ + SshAuthorizedKey: key.Cert, + }, + rootAuthClient, + ) require.NoError(t, err, "DeviceLogin failed") assert.Equal(t, got, validCerts, "DeviceLogin mismatch") assert.Equal(t, 2, runCeremonyCalls, "RunCeremony called an unexpected number of times") diff --git a/lib/client/cluster_client.go b/lib/client/cluster_client.go index fee7621dfea8f..5ea57aea821dd 100644 --- a/lib/client/cluster_client.go +++ b/lib/client/cluster_client.go @@ -16,6 +16,7 @@ package client import ( "context" + "sync" "github.com/gravitational/trace" "github.com/gravitational/trace/trail" @@ -38,6 +39,9 @@ type ClusterClient struct { AuthClient auth.ClientI Tracer oteltrace.Tracer cluster string + + mu sync.Mutex + rootCluster *auth.Client } // ClusterName returns the name of the cluster that the client @@ -48,8 +52,44 @@ func (c *ClusterClient) ClusterName() string { // Close terminates the connections to Auth and Proxy. func (c *ClusterClient) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.rootCluster == nil { + // close auth client first since it is tunneled through the proxy client + return trace.NewAggregate(c.AuthClient.Close(), c.ProxyClient.Close()) + } + // close auth client first since it is tunneled through the proxy client - return trace.NewAggregate(c.AuthClient.Close(), c.ProxyClient.Close()) + return trace.NewAggregate(c.AuthClient.Close(), c.rootCluster.Close(), c.ProxyClient.Close()) +} + +// RootClient return an [auth.ClientI] that is connected to the root +// cluster Auth server. +func (c *ClusterClient) RootClient(ctx context.Context) (auth.ClientI, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.rootCluster != nil { + return sharedAuthClient{ClientI: c.rootCluster}, nil + } + + root, err := c.tc.rootClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + + if root == c.cluster { + return sharedAuthClient{ClientI: c.AuthClient}, nil + } + + clt, err := auth.NewClient(c.ProxyClient.ClientConfig(ctx, root)) + if err != nil { + return nil, trace.Wrap(err) + } + + c.rootCluster = clt + return sharedAuthClient{ClientI: c.AuthClient}, nil } // SessionSSHConfig returns the [ssh.ClientConfig] that should be used to connected to the diff --git a/lib/client/mfa.go b/lib/client/mfa.go index f91dd5f65e3f0..56aef72a9272f 100644 --- a/lib/client/mfa.go +++ b/lib/client/mfa.go @@ -29,6 +29,7 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/observability/tracing" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" "github.com/gravitational/teleport/lib/utils/prompt" @@ -117,6 +118,9 @@ func (tc *TeleportClient) PromptMFAChallenge(ctx context.Context, proxyAddr stri // PromptMFAChallenge prompts the user to complete MFA authentication // challenges. func PromptMFAChallenge(ctx context.Context, c *proto.MFAAuthenticateChallenge, proxyAddr string, opts *PromptMFAChallengeOpts) (*proto.MFAAuthenticateResponse, error) { + ctx, span := tracing.NewTracer("mfa").Start(ctx, "PromptMFAChallenge", oteltrace.WithSpanKind(oteltrace.SpanKindClient)) + defer span.End() + // Is there a challenge present? if c.TOTP == nil && c.WebauthnChallenge == nil { return &proto.MFAAuthenticateResponse{}, nil diff --git a/lib/teleterm/clusters/cluster_auth.go b/lib/teleterm/clusters/cluster_auth.go index d83ed99a81c89..6793b053b4e50 100644 --- a/lib/teleterm/clusters/cluster_auth.go +++ b/lib/teleterm/clusters/cluster_auth.go @@ -179,12 +179,24 @@ func (c *Cluster) login(ctx context.Context, sshLoginFunc client.SSHLoginFunc) e c.clusterClient.LocalAgent().UpdateUsername(key.Username) c.clusterClient.Username = key.Username - if err := c.clusterClient.ActivateKey(ctx, key); err != nil { + clusterClient, err := c.clusterClient.ConnectToCluster(ctx) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + rootAuth, err := clusterClient.RootClient(ctx) + if err != nil { + return trace.Wrap(err) + } + defer rootAuth.Close() + + if err := c.clusterClient.ActivateKey(ctx, key, rootAuth); err != nil { return trace.Wrap(err) } // Attempt device login. This activates a fresh key if successful. - if err := c.clusterClient.AttemptDeviceLogin(ctx, key); err != nil { + if err := c.clusterClient.AttemptDeviceLogin(ctx, key, rootAuth); err != nil { return trace.Wrap(err) } diff --git a/tool/tsh/access_request.go b/tool/tsh/access_request.go index 67cbd429423a2..adbf1c910c26f 100644 --- a/tool/tsh/access_request.go +++ b/tool/tsh/access_request.go @@ -144,20 +144,31 @@ func onRequestShow(cf *CLIConf) error { cf.Username = tc.Username } - var req types.AccessRequest - err = tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { - req, err = services.GetAccessRequest(cf.Context, clt, cf.RequestID) + clusterClient, err := tc.ConnectToCluster(cf.Context) + if err != nil { return trace.Wrap(err) - }) + } + defer clusterClient.Close() + + rootAuth, err := clusterClient.RootClient(cf.Context) if err != nil { return trace.Wrap(err) } + defer rootAuth.Close() + req, err := services.GetAccessRequest(cf.Context, rootAuth, cf.RequestID) + if err != nil { + return trace.Wrap(err) + } + + return trace.Wrap(printAccessRequest(cf, req)) +} + +func printAccessRequest(cf *CLIConf, req types.AccessRequest) error { format := strings.ToLower(cf.Format) switch format { case teleport.Text, "": - err = printRequest(cf, req) - if err != nil { + if err := printRequest(cf, req); err != nil { return trace.Wrap(err) } case teleport.JSON, teleport.YAML: @@ -274,7 +285,19 @@ func onRequestCreate(cf *CLIConf) error { return trace.Wrap(err) } - if err := executeAccessRequest(cf, tc); err != nil { + clusterClient, err := tc.ConnectToCluster(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + rootAuth, err := clusterClient.RootClient(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer rootAuth.Close() + + if err := executeAccessRequest(cf, tc, rootAuth); err != nil { return trace.Wrap(err) } diff --git a/tool/tsh/kubectl.go b/tool/tsh/kubectl.go index cafe160e72d3a..0c793ccfacb6e 100644 --- a/tool/tsh/kubectl.go +++ b/tool/tsh/kubectl.go @@ -307,8 +307,21 @@ func createKubeAccessRequest(cf *CLIConf, resources []resourceKind, args []strin filepath.Join("/", tc.SiteName, rec.kind, kubeName, rec.subResourceName), ) } + + clusterClient, err := tc.ConnectToCluster(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + rootAuth, err := clusterClient.RootClient(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer rootAuth.Close() + cf.Reason = fmt.Sprintf("Resource request automatically created for %v", args) - if err := executeAccessRequest(cf, tc); err != nil { + if err := executeAccessRequest(cf, tc, rootAuth); err != nil { // TODO(tigrato): intercept the error to validate the origin return trace.Wrap(err) } diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 9ea04fb0ee171..68d2a528fd5e8 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -1659,6 +1659,18 @@ func onLogin(cf *CLIConf) error { // client is already logged in and profile is not expired if profile != nil && !profile.IsExpired(clockwork.NewRealClock()) { + clusterClient, err := tc.ConnectToCluster(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + rootAuth, err := clusterClient.RootClient(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer rootAuth.Close() + switch { // in case if nothing is specified, re-fetch kube clusters and print // current status @@ -1727,7 +1739,7 @@ func onLogin(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - if err := executeAccessRequest(cf, tc); err != nil { + if err := executeAccessRequest(cf, tc, rootAuth); err != nil { return trace.Wrap(err) } if err := updateKubeConfigOnLogin(cf, tc, updateKubeConfigOption); err != nil { @@ -1764,7 +1776,23 @@ func onLogin(cf *CLIConf) error { // "authoritative" source. cf.Username = tc.Username - if err := tc.ActivateKey(cf.Context, key); err != nil { + if err := tc.ActivateKeyWithoutTrustedCerts(cf.Context, key); err != nil { + return trace.Wrap(err) + } + + clusterClient, err := tc.ConnectToCluster(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + rootAuth, err := clusterClient.RootClient(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer rootAuth.Close() + + if err := tc.UpdateTrustedCA(cf.Context, rootAuth); err != nil { return trace.Wrap(err) } @@ -1774,8 +1802,7 @@ func onLogin(cf *CLIConf) error { // key.TrustedCA at this point only has the CA of the root cluster we // logged into. We need to fetch all the CAs for leaf clusters too, to // make them available in the identity file. - rootClusterName := key.TrustedCerts[0].ClusterName - authorities, err := tc.GetTrustedCA(cf.Context, rootClusterName) + authorities, err := rootAuth.GetCertAuthorities(cf.Context, types.HostCA, false) if err != nil { return trace.Wrap(err) } @@ -1807,7 +1834,7 @@ func onLogin(cf *CLIConf) error { // Attempt device login. This activates a fresh key if successful. // We do not save the resulting in the identity file above on purpose, as this // certificate is bound to the present device. - if err := tc.AttemptDeviceLogin(cf.Context, key); err != nil { + if err := tc.AttemptDeviceLogin(cf.Context, key, rootAuth); err != nil { return trace.Wrap(err) } @@ -1824,23 +1851,14 @@ func onLogin(cf *CLIConf) error { } if autoRequest && cf.DesiredRoles == "" && cf.RequestID == "" { - var capabailities *types.AccessCapabilities - err = tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { - cap, err := clt.GetAccessCapabilities(cf.Context, types.AccessCapabilitiesRequest{ - User: cf.Username, - }) - if err != nil { - return trace.Wrap(err) - } - - capabailities = cap - - return nil + capabailities, err := rootAuth.GetAccessCapabilities(cf.Context, types.AccessCapabilitiesRequest{ + User: cf.Username, }) if err != nil { logoutErr := tc.Logout() return trace.NewAggregate(err, logoutErr) } + if capabailities.RequireReason && cf.RequestReason == "" { msg := "--request-reason must be specified" if capabailities.RequestPrompt != "" { @@ -1857,7 +1875,7 @@ func onLogin(cf *CLIConf) error { if cf.DesiredRoles != "" || cf.RequestID != "" { fmt.Println("") // visually separate access request output - if err := executeAccessRequest(cf, tc); err != nil { + if err := executeAccessRequest(cf, tc, rootAuth); err != nil { logoutErr := tc.Logout() return trace.NewAggregate(err, logoutErr) } @@ -1886,7 +1904,7 @@ func onLogin(cf *CLIConf) error { alertSeverityMax = types.AlertSeverity_HIGH } - if err := common.ShowClusterAlerts(cf.Context, tc, os.Stderr, map[string]string{ + if err := common.ShowClusterAlerts(cf.Context, clusterClient.AuthClient, cf.Stderr(), map[string]string{ types.AlertOnLogin: "yes", }, types.AlertSeverity_LOW, alertSeverityMax); err != nil { log.WithError(err).Warn("Failed to display cluster alerts.") @@ -2291,23 +2309,18 @@ func serializeNodesWithClusters(nodes []nodeListing, format string) (string, err return string(out), trace.Wrap(err) } -func getAccessRequest(ctx context.Context, tc *client.TeleportClient, requestID, username string) (types.AccessRequest, error) { - var req types.AccessRequest - err := tc.WithRootClusterClient(ctx, func(clt auth.ClientI) error { - reqs, err := clt.GetAccessRequests(ctx, types.AccessRequestFilter{ - ID: requestID, - User: username, - }) - if err != nil { - return trace.Wrap(err) - } - if len(reqs) != 1 { - return trace.BadParameter(`invalid access request "%v"`, requestID) - } - req = reqs[0] - return nil +func getAccessRequest(ctx context.Context, rootAuth auth.ClientI, requestID, username string) (types.AccessRequest, error) { + reqs, err := rootAuth.GetAccessRequests(ctx, types.AccessRequestFilter{ + ID: requestID, + User: username, }) - return req, trace.Wrap(err) + if err != nil { + return nil, trace.Wrap(err) + } + if len(reqs) != 1 { + return nil, trace.BadParameter(`invalid access request "%v"`, requestID) + } + return reqs[0], nil } func createAccessRequest(cf *CLIConf) (types.AccessRequest, error) { @@ -2341,7 +2354,7 @@ func createAccessRequest(cf *CLIConf) (types.AccessRequest, error) { return req, nil } -func executeAccessRequest(cf *CLIConf, tc *client.TeleportClient) error { +func executeAccessRequest(cf *CLIConf, tc *client.TeleportClient, rootAuth auth.ClientI) error { if cf.DesiredRoles == "" && cf.RequestID == "" && len(cf.RequestedResourceIDs) == 0 { return trace.BadParameter("at least one role or resource or a request ID must be specified") } @@ -2359,7 +2372,7 @@ func executeAccessRequest(cf *CLIConf, tc *client.TeleportClient) error { var err error if cf.RequestID != "" { // This access request already exists, fetch it. - req, err = getAccessRequest(cf.Context, tc, cf.RequestID, cf.Username) + req, err = getAccessRequest(cf.Context, rootAuth, cf.RequestID, cf.Username) if err != nil { return trace.Wrap(err) } @@ -2384,7 +2397,7 @@ func executeAccessRequest(cf *CLIConf, tc *client.TeleportClient) error { defer requestWatcher.Close() if !cf.NoWait { // Don't initialize the watcher unless we'll actually use it. - if err := requestWatcher.initialize(cf.Context, tc); err != nil { + if err := requestWatcher.initialize(cf.Context, rootAuth); err != nil { return trace.Wrap(err) } } @@ -2394,15 +2407,12 @@ func executeAccessRequest(cf *CLIConf, tc *client.TeleportClient) error { cf.RequestID = req.GetName() fmt.Fprint(os.Stdout, "Creating request...\n") // always create access request against the root cluster - if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { - err := clt.CreateAccessRequest(cf.Context, req) - return trace.Wrap(err) - }); err != nil { + if err := rootAuth.CreateAccessRequest(cf.Context, req); err != nil { return trace.Wrap(err) } } - onRequestShow(cf) + printAccessRequest(cf, req) fmt.Println("") // Don't wait for request to get resolved, just print out request info. @@ -3061,19 +3071,29 @@ func retryWithAccessRequest(cf *CLIConf, tc *client.TeleportClient, fn func() er } req.SetRequestReason(requestReason) + clusterClient, err := tc.ConnectToCluster(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer clusterClient.Close() + + rootAuth, err := clusterClient.RootClient(cf.Context) + if err != nil { + return trace.Wrap(err) + } + defer rootAuth.Close() + // Watch for resolution events on the given request. Start watcher and wait // for it to be ready before creating the request to avoid a potential race. requestWatcher := newAccessRequestWatcher(req) defer requestWatcher.Close() - if err := requestWatcher.initialize(cf.Context, tc); err != nil { + if err := requestWatcher.initialize(cf.Context, rootAuth); err != nil { return trace.Wrap(err) } fmt.Fprint(os.Stdout, "Creating request...\n") // Always create access request against the root cluster. - if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { - return trace.Wrap(clt.CreateAccessRequest(cf.Context, req)) - }); err != nil { + if err := rootAuth.CreateAccessRequest(cf.Context, req); err != nil { return trace.Wrap(err) } @@ -3081,7 +3101,15 @@ func retryWithAccessRequest(cf *CLIConf, tc *client.TeleportClient, fn func() er cf.Username = tc.Username } // re-fetch the request to display it with roles populated. - onRequestShow(cf) + req, err = services.GetAccessRequest(cf.Context, rootAuth, cf.RequestID) + if err != nil { + return trace.Wrap(err) + } + + if err := printAccessRequest(cf, req); err != nil { + return trace.Wrap(err) + } + fmt.Println("") // Wait for the request to be resolved. @@ -3760,8 +3788,6 @@ func proxyHostsErrorMsgDefault(proxyAddress string, ports []int) string { // // If successful, setClientWebProxyAddr will modify the client Config in-place. func setClientWebProxyAddr(ctx context.Context, cf *CLIConf, c *client.Config) error { - ctx, span := cf.tracer.Start(ctx, "makeClientForProxy/setClientWebProxyAddr") - defer span.End() // If the user has specified a proxy on the command line, and one has not // already been specified from configuration... @@ -4185,7 +4211,6 @@ func host(in string) string { type accessRequestWatcher struct { req types.AccessRequest watcher types.Watcher - closers []io.Closer sync.RWMutex } @@ -4200,7 +4225,7 @@ func newAccessRequestWatcher(req types.AccessRequest) *accessRequestWatcher { // initialize sets up the underlying event watcher, when this returns without // error the watcher is guaranteed to be in a ready state. Call this before // creating the request to prevent a race. -func (w *accessRequestWatcher) initialize(ctx context.Context, tc *client.TeleportClient) error { +func (w *accessRequestWatcher) initialize(ctx context.Context, rootClient auth.ClientI) error { w.Lock() defer w.Unlock() @@ -4208,23 +4233,11 @@ func (w *accessRequestWatcher) initialize(ctx context.Context, tc *client.Telepo return trace.BadParameter("cannot re-initialize accessRequestWatcher") } - proxyClient, err := tc.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - w.closers = append(w.closers, proxyClient) - - rootClient, err := proxyClient.ConnectToRootCluster(ctx) - if err != nil { - return trace.Wrap(err) - } - w.closers = append(w.closers, rootClient) - filter := types.AccessRequestFilter{ User: w.req.GetUser(), ID: w.req.GetName(), } - w.watcher, err = rootClient.NewWatcher(ctx, types.Watch{ + watcher, err := rootClient.NewWatcher(ctx, types.Watch{ Name: "await-request-approval", Kinds: []types.WatchKind{{ Kind: types.KindAccessRequest, @@ -4234,7 +4247,7 @@ func (w *accessRequestWatcher) initialize(ctx context.Context, tc *client.Telepo if err != nil { return trace.Wrap(err) } - w.closers = append(w.closers, w.watcher) + w.watcher = watcher // Wait for OpInit event so that returned watcher is ready. select { @@ -4284,23 +4297,16 @@ func (w *accessRequestWatcher) awaitResolution() (types.AccessRequest, error) { } } -// Close closes the clients held by the watcher. +// Close terminates the watcher. func (w *accessRequestWatcher) Close() error { - var errs []error - // Close in reverse order, like defer. w.RLock() - for i := len(w.closers) - 1; i >= 0; i-- { - errs = append(errs, w.closers[i].Close()) - } - w.RUnlock() + defer w.RUnlock() - // Closed the watcher above, awaitResolution should now terminate and we can - // grab the lock. - w.Lock() - w.closers = nil - w.Unlock() + if w.watcher == nil { + return nil + } - return trace.NewAggregate(errs...) + return trace.Wrap(w.watcher.Close()) } func onRequestResolution(cf *CLIConf, tc *client.TeleportClient, req types.AccessRequest) error {