diff --git a/lib/auth/join_ec2.go b/lib/auth/join_ec2.go index d9d8782dd27f7..515919b1a41a2 100644 --- a/lib/auth/join_ec2.go +++ b/lib/auth/join_ec2.go @@ -252,6 +252,25 @@ func dbExists(ctx context.Context, presence services.Presence, hostID string) (b return false, nil } +func oktaExists(ctx context.Context, presence services.Presence, hostID string) (bool, error) { + namespaces, err := presence.GetNamespaces() + if err != nil { + return false, trace.Wrap(err) + } + for _, namespace := range namespaces { + apps, err := presence.GetApplicationServers(ctx, namespace.GetName()) + if err != nil { + return false, trace.Wrap(err) + } + for _, app := range apps { + if app.GetName() == hostID && app.Origin() == types.OriginOkta { + return true, nil + } + } + } + return false, nil +} + func desktopServiceExists(ctx context.Context, presence services.Presence, hostID string) (bool, error) { svcs, err := presence.GetWindowsDesktopServices(ctx) if err != nil { @@ -292,9 +311,17 @@ func (a *Server) tryToDetectIdentityReuse(ctx context.Context, req *types.Regist instanceExists, err = dbExists(ctx, a, req.HostID) case types.RoleWindowsDesktop: instanceExists, err = desktopServiceExists(ctx, a, req.HostID) + case types.RoleOkta: + instanceExists, err = oktaExists(ctx, a, req.HostID) case types.RoleInstance: // no appropriate check exists for the Instance role instanceExists = false + case types.RoleDiscovery: + // no appropriate check exists for the Discovery role + instanceExists = false + case types.RoleMDM: + // no appropriate check exists for the MDM role + instanceExists = false default: return trace.BadParameter("unsupported role: %q", req.Role) } diff --git a/lib/auth/join_ec2_test.go b/lib/auth/join_ec2_test.go index bcd53396094fe..6ce6f5aad84e2 100644 --- a/lib/auth/join_ec2_test.go +++ b/lib/auth/join_ec2_test.go @@ -596,6 +596,9 @@ func TestHostUniqueCheck(t *testing.T) { types.RoleDatabase, types.RoleApp, types.RoleWindowsDesktop, + types.RoleMDM, + types.RoleDiscovery, + types.RoleOkta, }, Allow: []*types.TokenRule{ { @@ -739,6 +742,39 @@ func TestHostUniqueCheck(t *testing.T) { require.NoError(t, err) }, }, + { + role: types.RoleOkta, + upserter: func(name string) { + app, err := types.NewAppV3( + types.Metadata{ + Name: "test-app", + Namespace: defaults.Namespace, + }, + types.AppSpecV3{ + URI: "https://app.localhost", + }) + require.NoError(t, err) + appServer, err := types.NewAppServerV3( + types.Metadata{ + Name: name, + Namespace: defaults.Namespace, + }, + types.AppServerSpecV3{ + HostID: name, + App: app, + }) + require.NoError(t, err) + appServer.SetOrigin(types.OriginOkta) + _, err = a.UpsertApplicationServer(context.Background(), appServer) + require.NoError(t, err) + }, + }, + { + role: types.RoleDiscovery, + }, + { + role: types.RoleMDM, + }, } ctx = context.WithValue(ctx, ec2ClientKey{}, ec2ClientRunning{}) @@ -759,14 +795,16 @@ func TestHostUniqueCheck(t *testing.T) { _, err = a.RegisterUsingToken(ctx, &request) require.NoError(t, err) - // add the server - name := instance1.account + "-" + instance1.instanceID - tc.upserter(name) + if tc.upserter != nil { + // add the server + name := instance1.account + "-" + instance1.instanceID + tc.upserter(name) - // request should fail - _, err = a.RegisterUsingToken(ctx, &request) - expectedErr := &trace.AccessDeniedError{} - require.ErrorAs(t, err, &expectedErr) + // request should fail + _, err = a.RegisterUsingToken(ctx, &request) + expectedErr := &trace.AccessDeniedError{} + require.ErrorAs(t, err, &expectedErr) + } }) } }