From e8f3b851113bf60097fb89059ea6b1cbb44456b1 Mon Sep 17 00:00:00 2001 From: joerger Date: Fri, 17 Feb 2023 17:35:01 -0800 Subject: [PATCH 01/18] Add Headless Authn service. --- lib/auth/auth.go | 34 ++++++++++ lib/auth/auth_login_test.go | 59 ++++++++++++++++++ lib/auth/auth_with_roles.go | 58 +++++++++++++++-- lib/auth/methods.go | 96 +++++++++++++++++++++++++++-- lib/services/identity.go | 10 +++ rfd/0105-headless-authentication.md | 4 +- 6 files changed, 249 insertions(+), 12 deletions(-) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 71cddd4a02bd0..62a1705162166 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -4703,6 +4703,40 @@ func (a *Server) GetLicense(ctx context.Context) (string, error) { return fmt.Sprintf("%s%s", a.license.CertPEM, a.license.KeyPEM), nil } +// GetOrWaitForHeadlessAuthentication returns a headless authentication from the backend by name. +// If it does not yet exist, an empty item will be inserted and this function will wait until +// the item is updated with the request details from the headless login request. +func (a *Server) GetOrWaitForHeadlessAuthentication(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { + if headlessAuthn, err := a.Services.GetHeadlessAuthentication(ctx, name); err == nil { + return headlessAuthn, nil + } else if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + + if _, err := a.Services.CreateHeadlessAuthenticationStub(ctx, name); err != nil { + return nil, trace.Wrap(err) + } + + // wait for the headless authentication to be updated with valid login details. + headlessAuthn, err := a.headlessAuthenticationWatcher.Wait(ctx, name, func(ha *types.HeadlessAuthentication) (bool, error) { + return services.ValidateHeadlessAuthentication(ha) == nil, nil + }) + if err != nil { + return nil, trace.Wrap(err) + } + return headlessAuthn, nil +} + +// CompareAndSwapHeadlessAuthentication performs a compare +// and swap replacement on a headless authentication resource. +func (a *Server) CompareAndSwapHeadlessAuthentication(ctx context.Context, old, new *types.HeadlessAuthentication) (*types.HeadlessAuthentication, error) { + headlessAuthn, err := a.Services.CompareAndSwapHeadlessAuthentication(ctx, old, new) + if err != nil { + return nil, trace.Wrap(err) + } + return headlessAuthn, nil +} + // authKeepAliver is a keep aliver using auth server directly type authKeepAliver struct { sync.RWMutex diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index dc41aba905e67..9b07611b6a240 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/teleport/lib/auth/mocku2f" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" ) func TestServer_CreateAuthenticateChallenge_authPreference(t *testing.T) { @@ -701,6 +702,64 @@ func TestServer_Authenticate_nonPasswordlessRequiresUsername(t *testing.T) { } } +func TestServer_Authenticate_headless(t *testing.T) { + t.Parallel() + srv := newTestTLSServer(t) + + // We don't mind about the specifics of the configuration, as long as we have + // a user and TOTP/WebAuthn devices. + mfa := configureForMFA(t, srv) + username := mfa.User + + proxyClient, err := srv.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) + + headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) + ctx := context.Background() + + // Approve the headless login in a goroutine + errC := make(chan error) + go func() { + defer close(errC) + + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + + headlessAuthn, err := srv.Auth().GetOrWaitForHeadlessAuthentication(ctx, headlessID) + if err != nil { + errC <- err + return + } + + // create a shallow copy with approval for the compare and swap below. + approvedHeadlessAuthn := *headlessAuthn + approvedHeadlessAuthn.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED + approvedHeadlessAuthn.MfaDevice = mfa.WebDev.MFA + _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, headlessAuthn, &approvedHeadlessAuthn) + if err != nil { + errC <- err + return + } + }() + + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + + _, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ + AuthenticateUserRequest: AuthenticateUserRequest{ + Username: username, + PublicKey: []byte(sshPubKey), + HeadlessAuthenticationID: headlessID, + ClientMetadata: &ForwardedClientMetadata{ + RemoteAddr: "0.0.0.0", + }, + }, + TTL: 24 * time.Hour, + }) + require.NoError(t, err) + require.NoError(t, <-errC) +} + type configureMFAResp struct { User, Password string TOTPDev, WebDev *TestDevice diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 26bdb21ee5900..2d2e935a62dd0 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -5722,14 +5722,62 @@ func (a *ServerWithRoles) DeleteAllUserGroups(ctx context.Context) error { // GetHeadlessAuthentication retrieves a headless authentication by id. func (a *ServerWithRoles) GetHeadlessAuthentication(ctx context.Context, id string) (*types.HeadlessAuthentication, error) { - // TODO (joerger): Add implementation - follow up PR - return nil, trace.NotImplemented("GetHeadlessAuthentication is not implemented") + headlessAuthn, err := a.authServer.GetOrWaitForHeadlessAuthentication(ctx, id) + if err != nil { + return nil, trace.Wrap(err) + } + + // User can always get their own headless authentication state. Otherwise, check for associated rule. + if !hasLocalUserRole(a.context) || headlessAuthn.User != a.context.User.GetName() { + if err := a.action(apidefaults.Namespace, types.KindHeadlessAuthentication, types.VerbRead); err != nil { + // If the headless authentication can not be accessed by the user, we will return a not + // found error. This method would usually time out above if the headless authentication + // does not exist, so we mimick this behavior here. + <-ctx.Done() + return nil, trace.Wrap(ctx.Err()) + } + } + + return headlessAuthn, nil } // UpdateHeadlessAuthenticationState updates a headless authentication state. -func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, id string, state types.HeadlessAuthenticationState, mfaResp *proto.MFAAuthenticateResponse) error { - // TODO (joerger): Add implementation - follow up PR - return trace.NotImplemented("UpdateHeadlessAuthenticationState is not implemented") +func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, id string, newState types.HeadlessAuthenticationState, mfaResp *proto.MFAAuthenticateResponse) error { + headlessAuthn, err := a.authServer.GetHeadlessAuthentication(ctx, id) + if err != nil { + return trace.Wrap(err) + } + + // Only users can approve their own headless auth requests. + if !hasLocalUserRole(a.context) || headlessAuthn.User != a.context.User.GetName() { + return trace.AccessDenied("cannot update a different user's headless authentication state") + } + + // Shallow copy headless authn for compare and swap below. + replaceHeadlessAuthn := *headlessAuthn + replaceHeadlessAuthn.State = newState + + // The user must authenticate with MFA to change the state to approved. + if newState == types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED { + if mfaResp == nil { + return trace.BadParameter("expected MFA auth challenge response") + } + + // Only WebAuthn is supported in headless login flow for superior phishing prevention. + if _, ok := mfaResp.Response.(*proto.MFAAuthenticateResponse_Webauthn); !ok { + return trace.BadParameter("expected WebAuthn challenge response, but got %T", mfaResp.Response) + } + + mfaDevice, _, err := a.authServer.validateMFAAuthResponse(ctx, mfaResp, headlessAuthn.User, false /* passwordless */) + if err != nil { + return trace.Wrap(err) + } + + replaceHeadlessAuthn.MfaDevice = mfaDevice + } + + _, err = a.authServer.CompareAndSwapHeadlessAuthentication(ctx, headlessAuthn, &replaceHeadlessAuthn) + return trace.Wrap(err) } // NewAdminAuthServer returns auth server authorized as admin, diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 9cc024ca0ee79..05941cf1acc1d 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -17,6 +17,7 @@ limitations under the License. package auth import ( + "bytes" "context" "errors" "time" @@ -79,7 +80,7 @@ func (a *AuthenticateUserRequest) CheckAndSetDefaults() error { case a.Username == "" && a.Webauthn != nil: // OK, passwordless. case a.Username == "": return trace.BadParameter("missing parameter 'username'") - case a.Pass == nil && a.Webauthn == nil && a.OTP == nil && a.Session == nil: + case a.Pass == nil && a.Webauthn == nil && a.OTP == nil && a.Session == nil && a.HeadlessAuthenticationID == "": return trace.BadParameter("at least one authentication method is required") } return nil @@ -166,6 +167,9 @@ var ( // invalidUserpass2FError is the error for when either the provided username, // password, or second factor is incorrect. invalidUserPass2FError = trace.AccessDenied("invalid username, password or second factor") + // invalidHeadlessAuthenticationError is the generic error returned for failed headless + // authentication attempts. + invalidHeadlessAuthenticationError = trace.AccessDenied("invalid Headless authentication") ) // IsInvalidLocalCredentialError checks if an error resulted from an incorrect username, @@ -216,6 +220,15 @@ func (s *Server) authenticateUser(ctx context.Context, req AuthenticateUserReque return res.mfaDev, nil } authErr = invalidUserPass2FError + case req.HeadlessAuthenticationID != "": + authenticateFn = func() (*types.MFADevice, error) { + mfaDevice, err := s.authenticateHeadless(ctx, req) + if err != nil { + return nil, trace.Wrap(err) + } + return mfaDevice, nil + } + authErr = invalidHeadlessAuthenticationError } if authenticateFn != nil { var dev *types.MFADevice @@ -234,8 +247,8 @@ func (s *Server) authenticateUser(ctx context.Context, req AuthenticateUserReque return nil, "", trace.Wrap(authErr) case dev == nil: log.Debugf( - "MFA authentication returned nil device (Webauthn = %v, TOTP = %v): %v.", - req.Webauthn != nil, req.OTP != nil, err) + "MFA authentication returned nil device (Webauthn = %v, TOTP = %v, Headless = %v): %v.", + req.Webauthn != nil, req.OTP != nil, req.HeadlessAuthenticationID != "", err) return nil, "", trace.Wrap(authErr) default: return dev, user, nil @@ -312,6 +325,64 @@ func (s *Server) authenticatePasswordless(ctx context.Context, req AuthenticateU return dev, user, nil } +func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserRequest) (*types.MFADevice, error) { + // Wait up to one minute for the headless auth request to be approved. + waitCtx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + headlessAuthn := &types.HeadlessAuthentication{ + ResourceHeader: types.ResourceHeader{ + Metadata: types.Metadata{ + Name: req.HeadlessAuthenticationID, + }, + }, + User: req.Username, + PublicKey: req.PublicKey, + ClientIpAddress: req.ClientMetadata.RemoteAddr, + } + headlessAuthn.SetExpiry(s.clock.Now().Add(time.Minute)) + if err := services.ValidateHeadlessAuthentication(headlessAuthn); err != nil { + return nil, trace.Wrap(err) + } + + // Wait for a headless authenticated stub to be inserted by an authenticated + // call to GetHeadlessAuthentication. We do this to avoid immediately inserting + // backend items from an unauthenticated endpoint. + headlessAuthnStub, err := s.headlessAuthenticationWatcher.Wait(waitCtx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { + // Only headless authentication stub can be inserted without the standard validation. + if services.ValidateHeadlessAuthentication(ha) == nil { + return false, trace.AlreadyExists("headless auth request already exists") + } + return true, nil + }) + if err != nil { + return nil, trace.Wrap(err) + } + + // Update headless authentication with login details and wait for it to be approved/denied. + if _, err := s.CompareAndSwapHeadlessAuthentication(ctx, headlessAuthnStub, headlessAuthn); err != nil { + return nil, trace.Wrap(err) + } + + headlessAuthn, err = s.headlessAuthenticationWatcher.Wait(waitCtx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { + switch ha.State { + case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED: + if ha.MfaDevice == nil { + return false, trace.AccessDenied("expected mfa approval for headless authentication approval") + } + return true, nil + case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED: + return false, trace.AccessDenied("headless authentication denied") + } + return false, nil + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return headlessAuthn.MfaDevice, nil +} + // AuthenticateWebUser authenticates web user, creates and returns a web session // if authentication is successful. In case the existing session ID is used to authenticate, // returns the existing session instead of creating a new one @@ -509,7 +580,7 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq return nil, trace.BadParameter("source IP pinning is enabled but client IP is unknown") } - certs, err := s.generateUserCert(certRequest{ + certReq := certRequest{ user: user, ttl: req.TTL, publicKey: req.PublicKey, @@ -520,7 +591,22 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq kubernetesCluster: req.KubernetesCluster, loginIP: clientIP, attestationStatement: req.AttestationStatement, - }) + } + + // For headless authentication, a short-lived mfa-verified cert should be generated. + if req.HeadlessAuthenticationID != "" { + ha, err := s.GetHeadlessAuthentication(ctx, req.HeadlessAuthenticationID) + if err != nil { + return nil, trace.Wrap(err) + } + if !bytes.Equal(req.PublicKey, ha.PublicKey) { + return nil, trace.AccessDenied("headless authentication public key mismatch") + } + certReq.mfaVerified = ha.MfaDevice.Metadata.Name + certReq.ttl = time.Minute + } + + certs, err := s.generateUserCert(certReq) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/services/identity.go b/lib/services/identity.go index 2c1661038cd45..80376afb440df 100644 --- a/lib/services/identity.go +++ b/lib/services/identity.go @@ -257,6 +257,16 @@ type Identity interface { // GetKeyAttestationData gets a verified public key attestation response. GetKeyAttestationData(ctx context.Context, publicKey crypto.PublicKey) (*keys.AttestationData, error) + // CreateHeadlessAuthenticationStub creates a headless authentication stub. + CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) + + // CompareAndSwapHeadlessAuthentication performs a compare + // and swap replacement on a headless authentication resource. + CompareAndSwapHeadlessAuthentication(ctx context.Context, old, new *types.HeadlessAuthentication) (*types.HeadlessAuthentication, error) + + // GetHeadlessAuthentication retrieves a headless authentication by name. + GetHeadlessAuthentication(ctx context.Context, name string) (*types.HeadlessAuthentication, error) + types.WebSessionsGetter types.WebTokensGetter diff --git a/rfd/0105-headless-authentication.md b/rfd/0105-headless-authentication.md index 380dc8bfbf29d..4c652d55a0f5b 100644 --- a/rfd/0105-headless-authentication.md +++ b/rfd/0105-headless-authentication.md @@ -335,8 +335,8 @@ type AuthenticateUserRequest struct { Session *SessionCreds `json:"session,omitempty"` // ClientMetadata includes forwarded information about a client ClientMetadata *ForwardedClientMetadata `json:"client_metadata,omitempty"` - // Headless determines whether headless authentication will be used - Headless bool `json:"headless"` + // HeadlessAuthenticationID is the ID for a headless authentication resource. + HeadlessAuthenticationID string `json:"headless_authentication_id"` } ``` From 0924c1d3ad93987c25c79944a1c81584af94ee8c Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 8 Mar 2023 19:16:45 -0800 Subject: [PATCH 02/18] Add/fix 3 minute headless login timeout. --- lib/auth/clt.go | 12 ++++++++++++ lib/auth/methods.go | 14 +++++++------- lib/services/local/headlessauthn.go | 1 + rfd/0105-headless-authentication.md | 2 +- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 33ed1f5c5354a..ac6e37405ea8c 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -842,6 +842,18 @@ func (c *Client) AuthenticateWebUser(ctx context.Context, req AuthenticateUserRe // AuthenticateSSHUser authenticates SSH console user, creates and returns a pair of signed TLS and SSH // short lived certificates as a result func (c *Client) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHRequest) (*SSHLoginResponse, error) { + if req.HeadlessAuthenticationID != "" { + // Replace the client timeout with the default callback timeout for this request. + previousResponseHeaderTimeout := c.HTTPClient.transport.ResponseHeaderTimeout + previousClientTimeout := c.HTTPClient.HTTPClient().Timeout + c.HTTPClient.transport.ResponseHeaderTimeout = defaults.CallbackTimeout + c.HTTPClient.HTTPClient().Timeout = defaults.CallbackTimeout + defer func() { + c.HTTPClient.transport.ResponseHeaderTimeout = previousResponseHeaderTimeout + c.HTTPClient.HTTPClient().Timeout = previousClientTimeout + }() + } + out, err := c.PostJSON( ctx, c.Endpoint("users", req.Username, "ssh", "authenticate"), diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 05941cf1acc1d..46b1dbe8141f0 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -31,6 +31,7 @@ import ( apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -326,10 +327,6 @@ func (s *Server) authenticatePasswordless(ctx context.Context, req AuthenticateU } func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserRequest) (*types.MFADevice, error) { - // Wait up to one minute for the headless auth request to be approved. - waitCtx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - headlessAuthn := &types.HeadlessAuthentication{ ResourceHeader: types.ResourceHeader{ Metadata: types.Metadata{ @@ -340,7 +337,10 @@ func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserR PublicKey: req.PublicKey, ClientIpAddress: req.ClientMetadata.RemoteAddr, } - headlessAuthn.SetExpiry(s.clock.Now().Add(time.Minute)) + + // Headless Authentication should expire when the callback expires. + headlessAuthn.SetExpiry(s.clock.Now().Add(defaults.CallbackTimeout)) + if err := services.ValidateHeadlessAuthentication(headlessAuthn); err != nil { return nil, trace.Wrap(err) } @@ -348,7 +348,7 @@ func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserR // Wait for a headless authenticated stub to be inserted by an authenticated // call to GetHeadlessAuthentication. We do this to avoid immediately inserting // backend items from an unauthenticated endpoint. - headlessAuthnStub, err := s.headlessAuthenticationWatcher.Wait(waitCtx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { + headlessAuthnStub, err := s.headlessAuthenticationWatcher.Wait(ctx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { // Only headless authentication stub can be inserted without the standard validation. if services.ValidateHeadlessAuthentication(ha) == nil { return false, trace.AlreadyExists("headless auth request already exists") @@ -364,7 +364,7 @@ func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserR return nil, trace.Wrap(err) } - headlessAuthn, err = s.headlessAuthenticationWatcher.Wait(waitCtx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { + headlessAuthn, err = s.headlessAuthenticationWatcher.Wait(ctx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { switch ha.State { case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED: if ha.MfaDevice == nil { diff --git a/lib/services/local/headlessauthn.go b/lib/services/local/headlessauthn.go index a0ad206835a34..58a554d297785 100644 --- a/lib/services/local/headlessauthn.go +++ b/lib/services/local/headlessauthn.go @@ -30,6 +30,7 @@ import ( // CreateHeadlessAuthenticationStub creates a headless authentication stub in the backend. func (s *IdentityService) CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { + // Stub should be replaced shortly after creation. expires := s.Clock().Now().Add(time.Minute) headlessAuthn := &types.HeadlessAuthentication{ ResourceHeader: types.ResourceHeader{ diff --git a/rfd/0105-headless-authentication.md b/rfd/0105-headless-authentication.md index 4c652d55a0f5b..e53c263bc2aeb 100644 --- a/rfd/0105-headless-authentication.md +++ b/rfd/0105-headless-authentication.md @@ -153,7 +153,7 @@ A request id will be derived from the client's public key so that an attacker ca Note: We could also use the public key directly (base64 encoded), but we choose to use a UUID to shorten the URL and improve its readability. -As [explained above](#unauthenticated-headless-login-endpoint), the Auth server will write the request details to the backend under `/headless_authentication/` on demand. It will have a 1 minute TTL, by which point the user should have completed the headless authentication flow. The request will begin in the pending state. The Auth server then waits for the user to approve the authentication request using a resource watcher. +As [explained above](#unauthenticated-headless-login-endpoint), the Auth server will write the request details to the backend under `/headless_authentication/` on demand. It will have a 3 minute TTL, matching the callback timeout of the request. The request will begin in the pending state. The Auth server then waits for the user to approve the authentication request using a resource watcher. #### Local authentication From f9d0e3c511869f66412d1801b9b23bcad76a9804 Mon Sep 17 00:00:00 2001 From: joerger Date: Mon, 13 Mar 2023 11:00:31 -0700 Subject: [PATCH 03/18] * Prevent repeated updates to headless authentication state * Prevent user lock out from headless authentication failure * Delete headless authentication on failed attempts --- lib/auth/auth_login_test.go | 125 +++++++++++++++++++++++++----------- lib/auth/auth_with_roles.go | 17 +++-- lib/auth/methods.go | 32 +++++++-- lib/services/identity.go | 3 + 4 files changed, 128 insertions(+), 49 deletions(-) diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index 9b07611b6a240..6da1374a02737 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -717,47 +717,96 @@ func TestServer_Authenticate_headless(t *testing.T) { headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) ctx := context.Background() - // Approve the headless login in a goroutine - errC := make(chan error) - go func() { - defer close(errC) - - ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) - defer cancel() - - headlessAuthn, err := srv.Auth().GetOrWaitForHeadlessAuthentication(ctx, headlessID) - if err != nil { - errC <- err - return - } + timeout := time.Millisecond * 100 - // create a shallow copy with approval for the compare and swap below. - approvedHeadlessAuthn := *headlessAuthn - approvedHeadlessAuthn.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED - approvedHeadlessAuthn.MfaDevice = mfa.WebDev.MFA - _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, headlessAuthn, &approvedHeadlessAuthn) - if err != nil { - errC <- err - return - } - }() - - ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) - defer cancel() - - _, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ - AuthenticateUserRequest: AuthenticateUserRequest{ - Username: username, - PublicKey: []byte(sshPubKey), - HeadlessAuthenticationID: headlessID, - ClientMetadata: &ForwardedClientMetadata{ - RemoteAddr: "0.0.0.0", + updateHeadlessAuthnInGoroutine := func(ctx context.Context, update func(*types.HeadlessAuthentication)) chan error { + errC := make(chan error) + go func() { + defer close(errC) + + headlessAuthn, err := srv.Auth().GetOrWaitForHeadlessAuthentication(ctx, headlessID) + if err != nil { + errC <- err + return + } + + // create a shallow copy and update for the compare and swap below. + replaceHeadlessAuthn := *headlessAuthn + update(&replaceHeadlessAuthn) + + _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, headlessAuthn, &replaceHeadlessAuthn) + if err != nil { + errC <- err + return + } + }() + return errC + } + + for _, tc := range []struct { + name string + update func(*types.HeadlessAuthentication) + checkErr require.ErrorAssertionFunc + }{ + { + name: "OK approved", + update: func(ha *types.HeadlessAuthentication) { + ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED + ha.MfaDevice = mfa.WebDev.MFA + }, + checkErr: require.NoError, + }, { + name: "NOK approved without MFA", + update: func(ha *types.HeadlessAuthentication) { + ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED + }, + checkErr: require.Error, + }, { + name: "NOK user mismatch", + update: func(ha *types.HeadlessAuthentication) { + ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED + ha.MfaDevice = mfa.WebDev.MFA + ha.User = "other-user" + }, + checkErr: require.Error, + }, { + name: "NOK denied", + update: func(ha *types.HeadlessAuthentication) { + ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED }, + checkErr: require.Error, + }, { + name: "NOK timeout", + update: func(ha *types.HeadlessAuthentication) { + time.Sleep(timeout) + }, + checkErr: require.Error, }, - TTL: 24 * time.Hour, - }) - require.NoError(t, err) - require.NoError(t, <-errC) + } { + t.Run(tc.name, func(t *testing.T) { + t.Cleanup(func() { + srv.Auth().DeleteHeadlessAuthentication(ctx, headlessID) + }) + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + errC := updateHeadlessAuthnInGoroutine(ctx, tc.update) + _, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ + AuthenticateUserRequest: AuthenticateUserRequest{ + Username: username, + PublicKey: []byte(sshPubKey), + HeadlessAuthenticationID: headlessID, + ClientMetadata: &ForwardedClientMetadata{ + RemoteAddr: "0.0.0.0", + }, + }, + TTL: defaults.CallbackTimeout, + }) + tc.checkErr(t, err) + require.NoError(t, <-errC) + }) + } } type configureMFAResp struct { diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 2d2e935a62dd0..e9e5302b3471a 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -5742,7 +5742,7 @@ func (a *ServerWithRoles) GetHeadlessAuthentication(ctx context.Context, id stri } // UpdateHeadlessAuthenticationState updates a headless authentication state. -func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, id string, newState types.HeadlessAuthenticationState, mfaResp *proto.MFAAuthenticateResponse) error { +func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, id string, state types.HeadlessAuthenticationState, mfaResp *proto.MFAAuthenticateResponse) error { headlessAuthn, err := a.authServer.GetHeadlessAuthentication(ctx, id) if err != nil { return trace.Wrap(err) @@ -5753,12 +5753,17 @@ func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, return trace.AccessDenied("cannot update a different user's headless authentication state") } + if headlessAuthn.State != types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_PENDING { + return trace.AccessDenied("cannot update a headless authentication state from a non-pending state") + } + // Shallow copy headless authn for compare and swap below. replaceHeadlessAuthn := *headlessAuthn - replaceHeadlessAuthn.State = newState + replaceHeadlessAuthn.State = state - // The user must authenticate with MFA to change the state to approved. - if newState == types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED { + switch state { + case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED: + // The user must authenticate with MFA to change the state to approved. if mfaResp == nil { return trace.BadParameter("expected MFA auth challenge response") } @@ -5774,6 +5779,10 @@ func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, } replaceHeadlessAuthn.MfaDevice = mfaDevice + case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED: + // continue to compare and swap without MFA. + default: + return trace.AccessDenied("cannot update a headless authentication state to %v", state.String()) } _, err = a.authServer.CompareAndSwapHeadlessAuthentication(ctx, headlessAuthn, &replaceHeadlessAuthn) diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 46b1dbe8141f0..7d17c581c0fbb 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -222,11 +222,14 @@ func (s *Server) authenticateUser(ctx context.Context, req AuthenticateUserReque } authErr = invalidUserPass2FError case req.HeadlessAuthenticationID != "": + // handle authentication before the user lock to prevent locking out users + // due to timed-out/canceled headless authentication attempts. + mfaDevice, err := s.authenticateHeadless(ctx, req) + if err != nil { + log.Debugf("Headless Authentication for user %q failed while waiting for approval: %v", user, err) + return nil, "", trace.Wrap(invalidHeadlessAuthenticationError) + } authenticateFn = func() (*types.MFADevice, error) { - mfaDevice, err := s.authenticateHeadless(ctx, req) - if err != nil { - return nil, trace.Wrap(err) - } return mfaDevice, nil } authErr = invalidHeadlessAuthenticationError @@ -326,7 +329,16 @@ func (s *Server) authenticatePasswordless(ctx context.Context, req AuthenticateU return dev, user, nil } -func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserRequest) (*types.MFADevice, error) { +func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserRequest) (mfa *types.MFADevice, err error) { + // Delete the headless authentication upon failure. + defer func() { + if err != nil { + if err := s.DeleteHeadlessAuthentication(ctx, req.HeadlessAuthenticationID); err != nil && !trace.IsNotFound(err) { + log.Debugf("Failed to delete headless authentication: %v", err) + } + } + }() + headlessAuthn := &types.HeadlessAuthentication{ ResourceHeader: types.ResourceHeader{ Metadata: types.Metadata{ @@ -359,11 +371,12 @@ func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserR return nil, trace.Wrap(err) } - // Update headless authentication with login details and wait for it to be approved/denied. + // Update headless authentication with login details. if _, err := s.CompareAndSwapHeadlessAuthentication(ctx, headlessAuthnStub, headlessAuthn); err != nil { return nil, trace.Wrap(err) } + // Wait for the request to be approved/denied. headlessAuthn, err = s.headlessAuthenticationWatcher.Wait(ctx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { switch ha.State { case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED: @@ -377,7 +390,12 @@ func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserR return false, nil }) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(ErrSAMLRequiresEnterprise) + } + + // Verify that the headless authentication has not been tampered with. + if headlessAuthn.User != req.Username { + return nil, trace.AccessDenied("user mismatch") } return headlessAuthn.MfaDevice, nil diff --git a/lib/services/identity.go b/lib/services/identity.go index 80376afb440df..b4d4e83a5e3f3 100644 --- a/lib/services/identity.go +++ b/lib/services/identity.go @@ -267,6 +267,9 @@ type Identity interface { // GetHeadlessAuthentication retrieves a headless authentication by name. GetHeadlessAuthentication(ctx context.Context, name string) (*types.HeadlessAuthentication, error) + // DeleteHeadlessAuthentication deletes a headless authentication from the backend by name. + DeleteHeadlessAuthentication(ctx context.Context, name string) error + types.WebSessionsGetter types.WebTokensGetter From aa9e8e02fd14c83128c43fc43a36814dcf987b4a Mon Sep 17 00:00:00 2001 From: joerger Date: Mon, 13 Mar 2023 16:59:34 -0700 Subject: [PATCH 04/18] Add auth_with_roles test. --- lib/auth/auth_with_roles.go | 4 +- lib/auth/auth_with_roles_test.go | 88 ++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index e9e5302b3471a..e276778d4aea8 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -5748,9 +5748,9 @@ func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, return trace.Wrap(err) } - // Only users can approve their own headless auth requests. + // Only users can approve/deny their own headless auth requests. if !hasLocalUserRole(a.context) || headlessAuthn.User != a.context.User.GetName() { - return trace.AccessDenied("cannot update a different user's headless authentication state") + return trace.NotFound("not found") } if headlessAuthn.State != types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_PENDING { diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 535afafa56bcb..a916c0d0f421c 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -38,6 +38,7 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/installers" + "github.com/gravitational/teleport/api/types/webauthn" "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/sshutils" @@ -4214,3 +4215,90 @@ func TestUnimplementedClients(t *testing.T) { require.True(t, trace.IsNotImplemented(err), err) }) } + +func TestHeadlessAuthentication(t *testing.T) { + ctx := context.Background() + srv := newTestTLSServer(t) + + mfa := configureForMFA(t, srv) + + user1, _, err := CreateUserAndRole(srv.Auth(), mfa.User, nil, nil) + require.NoError(t, err) + client1, err := srv.NewClient(TestUser(user1.GetName())) + require.NoError(t, err) + + user2, _, err := CreateUserAndRole(srv.Auth(), "user2", nil, nil) + require.NoError(t, err) + client2, err := srv.NewClient(TestUser(user2.GetName())) + require.NoError(t, err) + + // Insert a headless authentication resource into the backend. + headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) + headlessAuthn := &types.HeadlessAuthentication{ + ResourceHeader: types.ResourceHeader{ + Metadata: types.Metadata{ + Name: headlessID, + }, + }, + User: user1.GetName(), + PublicKey: []byte(sshPubKey), + ClientIpAddress: "0.0.0.0", + } + headlessAuthn.SetExpiry(time.Now().Add(time.Minute)) + + stub, err := srv.Auth().CreateHeadlessAuthenticationStub(ctx, headlessID) + require.NoError(t, err) + _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, stub, headlessAuthn) + require.NoError(t, err) + + // user2 should fail to get headless authentication, and wait for ctx to timeout as if not found + // to prevent leaking other user's headless authentication attempts. + failedGetCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + + _, err = client2.GetHeadlessAuthentication(failedGetCtx, headlessID) + require.Error(t, err) + require.Contains(t, err.Error(), "context deadline exceeded", "expected context deadline error but got: %v", err) + + // user1 should successfully get headless authentication with up to date login details + retrievedHeadlessAuthn, err := client1.GetHeadlessAuthentication(ctx, headlessID) + require.NoError(t, err) + require.Equal(t, headlessAuthn, retrievedHeadlessAuthn) + + // user2 should fail to update authentication state + err = client2.UpdateHeadlessAuthenticationState(ctx, headlessID, types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, nil) + require.Error(t, err) + require.True(t, trace.IsNotFound(err), "expected not found error but got: %v", err) + + // user1 should successfully update authentication state to denied + err = client1.UpdateHeadlessAuthenticationState(ctx, headlessID, types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, nil) + require.NoError(t, err) + + // reset to original state + retrievedHeadlessAuthn, err = client1.GetHeadlessAuthentication(ctx, headlessID) + require.NoError(t, err) + _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, retrievedHeadlessAuthn, headlessAuthn) + require.NoError(t, err) + + // user1 should fail to update authentication state to approved without mfa + err = client1.UpdateHeadlessAuthenticationState(ctx, headlessID, types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: &webauthn.CredentialAssertionResponse{ + Type: "bad response", + }, + }, + }) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err), "expected access denied error but got: %v", err) + + // user1 should successfully update authentication state to approved with MFA + challenge, err := client1.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ + Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{}, + }) + require.NoError(t, err) + resp, err := mfa.WebDev.SolveAuthn(challenge) + require.NoError(t, err) + + err = client1.UpdateHeadlessAuthenticationState(ctx, headlessID, types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, resp) + require.NoError(t, err) +} From 15642496307921e47755dd4814ece818e043a760 Mon Sep 17 00:00:00 2001 From: joerger Date: Mon, 13 Mar 2023 17:01:19 -0700 Subject: [PATCH 05/18] Extend timeout in test to reduce flakiness. --- lib/auth/auth_login_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index 6da1374a02737..9a64fb8e153b5 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -717,7 +717,7 @@ func TestServer_Authenticate_headless(t *testing.T) { headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) ctx := context.Background() - timeout := time.Millisecond * 100 + timeout := time.Millisecond * 500 updateHeadlessAuthnInGoroutine := func(ctx context.Context, update func(*types.HeadlessAuthentication)) chan error { errC := make(chan error) From c9ed8738616af75f6d0156fc24f20b65728bca81 Mon Sep 17 00:00:00 2001 From: joerger Date: Mon, 13 Mar 2023 17:39:59 -0700 Subject: [PATCH 06/18] Fix error typo. --- lib/auth/methods.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 7d17c581c0fbb..8a4bc8e66491c 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -390,7 +390,7 @@ func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserR return false, nil }) if err != nil { - return nil, trace.Wrap(ErrSAMLRequiresEnterprise) + return nil, trace.Wrap(err) } // Verify that the headless authentication has not been tampered with. From cd9b920995070e07e84aadb7dde78cf2d54fef28 Mon Sep 17 00:00:00 2001 From: joerger Date: Tue, 14 Mar 2023 10:41:37 -0700 Subject: [PATCH 07/18] Add context timeouts, remove initial GetHeadlessAuthentication call. --- lib/auth/auth.go | 34 +++++++++++++------------------- lib/auth/auth_login_test.go | 2 +- lib/auth/auth_with_roles.go | 9 ++++++--- lib/auth/auth_with_roles_test.go | 6 +++--- lib/auth/methods.go | 5 +++++ lib/services/identity.go | 3 --- 6 files changed, 29 insertions(+), 30 deletions(-) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 62a1705162166..ec76314f23aa3 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -4703,38 +4703,32 @@ func (a *Server) GetLicense(ctx context.Context) (string, error) { return fmt.Sprintf("%s%s", a.license.CertPEM, a.license.KeyPEM), nil } -// GetOrWaitForHeadlessAuthentication returns a headless authentication from the backend by name. -// If it does not yet exist, an empty item will be inserted and this function will wait until -// the item is updated with the request details from the headless login request. -func (a *Server) GetOrWaitForHeadlessAuthentication(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { - if headlessAuthn, err := a.Services.GetHeadlessAuthentication(ctx, name); err == nil { - return headlessAuthn, nil - } else if !trace.IsNotFound(err) { +// GetHeadlessAuthentication returns a headless authentication from the backend by name. +// If it does not yet exist, a stub will be created to signal the login process to upsert +// login details. This method will wait for the updated headless authentication and return it. +func (a *Server) GetHeadlessAuthentication(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { + // Try to create a stub if it doesn't already exist, then wait for full login details. + if _, err := a.Services.CreateHeadlessAuthenticationStub(ctx, name); err != nil && !trace.IsAlreadyExists(err) { return nil, trace.Wrap(err) } - if _, err := a.Services.CreateHeadlessAuthenticationStub(ctx, name); err != nil { - return nil, trace.Wrap(err) - } + // wait for the headless authentication to be updated with valid login details + // by the login process. If the headless authentication is already updated, + // Wait will return it immediately. + waitCtx, cancel := context.WithTimeout(ctx, defaults.HTTPRequestTimeout) + defer cancel() - // wait for the headless authentication to be updated with valid login details. - headlessAuthn, err := a.headlessAuthenticationWatcher.Wait(ctx, name, func(ha *types.HeadlessAuthentication) (bool, error) { + headlessAuthn, err := a.headlessAuthenticationWatcher.Wait(waitCtx, name, func(ha *types.HeadlessAuthentication) (bool, error) { return services.ValidateHeadlessAuthentication(ha) == nil, nil }) - if err != nil { - return nil, trace.Wrap(err) - } - return headlessAuthn, nil + return headlessAuthn, trace.Wrap(err) } // CompareAndSwapHeadlessAuthentication performs a compare // and swap replacement on a headless authentication resource. func (a *Server) CompareAndSwapHeadlessAuthentication(ctx context.Context, old, new *types.HeadlessAuthentication) (*types.HeadlessAuthentication, error) { headlessAuthn, err := a.Services.CompareAndSwapHeadlessAuthentication(ctx, old, new) - if err != nil { - return nil, trace.Wrap(err) - } - return headlessAuthn, nil + return headlessAuthn, trace.Wrap(err) } // authKeepAliver is a keep aliver using auth server directly diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index 9a64fb8e153b5..ad56c58523150 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -724,7 +724,7 @@ func TestServer_Authenticate_headless(t *testing.T) { go func() { defer close(errC) - headlessAuthn, err := srv.Auth().GetOrWaitForHeadlessAuthentication(ctx, headlessID) + headlessAuthn, err := srv.Auth().GetHeadlessAuthentication(ctx, headlessID) if err != nil { errC <- err return diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index e276778d4aea8..29088fd59d49a 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -5722,7 +5722,10 @@ func (a *ServerWithRoles) DeleteAllUserGroups(ctx context.Context) error { // GetHeadlessAuthentication retrieves a headless authentication by id. func (a *ServerWithRoles) GetHeadlessAuthentication(ctx context.Context, id string) (*types.HeadlessAuthentication, error) { - headlessAuthn, err := a.authServer.GetOrWaitForHeadlessAuthentication(ctx, id) + waitCtx, cancel := context.WithTimeout(ctx, defaults.HTTPRequestTimeout) + defer cancel() + + headlessAuthn, err := a.authServer.GetHeadlessAuthentication(waitCtx, id) if err != nil { return nil, trace.Wrap(err) } @@ -5733,8 +5736,8 @@ func (a *ServerWithRoles) GetHeadlessAuthentication(ctx context.Context, id stri // If the headless authentication can not be accessed by the user, we will return a not // found error. This method would usually time out above if the headless authentication // does not exist, so we mimick this behavior here. - <-ctx.Done() - return nil, trace.Wrap(ctx.Err()) + <-waitCtx.Done() + return nil, trace.Wrap(waitCtx.Err()) } } diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index a916c0d0f421c..f3ad006b94e87 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -4251,10 +4251,10 @@ func TestHeadlessAuthentication(t *testing.T) { _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, stub, headlessAuthn) require.NoError(t, err) - // user2 should fail to get headless authentication, and wait for ctx to timeout as if not found + // user2 should fail to get headless authentication, and return the ctx error // to prevent leaking other user's headless authentication attempts. - failedGetCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) - defer cancel() + failedGetCtx, cancel := context.WithCancel(ctx) + cancel() _, err = client2.GetHeadlessAuthentication(failedGetCtx, headlessID) require.Error(t, err) diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 8a4bc8e66491c..61806afc383a0 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -330,6 +330,11 @@ func (s *Server) authenticatePasswordless(ctx context.Context, req AuthenticateU } func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserRequest) (mfa *types.MFADevice, err error) { + // this authentication requires two client callbacks to create a headless authentication + // stub and approve/deny the headless authentication, so we use a standard callback timeout. + ctx, cancel := context.WithTimeout(ctx, defaults.CallbackTimeout) + defer cancel() + // Delete the headless authentication upon failure. defer func() { if err != nil { diff --git a/lib/services/identity.go b/lib/services/identity.go index b4d4e83a5e3f3..ceb280c0d3b75 100644 --- a/lib/services/identity.go +++ b/lib/services/identity.go @@ -264,9 +264,6 @@ type Identity interface { // and swap replacement on a headless authentication resource. CompareAndSwapHeadlessAuthentication(ctx context.Context, old, new *types.HeadlessAuthentication) (*types.HeadlessAuthentication, error) - // GetHeadlessAuthentication retrieves a headless authentication by name. - GetHeadlessAuthentication(ctx context.Context, name string) (*types.HeadlessAuthentication, error) - // DeleteHeadlessAuthentication deletes a headless authentication from the backend by name. DeleteHeadlessAuthentication(ctx context.Context, name string) error From 6a9c09592d0a2400be912731a2392bff66737b2c Mon Sep 17 00:00:00 2001 From: joerger Date: Tue, 14 Mar 2023 12:28:37 -0700 Subject: [PATCH 08/18] Resolve comments. --- lib/auth/auth_with_roles_test.go | 2 +- lib/auth/methods.go | 7 +++++-- lib/services/local/headlessauthn.go | 4 ++-- rfd/0105-headless-authentication.md | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index f3ad006b94e87..708237d07b78c 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -4258,7 +4258,7 @@ func TestHeadlessAuthentication(t *testing.T) { _, err = client2.GetHeadlessAuthentication(failedGetCtx, headlessID) require.Error(t, err) - require.Contains(t, err.Error(), "context deadline exceeded", "expected context deadline error but got: %v", err) + require.ErrorContains(t, err, "context deadline exceeded", "expected context deadline error but got: %v", err) // user1 should successfully get headless authentication with up to date login details retrievedHeadlessAuthn, err := client1.GetHeadlessAuthentication(ctx, headlessID) diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 61806afc383a0..25f1c5299b0e9 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -69,9 +69,12 @@ type AuthenticateUserRequest struct { } // ForwardedClientMetadata can be used by the proxy web API to forward information about -// the client to the auth service for logging purposes. +// the client to the auth service. type ForwardedClientMetadata struct { - UserAgent string `json:"user_agent,omitempty"` + UserAgent string `json:"user_agent,omitempty"` + // RemoteAddr is the IP address of the end user. This IP address is derived + // either from a direct client connection, or from a PROXY protocol header + // if the connection is forwarded through a load balancer. RemoteAddr string `json:"remote_addr,omitempty"` } diff --git a/lib/services/local/headlessauthn.go b/lib/services/local/headlessauthn.go index 58a554d297785..c8b0a70bb3cf4 100644 --- a/lib/services/local/headlessauthn.go +++ b/lib/services/local/headlessauthn.go @@ -24,14 +24,14 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) // CreateHeadlessAuthenticationStub creates a headless authentication stub in the backend. func (s *IdentityService) CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { - // Stub should be replaced shortly after creation. - expires := s.Clock().Now().Add(time.Minute) + expires := s.Clock().Now().Add(defaults.CallbackTimeout) headlessAuthn := &types.HeadlessAuthentication{ ResourceHeader: types.ResourceHeader{ Metadata: types.Metadata{ diff --git a/rfd/0105-headless-authentication.md b/rfd/0105-headless-authentication.md index e53c263bc2aeb..aadb9a757fc7d 100644 --- a/rfd/0105-headless-authentication.md +++ b/rfd/0105-headless-authentication.md @@ -153,7 +153,7 @@ A request id will be derived from the client's public key so that an attacker ca Note: We could also use the public key directly (base64 encoded), but we choose to use a UUID to shorten the URL and improve its readability. -As [explained above](#unauthenticated-headless-login-endpoint), the Auth server will write the request details to the backend under `/headless_authentication/` on demand. It will have a 3 minute TTL, matching the callback timeout of the request. The request will begin in the pending state. The Auth server then waits for the user to approve the authentication request using a resource watcher. +As [explained above](#unauthenticated-headless-login-endpoint), the Auth server will write the request details to the backend under `/headless_authentication/` on demand. It will have a short TTL, matching the callback timeout of the request. The request will begin in the pending state. The Auth server then waits for the user to approve the authentication request using a resource watcher. #### Local authentication From ed497d28c154c0a33f9fe1bf08d2583ae8f49e5d Mon Sep 17 00:00:00 2001 From: joerger Date: Tue, 14 Mar 2023 18:40:28 -0700 Subject: [PATCH 09/18] Move http client to it's own file; Add ability to clone HTTP client for per-request configuration changes. --- lib/auth/auth_with_roles.go | 24 - lib/auth/clt.go | 998 +------------------------------- lib/auth/http_client.go | 1087 +++++++++++++++++++++++++++++++++++ lib/service/connect.go | 5 +- lib/services/presence.go | 7 - lib/web/apiserver.go | 21 +- 6 files changed, 1130 insertions(+), 1012 deletions(-) create mode 100644 lib/auth/http_client.go diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 29088fd59d49a..d4a902e859853 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -1109,30 +1109,6 @@ func (a *ServerWithRoles) UpsertNode(ctx context.Context, s types.Server) (*type return a.authServer.UpsertNode(ctx, s) } -// DELETE IN: 5.1.0 -// -// This logic has moved to KeepAliveServer. -func (a *ServerWithRoles) KeepAliveNode(ctx context.Context, handle types.KeepAlive) error { - if !a.hasBuiltinRole(types.RoleNode) { - return trace.AccessDenied("[10] access denied") - } - clusterName, err := a.GetDomainName(ctx) - if err != nil { - return trace.Wrap(err) - } - serverName, err := ExtractHostID(a.context.User.GetName(), clusterName) - if err != nil { - return trace.AccessDenied("[10] access denied") - } - if serverName != handle.Name { - return trace.AccessDenied("[10] access denied") - } - if err := a.action(apidefaults.Namespace, types.KindNode, types.VerbUpdate); err != nil { - return trace.Wrap(err) - } - return a.authServer.KeepAliveNode(ctx, handle) -} - // KeepAliveServer updates expiry time of a server resource. func (a *ServerWithRoles) KeepAliveServer(ctx context.Context, handle types.KeepAlive) error { clusterName, err := a.GetDomainName(ctx) diff --git a/lib/auth/clt.go b/lib/auth/clt.go index ac6e37405ea8c..7860db683b668 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -18,39 +18,26 @@ package auth import ( "context" - "crypto/tls" - "encoding/json" "fmt" "net" - "net/http" "net/url" - "strconv" - "strings" "time" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" - "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/okta" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/constants" - apidefaults "github.com/gravitational/teleport/api/defaults" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" loginrulepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/loginrule/v1" pluginspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" samlidppb "github.com/gravitational/teleport/api/gen/proto/go/teleport/samlidp/v1" - tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" - "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/utils" ) const ( @@ -62,6 +49,9 @@ const ( MissingNamespaceError = "missing required parameter: namespace" ) +// APIClient is aliased here so that it can be embedded in Client. +type APIClient = client.Client + // Client is the Auth API client. It works by connecting to auth servers // via gRPC and HTTP. // @@ -103,43 +93,14 @@ func NewClient(cfg client.Config, params ...roundtrip.ClientParam) (*Client, err } // apiClient configures the tls.Config, so we clone it and reuse it for http. - tlsConfig := apiClient.Config().Clone() - httpClient, err := NewHTTPClient(cfg, tlsConfig, params...) - if err != nil { - return nil, trace.Wrap(err) - } - - return &Client{ - APIClient: apiClient, - HTTPClient: httpClient, - }, nil -} - -// APIClient is aliased here so that it can be embedded in Client. -type APIClient = client.Client - -// HTTPClient is a teleport HTTP API client. -type HTTPClient struct { - roundtrip.Client - // transport defines the methods by which the client can reach the server. - transport *http.Transport - // TLS holds the TLS config for the http client. - tls *tls.Config -} - -// NewHTTPClient creates a new HTTP client with TLS authentication and the given dialer. -func NewHTTPClient(cfg client.Config, tls *tls.Config, params ...roundtrip.ClientParam) (*HTTPClient, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, err - } - - dialer := cfg.Dialer - if dialer == nil { + httpTLS := apiClient.Config().Clone() + httpDialer := cfg.Dialer + if httpDialer == nil { if len(cfg.Addrs) == 0 { return nil, trace.BadParameter("no addresses to dial") } - contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, client.WithTLSConfig(tls)) - dialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) { + contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, client.WithTLSConfig(httpTLS)) + httpDialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) { for _, addr := range cfg.Addrs { conn, err = contextDialer.DialContext(ctx, network, addr) if err == nil { @@ -150,148 +111,22 @@ func NewHTTPClient(cfg client.Config, tls *tls.Config, params ...roundtrip.Clien return nil, err }) } - - // Set the next protocol. This is needed due to the Auth Server using a - // multiplexer for protocol detection. Unless next protocol is specified - // it will attempt to upgrade to HTTP2 and at that point there is no way - // to distinguish between HTTP2/JSON or GPRC. - tls.NextProtos = []string{teleport.HTTPNextProtoTLS} - // Configure ALPN SNI direct dial TLS routing information used by ALPN SNI proxy in order to - // dial auth service without using SSH tunnels. - tls = client.ConfigureALPN(tls, cfg.ALPNSNIAuthDialClusterName) - - transport := &http.Transport{ - // notice that below roundtrip.Client is passed - // teleport.APIDomain as an address for the API server, this is - // to make sure client verifies the DNS name of the API server and - // custom DialContext overrides this DNS name to the real address. - // In addition this dialer tries multiple addresses if provided - DialContext: dialer.DialContext, - ResponseHeaderTimeout: apidefaults.DefaultIOTimeout, - TLSClientConfig: tls, - - // Increase the size of the connection pool. This substantially improves the - // performance of Teleport under load as it reduces the number of TLS - // handshakes performed. - MaxIdleConns: defaults.HTTPMaxIdleConns, - MaxIdleConnsPerHost: defaults.HTTPMaxIdleConnsPerHost, - - // Limit the total number of connections to the Auth Server. Some hosts allow a low - // number of connections per process (ulimit) to a host. This is a problem for - // enhanced session recording auditing which emits so many events to the - // Audit Log (using the Auth Client) that the connection pool often does not - // have a free connection to return, so just opens a new one. This quickly - // leads to hitting the OS limit and the client returning out of file - // descriptors error. - MaxConnsPerHost: defaults.HTTPMaxConnsPerHost, - - // IdleConnTimeout defines the maximum amount of time before idle connections - // are closed. Leaving this unset will lead to connections open forever and - // will cause memory leaks in a long running process. - IdleConnTimeout: defaults.HTTPIdleTimeout, - } - - cb, err := breaker.New(cfg.CircuitBreakerConfig) - if err != nil { - return nil, trace.Wrap(err) + httpClientCfg := &HTTPClientConfig{ + TLS: httpTLS, + Dialer: httpDialer, + ALPNSNIAuthDialClusterName: cfg.ALPNSNIAuthDialClusterName, } - - clientParams := append( - []roundtrip.ClientParam{ - roundtrip.HTTPClient(&http.Client{ - Timeout: defaults.HTTPRequestTimeout, - Transport: tracehttp.NewTransport(breaker.NewRoundTripper(cb, transport)), - }), - roundtrip.SanitizerEnabled(true), - }, - params..., - ) - - // Since the client uses a custom dialer and SNI is used for TLS handshake, the address - // used here is arbitrary as it just needs to be set to pass http request validation. - httpClient, err := roundtrip.NewClient("https://"+constants.APIDomain, CurrentVersion, clientParams...) + httpClient, err := NewHTTPClient(httpClientCfg, params...) if err != nil { return nil, trace.Wrap(err) } - return &HTTPClient{ - Client: *httpClient, - transport: transport, - tls: tls, + return &Client{ + APIClient: apiClient, + HTTPClient: httpClient, }, nil } -// Close closes the HTTP client connection to the auth server. -func (c *HTTPClient) Close() { - c.transport.CloseIdleConnections() -} - -// TLSConfig returns the HTTP client's TLS config. -func (c *HTTPClient) TLSConfig() *tls.Config { - return c.tls -} - -// GetTransport returns the HTTP client's transport. -func (c *HTTPClient) GetTransport() *http.Transport { - return c.transport -} - -// ClientTimeout sets idle and dial timeouts of the HTTP transport -// used by the client. -func ClientTimeout(timeout time.Duration) roundtrip.ClientParam { - return func(c *roundtrip.Client) error { - transport, ok := (c.HTTPClient().Transport).(*http.Transport) - if !ok { - return nil - } - transport.IdleConnTimeout = timeout - transport.ResponseHeaderTimeout = timeout - return nil - } -} - -// PostJSON is a generic method that issues http POST request to the server -func (c *Client) PostJSON(ctx context.Context, endpoint string, val interface{}) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.PostJSON(ctx, endpoint, val)) -} - -// PutJSON is a generic method that issues http PUT request to the server -func (c *Client) PutJSON(ctx context.Context, endpoint string, val interface{}) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.PutJSON(ctx, endpoint, val)) -} - -// PostForm is a generic method that issues http POST request to the server -func (c *Client) PostForm(ctx context.Context, endpoint string, vals url.Values, files ...roundtrip.File) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.PostForm(ctx, endpoint, vals, files...)) -} - -// Get issues http GET request to the server -func (c *Client) Get(ctx context.Context, u string, params url.Values) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.Get(ctx, u, params)) -} - -// Delete issues http Delete Request to the server -func (c *Client) Delete(ctx context.Context, u string) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.Delete(ctx, u)) -} - -// ProcessKubeCSR processes CSR request against Kubernetes CA, returns -// signed certificate if successful. -func (c *Client) ProcessKubeCSR(req KubeCSR) (*KubeCSRResponse, error) { - if err := req.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - out, err := c.PostJSON(context.TODO(), c.Endpoint("kube", "csr"), req) - if err != nil { - return nil, trace.Wrap(err) - } - var re KubeCSRResponse - if err := json.Unmarshal(out.Bytes(), &re); err != nil { - return nil, trace.Wrap(err) - } - return &re, nil -} - func (c *Client) Close() error { c.HTTPClient.Close() return c.APIClient.Close() @@ -302,48 +137,6 @@ func (c *Client) CreateCertAuthority(ca types.CertAuthority) error { return trace.NotImplemented(notImplementedMessage) } -// RotateCertAuthority starts or restarts certificate authority rotation process. -func (c *Client) RotateCertAuthority(ctx context.Context, req RotateRequest) error { - _, err := c.PostJSON(ctx, c.Endpoint("authorities", string(req.Type), "rotate"), req) - return trace.Wrap(err) -} - -// RotateExternalCertAuthority rotates external certificate authority, -// this method is used to update only public keys and certificates of the -// the certificate authorities of trusted clusters. -func (c *Client) RotateExternalCertAuthority(ctx context.Context, ca types.CertAuthority) error { - if err := services.ValidateCertAuthority(ca); err != nil { - return trace.Wrap(err) - } - data, err := services.MarshalCertAuthority(ca) - if err != nil { - return trace.Wrap(err) - } - _, err = c.PostJSON(ctx, c.Endpoint("authorities", string(ca.GetType()), "rotate", "external"), - &rotateExternalCertAuthorityRawReq{CA: data}) - return trace.Wrap(err) -} - -// UpsertCertAuthority updates or inserts new cert authority -func (c *Client) UpsertCertAuthority(ca types.CertAuthority) error { - if err := services.ValidateCertAuthority(ca); err != nil { - return trace.Wrap(err) - } - data, err := services.MarshalCertAuthority(ca) - if err != nil { - return trace.Wrap(err) - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("authorities", string(ca.GetType())), - &upsertCertAuthorityRawReq{CA: data}) - return trace.Wrap(err) -} - -// CompareAndSwapCertAuthority updates existing cert authority if the existing cert authority -// value matches the value stored in the backend. -func (c *Client) CompareAndSwapCertAuthority(new, existing types.CertAuthority) error { - return trace.BadParameter("this function is not supported on the client") -} - // GetCertAuthorities returns a list of certificate authorities func (c *Client) GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error) { if err := caType.Check(); err != nil { @@ -355,27 +148,8 @@ func (c *Client) GetCertAuthorities(ctx context.Context, caType types.CertAuthTy case err == nil: return cas, nil case trace.IsNotImplemented(err): - resp, err := c.Get(ctx, c.Endpoint("authorities", string(caType)), url.Values{ - "load_keys": []string{fmt.Sprintf("%t", loadKeys)}, - }) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(resp.Bytes(), &items); err != nil { - return nil, err - } - cas := make([]types.CertAuthority, 0, len(items)) - for _, raw := range items { - ca, err := services.UnmarshalCertAuthority(raw) - if err != nil { - return nil, trace.Wrap(err) - } - - cas = append(cas, ca) - } - - return cas, nil + cas, err := c.HTTPClient.GetCertAuthorities(ctx, caType, loadKeys) + return cas, trace.Wrap(err) default: return nil, trace.Wrap(err) } @@ -417,7 +191,7 @@ func (c *Client) DeleteCertAuthority(ctx context.Context, id types.CertAuthID) e case err == nil: return nil case trace.IsNotImplemented(err): - _, err := c.Delete(ctx, c.Endpoint("authorities", string(id.Type), id.DomainName)) + err = c.HTTPClient.DeleteCertAuthority(id) return trace.Wrap(err) default: return trace.Wrap(err) @@ -439,181 +213,21 @@ func (c *Client) UpdateUserCARoleMap(ctx context.Context, name string, roleMap t return trace.NotImplemented(notImplementedMessage) } -// RegisterUsingToken calls the auth service API to register a new node using a registration token -// which was previously issued via CreateToken/UpsertToken. -func (c *Client) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) { - if err := req.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - out, err := c.PostJSON(ctx, c.Endpoint("tokens", "register"), req) - if err != nil { - return nil, trace.Wrap(err) - } - - var certs proto.Certs - if err := json.Unmarshal(out.Bytes(), &certs); err != nil { - return nil, trace.Wrap(err) - } - - return &certs, nil -} - -// DELETE IN: 5.1.0 -// -// This logic has been moved to KeepAliveServer. -// -// KeepAliveNode updates node keep alive information. -func (c *Client) KeepAliveNode(ctx context.Context, keepAlive types.KeepAlive) error { - return trace.BadParameter("not implemented, use StreamKeepAlives instead") -} - // KeepAliveServer not implemented: can only be called locally. func (c *Client) KeepAliveServer(ctx context.Context, keepAlive types.KeepAlive) error { return trace.BadParameter("not implemented, use StreamKeepAlives instead") } -// UpsertReverseTunnel is used by admins to create a new reverse tunnel -// to the remote proxy to bypass firewall restrictions -func (c *Client) UpsertReverseTunnel(tunnel types.ReverseTunnel) error { - data, err := services.MarshalReverseTunnel(tunnel) - if err != nil { - return trace.Wrap(err) - } - args := &upsertReverseTunnelRawReq{ - ReverseTunnel: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("reversetunnels"), args) - return trace.Wrap(err) -} - // GetReverseTunnel not implemented: can only be called locally. func (c *Client) GetReverseTunnel(name string, opts ...services.MarshalOption) (types.ReverseTunnel, error) { return nil, trace.NotImplemented(notImplementedMessage) } -// GetReverseTunnels returns the list of created reverse tunnels -func (c *Client) GetReverseTunnels(ctx context.Context, opts ...services.MarshalOption) ([]types.ReverseTunnel, error) { - out, err := c.Get(ctx, c.Endpoint("reversetunnels"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - tunnels := make([]types.ReverseTunnel, len(items)) - for i, raw := range items { - tunnel, err := services.UnmarshalReverseTunnel(raw) - if err != nil { - return nil, trace.Wrap(err) - } - tunnels[i] = tunnel - } - return tunnels, nil -} - -// DeleteReverseTunnel deletes reverse tunnel by domain name -func (c *Client) DeleteReverseTunnel(domainName string) error { - // this is to avoid confusing error in case if domain empty for example - // HTTP route will fail producing generic not found error - // instead we catch the error here - if strings.TrimSpace(domainName) == "" { - return trace.BadParameter("empty domain name") - } - _, err := c.Delete(context.TODO(), c.Endpoint("reversetunnels", domainName)) - return trace.Wrap(err) -} - -// UpsertTunnelConnection upserts tunnel connection -func (c *Client) UpsertTunnelConnection(conn types.TunnelConnection) error { - data, err := services.MarshalTunnelConnection(conn) - if err != nil { - return trace.Wrap(err) - } - args := &upsertTunnelConnectionRawReq{ - TunnelConnection: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("tunnelconnections"), args) - return trace.Wrap(err) -} - -// GetTunnelConnections returns tunnel connections for a given cluster -func (c *Client) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) { - if clusterName == "" { - return nil, trace.BadParameter("missing cluster name parameter") - } - out, err := c.Get(context.TODO(), c.Endpoint("tunnelconnections", clusterName), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - conns := make([]types.TunnelConnection, len(items)) - for i, raw := range items { - conn, err := services.UnmarshalTunnelConnection(raw) - if err != nil { - return nil, trace.Wrap(err) - } - conns[i] = conn - } - return conns, nil -} - -// GetAllTunnelConnections returns all tunnel connections -func (c *Client) GetAllTunnelConnections(opts ...services.MarshalOption) ([]types.TunnelConnection, error) { - out, err := c.Get(context.TODO(), c.Endpoint("tunnelconnections"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - conns := make([]types.TunnelConnection, len(items)) - for i, raw := range items { - conn, err := services.UnmarshalTunnelConnection(raw) - if err != nil { - return nil, trace.Wrap(err) - } - conns[i] = conn - } - return conns, nil -} - -// DeleteTunnelConnection deletes tunnel connection by name -func (c *Client) DeleteTunnelConnection(clusterName string, connName string) error { - if clusterName == "" { - return trace.BadParameter("missing parameter cluster name") - } - if connName == "" { - return trace.BadParameter("missing parameter connection name") - } - _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections", clusterName, connName)) - return trace.Wrap(err) -} - -// DeleteTunnelConnections deletes all tunnel connections for cluster -func (c *Client) DeleteTunnelConnections(clusterName string) error { - if clusterName == "" { - return trace.BadParameter("missing parameter cluster name") - } - _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections", clusterName)) - return trace.Wrap(err) -} - // DeleteAllTokens not implemented: can only be called locally. func (c *Client) DeleteAllTokens() error { return trace.NotImplemented(notImplementedMessage) } -// DeleteAllTunnelConnections deletes all tunnel connections -func (c *Client) DeleteAllTunnelConnections() error { - _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections")) - return trace.Wrap(err) -} - // AddUserLoginAttempt logs user login attempt func (c *Client) AddUserLoginAttempt(user string, attempt services.LoginAttempt, ttl time.Duration) error { panic("not implemented") @@ -624,103 +238,6 @@ func (c *Client) GetUserLoginAttempts(user string) ([]services.LoginAttempt, err panic("not implemented") } -// GetRemoteClusters returns a list of remote clusters -func (c *Client) GetRemoteClusters(opts ...services.MarshalOption) ([]types.RemoteCluster, error) { - out, err := c.Get(context.TODO(), c.Endpoint("remoteclusters"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - conns := make([]types.RemoteCluster, len(items)) - for i, raw := range items { - conn, err := services.UnmarshalRemoteCluster(raw) - if err != nil { - return nil, trace.Wrap(err) - } - conns[i] = conn - } - return conns, nil -} - -// GetRemoteCluster returns a remote cluster by name -func (c *Client) GetRemoteCluster(clusterName string) (types.RemoteCluster, error) { - if clusterName == "" { - return nil, trace.BadParameter("missing cluster name") - } - out, err := c.Get(context.TODO(), c.Endpoint("remoteclusters", clusterName), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalRemoteCluster(out.Bytes()) -} - -// DeleteRemoteCluster deletes remote cluster by name -func (c *Client) DeleteRemoteCluster(ctx context.Context, clusterName string) error { - if clusterName == "" { - return trace.BadParameter("missing parameter cluster name") - } - - _, err := c.Delete(ctx, c.Endpoint("remoteclusters", clusterName)) - return trace.Wrap(err) -} - -// DeleteAllRemoteClusters deletes all remote clusters -func (c *Client) DeleteAllRemoteClusters() error { - _, err := c.Delete(context.TODO(), c.Endpoint("remoteclusters")) - return trace.Wrap(err) -} - -// CreateRemoteCluster creates remote cluster resource -func (c *Client) CreateRemoteCluster(rc types.RemoteCluster) error { - data, err := services.MarshalRemoteCluster(rc) - if err != nil { - return trace.Wrap(err) - } - args := &createRemoteClusterRawReq{ - RemoteCluster: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("remoteclusters"), args) - return trace.Wrap(err) -} - -// UpsertAuthServer is used by auth servers to report their presence -// to other auth servers in form of hearbeat expiring after ttl period. -func (c *Client) UpsertAuthServer(s types.Server) error { - data, err := services.MarshalServer(s) - if err != nil { - return trace.Wrap(err) - } - args := &upsertServerRawReq{ - Server: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("authservers"), args) - return trace.Wrap(err) -} - -// GetAuthServers returns the list of auth servers registered in the cluster. -func (c *Client) GetAuthServers() ([]types.Server, error) { - out, err := c.Get(context.TODO(), c.Endpoint("authservers"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - re := make([]types.Server, len(items)) - for i, raw := range items { - server, err := services.UnmarshalServer(raw, types.KindAuthServer) - if err != nil { - return nil, trace.Wrap(err) - } - re[i] = server - } - return re, nil -} - // DeleteAllAuthServers not implemented: can only be called locally. func (c *Client) DeleteAllAuthServers() error { return trace.NotImplemented(notImplementedMessage) @@ -731,344 +248,11 @@ func (c *Client) DeleteAuthServer(name string) error { return trace.NotImplemented(notImplementedMessage) } -// UpsertProxy is used by proxies to report their presence -// to other auth servers in form of heartbeat expiring after ttl period. -func (c *Client) UpsertProxy(s types.Server) error { - data, err := services.MarshalServer(s) - if err != nil { - return trace.Wrap(err) - } - args := &upsertServerRawReq{ - Server: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("proxies"), args) - return trace.Wrap(err) -} - -// GetProxies returns the list of auth servers registered in the cluster. -func (c *Client) GetProxies() ([]types.Server, error) { - out, err := c.Get(context.TODO(), c.Endpoint("proxies"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - re := make([]types.Server, len(items)) - for i, raw := range items { - server, err := services.UnmarshalServer(raw, types.KindProxy) - if err != nil { - return nil, trace.Wrap(err) - } - re[i] = server - } - return re, nil -} - -// DeleteAllProxies deletes all proxies -func (c *Client) DeleteAllProxies() error { - _, err := c.Delete(context.TODO(), c.Endpoint("proxies")) - if err != nil { - return trace.Wrap(err) - } - return nil -} - -// DeleteProxy deletes proxy by name -func (c *Client) DeleteProxy(name string) error { - if name == "" { - return trace.BadParameter("missing parameter name") - } - _, err := c.Delete(context.TODO(), c.Endpoint("proxies", name)) - if err != nil { - return trace.Wrap(err) - } - return nil -} - -// UpsertUser user updates user entry. -func (c *Client) UpsertUser(user types.User) error { - data, err := services.MarshalUser(user) - if err != nil { - return trace.Wrap(err) - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("users"), &upsertUserRawReq{User: data}) - return trace.Wrap(err) -} - // CompareAndSwapUser not implemented: can only be called locally func (c *Client) CompareAndSwapUser(ctx context.Context, new, expected types.User) error { return trace.NotImplemented(notImplementedMessage) } -// ExtendWebSession creates a new web session for a user based on another -// valid web session -func (c *Client) ExtendWebSession(ctx context.Context, req WebSessionReq) (types.WebSession, error) { - out, err := c.PostJSON(ctx, c.Endpoint("users", req.User, "web", "sessions"), req) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalWebSession(out.Bytes()) -} - -// CreateWebSession creates a new web session for a user -func (c *Client) CreateWebSession(ctx context.Context, user string) (types.WebSession, error) { - out, err := c.PostJSON( - ctx, - c.Endpoint("users", user, "web", "sessions"), - WebSessionReq{User: user}, - ) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalWebSession(out.Bytes()) -} - -// AuthenticateWebUser authenticates web user, creates and returns web session -// in case if authentication is successful -func (c *Client) AuthenticateWebUser(ctx context.Context, req AuthenticateUserRequest) (types.WebSession, error) { - out, err := c.PostJSON( - ctx, - c.Endpoint("users", req.Username, "web", "authenticate"), - req, - ) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalWebSession(out.Bytes()) -} - -// AuthenticateSSHUser authenticates SSH console user, creates and returns a pair of signed TLS and SSH -// short lived certificates as a result -func (c *Client) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHRequest) (*SSHLoginResponse, error) { - if req.HeadlessAuthenticationID != "" { - // Replace the client timeout with the default callback timeout for this request. - previousResponseHeaderTimeout := c.HTTPClient.transport.ResponseHeaderTimeout - previousClientTimeout := c.HTTPClient.HTTPClient().Timeout - c.HTTPClient.transport.ResponseHeaderTimeout = defaults.CallbackTimeout - c.HTTPClient.HTTPClient().Timeout = defaults.CallbackTimeout - defer func() { - c.HTTPClient.transport.ResponseHeaderTimeout = previousResponseHeaderTimeout - c.HTTPClient.HTTPClient().Timeout = previousClientTimeout - }() - } - - out, err := c.PostJSON( - ctx, - c.Endpoint("users", req.Username, "ssh", "authenticate"), - req, - ) - if err != nil { - return nil, trace.Wrap(err) - } - var re SSHLoginResponse - if err := json.Unmarshal(out.Bytes(), &re); err != nil { - return nil, trace.Wrap(err) - } - return &re, nil -} - -// GetWebSessionInfo checks if a web sesion is valid, returns session id in case if -// it is valid, or error otherwise. -func (c *Client) GetWebSessionInfo(ctx context.Context, user, sessionID string) (types.WebSession, error) { - out, err := c.Get( - ctx, - c.Endpoint("users", user, "web", "sessions", sessionID), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalWebSession(out.Bytes()) -} - -// DeleteWebSession deletes the web session specified with sid for the given user -func (c *Client) DeleteWebSession(ctx context.Context, user string, sid string) error { - _, err := c.Delete(ctx, c.Endpoint("users", user, "web", "sessions", sid)) - return trace.Wrap(err) -} - -// GenerateHostCert takes the public key in the Open SSH “authorized_keys“ -// plain text format, signs it using Host Certificate Authority private key and returns the -// resulting certificate. -func (c *Client) GenerateHostCert( - ctx context.Context, key []byte, hostID, nodeName string, principals []string, clusterName string, role types.SystemRole, ttl time.Duration, -) ([]byte, error) { - out, err := c.PostJSON(ctx, c.Endpoint("ca", "host", "certs"), - generateHostCertReq{ - Key: key, - HostID: hostID, - NodeName: nodeName, - Principals: principals, - ClusterName: clusterName, - Roles: types.SystemRoles{role}, - TTL: ttl, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - var cert string - if err := json.Unmarshal(out.Bytes(), &cert); err != nil { - return nil, err - } - - return []byte(cert), nil -} - -// ValidateOIDCAuthCallback validates OIDC auth callback returned from redirect -func (c *Client) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { - out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), ValidateOIDCAuthCallbackReq{ - Query: q, - }) - if err != nil { - return nil, trace.Wrap(err) - } - var rawResponse *OIDCAuthRawResponse - if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { - return nil, trace.Wrap(err) - } - response := OIDCAuthResponse{ - Username: rawResponse.Username, - Identity: rawResponse.Identity, - Cert: rawResponse.Cert, - Req: rawResponse.Req, - TLSCert: rawResponse.TLSCert, - } - if len(rawResponse.Session) != 0 { - session, err := services.UnmarshalWebSession(rawResponse.Session) - if err != nil { - return nil, trace.Wrap(err) - } - response.Session = session - } - response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) - for i, raw := range rawResponse.HostSigners { - ca, err := services.UnmarshalCertAuthority(raw) - if err != nil { - return nil, trace.Wrap(err) - } - response.HostSigners[i] = ca - } - return &response, nil -} - -// ValidateSAMLResponse validates response returned by SAML identity provider -func (c *Client) ValidateSAMLResponse(ctx context.Context, re string, connectorID string) (*SAMLAuthResponse, error) { - out, err := c.PostJSON(ctx, c.Endpoint("saml", "requests", "validate"), ValidateSAMLResponseReq{ - Response: re, - ConnectorID: connectorID, - }) - if err != nil { - return nil, trace.Wrap(err) - } - var rawResponse *SAMLAuthRawResponse - if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { - return nil, trace.Wrap(err) - } - response := SAMLAuthResponse{ - Username: rawResponse.Username, - Identity: rawResponse.Identity, - Cert: rawResponse.Cert, - Req: rawResponse.Req, - TLSCert: rawResponse.TLSCert, - } - if len(rawResponse.Session) != 0 { - session, err := services.UnmarshalWebSession(rawResponse.Session) - if err != nil { - return nil, trace.Wrap(err) - } - response.Session = session - } - response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) - for i, raw := range rawResponse.HostSigners { - ca, err := services.UnmarshalCertAuthority(raw) - if err != nil { - return nil, trace.Wrap(err) - } - response.HostSigners[i] = ca - } - return &response, nil -} - -// ValidateGithubAuthCallback validates Github auth callback returned from redirect -func (c *Client) ValidateGithubAuthCallback(ctx context.Context, q url.Values) (*GithubAuthResponse, error) { - out, err := c.PostJSON(ctx, c.Endpoint("github", "requests", "validate"), - validateGithubAuthCallbackReq{Query: q}) - if err != nil { - return nil, trace.Wrap(err) - } - var rawResponse githubAuthRawResponse - if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { - return nil, trace.Wrap(err) - } - response := GithubAuthResponse{ - Username: rawResponse.Username, - Identity: rawResponse.Identity, - Cert: rawResponse.Cert, - Req: rawResponse.Req, - TLSCert: rawResponse.TLSCert, - } - if len(rawResponse.Session) != 0 { - session, err := services.UnmarshalWebSession( - rawResponse.Session) - if err != nil { - return nil, trace.Wrap(err) - } - response.Session = session - } - response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) - for i, raw := range rawResponse.HostSigners { - ca, err := services.UnmarshalCertAuthority(raw) - if err != nil { - return nil, trace.Wrap(err) - } - response.HostSigners[i] = ca - } - return &response, nil -} - -// GetSessionChunk allows clients to receive a byte array (chunk) from a recorded -// session stream, starting from 'offset', up to 'max' in length. The upper bound -// of 'max' is set to events.MaxChunkBytes -func (c *Client) GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { - if namespace == "" { - return nil, trace.BadParameter(MissingNamespaceError) - } - response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "stream"), url.Values{ - "offset": []string{strconv.Itoa(offsetBytes)}, - "bytes": []string{strconv.Itoa(maxBytes)}, - }) - if err != nil { - log.Error(err) - return nil, trace.Wrap(err) - } - return response.Bytes(), nil -} - -// Returns events that happen during a session sorted by time -// (oldest first). -// -// afterN allows to filter by "newer than N" value where N is the cursor ID -// of previously returned bunch (good for polling for latest) -func (c *Client) GetSessionEvents(namespace string, sid session.ID, afterN int) (retval []events.EventFields, err error) { - if namespace == "" { - return nil, trace.BadParameter(MissingNamespaceError) - } - query := make(url.Values) - if afterN > 0 { - query.Set("after", strconv.Itoa(afterN)) - } - response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "events"), query) - if err != nil { - return nil, trace.Wrap(err) - } - retval = make([]events.EventFields, 0) - if err := json.Unmarshal(response.Bytes(), &retval); err != nil { - return nil, trace.Wrap(err) - } - return retval, nil -} - // StreamSessionEvents streams all events from a given session recording. An error is returned on the first // channel if one is encountered. Otherwise the event channel is closed when the stream ends. // The event channel is not closed on error to prevent race conditions in downstream select statements. @@ -1096,120 +280,16 @@ func (c *Client) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order return events, lastKey, nil } -// GetNamespaces returns a list of namespaces -func (c *Client) GetNamespaces() ([]types.Namespace, error) { - out, err := c.Get(context.TODO(), c.Endpoint("namespaces"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var re []types.Namespace - if err := utils.FastUnmarshal(out.Bytes(), &re); err != nil { - return nil, trace.Wrap(err) - } - return re, nil -} - -// GetNamespace returns namespace by name -func (c *Client) GetNamespace(name string) (*types.Namespace, error) { - if name == "" { - return nil, trace.BadParameter("missing namespace name") - } - out, err := c.Get(context.TODO(), c.Endpoint("namespaces", name), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalNamespace(out.Bytes()) -} - -// UpsertNamespace upserts namespace -func (c *Client) UpsertNamespace(ns types.Namespace) error { - _, err := c.PostJSON(context.TODO(), c.Endpoint("namespaces"), upsertNamespaceReq{Namespace: ns}) - return trace.Wrap(err) -} - -// DeleteNamespace deletes namespace by name -func (c *Client) DeleteNamespace(name string) error { - _, err := c.Delete(context.TODO(), c.Endpoint("namespaces", name)) - return trace.Wrap(err) -} - // CreateRole not implemented: can only be called locally. func (c *Client) CreateRole(ctx context.Context, role types.Role) error { return trace.NotImplemented(notImplementedMessage) } -// GetClusterName returns a cluster name -func (c *Client) GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) { - out, err := c.Get(context.TODO(), c.Endpoint("configuration", "name"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - - cn, err := services.UnmarshalClusterName(out.Bytes()) - if err != nil { - return nil, trace.Wrap(err) - } - - return cn, err -} - -// SetClusterName sets cluster name once, will -// return Already Exists error if the name is already set -func (c *Client) SetClusterName(cn types.ClusterName) error { - data, err := services.MarshalClusterName(cn) - if err != nil { - return trace.Wrap(err) - } - - _, err = c.PostJSON(context.TODO(), c.Endpoint("configuration", "name"), &setClusterNameReq{ClusterName: data}) - if err != nil { - return trace.Wrap(err) - } - - return nil -} - // UpsertClusterName not implemented: can only be called locally. func (c *Client) UpsertClusterName(cn types.ClusterName) error { return trace.NotImplemented(notImplementedMessage) } -// DeleteStaticTokens deletes static tokens -func (c *Client) DeleteStaticTokens() error { - _, err := c.Delete(context.TODO(), c.Endpoint("configuration", "static_tokens")) - return trace.Wrap(err) -} - -// GetStaticTokens returns a list of static register tokens -func (c *Client) GetStaticTokens() (types.StaticTokens, error) { - out, err := c.Get(context.TODO(), c.Endpoint("configuration", "static_tokens"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - - st, err := services.UnmarshalStaticTokens(out.Bytes()) - if err != nil { - return nil, trace.Wrap(err) - } - - return st, err -} - -// SetStaticTokens sets a list of static register tokens -func (c *Client) SetStaticTokens(st types.StaticTokens) error { - data, err := services.MarshalStaticTokens(st) - if err != nil { - return trace.Wrap(err) - } - - _, err = c.PostJSON(context.TODO(), c.Endpoint("configuration", "static_tokens"), &setStaticTokensReq{StaticTokens: data}) - if err != nil { - return trace.Wrap(err) - } - - return nil -} - // DeleteClusterName not implemented: can only be called locally. func (c *Client) DeleteClusterName() error { return trace.NotImplemented(notImplementedMessage) @@ -1250,31 +330,6 @@ func (c *Client) DeleteAllUsers() error { return trace.NotImplemented(notImplementedMessage) } -func (c *Client) ValidateTrustedCluster(ctx context.Context, validateRequest *ValidateTrustedClusterRequest) (*ValidateTrustedClusterResponse, error) { - validateRequestRaw, err := validateRequest.ToRaw() - if err != nil { - return nil, trace.Wrap(err) - } - - out, err := c.PostJSON(ctx, c.Endpoint("trustedclusters", "validate"), validateRequestRaw) - if err != nil { - return nil, trace.Wrap(err) - } - - var validateResponseRaw ValidateTrustedClusterResponseRaw - err = json.Unmarshal(out.Bytes(), &validateResponseRaw) - if err != nil { - return nil, trace.Wrap(err) - } - - validateResponse, err := validateResponseRaw.ToNative() - if err != nil { - return nil, trace.Wrap(err) - } - - return validateResponse, nil -} - // CreateResetPasswordToken creates reset password token func (c *Client) CreateResetPasswordToken(ctx context.Context, req CreateUserTokenRequest) (types.UserToken, error) { return c.APIClient.CreateResetPasswordToken(ctx, &proto.CreateResetPasswordTokenRequest{ @@ -1284,21 +339,6 @@ func (c *Client) CreateResetPasswordToken(ctx context.Context, req CreateUserTok }) } -// CreateBot creates a bot and associated resources. -func (c *Client) CreateBot(ctx context.Context, req *proto.CreateBotRequest) (*proto.CreateBotResponse, error) { - return c.APIClient.CreateBot(ctx, req) -} - -// DeleteBot deletes a certificate renewal bot and associated resources. -func (c *Client) DeleteBot(ctx context.Context, botName string) error { - return c.APIClient.DeleteBot(ctx, botName) -} - -// GetBotUsers fetches all bot users. -func (c *Client) GetBotUsers(ctx context.Context) ([]types.User, error) { - return c.APIClient.GetBotUsers(ctx) -} - // GetDatabaseServers returns all registered database proxy servers. func (c *Client) GetDatabaseServers(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.DatabaseServer, error) { return c.APIClient.GetDatabaseServers(ctx, namespace) diff --git a/lib/auth/http_client.go b/lib/auth/http_client.go new file mode 100644 index 0000000000000..e60aa0fc0398e --- /dev/null +++ b/lib/auth/http_client.go @@ -0,0 +1,1087 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/gravitational/roundtrip" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/breaker" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" + tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/httplib" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/utils" +) + +// HTTPClientConfig contains configuration for an HTTP client. +type HTTPClientConfig struct { + // TLS holds the TLS config for the http client. + TLS *tls.Config + // MaxIdleConns controls the maximum number of idle (keep-alive) connections across all hosts. + MaxIdleConns int + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle (keep-alive) connections to keep per-host. + MaxIdleConnsPerHost int + // MaxConnsPerHost limits the total number of connections per host, including connections in the dialing, + // active, and idle states. On limit violation, dials will block. + MaxConnsPerHost int + // RequestTimeout specifies a time limit for requests made by this Client. + RequestTimeout time.Duration + // IdleConnTimeout defines the maximum amount of time before idle connections are closed. + IdleConnTimeout time.Duration + // ResponseHeaderTimeout specifies the amount of time to wait for a server's + // response headers after fully writing the request (including its body, if any). + // This time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration + // Dialer is a custom dialer used to dial a server. The Dialer should + // have custom logic to provide an address to the dialer. If set, Dialer + // takes precedence over all other connection options. + Dialer client.ContextDialer + // ALPNSNIAuthDialClusterName if present the client will include ALPN SNI routing information in TLS Hello message + // allowing to dial auth service through Teleport Proxy directly without using SSH Tunnels. + ALPNSNIAuthDialClusterName string + // CircuitBreakerConfig defines how the circuit breaker should behave. + CircuitBreakerConfig breaker.Config +} + +// CheckAndSetDefaults validates and sets defaults for HTTP configuration. +func (c *HTTPClientConfig) CheckAndSetDefaults() error { + if c.TLS == nil { + return trace.BadParameter("missing TLS config") + } + + if c.Dialer == nil { + return trace.BadParameter("missing dialer") + } + + // Set the next protocol. This is needed due to the Auth Server using a + // multiplexer for protocol detection. Unless next protocol is specified + // it will attempt to upgrade to HTTP2 and at that point there is no way + // to distinguish between HTTP2/JSON or GPRC. + c.TLS.NextProtos = []string{teleport.HTTPNextProtoTLS} + + // Configure ALPN SNI direct dial TLS routing information used by ALPN SNI proxy in order to + // dial auth service without using SSH tunnels. + c.TLS = client.ConfigureALPN(c.TLS, c.ALPNSNIAuthDialClusterName) + + if c.CircuitBreakerConfig.Trip == nil || c.CircuitBreakerConfig.IsSuccessful == nil { + c.CircuitBreakerConfig = breaker.DefaultBreakerConfig(clockwork.NewRealClock()) + } + + // One or both of these timeouts should be set to ensure there is a timeout in place. + if c.RequestTimeout == 0 && c.ResponseHeaderTimeout == 0 { + c.RequestTimeout = defaults.HTTPRequestTimeout + c.ResponseHeaderTimeout = apidefaults.DefaultIOTimeout + } + + // Leaving this unset will lead to connections open forever and will cause memory leaks in a long running process. + if c.IdleConnTimeout == 0 { + c.IdleConnTimeout = defaults.HTTPIdleTimeout + } + + // Increase the size of the connection pool. This substantially improves the + // performance of Teleport under load as it reduces the number of TLS + // handshakes performed. + if c.MaxIdleConns == 0 { + c.MaxIdleConns = defaults.HTTPMaxIdleConns + } + if c.MaxIdleConnsPerHost == 0 { + c.MaxIdleConnsPerHost = defaults.HTTPMaxIdleConnsPerHost + } + + // Limit the total number of connections to the Auth Server. Some hosts allow a low + // number of connections per process (ulimit) to a host. This is a problem for + // enhanced session recording auditing which emits so many events to the + // Audit Log (using the Auth Client) that the connection pool often does not + // have a free connection to return, so just opens a new one. This quickly + // leads to hitting the OS limit and the client returning out of file + // descriptors error. + if c.MaxConnsPerHost == 0 { + c.MaxConnsPerHost = defaults.HTTPMaxConnsPerHost + } + + return nil +} + +// Clone creates a new client with the same configuration. +func (c *HTTPClientConfig) Clone() *HTTPClientConfig { + return &HTTPClientConfig{ + TLS: c.TLS.Clone(), + MaxIdleConns: c.MaxIdleConns, + MaxIdleConnsPerHost: c.MaxIdleConnsPerHost, + MaxConnsPerHost: c.MaxConnsPerHost, + RequestTimeout: c.RequestTimeout, + IdleConnTimeout: c.IdleConnTimeout, + ResponseHeaderTimeout: c.ResponseHeaderTimeout, + Dialer: c.Dialer, + ALPNSNIAuthDialClusterName: c.ALPNSNIAuthDialClusterName, + CircuitBreakerConfig: c.CircuitBreakerConfig, + } +} + +// HTTPClient is a teleport HTTP API client. +type HTTPClient struct { + roundtrip.Client + // cfg is the http client configuration. + cfg *HTTPClientConfig + // transport defines the methods by which the client can reach the server. + transport *http.Transport +} + +// NewHTTPClient creates a new HTTP client with TLS authentication and the given dialer. +func NewHTTPClient(cfg *HTTPClientConfig, params ...roundtrip.ClientParam) (*HTTPClient, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, err + } + + transport := &http.Transport{ + DialContext: cfg.Dialer.DialContext, + ResponseHeaderTimeout: cfg.ResponseHeaderTimeout, + TLSClientConfig: cfg.TLS, + MaxIdleConns: cfg.MaxIdleConns, + MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, + MaxConnsPerHost: cfg.MaxConnsPerHost, + IdleConnTimeout: cfg.IdleConnTimeout, + } + + roundtripClient, err := newRoundtripClient(cfg, transport) + if err != nil { + return nil, trace.Wrap(err) + } + + return &HTTPClient{ + cfg: cfg, + Client: *roundtripClient, + transport: transport, + }, nil +} + +func newRoundtripClient(cfg *HTTPClientConfig, transport *http.Transport, params ...roundtrip.ClientParam) (*roundtrip.Client, error) { + cb, err := breaker.New(cfg.CircuitBreakerConfig) + if err != nil { + return nil, trace.Wrap(err) + } + + clientParams := append( + []roundtrip.ClientParam{ + roundtrip.HTTPClient(&http.Client{ + Timeout: cfg.RequestTimeout, + Transport: tracehttp.NewTransport(breaker.NewRoundTripper(cb, transport)), + }), + roundtrip.SanitizerEnabled(true), + }, + params..., + ) + + // Since the client uses a custom dialer and SNI is used for TLS handshake, the address + // used here is arbitrary as it just needs to be set to pass http request validation. + roundtripClient, err := roundtrip.NewClient("https://"+constants.APIDomain, CurrentVersion, clientParams...) + if err != nil { + return nil, trace.Wrap(err) + } + + return roundtripClient, nil +} + +// Clone creates a new client with the same configuration. +func (c *HTTPClient) Clone(params ...roundtrip.ClientParam) (*HTTPClient, error) { + cfg := c.cfg.Clone() + transport := c.transport.Clone() + + roundtripClient, err := newRoundtripClient(c.cfg, transport) + if err != nil { + return nil, trace.Wrap(err) + } + + return &HTTPClient{ + Client: *roundtripClient, + cfg: cfg, + transport: transport, + }, nil +} + +// ClientParamRequestTimeout sets request timeout of the HTTP transport used by the client. +func ClientParamTimeout(timeout time.Duration) roundtrip.ClientParam { + return func(c *roundtrip.Client) error { + c.HTTPClient().Timeout = timeout + return nil + } +} + +// ClientParamResponseHeaderTimeout sets response header timeout of the HTTP transport used by the client. +func ClientParamResponseHeaderTimeout(timeout time.Duration) roundtrip.ClientParam { + return func(c *roundtrip.Client) error { + transport, ok := (c.HTTPClient().Transport).(*http.Transport) + if !ok { + return nil + } + transport.ResponseHeaderTimeout = timeout + return nil + } +} + +// ClientParamIdleConnTimeout sets idle connection header timeout of the HTTP transport used by the client. +func ClientParamIdleConnTimeout(timeout time.Duration) roundtrip.ClientParam { + return func(c *roundtrip.Client) error { + transport, ok := (c.HTTPClient().Transport).(*http.Transport) + if !ok { + return nil + } + transport.IdleConnTimeout = timeout + return nil + } +} + +// Close closes the HTTP client connection to the auth server. +func (c *HTTPClient) Close() { + c.transport.CloseIdleConnections() +} + +// TLSConfig returns the HTTP client's TLS config. +func (c *HTTPClient) TLSConfig() *tls.Config { + return c.transport.TLSClientConfig +} + +// GetTransport returns the HTTP client's transport. +func (c *HTTPClient) GetTransport() *http.Transport { + return c.transport +} + +// PostJSON is a generic method that issues http POST request to the server +func (c *HTTPClient) PostJSON(ctx context.Context, endpoint string, val interface{}) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.PostJSON(ctx, endpoint, val)) +} + +// PutJSON is a generic method that issues http PUT request to the server +func (c *HTTPClient) PutJSON(ctx context.Context, endpoint string, val interface{}) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.PutJSON(ctx, endpoint, val)) +} + +// PostForm is a generic method that issues http POST request to the server +func (c *HTTPClient) PostForm(ctx context.Context, endpoint string, vals url.Values, files ...roundtrip.File) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.PostForm(ctx, endpoint, vals, files...)) +} + +// Get issues http GET request to the server +func (c *HTTPClient) Get(ctx context.Context, u string, params url.Values) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.Get(ctx, u, params)) +} + +// Delete issues http Delete Request to the server +func (c *HTTPClient) Delete(ctx context.Context, u string) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.Delete(ctx, u)) +} + +// ProcessKubeCSR processes CSR request against Kubernetes CA, returns +// signed certificate if successful. +func (c *HTTPClient) ProcessKubeCSR(req KubeCSR) (*KubeCSRResponse, error) { + if err := req.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + out, err := c.PostJSON(context.TODO(), c.Endpoint("kube", "csr"), req) + if err != nil { + return nil, trace.Wrap(err) + } + var re KubeCSRResponse + if err := json.Unmarshal(out.Bytes(), &re); err != nil { + return nil, trace.Wrap(err) + } + return &re, nil +} + +// RotateCertAuthority starts or restarts certificate authority rotation process. +func (c *HTTPClient) RotateCertAuthority(ctx context.Context, req RotateRequest) error { + _, err := c.PostJSON(ctx, c.Endpoint("authorities", string(req.Type), "rotate"), req) + return trace.Wrap(err) +} + +// RotateExternalCertAuthority rotates external certificate authority, +// this method is used to update only public keys and certificates of the +// the certificate authorities of trusted clusters. +func (c *HTTPClient) RotateExternalCertAuthority(ctx context.Context, ca types.CertAuthority) error { + if err := services.ValidateCertAuthority(ca); err != nil { + return trace.Wrap(err) + } + data, err := services.MarshalCertAuthority(ca) + if err != nil { + return trace.Wrap(err) + } + _, err = c.PostJSON(ctx, c.Endpoint("authorities", string(ca.GetType()), "rotate", "external"), + &rotateExternalCertAuthorityRawReq{CA: data}) + return trace.Wrap(err) +} + +// UpsertCertAuthority updates or inserts new cert authority +func (c *HTTPClient) UpsertCertAuthority(ca types.CertAuthority) error { + if err := services.ValidateCertAuthority(ca); err != nil { + return trace.Wrap(err) + } + data, err := services.MarshalCertAuthority(ca) + if err != nil { + return trace.Wrap(err) + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("authorities", string(ca.GetType())), + &upsertCertAuthorityRawReq{CA: data}) + return trace.Wrap(err) +} + +// GetCertAuthorities returns a list of certificate authorities +func (c *HTTPClient) GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool, opts ...services.MarshalOption) ([]types.CertAuthority, error) { + resp, err := c.Get(ctx, c.Endpoint("authorities", string(caType)), url.Values{ + "load_keys": []string{fmt.Sprintf("%t", loadKeys)}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(resp.Bytes(), &items); err != nil { + return nil, err + } + cas := make([]types.CertAuthority, 0, len(items)) + for _, raw := range items { + ca, err := services.UnmarshalCertAuthority(raw) + if err != nil { + return nil, trace.Wrap(err) + } + + cas = append(cas, ca) + } + + return cas, nil +} + +// DeleteCertAuthority deletes cert authority by ID +func (c *HTTPClient) DeleteCertAuthority(id types.CertAuthID) error { + _, err := c.Delete(context.TODO(), c.Endpoint("authorities", string(id.Type), id.DomainName)) + return trace.Wrap(err) +} + +// RegisterUsingToken calls the auth service API to register a new node using a registration token +// which was previously issued via CreateToken/UpsertToken. +func (c *HTTPClient) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) { + if err := req.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + out, err := c.PostJSON(ctx, c.Endpoint("tokens", "register"), req) + if err != nil { + return nil, trace.Wrap(err) + } + + var certs proto.Certs + if err := json.Unmarshal(out.Bytes(), &certs); err != nil { + return nil, trace.Wrap(err) + } + + return &certs, nil +} + +// UpsertReverseTunnel is used by admins to create a new reverse tunnel +// to the remote proxy to bypass firewall restrictions +func (c *HTTPClient) UpsertReverseTunnel(tunnel types.ReverseTunnel) error { + data, err := services.MarshalReverseTunnel(tunnel) + if err != nil { + return trace.Wrap(err) + } + args := &upsertReverseTunnelRawReq{ + ReverseTunnel: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("reversetunnels"), args) + return trace.Wrap(err) +} + +// GetReverseTunnels returns the list of created reverse tunnels +func (c *HTTPClient) GetReverseTunnels(ctx context.Context, opts ...services.MarshalOption) ([]types.ReverseTunnel, error) { + out, err := c.Get(ctx, c.Endpoint("reversetunnels"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + tunnels := make([]types.ReverseTunnel, len(items)) + for i, raw := range items { + tunnel, err := services.UnmarshalReverseTunnel(raw) + if err != nil { + return nil, trace.Wrap(err) + } + tunnels[i] = tunnel + } + return tunnels, nil +} + +// DeleteReverseTunnel deletes reverse tunnel by domain name +func (c *HTTPClient) DeleteReverseTunnel(domainName string) error { + // this is to avoid confusing error in case if domain empty for example + // HTTP route will fail producing generic not found error + // instead we catch the error here + if strings.TrimSpace(domainName) == "" { + return trace.BadParameter("empty domain name") + } + _, err := c.Delete(context.TODO(), c.Endpoint("reversetunnels", domainName)) + return trace.Wrap(err) +} + +// UpsertTunnelConnection upserts tunnel connection +func (c *HTTPClient) UpsertTunnelConnection(conn types.TunnelConnection) error { + data, err := services.MarshalTunnelConnection(conn) + if err != nil { + return trace.Wrap(err) + } + args := &upsertTunnelConnectionRawReq{ + TunnelConnection: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("tunnelconnections"), args) + return trace.Wrap(err) +} + +// GetTunnelConnections returns tunnel connections for a given cluster +func (c *HTTPClient) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) { + if clusterName == "" { + return nil, trace.BadParameter("missing cluster name parameter") + } + out, err := c.Get(context.TODO(), c.Endpoint("tunnelconnections", clusterName), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + conns := make([]types.TunnelConnection, len(items)) + for i, raw := range items { + conn, err := services.UnmarshalTunnelConnection(raw) + if err != nil { + return nil, trace.Wrap(err) + } + conns[i] = conn + } + return conns, nil +} + +// GetAllTunnelConnections returns all tunnel connections +func (c *HTTPClient) GetAllTunnelConnections(opts ...services.MarshalOption) ([]types.TunnelConnection, error) { + out, err := c.Get(context.TODO(), c.Endpoint("tunnelconnections"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + conns := make([]types.TunnelConnection, len(items)) + for i, raw := range items { + conn, err := services.UnmarshalTunnelConnection(raw) + if err != nil { + return nil, trace.Wrap(err) + } + conns[i] = conn + } + return conns, nil +} + +// DeleteTunnelConnection deletes tunnel connection by name +func (c *HTTPClient) DeleteTunnelConnection(clusterName string, connName string) error { + if clusterName == "" { + return trace.BadParameter("missing parameter cluster name") + } + if connName == "" { + return trace.BadParameter("missing parameter connection name") + } + _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections", clusterName, connName)) + return trace.Wrap(err) +} + +// DeleteTunnelConnections deletes all tunnel connections for cluster +func (c *HTTPClient) DeleteTunnelConnections(clusterName string) error { + if clusterName == "" { + return trace.BadParameter("missing parameter cluster name") + } + _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections", clusterName)) + return trace.Wrap(err) +} + +// DeleteAllTunnelConnections deletes all tunnel connections +func (c *HTTPClient) DeleteAllTunnelConnections() error { + _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections")) + return trace.Wrap(err) +} + +// GetRemoteClusters returns a list of remote clusters +func (c *HTTPClient) GetRemoteClusters(opts ...services.MarshalOption) ([]types.RemoteCluster, error) { + out, err := c.Get(context.TODO(), c.Endpoint("remoteclusters"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + conns := make([]types.RemoteCluster, len(items)) + for i, raw := range items { + conn, err := services.UnmarshalRemoteCluster(raw) + if err != nil { + return nil, trace.Wrap(err) + } + conns[i] = conn + } + return conns, nil +} + +// GetRemoteCluster returns a remote cluster by name +func (c *HTTPClient) GetRemoteCluster(clusterName string) (types.RemoteCluster, error) { + if clusterName == "" { + return nil, trace.BadParameter("missing cluster name") + } + out, err := c.Get(context.TODO(), c.Endpoint("remoteclusters", clusterName), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalRemoteCluster(out.Bytes()) +} + +// DeleteRemoteCluster deletes remote cluster by name +func (c *HTTPClient) DeleteRemoteCluster(ctx context.Context, clusterName string) error { + if clusterName == "" { + return trace.BadParameter("missing parameter cluster name") + } + _, err := c.Delete(ctx, c.Endpoint("remoteclusters", clusterName)) + return trace.Wrap(err) +} + +// DeleteAllRemoteClusters deletes all remote clusters +func (c *HTTPClient) DeleteAllRemoteClusters() error { + _, err := c.Delete(context.TODO(), c.Endpoint("remoteclusters")) + return trace.Wrap(err) +} + +// CreateRemoteCluster creates remote cluster resource +func (c *HTTPClient) CreateRemoteCluster(rc types.RemoteCluster) error { + data, err := services.MarshalRemoteCluster(rc) + if err != nil { + return trace.Wrap(err) + } + args := &createRemoteClusterRawReq{ + RemoteCluster: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("remoteclusters"), args) + return trace.Wrap(err) +} + +// UpsertAuthServer is used by auth servers to report their presence +// to other auth servers in form of hearbeat expiring after ttl period. +func (c *HTTPClient) UpsertAuthServer(s types.Server) error { + data, err := services.MarshalServer(s) + if err != nil { + return trace.Wrap(err) + } + args := &upsertServerRawReq{ + Server: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("authservers"), args) + return trace.Wrap(err) +} + +// GetAuthServers returns the list of auth servers registered in the cluster. +func (c *HTTPClient) GetAuthServers() ([]types.Server, error) { + out, err := c.Get(context.TODO(), c.Endpoint("authservers"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + re := make([]types.Server, len(items)) + for i, raw := range items { + server, err := services.UnmarshalServer(raw, types.KindAuthServer) + if err != nil { + return nil, trace.Wrap(err) + } + re[i] = server + } + return re, nil +} + +// UpsertProxy is used by proxies to report their presence +// to other auth servers in form of heartbeat expiring after ttl period. +func (c *HTTPClient) UpsertProxy(s types.Server) error { + data, err := services.MarshalServer(s) + if err != nil { + return trace.Wrap(err) + } + args := &upsertServerRawReq{ + Server: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("proxies"), args) + return trace.Wrap(err) +} + +// GetProxies returns the list of auth servers registered in the cluster. +func (c *HTTPClient) GetProxies() ([]types.Server, error) { + out, err := c.Get(context.TODO(), c.Endpoint("proxies"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + re := make([]types.Server, len(items)) + for i, raw := range items { + server, err := services.UnmarshalServer(raw, types.KindProxy) + if err != nil { + return nil, trace.Wrap(err) + } + re[i] = server + } + return re, nil +} + +// DeleteAllProxies deletes all proxies +func (c *HTTPClient) DeleteAllProxies() error { + _, err := c.Delete(context.TODO(), c.Endpoint("proxies")) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// DeleteProxy deletes proxy by name +func (c *HTTPClient) DeleteProxy(name string) error { + if name == "" { + return trace.BadParameter("missing parameter name") + } + _, err := c.Delete(context.TODO(), c.Endpoint("proxies", name)) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// UpsertUser user updates user entry. +func (c *HTTPClient) UpsertUser(user types.User) error { + data, err := services.MarshalUser(user) + if err != nil { + return trace.Wrap(err) + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("users"), &upsertUserRawReq{User: data}) + return trace.Wrap(err) +} + +// ExtendWebSession creates a new web session for a user based on another +// valid web session +func (c *HTTPClient) ExtendWebSession(ctx context.Context, req WebSessionReq) (types.WebSession, error) { + out, err := c.PostJSON(ctx, c.Endpoint("users", req.User, "web", "sessions"), req) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalWebSession(out.Bytes()) +} + +// CreateWebSession creates a new web session for a user +func (c *HTTPClient) CreateWebSession(ctx context.Context, user string) (types.WebSession, error) { + out, err := c.PostJSON( + ctx, + c.Endpoint("users", user, "web", "sessions"), + WebSessionReq{User: user}, + ) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalWebSession(out.Bytes()) +} + +// AuthenticateWebUser authenticates web user, creates and returns web session +// in case if authentication is successful +func (c *HTTPClient) AuthenticateWebUser(ctx context.Context, req AuthenticateUserRequest) (types.WebSession, error) { + out, err := c.PostJSON( + ctx, + c.Endpoint("users", req.Username, "web", "authenticate"), + req, + ) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalWebSession(out.Bytes()) +} + +// AuthenticateSSHUser authenticates SSH console user, creates and returns a pair of signed TLS and SSH +// short lived certificates as a result +func (c *HTTPClient) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHRequest) (*SSHLoginResponse, error) { + out, err := c.PostJSON( + ctx, + c.Endpoint("users", req.Username, "ssh", "authenticate"), + req, + ) + if err != nil { + return nil, trace.Wrap(err) + } + var re SSHLoginResponse + if err := json.Unmarshal(out.Bytes(), &re); err != nil { + return nil, trace.Wrap(err) + } + return &re, nil +} + +// GetWebSessionInfo checks if a web sesion is valid, returns session id in case if +// it is valid, or error otherwise. +func (c *HTTPClient) GetWebSessionInfo(ctx context.Context, user, sessionID string) (types.WebSession, error) { + out, err := c.Get( + ctx, + c.Endpoint("users", user, "web", "sessions", sessionID), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalWebSession(out.Bytes()) +} + +// DeleteWebSession deletes the web session specified with sid for the given user +func (c *HTTPClient) DeleteWebSession(ctx context.Context, user string, sid string) error { + _, err := c.Delete(ctx, c.Endpoint("users", user, "web", "sessions", sid)) + return trace.Wrap(err) +} + +// GenerateHostCert takes the public key in the Open SSH “authorized_keys“ +// plain text format, signs it using Host Certificate Authority private key and returns the +// resulting certificate. +func (c *HTTPClient) GenerateHostCert( + ctx context.Context, key []byte, hostID, nodeName string, principals []string, clusterName string, role types.SystemRole, ttl time.Duration, +) ([]byte, error) { + out, err := c.PostJSON(ctx, c.Endpoint("ca", "host", "certs"), + generateHostCertReq{ + Key: key, + HostID: hostID, + NodeName: nodeName, + Principals: principals, + ClusterName: clusterName, + Roles: types.SystemRoles{role}, + TTL: ttl, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + var cert string + if err := json.Unmarshal(out.Bytes(), &cert); err != nil { + return nil, err + } + + return []byte(cert), nil +} + +// ValidateOIDCAuthCallback validates OIDC auth callback returned from redirect +func (c *HTTPClient) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { + out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), ValidateOIDCAuthCallbackReq{ + Query: q, + }) + if err != nil { + return nil, trace.Wrap(err) + } + var rawResponse *OIDCAuthRawResponse + if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { + return nil, trace.Wrap(err) + } + response := OIDCAuthResponse{ + Username: rawResponse.Username, + Identity: rawResponse.Identity, + Cert: rawResponse.Cert, + Req: rawResponse.Req, + TLSCert: rawResponse.TLSCert, + } + if len(rawResponse.Session) != 0 { + session, err := services.UnmarshalWebSession(rawResponse.Session) + if err != nil { + return nil, trace.Wrap(err) + } + response.Session = session + } + response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) + for i, raw := range rawResponse.HostSigners { + ca, err := services.UnmarshalCertAuthority(raw) + if err != nil { + return nil, trace.Wrap(err) + } + response.HostSigners[i] = ca + } + return &response, nil +} + +// ValidateSAMLResponse validates response returned by SAML identity provider +func (c *HTTPClient) ValidateSAMLResponse(ctx context.Context, re string, connectorID string) (*SAMLAuthResponse, error) { + out, err := c.PostJSON(ctx, c.Endpoint("saml", "requests", "validate"), ValidateSAMLResponseReq{ + Response: re, + ConnectorID: connectorID, + }) + if err != nil { + return nil, trace.Wrap(err) + } + var rawResponse *SAMLAuthRawResponse + if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { + return nil, trace.Wrap(err) + } + response := SAMLAuthResponse{ + Username: rawResponse.Username, + Identity: rawResponse.Identity, + Cert: rawResponse.Cert, + Req: rawResponse.Req, + TLSCert: rawResponse.TLSCert, + } + if len(rawResponse.Session) != 0 { + session, err := services.UnmarshalWebSession(rawResponse.Session) + if err != nil { + return nil, trace.Wrap(err) + } + response.Session = session + } + response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) + for i, raw := range rawResponse.HostSigners { + ca, err := services.UnmarshalCertAuthority(raw) + if err != nil { + return nil, trace.Wrap(err) + } + response.HostSigners[i] = ca + } + return &response, nil +} + +// ValidateGithubAuthCallback validates Github auth callback returned from redirect +func (c *HTTPClient) ValidateGithubAuthCallback(ctx context.Context, q url.Values) (*GithubAuthResponse, error) { + out, err := c.PostJSON(ctx, c.Endpoint("github", "requests", "validate"), + validateGithubAuthCallbackReq{Query: q}) + if err != nil { + return nil, trace.Wrap(err) + } + var rawResponse githubAuthRawResponse + if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { + return nil, trace.Wrap(err) + } + response := GithubAuthResponse{ + Username: rawResponse.Username, + Identity: rawResponse.Identity, + Cert: rawResponse.Cert, + Req: rawResponse.Req, + TLSCert: rawResponse.TLSCert, + } + if len(rawResponse.Session) != 0 { + session, err := services.UnmarshalWebSession( + rawResponse.Session) + if err != nil { + return nil, trace.Wrap(err) + } + response.Session = session + } + response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) + for i, raw := range rawResponse.HostSigners { + ca, err := services.UnmarshalCertAuthority(raw) + if err != nil { + return nil, trace.Wrap(err) + } + response.HostSigners[i] = ca + } + return &response, nil +} + +// GetSessionChunk allows clients to receive a byte array (chunk) from a recorded +// session stream, starting from 'offset', up to 'max' in length. The upper bound +// of 'max' is set to events.MaxChunkBytes +func (c *HTTPClient) GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { + if namespace == "" { + return nil, trace.BadParameter(MissingNamespaceError) + } + response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "stream"), url.Values{ + "offset": []string{strconv.Itoa(offsetBytes)}, + "bytes": []string{strconv.Itoa(maxBytes)}, + }) + if err != nil { + log.Error(err) + return nil, trace.Wrap(err) + } + return response.Bytes(), nil +} + +// Returns events that happen during a session sorted by time +// (oldest first). +// +// afterN allows to filter by "newer than N" value where N is the cursor ID +// of previously returned bunch (good for polling for latest) +func (c *HTTPClient) GetSessionEvents(namespace string, sid session.ID, afterN int) (retval []events.EventFields, err error) { + if namespace == "" { + return nil, trace.BadParameter(MissingNamespaceError) + } + query := make(url.Values) + if afterN > 0 { + query.Set("after", strconv.Itoa(afterN)) + } + response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "events"), query) + if err != nil { + return nil, trace.Wrap(err) + } + retval = make([]events.EventFields, 0) + if err := json.Unmarshal(response.Bytes(), &retval); err != nil { + return nil, trace.Wrap(err) + } + return retval, nil +} + +// GetNamespaces returns a list of namespaces +func (c *HTTPClient) GetNamespaces() ([]types.Namespace, error) { + out, err := c.Get(context.TODO(), c.Endpoint("namespaces"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var re []types.Namespace + if err := utils.FastUnmarshal(out.Bytes(), &re); err != nil { + return nil, trace.Wrap(err) + } + return re, nil +} + +// GetNamespace returns namespace by name +func (c *HTTPClient) GetNamespace(name string) (*types.Namespace, error) { + if name == "" { + return nil, trace.BadParameter("missing namespace name") + } + out, err := c.Get(context.TODO(), c.Endpoint("namespaces", name), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalNamespace(out.Bytes()) +} + +// UpsertNamespace upserts namespace +func (c *HTTPClient) UpsertNamespace(ns types.Namespace) error { + _, err := c.PostJSON(context.TODO(), c.Endpoint("namespaces"), upsertNamespaceReq{Namespace: ns}) + return trace.Wrap(err) +} + +// DeleteNamespace deletes namespace by name +func (c *HTTPClient) DeleteNamespace(name string) error { + _, err := c.Delete(context.TODO(), c.Endpoint("namespaces", name)) + return trace.Wrap(err) +} + +// GetClusterName returns a cluster name +func (c *HTTPClient) GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) { + out, err := c.Get(context.TODO(), c.Endpoint("configuration", "name"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + + cn, err := services.UnmarshalClusterName(out.Bytes()) + if err != nil { + return nil, trace.Wrap(err) + } + + return cn, err +} + +// SetClusterName sets cluster name once, will +// return Already Exists error if the name is already set +func (c *HTTPClient) SetClusterName(cn types.ClusterName) error { + data, err := services.MarshalClusterName(cn) + if err != nil { + return trace.Wrap(err) + } + + _, err = c.PostJSON(context.TODO(), c.Endpoint("configuration", "name"), &setClusterNameReq{ClusterName: data}) + if err != nil { + return trace.Wrap(err) + } + + return nil +} + +// DeleteStaticTokens deletes static tokens +func (c *HTTPClient) DeleteStaticTokens() error { + _, err := c.Delete(context.TODO(), c.Endpoint("configuration", "static_tokens")) + return trace.Wrap(err) +} + +// GetStaticTokens returns a list of static register tokens +func (c *HTTPClient) GetStaticTokens() (types.StaticTokens, error) { + out, err := c.Get(context.TODO(), c.Endpoint("configuration", "static_tokens"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + + st, err := services.UnmarshalStaticTokens(out.Bytes()) + if err != nil { + return nil, trace.Wrap(err) + } + + return st, err +} + +// SetStaticTokens sets a list of static register tokens +func (c *HTTPClient) SetStaticTokens(st types.StaticTokens) error { + data, err := services.MarshalStaticTokens(st) + if err != nil { + return trace.Wrap(err) + } + + _, err = c.PostJSON(context.TODO(), c.Endpoint("configuration", "static_tokens"), &setStaticTokensReq{StaticTokens: data}) + if err != nil { + return trace.Wrap(err) + } + + return nil +} + +func (c *HTTPClient) ValidateTrustedCluster(ctx context.Context, validateRequest *ValidateTrustedClusterRequest) (*ValidateTrustedClusterResponse, error) { + validateRequestRaw, err := validateRequest.ToRaw() + if err != nil { + return nil, trace.Wrap(err) + } + + out, err := c.PostJSON(ctx, c.Endpoint("trustedclusters", "validate"), validateRequestRaw) + if err != nil { + return nil, trace.Wrap(err) + } + + var validateResponseRaw ValidateTrustedClusterResponseRaw + err = json.Unmarshal(out.Bytes(), &validateResponseRaw) + if err != nil { + return nil, trace.Wrap(err) + } + + validateResponse, err := validateResponseRaw.ToNative() + if err != nil { + return nil, trace.Wrap(err) + } + + return validateResponse, nil +} diff --git a/lib/service/connect.go b/lib/service/connect.go index 0c291957b1146..09be003fa9895 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -1198,7 +1198,10 @@ func (process *TeleportProcess) newClientThroughTunnel(authServers []utils.NetAd func (process *TeleportProcess) newClientDirect(authServers []utils.NetAddr, tlsConfig *tls.Config, role types.SystemRole) (*auth.Client, error) { var cltParams []roundtrip.ClientParam if process.Config.ClientTimeout != 0 { - cltParams = []roundtrip.ClientParam{auth.ClientTimeout(process.Config.ClientTimeout)} + cltParams = []roundtrip.ClientParam{ + auth.ClientParamIdleConnTimeout(process.Config.ClientTimeout), + auth.ClientParamResponseHeaderTimeout(process.Config.ClientTimeout), + } } var dialOpts []grpc.DialOption diff --git a/lib/services/presence.go b/lib/services/presence.go index f15415d757646..72558f78ebe45 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -61,13 +61,6 @@ type Presence interface { // specified duration with second resolution if it's >= 1 second. UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) - // DELETE IN: 5.1.0 - // - // This logic has been moved to KeepAliveServer. - // - // KeepAliveNode updates node TTL in the storage - KeepAliveNode(ctx context.Context, h types.KeepAlive) error - // GetAuthServers returns a list of registered servers GetAuthServers() ([]types.Server, error) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index a9c5d79151583..93a0edaeaaa4c 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3117,6 +3117,10 @@ func (h *Handler) createSSHCert(w http.ResponseWriter, r *http.Request, p httpro return nil, trace.Wrap(err) } + var authenticationClient interface { + AuthenticateSSHUser(ctx context.Context, req auth.AuthenticateSSHRequest) (*auth.SSHLoginResponse, error) + } = authClient + authReq := auth.AuthenticateUserRequest{ Username: req.User, PublicKey: req.PubKey, @@ -3136,11 +3140,26 @@ func (h *Handler) createSSHCert(w http.ResponseWriter, r *http.Request, p httpro case constants.SecondFactorWebauthn: // WebAuthn only supports this endpoint for headless login. authReq.HeadlessAuthenticationID = req.HeadlessAuthenticationID + + clt, ok := authClient.(*auth.Client) + if !ok { + return nil, trace.Errorf("expected client type *auth.Client but got %T", authClient) + } + + clientParams := []roundtrip.ClientParam{ + auth.ClientParamTimeout(defaults.CallbackTimeout), + auth.ClientParamResponseHeaderTimeout(defaults.CallbackTimeout), + } + + authenticationClient, err = clt.HTTPClient.Clone(clientParams...) + if err != nil { + return nil, trace.Wrap(err) + } default: return nil, trace.AccessDenied("unsupported second factor type: %q", cap.GetSecondFactor()) } - loginResp, err := authClient.AuthenticateSSHUser(r.Context(), auth.AuthenticateSSHRequest{ + loginResp, err := authenticationClient.AuthenticateSSHUser(r.Context(), auth.AuthenticateSSHRequest{ AuthenticateUserRequest: authReq, CompatibilityMode: req.Compatibility, TTL: req.TTL, From 1f7abd1800630984bbaf402764b00e97fe18ab0b Mon Sep 17 00:00:00 2001 From: joerger Date: Tue, 14 Mar 2023 18:58:11 -0700 Subject: [PATCH 10/18] Fix flaky test. --- lib/auth/auth_login_test.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index ad56c58523150..697cbbeda934e 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -717,7 +717,7 @@ func TestServer_Authenticate_headless(t *testing.T) { headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) ctx := context.Background() - timeout := time.Millisecond * 500 + timeout := time.Millisecond * 200 updateHeadlessAuthnInGoroutine := func(ctx context.Context, update func(*types.HeadlessAuthentication)) chan error { errC := make(chan error) @@ -776,10 +776,8 @@ func TestServer_Authenticate_headless(t *testing.T) { }, checkErr: require.Error, }, { - name: "NOK timeout", - update: func(ha *types.HeadlessAuthentication) { - time.Sleep(timeout) - }, + name: "NOK timeout", + update: func(ha *types.HeadlessAuthentication) {}, checkErr: require.Error, }, } { From 461aee223a1591e1926f65374409248af318a507 Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 15 Mar 2023 11:08:57 -0700 Subject: [PATCH 11/18] Remove shared state from test. --- lib/auth/auth_login_test.go | 49 ++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index 697cbbeda934e..57b214cb54257 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -704,22 +704,13 @@ func TestServer_Authenticate_nonPasswordlessRequiresUsername(t *testing.T) { func TestServer_Authenticate_headless(t *testing.T) { t.Parallel() - srv := newTestTLSServer(t) - - // We don't mind about the specifics of the configuration, as long as we have - // a user and TOTP/WebAuthn devices. - mfa := configureForMFA(t, srv) - username := mfa.User - - proxyClient, err := srv.NewClient(TestBuiltin(types.RoleProxy)) - require.NoError(t, err) - headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) ctx := context.Background() + headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) + const timeout = time.Millisecond * 200 - timeout := time.Millisecond * 200 - - updateHeadlessAuthnInGoroutine := func(ctx context.Context, update func(*types.HeadlessAuthentication)) chan error { + type updateHeadlessAuthnFn func(*types.HeadlessAuthentication, *types.MFADevice) + updateHeadlessAuthnInGoroutine := func(ctx context.Context, srv *TestTLSServer, mfa *types.MFADevice, update updateHeadlessAuthnFn) chan error { errC := make(chan error) go func() { defer close(errC) @@ -732,7 +723,7 @@ func TestServer_Authenticate_headless(t *testing.T) { // create a shallow copy and update for the compare and swap below. replaceHeadlessAuthn := *headlessAuthn - update(&replaceHeadlessAuthn) + update(&replaceHeadlessAuthn, mfa) _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, headlessAuthn, &replaceHeadlessAuthn) if err != nil { @@ -745,43 +736,55 @@ func TestServer_Authenticate_headless(t *testing.T) { for _, tc := range []struct { name string - update func(*types.HeadlessAuthentication) + update updateHeadlessAuthnFn checkErr require.ErrorAssertionFunc }{ { name: "OK approved", - update: func(ha *types.HeadlessAuthentication) { + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED - ha.MfaDevice = mfa.WebDev.MFA + ha.MfaDevice = mfa }, checkErr: require.NoError, }, { name: "NOK approved without MFA", - update: func(ha *types.HeadlessAuthentication) { + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED }, checkErr: require.Error, }, { name: "NOK user mismatch", - update: func(ha *types.HeadlessAuthentication) { + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED - ha.MfaDevice = mfa.WebDev.MFA + ha.MfaDevice = mfa ha.User = "other-user" }, checkErr: require.Error, }, { name: "NOK denied", - update: func(ha *types.HeadlessAuthentication) { + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED }, checkErr: require.Error, }, { name: "NOK timeout", - update: func(ha *types.HeadlessAuthentication) {}, + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {}, checkErr: require.Error, }, } { t.Run(tc.name, func(t *testing.T) { + tc := tc + t.Parallel() + + srv := newTestTLSServer(t) + proxyClient, err := srv.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) + + // We don't mind about the specifics of the configuration, as long as we have + // a user and TOTP/WebAuthn devices. + mfa := configureForMFA(t, srv) + username := mfa.User + t.Cleanup(func() { srv.Auth().DeleteHeadlessAuthentication(ctx, headlessID) }) @@ -789,7 +792,7 @@ func TestServer_Authenticate_headless(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - errC := updateHeadlessAuthnInGoroutine(ctx, tc.update) + errC := updateHeadlessAuthnInGoroutine(ctx, srv, mfa.WebDev.MFA, tc.update) _, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ AuthenticateUserRequest: AuthenticateUserRequest{ Username: username, From 390b3adfca3f3cf322108c210209e0fc479a0d6d Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 15 Mar 2023 13:27:22 -0700 Subject: [PATCH 12/18] Update error handling and testing for auth_with_roles. --- lib/auth/auth_with_roles.go | 31 ++-- lib/auth/auth_with_roles_test.go | 264 ++++++++++++++++++++++++------- 2 files changed, 222 insertions(+), 73 deletions(-) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index d4a902e859853..600d3b4c9b3a9 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -5698,23 +5698,22 @@ func (a *ServerWithRoles) DeleteAllUserGroups(ctx context.Context) error { // GetHeadlessAuthentication retrieves a headless authentication by id. func (a *ServerWithRoles) GetHeadlessAuthentication(ctx context.Context, id string) (*types.HeadlessAuthentication, error) { - waitCtx, cancel := context.WithTimeout(ctx, defaults.HTTPRequestTimeout) + // GetHeadlessAuthentication will wait for the headless details + // if they don't yet exist in the backend. + ctx, cancel := context.WithTimeout(ctx, defaults.HTTPRequestTimeout) defer cancel() - headlessAuthn, err := a.authServer.GetHeadlessAuthentication(waitCtx, id) + headlessAuthn, err := a.authServer.GetHeadlessAuthentication(ctx, id) if err != nil { return nil, trace.Wrap(err) } - // User can always get their own headless authentication state. Otherwise, check for associated rule. + // Only users can get their own headless authentication requests. if !hasLocalUserRole(a.context) || headlessAuthn.User != a.context.User.GetName() { - if err := a.action(apidefaults.Namespace, types.KindHeadlessAuthentication, types.VerbRead); err != nil { - // If the headless authentication can not be accessed by the user, we will return a not - // found error. This method would usually time out above if the headless authentication - // does not exist, so we mimick this behavior here. - <-waitCtx.Done() - return nil, trace.Wrap(waitCtx.Err()) - } + // This method would usually time out above if the headless authentication + // does not exist, so we mimick this behavior here for users without access. + <-ctx.Done() + return nil, trace.Wrap(ctx.Err()) } return headlessAuthn, nil @@ -5722,14 +5721,22 @@ func (a *ServerWithRoles) GetHeadlessAuthentication(ctx context.Context, id stri // UpdateHeadlessAuthenticationState updates a headless authentication state. func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, id string, state types.HeadlessAuthenticationState, mfaResp *proto.MFAAuthenticateResponse) error { + // GetHeadlessAuthentication will wait for the headless details + // if they don't yet exist in the backend. + ctx, cancel := context.WithTimeout(ctx, defaults.HTTPRequestTimeout) + defer cancel() + headlessAuthn, err := a.authServer.GetHeadlessAuthentication(ctx, id) if err != nil { return trace.Wrap(err) } - // Only users can approve/deny their own headless auth requests. + // Only users can update their own headless authentication requests. if !hasLocalUserRole(a.context) || headlessAuthn.User != a.context.User.GetName() { - return trace.NotFound("not found") + // This method would usually time out above if the headless authentication + // does not exist, so we mimick this behavior here for users without access. + <-ctx.Done() + return trace.Wrap(ctx.Err()) } if headlessAuthn.State != types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_PENDING { diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 708237d07b78c..aee9c2ac36259 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -4216,23 +4216,9 @@ func TestUnimplementedClients(t *testing.T) { }) } -func TestHeadlessAuthentication(t *testing.T) { - ctx := context.Background() - srv := newTestTLSServer(t) - - mfa := configureForMFA(t, srv) - - user1, _, err := CreateUserAndRole(srv.Auth(), mfa.User, nil, nil) - require.NoError(t, err) - client1, err := srv.NewClient(TestUser(user1.GetName())) - require.NoError(t, err) - - user2, _, err := CreateUserAndRole(srv.Auth(), "user2", nil, nil) - require.NoError(t, err) - client2, err := srv.NewClient(TestUser(user2.GetName())) - require.NoError(t, err) - - // Insert a headless authentication resource into the backend. +// getTestHeadlessAuthenticationID returns the headless authentication resource +// used across headless authentication tests. +func getTestHeadlessAuthn(t *testing.T, user string) *types.HeadlessAuthentication { headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) headlessAuthn := &types.HeadlessAuthentication{ ResourceHeader: types.ResourceHeader{ @@ -4240,65 +4226,221 @@ func TestHeadlessAuthentication(t *testing.T) { Name: headlessID, }, }, - User: user1.GetName(), + User: user, PublicKey: []byte(sshPubKey), ClientIpAddress: "0.0.0.0", } headlessAuthn.SetExpiry(time.Now().Add(time.Minute)) - stub, err := srv.Auth().CreateHeadlessAuthenticationStub(ctx, headlessID) - require.NoError(t, err) - _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, stub, headlessAuthn) + err := headlessAuthn.CheckAndSetDefaults() require.NoError(t, err) - // user2 should fail to get headless authentication, and return the ctx error - // to prevent leaking other user's headless authentication attempts. - failedGetCtx, cancel := context.WithCancel(ctx) - cancel() + return headlessAuthn +} - _, err = client2.GetHeadlessAuthentication(failedGetCtx, headlessID) - require.Error(t, err) - require.ErrorContains(t, err, "context deadline exceeded", "expected context deadline error but got: %v", err) +func TestGetHeadlessAuthentication(t *testing.T) { + ctx := context.Background() + username := "teleport-user" + headlessAuthn := getTestHeadlessAuthn(t, username) + otherUsername := "other-user" - // user1 should successfully get headless authentication with up to date login details - retrievedHeadlessAuthn, err := client1.GetHeadlessAuthentication(ctx, headlessID) - require.NoError(t, err) - require.Equal(t, headlessAuthn, retrievedHeadlessAuthn) + for _, tc := range []struct { + name string + headlessID string + identity TestIdentity + assertError require.ErrorAssertionFunc + expectedHeadlessAuthn *types.HeadlessAuthentication + }{ + { + name: "OK same user", + headlessID: headlessAuthn.GetName(), + identity: TestUser(username), + assertError: require.NoError, + expectedHeadlessAuthn: headlessAuthn, + }, { + name: "NOK not found", + headlessID: uuid.NewString(), + identity: TestUser(username), + assertError: func(t require.TestingT, err error, i ...interface{}) { - // user2 should fail to update authentication state - err = client2.UpdateHeadlessAuthenticationState(ctx, headlessID, types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, nil) - require.Error(t, err) - require.True(t, trace.IsNotFound(err), "expected not found error but got: %v", err) + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK different user", + headlessID: headlessAuthn.GetName(), + identity: TestUser(otherUsername), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK admin", + headlessID: headlessAuthn.GetName(), + identity: TestAdmin(), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + tc := tc + t.Parallel() - // user1 should successfully update authentication state to denied - err = client1.UpdateHeadlessAuthenticationState(ctx, headlessID, types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, nil) - require.NoError(t, err) + srv := newTestTLSServer(t) + _, _, err := CreateUserAndRole(srv.Auth(), username, nil, nil) + require.NoError(t, err) + _, _, err = CreateUserAndRole(srv.Auth(), otherUsername, nil, nil) + require.NoError(t, err) - // reset to original state - retrievedHeadlessAuthn, err = client1.GetHeadlessAuthentication(ctx, headlessID) - require.NoError(t, err) - _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, retrievedHeadlessAuthn, headlessAuthn) - require.NoError(t, err) + // create headless authn + stub, err := srv.Auth().CreateHeadlessAuthenticationStub(ctx, headlessAuthn.GetName()) + require.NoError(t, err) + _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, stub, headlessAuthn) + require.NoError(t, err) - // user1 should fail to update authentication state to approved without mfa - err = client1.UpdateHeadlessAuthenticationState(ctx, headlessID, types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, &proto.MFAAuthenticateResponse{ - Response: &proto.MFAAuthenticateResponse_Webauthn{ - Webauthn: &webauthn.CredentialAssertionResponse{ - Type: "bad response", + client, err := srv.NewClient(tc.identity) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + + retrievedHeadlessAuthn, err := client.GetHeadlessAuthentication(ctx, tc.headlessID) + tc.assertError(t, err) + require.Equal(t, tc.expectedHeadlessAuthn, retrievedHeadlessAuthn) + }) + } +} + +func TestUpdateHeadlessAuthenticationState(t *testing.T) { + ctx := context.Background() + otherUsername := "other-user" + + for _, tc := range []struct { + name string + // defaults to the mfa identity tied to the headless authentication created + identity TestIdentity + // defaults to id of the headless authentication created + headlessID string + state types.HeadlessAuthenticationState + withMFA bool + assertError require.ErrorAssertionFunc + }{ + { + name: "OK same user denied", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, + assertError: require.NoError, + }, { + name: "OK same user approved with mfa", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, + withMFA: true, + assertError: require.NoError, + }, { + name: "NOK same user approved without mfa", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, + withMFA: false, + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err), "expected access denied error but got: %v", err) + }, + }, { + name: "NOK not found", + headlessID: uuid.NewString(), + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK different user denied", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, + identity: TestUser(otherUsername), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK different user approved", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, + identity: TestUser(otherUsername), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK admin denied", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, + identity: TestAdmin(), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK admin approved", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, + identity: TestAdmin(), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) }, }, - }) - require.Error(t, err) - require.True(t, trace.IsAccessDenied(err), "expected access denied error but got: %v", err) + } { + t.Run(tc.name, func(t *testing.T) { + tc := tc + t.Parallel() - // user1 should successfully update authentication state to approved with MFA - challenge, err := client1.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ - Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{}, - }) - require.NoError(t, err) - resp, err := mfa.WebDev.SolveAuthn(challenge) - require.NoError(t, err) + srv := newTestTLSServer(t) + mfa := configureForMFA(t, srv) - err = client1.UpdateHeadlessAuthenticationState(ctx, headlessID, types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, resp) - require.NoError(t, err) + _, _, err := CreateUserAndRole(srv.Auth(), otherUsername, nil, nil) + require.NoError(t, err) + + // create headless authn + headlessAuthn := getTestHeadlessAuthn(t, mfa.User) + stub, err := srv.Auth().CreateHeadlessAuthenticationStub(ctx, headlessAuthn.GetName()) + require.NoError(t, err) + _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, stub, headlessAuthn) + require.NoError(t, err) + + // default to mfa user + if tc.identity.I == nil { + tc.identity = TestUser(mfa.User) + } + + client, err := srv.NewClient(tc.identity) + require.NoError(t, err) + + resp := &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: &webauthn.CredentialAssertionResponse{ + Type: "bad response", + }, + }, + } + if tc.withMFA { + client, err := srv.NewClient(TestUser(mfa.User)) + require.NoError(t, err) + + challenge, err := client.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ + Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{}, + }) + require.NoError(t, err) + + resp, err = mfa.WebDev.SolveAuthn(challenge) + require.NoError(t, err) + } + + // default to same headlessAuthn + if tc.headlessID == "" { + tc.headlessID = headlessAuthn.GetName() + } + + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + + err = client.UpdateHeadlessAuthenticationState(ctx, tc.headlessID, tc.state, resp) + tc.assertError(t, err) + }) + } } From 31ed159980cdc729313f50a3658a55275fa001f8 Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 15 Mar 2023 18:05:23 -0700 Subject: [PATCH 13/18] Fix rebase misshap. --- lib/auth/clt.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 7860db683b668..02f54fa80ddcd 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -137,6 +137,12 @@ func (c *Client) CreateCertAuthority(ca types.CertAuthority) error { return trace.NotImplemented(notImplementedMessage) } +// CompareAndSwapCertAuthority updates existing cert authority if the existing cert authority +// value matches the value stored in the backend. +func (c *Client) CompareAndSwapCertAuthority(new, existing types.CertAuthority) error { + return trace.BadParameter("this function is not supported on the client") +} + // GetCertAuthorities returns a list of certificate authorities func (c *Client) GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error) { if err := caType.Check(); err != nil { From 3eee16be3bc77ec4a82c181609e0ffcb1007d31a Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 15 Mar 2023 18:38:04 -0700 Subject: [PATCH 14/18] Fix race condition in test. --- lib/auth/auth_with_roles_test.go | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index aee9c2ac36259..fdcad9cae4d86 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -4241,7 +4241,6 @@ func getTestHeadlessAuthn(t *testing.T, user string) *types.HeadlessAuthenticati func TestGetHeadlessAuthentication(t *testing.T) { ctx := context.Background() username := "teleport-user" - headlessAuthn := getTestHeadlessAuthn(t, username) otherUsername := "other-user" for _, tc := range []struct { @@ -4252,32 +4251,27 @@ func TestGetHeadlessAuthentication(t *testing.T) { expectedHeadlessAuthn *types.HeadlessAuthentication }{ { - name: "OK same user", - headlessID: headlessAuthn.GetName(), - identity: TestUser(username), - assertError: require.NoError, - expectedHeadlessAuthn: headlessAuthn, + name: "OK same user", + identity: TestUser(username), + assertError: require.NoError, }, { name: "NOK not found", headlessID: uuid.NewString(), identity: TestUser(username), assertError: func(t require.TestingT, err error, i ...interface{}) { - require.Error(t, err) require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) }, }, { - name: "NOK different user", - headlessID: headlessAuthn.GetName(), - identity: TestUser(otherUsername), + name: "NOK different user", + identity: TestUser(otherUsername), assertError: func(t require.TestingT, err error, i ...interface{}) { require.Error(t, err) require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) }, }, { - name: "NOK admin", - headlessID: headlessAuthn.GetName(), - identity: TestAdmin(), + name: "NOK admin", + identity: TestAdmin(), assertError: func(t require.TestingT, err error, i ...interface{}) { require.Error(t, err) require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) @@ -4295,6 +4289,7 @@ func TestGetHeadlessAuthentication(t *testing.T) { require.NoError(t, err) // create headless authn + headlessAuthn := getTestHeadlessAuthn(t, username) stub, err := srv.Auth().CreateHeadlessAuthenticationStub(ctx, headlessAuthn.GetName()) require.NoError(t, err) _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, stub, headlessAuthn) @@ -4306,9 +4301,16 @@ func TestGetHeadlessAuthentication(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() + // default to same headlessAuthn + if tc.headlessID == "" { + tc.headlessID = headlessAuthn.GetName() + } + retrievedHeadlessAuthn, err := client.GetHeadlessAuthentication(ctx, tc.headlessID) tc.assertError(t, err) - require.Equal(t, tc.expectedHeadlessAuthn, retrievedHeadlessAuthn) + if err == nil { + require.Equal(t, headlessAuthn, retrievedHeadlessAuthn) + } }) } } From feec9735734a0a5d854056e059c05acc3b5cf0f1 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 16 Mar 2023 12:43:02 -0400 Subject: [PATCH 15/18] update e ref --- e | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/e b/e index f64397a5610b2..76f403a8e21bd 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit f64397a5610b2bde64f8598fc9b4ed43f5065f42 +Subproject commit 76f403a8e21bde4d6f7170b560e2acf7345b2158 From 6d4cbd744dbfdab00f27f537af1f93bc4e72e24a Mon Sep 17 00:00:00 2001 From: joerger Date: Thu, 16 Mar 2023 11:41:20 -0700 Subject: [PATCH 16/18] Fix ctx missing. --- lib/auth/clt.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/auth/clt.go b/lib/auth/clt.go index aae9f7ef33ac1..bbc1d1f23413c 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -216,7 +216,7 @@ func (c *Client) DeleteCertAuthority(ctx context.Context, id types.CertAuthID) e // Fallback to HTTP API // DELETE IN 14.0.0 case trace.IsNotImplemented(err): - err = c.HTTPClient.DeleteCertAuthority(id) + err = c.HTTPClient.DeleteCertAuthority(ctx, id) return trace.Wrap(err) default: return trace.Wrap(err) From ef97f18244bb073ca58ee162a3d788483c2b1ffc Mon Sep 17 00:00:00 2001 From: joerger Date: Thu, 16 Mar 2023 12:03:42 -0700 Subject: [PATCH 17/18] Extend test timeout to prevent flakiness. --- lib/auth/auth_with_roles_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index fdcad9cae4d86..3094db9a9e4c2 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -4438,7 +4438,7 @@ func TestUpdateHeadlessAuthenticationState(t *testing.T) { tc.headlessID = headlessAuthn.GetName() } - ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() err = client.UpdateHeadlessAuthenticationState(ctx, tc.headlessID, tc.state, resp) From e27445a6992f8a7ea04a8821eef50ea6a6799c7e Mon Sep 17 00:00:00 2001 From: joerger Date: Thu, 16 Mar 2023 13:25:43 -0700 Subject: [PATCH 18/18] Fix issue with roundtrip.ClientParams not being applied due to roundtripper wrapping. --- api/breaker/round_tripper.go | 5 ++ api/observability/tracing/http/http.go | 51 ++++++++++---------- lib/auth/auth_with_roles.go | 6 +++ lib/auth/clt.go | 3 ++ lib/auth/http_client.go | 65 ++++++++++++++++---------- lib/auth/methods.go | 12 ++--- lib/web/apiserver.go | 12 ++--- 7 files changed, 88 insertions(+), 66 deletions(-) diff --git a/api/breaker/round_tripper.go b/api/breaker/round_tripper.go index 5eec8c61b1520..17b7b9eeb415b 100644 --- a/api/breaker/round_tripper.go +++ b/api/breaker/round_tripper.go @@ -59,3 +59,8 @@ func (t *RoundTripper) RoundTrip(request *http.Request) (*http.Response, error) return v.(*http.Response), err } + +// Unwrap returns the inner round tripper. +func (t *RoundTripper) Unwrap() http.RoundTripper { + return t.tripper +} diff --git a/api/observability/tracing/http/http.go b/api/observability/tracing/http/http.go index 8d8eeca7bfe12..28567fffddbfc 100644 --- a/api/observability/tracing/http/http.go +++ b/api/observability/tracing/http/http.go @@ -15,6 +15,7 @@ package http import ( + "net/http" nethttp "net/http" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -43,39 +44,37 @@ func HandlerFormatter(operation string, r *nethttp.Request) string { // https://github.com/open-telemetry/opentelemetry-go-contrib/issues/3543. // Once the issue is resolved the wrapper may be discarded. func NewTransport(rt nethttp.RoundTripper) nethttp.RoundTripper { - return enforceCloseIdleConnections( - otelhttp.NewTransport(rt, - otelhttp.WithSpanNameFormatter(TransportFormatter), - ), rt) + return &roundTripWrapper{ + RoundTripper: otelhttp.NewTransport(rt, otelhttp.WithSpanNameFormatter(TransportFormatter)), + inner: rt, + } +} + +type closeIdler interface { + CloseIdleConnections() +} + +type roundTripWrapper struct { + nethttp.RoundTripper + inner nethttp.RoundTripper +} + +// Unwrap returns the inner round tripper. +func (w *roundTripWrapper) Unwrap() http.RoundTripper { + return w.inner } -// enforceCloseIdleConnections ensures that the returned [nethttp.RoundTripper] +// CloseIdleConnections ensures that the returned [nethttp.RoundTripper] // has a CloseIdleConnections method. Since [otelhttp.Transport] does not implement // this any [nethttp.Client.CloseIdleConnections] calls result in a noop instead // of forwarding the request onto its wrapped [nethttp.RoundTripper]. // // DELETE WHEN https://github.com/open-telemetry/opentelemetry-go-contrib/issues/3543 // has been resolved. -func enforceCloseIdleConnections(wrapper, wrapped nethttp.RoundTripper) nethttp.RoundTripper { - type closeIdler interface { - CloseIdleConnections() +func (w *roundTripWrapper) CloseIdleConnections() { + if c, ok := w.RoundTripper.(closeIdler); ok { + c.CloseIdleConnections() + } else if c, ok := w.inner.(closeIdler); ok { + c.CloseIdleConnections() } - - type unwrapper struct { - nethttp.RoundTripper - closeIdler - } - - if _, ok := wrapper.(closeIdler); ok { - return wrapper - } - - if c, ok := wrapped.(closeIdler); ok { - return &unwrapper{ - RoundTripper: wrapper, - closeIdler: c, - } - } - - return wrapper } diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 7fda6bd5235d6..df4f41d46b267 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -23,6 +23,7 @@ import ( "strings" "time" + "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/sirupsen/logrus" collectortracev1 "go.opentelemetry.io/proto/otlp/collector/trace/v1" @@ -5785,6 +5786,11 @@ func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, return trace.Wrap(err) } +// CloneHTTPClient creates a new HTTP client with the same configuration. +func (a *ServerWithRoles) CloneHTTPClient(params ...roundtrip.ClientParam) (*HTTPClient, error) { + return nil, trace.NotImplemented("not implemented") +} + // NewAdminAuthServer returns auth server authorized as admin, // used for auth server cached access func NewAdminAuthServer(authServer *Server, alog events.AuditLogSessionStreamer) (ClientI, error) { diff --git a/lib/auth/clt.go b/lib/auth/clt.go index bbc1d1f23413c..3eb42be3a0bc7 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -831,4 +831,7 @@ type ClientI interface { // still get an Okta client when calling this method, but all RPCs will return // "not implemented" errors (as per the default gRPC behavior). OktaClient() *okta.Client + + // CloneHTTPClient creates a new HTTP client with the same configuration. + CloneHTTPClient(params ...roundtrip.ClientParam) (*HTTPClient, error) } diff --git a/lib/auth/http_client.go b/lib/auth/http_client.go index a3c83953af4cf..5fca47b82dfa0 100644 --- a/lib/auth/http_client.go +++ b/lib/auth/http_client.go @@ -154,11 +154,9 @@ func (c *HTTPClientConfig) Clone() *HTTPClientConfig { // HTTPClient is a teleport HTTP API client. type HTTPClient struct { - roundtrip.Client + *roundtrip.Client // cfg is the http client configuration. cfg *HTTPClientConfig - // transport defines the methods by which the client can reach the server. - transport *http.Transport } // NewHTTPClient creates a new HTTP client with TLS authentication and the given dialer. @@ -183,9 +181,8 @@ func NewHTTPClient(cfg *HTTPClientConfig, params ...roundtrip.ClientParam) (*HTT } return &HTTPClient{ - cfg: cfg, - Client: *roundtripClient, - transport: transport, + cfg: cfg, + Client: roundtripClient, }, nil } @@ -216,20 +213,24 @@ func newRoundtripClient(cfg *HTTPClientConfig, transport *http.Transport, params return roundtripClient, nil } -// Clone creates a new client with the same configuration. -func (c *HTTPClient) Clone(params ...roundtrip.ClientParam) (*HTTPClient, error) { +// CloneHTTPClient creates a new HTTP client with the same configuration. +func (c *HTTPClient) CloneHTTPClient(params ...roundtrip.ClientParam) (*HTTPClient, error) { cfg := c.cfg.Clone() - transport := c.transport.Clone() - roundtripClient, err := newRoundtripClient(c.cfg, transport) + // We copy the transport which may have had roundtrip.ClientParams applied on initial creation. + transport, err := c.getTransport() + if err != nil { + return nil, trace.Wrap(err) + } + + roundtripClient, err := newRoundtripClient(c.cfg, transport, params...) if err != nil { return nil, trace.Wrap(err) } return &HTTPClient{ - Client: *roundtripClient, - cfg: cfg, - transport: transport, + Client: roundtripClient, + cfg: cfg, }, nil } @@ -244,11 +245,9 @@ func ClientParamTimeout(timeout time.Duration) roundtrip.ClientParam { // ClientParamResponseHeaderTimeout sets response header timeout of the HTTP transport used by the client. func ClientParamResponseHeaderTimeout(timeout time.Duration) roundtrip.ClientParam { return func(c *roundtrip.Client) error { - transport, ok := (c.HTTPClient().Transport).(*http.Transport) - if !ok { - return nil + if t, err := getHTTPTransport(c); err == nil { + t.ResponseHeaderTimeout = timeout } - transport.ResponseHeaderTimeout = timeout return nil } } @@ -256,28 +255,44 @@ func ClientParamResponseHeaderTimeout(timeout time.Duration) roundtrip.ClientPar // ClientParamIdleConnTimeout sets idle connection header timeout of the HTTP transport used by the client. func ClientParamIdleConnTimeout(timeout time.Duration) roundtrip.ClientParam { return func(c *roundtrip.Client) error { - transport, ok := (c.HTTPClient().Transport).(*http.Transport) - if !ok { - return nil + if t, err := getHTTPTransport(c); err == nil { + t.IdleConnTimeout = timeout } - transport.IdleConnTimeout = timeout return nil } } // Close closes the HTTP client connection to the auth server. func (c *HTTPClient) Close() { - c.transport.CloseIdleConnections() + c.Client.HTTPClient().CloseIdleConnections() } // TLSConfig returns the HTTP client's TLS config. func (c *HTTPClient) TLSConfig() *tls.Config { - return c.transport.TLSClientConfig + return c.cfg.TLS } // GetTransport returns the HTTP client's transport. -func (c *HTTPClient) GetTransport() *http.Transport { - return c.transport +func (c *HTTPClient) getTransport() (*http.Transport, error) { + return getHTTPTransport(c.Client) +} + +func getHTTPTransport(c *roundtrip.Client) (*http.Transport, error) { + type wrapper interface { + Unwrap() http.RoundTripper + } + + transport := c.HTTPClient().Transport + for { + switch t := transport.(type) { + case wrapper: + transport = t.Unwrap() + case *http.Transport: + return t, nil + default: + return nil, trace.BadParameter("unexpected transport type %T", t) + } + } } // PostJSON is a generic method that issues http POST request to the server diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 25f1c5299b0e9..a578ec6594890 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -333,20 +333,20 @@ func (s *Server) authenticatePasswordless(ctx context.Context, req AuthenticateU } func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserRequest) (mfa *types.MFADevice, err error) { - // this authentication requires two client callbacks to create a headless authentication - // stub and approve/deny the headless authentication, so we use a standard callback timeout. - ctx, cancel := context.WithTimeout(ctx, defaults.CallbackTimeout) - defer cancel() - // Delete the headless authentication upon failure. defer func() { if err != nil { - if err := s.DeleteHeadlessAuthentication(ctx, req.HeadlessAuthenticationID); err != nil && !trace.IsNotFound(err) { + if err := s.DeleteHeadlessAuthentication(s.CloseContext(), req.HeadlessAuthenticationID); err != nil && !trace.IsNotFound(err) { log.Debugf("Failed to delete headless authentication: %v", err) } } }() + // this authentication requires two client callbacks to create a headless authentication + // stub and approve/deny the headless authentication, so we use a standard callback timeout. + ctx, cancel := context.WithTimeout(ctx, defaults.CallbackTimeout) + defer cancel() + headlessAuthn := &types.HeadlessAuthentication{ ResourceHeader: types.ResourceHeader{ Metadata: types.Metadata{ diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 93a0edaeaaa4c..b4227adf1b0a8 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3141,17 +3141,11 @@ func (h *Handler) createSSHCert(w http.ResponseWriter, r *http.Request, p httpro // WebAuthn only supports this endpoint for headless login. authReq.HeadlessAuthenticationID = req.HeadlessAuthenticationID - clt, ok := authClient.(*auth.Client) - if !ok { - return nil, trace.Errorf("expected client type *auth.Client but got %T", authClient) - } - - clientParams := []roundtrip.ClientParam{ + // create a new http client with a standard callback timeout. + authenticationClient, err = authClient.CloneHTTPClient( auth.ClientParamTimeout(defaults.CallbackTimeout), auth.ClientParamResponseHeaderTimeout(defaults.CallbackTimeout), - } - - authenticationClient, err = clt.HTTPClient.Clone(clientParams...) + ) if err != nil { return nil, trace.Wrap(err) }