diff --git a/lib/auth/auth.go b/lib/auth/auth.go index d238769e1517f..5381a4d544347 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -471,6 +471,11 @@ var ( } ) +// LoginHook is a function that will be called on a successful login. This will likely be used +// for enterprise services that need to add in feature specific operations after a user has been +// successfully authenticated. An example would be creating objects based on the user. +type LoginHook func(context.Context, types.User) error + // Server keeps the cluster together. It acts as a certificate authority (CA) for // a cluster and: // - generates the keypair for the node it's running on @@ -586,6 +591,10 @@ type Server struct { // headlessAuthenticationWatcher is a headless authentication watcher, // used to catch and propagate headless authentication request changes. headlessAuthenticationWatcher *local.HeadlessAuthenticationWatcher + + loginHooksMu sync.RWMutex + // loginHooks are a list of hooks that will be called on login. + loginHooks []LoginHook } // SetSAMLService registers svc as the SAMLService that provides the SAML @@ -641,6 +650,41 @@ func (a *Server) GetLoginRuleEvaluator() loginrule.Evaluator { return a.loginRuleEvaluator } +// RegisterLoginHook will register a login hook with the auth server. +func (a *Server) RegisterLoginHook(hook LoginHook) { + a.loginHooksMu.Lock() + defer a.loginHooksMu.Unlock() + + a.loginHooks = append(a.loginHooks, hook) +} + +// CallLoginHooks will call the registered login hooks. +func (a *Server) CallLoginHooks(ctx context.Context, user types.User) error { + // Make a copy of the login hooks to operate on. + a.loginHooksMu.RLock() + loginHooks := make([]LoginHook, len(a.loginHooks)) + copy(loginHooks, a.loginHooks) + a.loginHooksMu.RUnlock() + + if len(loginHooks) == 0 { + return nil + } + + var errs []error + for _, hook := range loginHooks { + errs = append(errs, hook(ctx, user)) + } + + return trace.NewAggregate(errs...) +} + +// ResetLoginHooks will clear out the login hooks. +func (a *Server) ResetLoginHooks() { + a.loginHooksMu.Lock() + a.loginHooks = nil + a.loginHooksMu.Unlock() +} + // CloseContext returns the close context func (a *Server) CloseContext() context.Context { return a.closeCtx diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index d1523ec7597cf..40059bab3774b 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -16,6 +16,7 @@ package auth import ( "context" + "sync/atomic" "testing" "time" @@ -553,8 +554,16 @@ func TestServer_Authenticate_passwordless(t *testing.T) { proxyClient, err := svr.NewClient(TestBuiltin(types.RoleProxy)) require.NoError(t, err) + // used to keep track of calls to login hooks. + var loginHookCounter atomic.Int32 + var loginHook LoginHook = func(_ context.Context, _ types.User) error { + loginHookCounter.Add(1) + return nil + } + tests := []struct { name string + loginHooks []LoginHook authenticate func(t *testing.T, resp *wanlib.CredentialAssertionResponse) }{ { @@ -573,6 +582,26 @@ func TestServer_Authenticate_passwordless(t *testing.T) { require.Equal(t, user, loginResp.Username, "Unexpected username") }, }, + { + name: "ssh with login hooks", + loginHooks: []LoginHook{ + loginHook, + loginHook, + }, + authenticate: func(t *testing.T, resp *wanlib.CredentialAssertionResponse) { + loginResp, err := proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ + AuthenticateUserRequest: AuthenticateUserRequest{ + Webauthn: resp, + PublicKey: []byte(sshPubKey), + }, + TTL: 24 * time.Hour, + }) + require.NoError(t, err, "Failed to perform passwordless authentication") + require.NotNil(t, loginResp, "SSH response nil") + require.NotEmpty(t, loginResp.Cert, "SSH certificate empty") + require.Equal(t, user, loginResp.Username, "Unexpected username") + }, + }, { name: "web", authenticate: func(t *testing.T, resp *wanlib.CredentialAssertionResponse) { @@ -583,9 +612,28 @@ func TestServer_Authenticate_passwordless(t *testing.T) { require.Equal(t, user, session.GetUser(), "Unexpected username") }, }, + { + name: "web with login hooks", + loginHooks: []LoginHook{ + loginHook, + }, + authenticate: func(t *testing.T, resp *wanlib.CredentialAssertionResponse) { + session, err := proxyClient.AuthenticateWebUser(ctx, AuthenticateUserRequest{ + Webauthn: resp, + }) + require.NoError(t, err, "Failed to perform passwordless authentication") + require.Equal(t, user, session.GetUser(), "Unexpected username") + }, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + svr.Auth().ResetLoginHooks() + loginHookCounter.Store(0) + for _, hook := range test.loginHooks { + svr.Auth().RegisterLoginHook(hook) + } + // Fail a login attempt so have a non-empty list of attempts. _, err := proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ AuthenticateUserRequest: AuthenticateUserRequest{ @@ -621,6 +669,8 @@ func TestServer_Authenticate_passwordless(t *testing.T) { attempts, err = authServer.GetUserLoginAttempts(user) require.NoError(t, err) require.Empty(t, attempts, "Login attempts not reset") + + require.Equal(t, len(test.loginHooks), int(loginHookCounter.Load())) }) } } diff --git a/lib/auth/github.go b/lib/auth/github.go index 6030d7e91b957..ea3a78e7fd9e3 100644 --- a/lib/auth/github.go +++ b/lib/auth/github.go @@ -618,6 +618,10 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia return nil, trace.Wrap(err, "Failed to create user from provided parameters.") } + if err := a.CallLoginHooks(ctx, user); err != nil { + return nil, trace.Wrap(err) + } + // Auth was successful, return session, certificate, etc. to caller. auth := GithubAuthResponse{ Req: GithubAuthRequestFromProto(req), diff --git a/lib/auth/methods.go b/lib/auth/methods.go index f6475a9e22271..a95ce5342309a 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -112,17 +112,17 @@ type SessionCreds struct { // AuthenticateUser authenticates user based on the request type. // Returns the username of the authenticated user. -func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserRequest) (string, error) { - user := req.Username +func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserRequest) (types.User, error) { + username := req.Username - mfaDev, actualUser, err := s.authenticateUser(ctx, req) + mfaDev, actualUsername, err := s.authenticateUser(ctx, req) // err is handled below. switch { - case user != "" && actualUser != "" && user != actualUser: - log.Warnf("Authenticate user mismatch (%q vs %q). Using request user (%q)", user, actualUser, user) - case user == "" && actualUser != "": - log.Debugf("User %q authenticated via passwordless", actualUser) - user = actualUser + case username != "" && actualUsername != "" && username != actualUsername: + log.Warnf("Authenticate user mismatch (%q vs %q). Using request user (%q)", username, actualUsername, username) + case username == "" && actualUsername != "": + log.Debugf("User %q authenticated via passwordless", actualUsername) + username = actualUsername } event := &apievents.UserLogin{ @@ -131,7 +131,7 @@ func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserReque Code: events.UserLocalLoginFailureCode, }, UserMetadata: apievents.UserMetadata{ - User: user, + User: username, }, Method: events.LoginMethodLocal, } @@ -147,6 +147,8 @@ func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserReque event.UserAgent = req.ClientMetadata.UserAgent } } + + var user types.User if err != nil { event.Code = events.UserLocalLoginFailureCode event.Status.Success = false @@ -154,6 +156,19 @@ func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserReque } else { event.Code = events.UserLocalLoginCode event.Status.Success = true + + var err error + user, err = s.GetUser(username, false /* withSecrets */) + if err != nil { + return nil, trace.Wrap(err) + } + + // After we're sure that the user has been logged in successfully, we should call + // the registered login hooks. Login hooks can be registered by other processes to + // execute arbitrary operations after a successful login. + if err := s.CallLoginHooks(ctx, user); err != nil { + return nil, trace.Wrap(err) + } } if err := s.emitter.EmitAuditEvent(s.closeCtx, event); err != nil { log.WithError(err).Warn("Failed to emit login event.") @@ -442,13 +457,7 @@ func (s *Server) AuthenticateWebUser(ctx context.Context, req AuthenticateUserRe return session, nil } - actualUser, err := s.AuthenticateUser(ctx, req) - if err != nil { - return nil, trace.Wrap(err) - } - username = actualUser - - user, err := s.GetUser(username, false /* withSecrets */) + user, err := s.AuthenticateUser(ctx, req) if err != nil { return nil, trace.Wrap(err) } @@ -566,18 +575,13 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq return nil, trace.Wrap(err) } - actualUser, err := s.AuthenticateUser(ctx, req.AuthenticateUserRequest) - if err != nil { - return nil, trace.Wrap(err) - } - username = actualUser - // It's safe to extract the roles and traits directly from services.User as // this endpoint is only used for local accounts. - user, err := s.GetUser(username, false /* withSecrets */) + user, err := s.AuthenticateUser(ctx, req.AuthenticateUserRequest) if err != nil { return nil, trace.Wrap(err) } + accessInfo := services.AccessInfoFromUser(user) checker, err := services.NewAccessChecker(accessInfo, clusterName.GetClusterName(), s) if err != nil { @@ -640,7 +644,7 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq } UserLoginCount.Inc() return &SSHLoginResponse{ - Username: username, + Username: user.GetName(), Cert: certs.SSH, TLSCert: certs.TLS, HostSigners: AuthoritiesToTrustedCerts(hostCertAuthorities),