Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,11 @@ var (
}
)

// LoginHook is a function that will be called on a successful login. This will likely be used
// for enterprise services that need to add in feature specific operations after a user has been
// successfully authenticated. An example would be creating objects based on the user.
type LoginHook func(context.Context, types.User) error

// Server keeps the cluster together. It acts as a certificate authority (CA) for
// a cluster and:
// - generates the keypair for the node it's running on
Expand Down Expand Up @@ -584,6 +589,10 @@ type Server struct {
// headlessAuthenticationWatcher is a headless authentication watcher,
// used to catch and propagate headless authentication request changes.
headlessAuthenticationWatcher *local.HeadlessAuthenticationWatcher

loginHooksMu sync.RWMutex
// loginHooks are a list of hooks that will be called on login.
loginHooks []LoginHook
}

// SetSAMLService registers svc as the SAMLService that provides the SAML
Expand Down Expand Up @@ -639,6 +648,41 @@ func (a *Server) GetLoginRuleEvaluator() loginrule.Evaluator {
return a.loginRuleEvaluator
}

// RegisterLoginHook will register a login hook with the auth server.
func (a *Server) RegisterLoginHook(hook LoginHook) {
a.loginHooksMu.Lock()
defer a.loginHooksMu.Unlock()

a.loginHooks = append(a.loginHooks, hook)
}

// CallLoginHooks will call the registered login hooks.
func (a *Server) CallLoginHooks(ctx context.Context, user types.User) error {
// Make a copy of the login hooks to operate on.
a.loginHooksMu.RLock()
loginHooks := make([]LoginHook, len(a.loginHooks))
copy(loginHooks, a.loginHooks)
a.loginHooksMu.RUnlock()

if len(loginHooks) == 0 {
return nil
}

var errs []error
for _, hook := range loginHooks {
errs = append(errs, hook(ctx, user))
}

return trace.NewAggregate(errs...)
}

// ResetLoginHooks will clear out the login hooks.
func (a *Server) ResetLoginHooks() {
a.loginHooksMu.Lock()
a.loginHooks = nil
a.loginHooksMu.Unlock()
}

