Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions api/types/headlessauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,25 @@ limitations under the License.
package types

import (
"time"

"github.com/gravitational/trace"
)

// NewHeadlessAuthenticationStub creates a new a headless authentication resource with limited data.
// The stub is used to initiate headless login.
func NewHeadlessAuthenticationStub(name string, expires time.Time) (*HeadlessAuthentication, error) {
ha := &HeadlessAuthentication{
ResourceHeader: ResourceHeader{
Metadata: Metadata{
Name: name,
Expires: &expires,
},
},
}
return ha, ha.CheckAndSetDefaults()
Comment thread
Joerger marked this conversation as resolved.
}

// CheckAndSetDefaults does basic validation and default setting.
func (h *HeadlessAuthentication) CheckAndSetDefaults() error {
h.setStaticFields()
Expand Down
19 changes: 7 additions & 12 deletions lib/services/local/headlessauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,12 @@ import (

// CreateHeadlessAuthenticationStub creates a headless authentication stub in the backend.
func (s *IdentityService) CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) {
expires := s.Clock().Now().Add(defaults.CallbackTimeout)
headlessAuthn := &types.HeadlessAuthentication{
ResourceHeader: types.ResourceHeader{
Metadata: types.Metadata{
Name: name,
Expires: &expires,
},
},
headlessAuthn, err := types.NewHeadlessAuthenticationStub(name, s.Clock().Now().Add(defaults.CallbackTimeout))
if err != nil {
return nil, trace.Wrap(err)
}

item, err := marshalHeadlessAuthenticationToItem(headlessAuthn)
item, err := MarshalHeadlessAuthenticationToItem(headlessAuthn)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -60,12 +55,12 @@ func (s *IdentityService) CompareAndSwapHeadlessAuthentication(ctx context.Conte
return nil, trace.Wrap(err)
}

oldItem, err := marshalHeadlessAuthenticationToItem(old)
oldItem, err := MarshalHeadlessAuthenticationToItem(old)
if err != nil {
return nil, trace.Wrap(err)
}

newItem, err := marshalHeadlessAuthenticationToItem(new)
newItem, err := MarshalHeadlessAuthenticationToItem(new)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -120,7 +115,7 @@ func (s *IdentityService) DeleteHeadlessAuthentication(ctx context.Context, name
return trace.Wrap(err)
}

func marshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentication) (*backend.Item, error) {
func MarshalHeadlessAuthenticationToItem(headlessAuthn *types.HeadlessAuthentication) (*backend.Item, error) {
if err := headlessAuthn.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
43 changes: 16 additions & 27 deletions lib/services/local/headlessauthn_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ var watcherClosedErr = trace.Errorf("headless authentication watcher closed")
type HeadlessAuthenticationWatcherConfig struct {
// Backend is the storage backend used to create watchers.
Backend backend.Backend
// WatcherService is a service used to create new watchers.
// If nil, Backend will be used as the watcher service.
WatcherService interface {
NewWatcher(ctx context.Context, watch backend.Watch) (backend.Watcher, error)
}
// Log is a logger.
Log logrus.FieldLogger
// Clock is used to control time.
Expand All @@ -61,6 +66,9 @@ func (cfg *HeadlessAuthenticationWatcherConfig) CheckAndSetDefaults() error {
if cfg.Backend == nil {
return trace.BadParameter("missing parameter Backend")
}
if cfg.WatcherService == nil {
cfg.WatcherService = cfg.Backend
}
if cfg.Log == nil {
cfg.Log = logrus.StandardLogger()
cfg.Log.WithField("resource-kind", types.KindHeadlessAuthentication)
Expand Down Expand Up @@ -145,7 +153,7 @@ func (h *HeadlessAuthenticationWatcher) runWatchLoop(ctx context.Context) {
}

func (h *HeadlessAuthenticationWatcher) watch(ctx context.Context) error {
watcher, err := h.Backend.NewWatcher(ctx, backend.Watch{
watcher, err := h.WatcherService.NewWatcher(ctx, backend.Watch{
Name: types.KindHeadlessAuthentication,
MetricComponent: types.KindHeadlessAuthentication,
Prefixes: [][]byte{headlessAuthenticationKey("")},
Expand Down Expand Up @@ -211,29 +219,10 @@ func (h *HeadlessAuthenticationWatcher) notify(headlessAuthns ...*types.Headless
}
}

// CheckWaiter checks if there is an active waiter matching the given
// headless authentication ID. Used in tests.
func (h *HeadlessAuthenticationWatcher) CheckWaiter(name string) bool {
h.mux.Lock()
defer h.mux.Unlock()
for i := range h.waiters {
if h.waiters[i].name == name {
return true
}
}
return false
}

// Wait watches for the headless authentication with the given id to be added/updated
// in the backend, and waits for the given condition to be met, to result in an error,
// or for the given context to close.
func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, cond func(*types.HeadlessAuthentication) (bool, error)) (*types.HeadlessAuthentication, error) {
waiter, err := h.assignWaiter(ctx, name)
if err != nil {
return nil, trace.Wrap(err)
}
defer h.unassignWaiter(waiter)

checkBackend := func() (*types.HeadlessAuthentication, bool, error) {
headlessAuthn, err := h.identityService.GetHeadlessAuthentication(ctx, name)
if err != nil {
Expand All @@ -248,14 +237,20 @@ func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, c
return headlessAuthn, ok, nil
}

// With the waiter allocated, check if there is an existing entry in the backend.
// Before the waiter is allocated, check if there is an existing entry in the backend.
headlessAuthn, ok, err := checkBackend()
if err != nil && !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
} else if ok {
return headlessAuthn, nil
}

waiter, err := h.assignWaiter(ctx, name)
if err != nil {
return nil, trace.Wrap(err)
}
defer h.unassignWaiter(waiter)

for {
select {
case <-waiter.stale:
Expand All @@ -270,12 +265,6 @@ func (h *HeadlessAuthenticationWatcher) Wait(ctx context.Context, name string, c
return headlessAuthn, nil
}
case headlessAuthn := <-waiter.ch:
select {
case <-waiter.stale:
// prioritize stale check.
continue
default:
}
if ok, err := cond(headlessAuthn); err != nil {
return nil, trace.Wrap(err)
} else if ok {
Expand Down
Loading