Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,10 @@ type Server struct {

// license is the Teleport Enterprise license used to start the auth server
license *liblicense.License

// headlessAuthenticationWatcher is a headless authentication watcher,
// used to catch and propagate headless authentication request changes.
headlessAuthenticationWatcher *local.HeadlessAuthenticationWatcher
}

// SetSAMLService registers svc as the SAMLService that provides the SAML
Expand Down Expand Up @@ -604,6 +608,12 @@ func (a *Server) checkLockInForce(mode constants.LockingMode, targets []types.Lo
return a.lockWatcher.CheckLockInForce(mode, targets...)
}

func (a *Server) SetHeadlessAuthenticationWatcher(headlessAuthenticationWatcher *local.HeadlessAuthenticationWatcher) {
a.lock.Lock()
defer a.lock.Unlock()
a.headlessAuthenticationWatcher = headlessAuthenticationWatcher
}

// runPeriodicOperations runs some periodic bookkeeping operations
// performed by auth server
func (a *Server) runPeriodicOperations() {
Expand Down
8 changes: 8 additions & 0 deletions lib/auth/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
}
srv.AuthServer.SetLockWatcher(srv.LockWatcher)

headlessAuthenticationWatcher, err := local.NewHeadlessAuthenticationWatcher(ctx, local.HeadlessAuthenticationWatcherConfig{
Backend: b,
})
if err != nil {
return nil, trace.Wrap(err)
}
srv.AuthServer.SetHeadlessAuthenticationWatcher(headlessAuthenticationWatcher)

srv.Authorizer, err = authz.NewAuthorizer(authz.AuthorizerOpts{
ClusterName: srv.ClusterName,
AccessPoint: srv.AuthServer,
Expand Down
8 changes: 8 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,14 @@ func (process *TeleportProcess) initAuthService() error {
}
authServer.SetLockWatcher(lockWatcher)

headlessAuthenticationWatcher, err := local.NewHeadlessAuthenticationWatcher(process.ExitContext(), local.HeadlessAuthenticationWatcherConfig{
Backend: b,
})
if err != nil {
return trace.Wrap(err)
}
authServer.SetHeadlessAuthenticationWatcher(headlessAuthenticationWatcher)

process.setLocalAuth(authServer)

// The auth server runs its own upload completer, which is necessary in sync recording modes where
Expand Down
27 changes: 26 additions & 1 deletion lib/services/local/headlessauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func (s *IdentityService) CreateHeadlessAuthenticationStub(ctx context.Context,
if _, err = s.Create(ctx, *item); err != nil {
return nil, trace.Wrap(err)
}

return headlessAuthn, nil
}

Expand Down Expand Up @@ -87,9 +88,31 @@ func (s *IdentityService) GetHeadlessAuthentication(ctx context.Context, name st
if err != nil {
return nil, trace.Wrap(err)
}

return headlessAuthn, nil
}

// GetHeadlessAuthentications returns all headless authentications from the backend.
func (s *IdentityService) GetHeadlessAuthentications(ctx context.Context) ([]*types.HeadlessAuthentication, error) {
rangeStart := headlessAuthenticationKey("")
rangeEnd := backend.RangeEnd(rangeStart)
items, err := s.GetRange(ctx, rangeStart, rangeEnd, 0)
if err != nil {
return nil, trace.Wrap(err)
}

headlessAuthns := make([]*types.HeadlessAuthentication, len(items.Items))
for i, item := range items.Items {
headlessAuthn, err := unmarshalHeadlessAuthenticationFromItem(&item)
if err != nil {
return nil, trace.Wrap(err)
}
headlessAuthns[i] = headlessAuthn
}

return headlessAuthns, nil
}

// DeleteHeadlessAuthentication deletes a headless authentication from the backend by name.
func (s *IdentityService) DeleteHeadlessAuthentication(ctx context.Context, name string) error {
err := s.Delete(ctx, headlessAuthenticationKey(name))
Expand Down Expand Up @@ -119,7 +142,9 @@ func unmarshalHeadlessAuthenticationFromItem(item *backend.Item) (*types.Headles
return nil, trace.Wrap(err, "error unmarshalling headless authentication from storage")
}

headlessAuthn.Metadata.Expires = &item.Expires
// Copy item.Expires without pointer to avoid race conditions with memory backend.
headlessAuthn.Metadata.Expires = new(time.Time)
*headlessAuthn.Metadata.Expires = item.Expires
if err := headlessAuthn.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
4 changes: 4 additions & 0 deletions lib/services/local/headlessauthn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ func TestIdentityService_HeadlessAuthenticationBackend(t *testing.T) {
retrieved, err := identity.GetHeadlessAuthentication(ctx, test.ha.Metadata.Name)
require.NoError(t, err, "GetHeadlessAuthentication returned non-nil error")
require.Equal(t, swapped, retrieved)

retrievedList, err := identity.GetHeadlessAuthentications(ctx)
require.NoError(t, err, "GetHeadlessAuthentications returned non-nil error")
require.Equal(t, []*types.HeadlessAuthentication{swapped}, retrievedList)
})
}
}
Expand Down
Loading