// CloseContext returns the close context
func (a *Server) CloseContext() context.Context {
return a.closeCtx
Expand Down
50 changes: 50 additions & 0 deletions lib/auth/auth_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package auth

import (
"context"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -553,8 +554,16 @@ func TestServer_Authenticate_passwordless(t *testing.T) {
proxyClient, err := svr.NewClient(TestBuiltin(types.RoleProxy))
require.NoError(t, err)

// used to keep track of calls to login hooks.
var loginHookCounter atomic.Int32
var loginHook LoginHook = func(_ context.Context, _ types.User) error {
loginHookCounter.Add(1)
return nil
}

tests := []struct {
name string
loginHooks []LoginHook
authenticate func(t *testing.T, resp *wanlib.CredentialAssertionResponse)
}{
{
Expand All @@ -573,6 +582,26 @@ func TestServer_Authenticate_passwordless(t *testing.T) {
require.Equal(t, user, loginResp.Username, "Unexpected username")
},
},
{
name: "ssh with login hooks",
loginHooks: []LoginHook{
loginHook,
loginHook,
},
authenticate: func(t *testing.T, resp *wanlib.CredentialAssertionResponse) {
loginResp, err := proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{
AuthenticateUserRequest: AuthenticateUserRequest{
Webauthn: resp,
PublicKey: []byte(sshPubKey),
},
TTL: 24 * time.Hour,
})
require.NoError(t, err, "Failed to perform passwordless authentication")
require.NotNil(t, loginResp, "SSH response nil")
require.NotEmpty(t, loginResp.Cert, "SSH certificate empty")
require.Equal(t, user, loginResp.Username, "Unexpected username")
},
},
{
name: "web",
authenticate: func(t *testing.T, resp *wanlib.CredentialAssertionResponse) {
Expand All @@ -583,9 +612,28 @@ func TestServer_Authenticate_passwordless(t *testing.T) {
require.Equal(t, user, session.GetUser(), "Unexpected username")
},
},
{
name: "web with login hooks",
loginHooks: []LoginHook{
loginHook,
},
authenticate: func(t *testing.T, resp *wanlib.CredentialAssertionResponse) {
session, err := proxyClient.AuthenticateWebUser(ctx, AuthenticateUserRequest{
Webauthn: resp,
})
require.NoError(t, err, "Failed to perform passwordless authentication")
require.Equal(t, user, session.GetUser(), "Unexpected username")
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
svr.Auth().ResetLoginHooks()
loginHookCounter.Store(0)
for _, hook := range test.loginHooks {
svr.Auth().RegisterLoginHook(hook)
}

// Fail a login attempt so have a non-empty list of attempts.
_, err := proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{
AuthenticateUserRequest: AuthenticateUserRequest{
Expand Down Expand Up @@ -621,6 +669,8 @@ func TestServer_Authenticate_passwordless(t *testing.T) {
attempts, err = authServer.GetUserLoginAttempts(user)
require.NoError(t, err)
require.Empty(t, attempts, "Login attempts not reset")

require.Equal(t, len(test.loginHooks), int(loginHookCounter.Load()))
})
}
}
Expand Down
4 changes: 4 additions & 0 deletions lib/auth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,10 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODia
return nil, trace.Wrap(err, "Failed to create user from provided parameters.")
}

if err := a.CallLoginHooks(ctx, user); err != nil {
return nil, trace.Wrap(err)
}

// Auth was successful, return session, certificate, etc. to caller.
auth := GithubAuthResponse{
Req: GithubAuthRequestFromProto(req),
Expand Down
52 changes: 28 additions & 24 deletions lib/auth/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,17 @@ type SessionCreds struct {

// AuthenticateUser authenticates user based on the request type.
// Returns the username of the authenticated user.
func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserRequest) (string, error) {
user := req.Username
func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserRequest) (types.User, error) {
username := req.Username

mfaDev, actualUser, err := s.authenticateUser(ctx, req)
mfaDev, actualUsername, err := s.authenticateUser(ctx, req)
// err is handled below.
switch {
case user != "" && actualUser != "" && user != actualUser:
log.Warnf("Authenticate user mismatch (%q vs %q). Using request user (%q)", user, actualUser, user)
case user == "" && actualUser != "":
log.Debugf("User %q authenticated via passwordless", actualUser)
user = actualUser
case username != "" && actualUsername != "" && username != actualUsername:
log.Warnf("Authenticate user mismatch (%q vs %q). Using request user (%q)", username, actualUsername, username)
case username == "" && actualUsername != "":
log.Debugf("User %q authenticated via passwordless", actualUsername)
username = actualUsername
}

event := &apievents.UserLogin{
Expand All @@ -132,7 +132,7 @@ func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserReque
Code: events.UserLocalLoginFailureCode,
},
UserMetadata: apievents.UserMetadata{
User: user,
User: username,
},
Method: events.LoginMethodLocal,
}
Expand All @@ -148,13 +148,28 @@ func (s *Server) AuthenticateUser(ctx context.Context, req AuthenticateUserReque
event.UserAgent = req.ClientMetadata.UserAgent
}
}

var user types.User
if err != nil {
event.Code = events.UserLocalLoginFailureCode
event.Status.Success = false
event.Status.Error = err.Error()
} else {
event.Code = events.UserLocalLoginCode
event.Status.Success = true

var err error
user, err = s.GetUser(username, false /* withSecrets */)
if err != nil {
return nil, trace.Wrap(err)
}

// After we're sure that the user has been logged in successfully, we should call
// the registered login hooks. Login hooks can be registered by other processes to
// execute arbitrary operations after a successful login.
if err := s.CallLoginHooks(ctx, user); err != nil {
return nil, trace.Wrap(err)
}
}
if err := s.emitter.EmitAuditEvent(s.closeCtx, event); err != nil {
log.WithError(err).Warn("Failed to emit login event.")
Expand Down Expand Up @@ -443,13 +458,7 @@ func (s *Server) AuthenticateWebUser(ctx context.Context, req AuthenticateUserRe
return session, nil
}

actualUser, err := s.AuthenticateUser(ctx, req)
if err != nil {
return nil, trace.Wrap(err)
}
username = actualUser

user, err := s.GetUser(username, false /* withSecrets */)
user, err := s.AuthenticateUser(ctx, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -575,18 +584,13 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq
return nil, trace.Wrap(err)
}

actualUser, err := s.AuthenticateUser(ctx, req.AuthenticateUserRequest)
if err != nil {
return nil, trace.Wrap(err)
}
username = actualUser

// It's safe to extract the roles and traits directly from services.User as
// this endpoint is only used for local accounts.
user, err := s.GetUser(username, false /* withSecrets */)
user, err := s.AuthenticateUser(ctx, req.AuthenticateUserRequest)
if err != nil {
return nil, trace.Wrap(err)
}

accessInfo := services.AccessInfoFromUser(user)
checker, err := services.NewAccessChecker(accessInfo, clusterName.GetClusterName(), s)
if err != nil {
Expand Down Expand Up @@ -649,7 +653,7 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq
}
UserLoginCount.Inc()
return &SSHLoginResponse{
Username: username,
Username: user.GetName(),
Cert: certs.SSH,
TLSCert: certs.TLS,
HostSigners: AuthoritiesToTrustedCerts(hostCertAuthorities),
Expand Down