Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lib/auth/join_ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,25 @@ func dbExists(ctx context.Context, presence services.Presence, hostID string) (b
return false, nil
}

func oktaExists(ctx context.Context, presence services.Presence, hostID string) (bool, error) {
namespaces, err := presence.GetNamespaces()
if err != nil {
return false, trace.Wrap(err)
}
for _, namespace := range namespaces {
apps, err := presence.GetApplicationServers(ctx, namespace.GetName())
if err != nil {
return false, trace.Wrap(err)
}
for _, app := range apps {
if app.GetName() == hostID && app.Origin() == types.OriginOkta {
return true, nil
}
}
}
return false, nil
}

func desktopServiceExists(ctx context.Context, presence services.Presence, hostID string) (bool, error) {
svcs, err := presence.GetWindowsDesktopServices(ctx)
if err != nil {
Expand Down Expand Up @@ -292,9 +311,17 @@ func (a *Server) tryToDetectIdentityReuse(ctx context.Context, req *types.Regist
instanceExists, err = dbExists(ctx, a, req.HostID)
case types.RoleWindowsDesktop:
instanceExists, err = desktopServiceExists(ctx, a, req.HostID)
case types.RoleOkta:
instanceExists, err = oktaExists(ctx, a, req.HostID)
case types.RoleInstance:
// no appropriate check exists for the Instance role
instanceExists = false
case types.RoleDiscovery:
// no appropriate check exists for the Discovery role
instanceExists = false
case types.RoleMDM:
// no appropriate check exists for the MDM role
instanceExists = false
default:
return trace.BadParameter("unsupported role: %q", req.Role)
}
Expand Down
52 changes: 45 additions & 7 deletions lib/auth/join_ec2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,9 @@ func TestHostUniqueCheck(t *testing.T) {
types.RoleDatabase,
types.RoleApp,
types.RoleWindowsDesktop,
types.RoleMDM,
types.RoleDiscovery,
types.RoleOkta,
},
Allow: []*types.TokenRule{
{
Expand Down Expand Up @@ -739,6 +742,39 @@ func TestHostUniqueCheck(t *testing.T) {
require.NoError(t, err)
},
},
{
role: types.RoleOkta,
upserter: func(name string) {
app, err := types.NewAppV3(
types.Metadata{
Name: "test-app",
Namespace: defaults.Namespace,
},
types.AppSpecV3{
URI: "https://app.localhost",
})
require.NoError(t, err)
appServer, err := types.NewAppServerV3(
types.Metadata{
Name: name,
Namespace: defaults.Namespace,
},
types.AppServerSpecV3{
HostID: name,
App: app,
})
require.NoError(t, err)
appServer.SetOrigin(types.OriginOkta)
_, err = a.UpsertApplicationServer(context.Background(), appServer)
require.NoError(t, err)
},
},
{
role: types.RoleDiscovery,
},
{
role: types.RoleMDM,
},
}

ctx = context.WithValue(ctx, ec2ClientKey{}, ec2ClientRunning{})
Expand All @@ -759,14 +795,16 @@ func TestHostUniqueCheck(t *testing.T) {
_, err = a.RegisterUsingToken(ctx, &request)
require.NoError(t, err)

// add the server
name := instance1.account + "-" + instance1.instanceID
tc.upserter(name)
if tc.upserter != nil {
// add the server
name := instance1.account + "-" + instance1.instanceID
tc.upserter(name)

// request should fail
_, err = a.RegisterUsingToken(ctx, &request)
expectedErr := &trace.AccessDeniedError{}
require.ErrorAs(t, err, &expectedErr)
// request should fail
_, err = a.RegisterUsingToken(ctx, &request)
expectedErr := &trace.AccessDeniedError{}
require.ErrorAs(t, err, &expectedErr)
}
})
}
}