diff --git a/lib/auth/join_ec2.go b/lib/auth/join_ec2.go index d9d8782dd27f7..4b2d2c2ac7c84 100644 --- a/lib/auth/join_ec2.go +++ b/lib/auth/join_ec2.go @@ -252,6 +252,25 @@ func dbExists(ctx context.Context, presence services.Presence, hostID string) (b return false, nil } +func oktaExists(ctx context.Context, presence services.Presence, hostID string) (bool, error) { + namespaces, err := presence.GetNamespaces() + if err != nil { + return false, trace.Wrap(err) + } + for _, namespace := range namespaces { + apps, err := presence.GetApplicationServers(ctx, namespace.GetName()) + if err != nil { + return false, trace.Wrap(err) + } + for _, app := range apps { + if app.GetName() == hostID && app.Origin() == types.OriginOkta { + return true, nil + } + } + } + return false, nil +} + func desktopServiceExists(ctx context.Context, presence services.Presence, hostID string) (bool, error) { svcs, err := presence.GetWindowsDesktopServices(ctx) if err != nil { @@ -292,9 +311,14 @@ func (a *Server) tryToDetectIdentityReuse(ctx context.Context, req *types.Regist instanceExists, err = dbExists(ctx, a, req.HostID) case types.RoleWindowsDesktop: instanceExists, err = desktopServiceExists(ctx, a, req.HostID) + case types.RoleOkta: + instanceExists, err = oktaExists(ctx, a, req.HostID) case types.RoleInstance: // no appropriate check exists for the Instance role instanceExists = false + case types.RoleDiscovery: + // no appropriate check exists for the Discovery role + instanceExists = false default: return trace.BadParameter("unsupported role: %q", req.Role) } diff --git a/lib/auth/join_ec2_test.go b/lib/auth/join_ec2_test.go index 72c9b5b0ca782..ccd9ae79ce57a 100644 --- a/lib/auth/join_ec2_test.go +++ b/lib/auth/join_ec2_test.go @@ -97,9 +97,11 @@ VUP+3jgenPrd7OyCWPSwOoOBMhSlAAAAAAAA`), } ) -type ec2ClientNoInstance struct{} -type ec2ClientNotRunning struct{} -type ec2ClientRunning struct{} +type ( + ec2ClientNoInstance struct{} + ec2ClientNotRunning struct{} + ec2ClientRunning struct{} +) func (c ec2ClientNoInstance) DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { return &ec2.DescribeInstancesOutput{}, nil @@ -594,6 +596,8 @@ func TestHostUniqueCheck(t *testing.T) { types.RoleDatabase, types.RoleApp, types.RoleWindowsDesktop, + types.RoleDiscovery, + types.RoleOkta, }, Allow: []*types.TokenRule{ { @@ -650,7 +654,6 @@ func TestHostUniqueCheck(t *testing.T) { { role: types.RoleKube, upserter: func(name string) { - kube, err := types.NewKubernetesServerV3( types.Metadata{ Name: name, @@ -669,7 +672,6 @@ func TestHostUniqueCheck(t *testing.T) { require.NoError(t, err) _, err = a.UpsertKubernetesServer(context.Background(), kube) require.NoError(t, err) - }, }, { @@ -729,6 +731,36 @@ func TestHostUniqueCheck(t *testing.T) { require.NoError(t, err) }, }, + { + role: types.RoleOkta, + upserter: func(name string) { + app, err := types.NewAppV3( + types.Metadata{ + Name: "test-app", + Namespace: defaults.Namespace, + }, + types.AppSpecV3{ + URI: "https://app.localhost", + }) + require.NoError(t, err) + appServer, err := types.NewAppServerV3( + types.Metadata{ + Name: name, + Namespace: defaults.Namespace, + }, + types.AppServerSpecV3{ + HostID: name, + App: app, + }) + require.NoError(t, err) + appServer.SetOrigin(types.OriginOkta) + _, err = a.UpsertApplicationServer(context.Background(), appServer) + require.NoError(t, err) + }, + }, + { + role: types.RoleDiscovery, + }, } ctx = context.WithValue(ctx, ec2ClientKey{}, ec2ClientRunning{}) @@ -749,14 +781,16 @@ func TestHostUniqueCheck(t *testing.T) { _, err = a.RegisterUsingToken(ctx, &request) require.NoError(t, err) - // add the server - name := instance1.account + "-" + instance1.instanceID - tc.upserter(name) + if tc.upserter != nil { + // add the server + name := instance1.account + "-" + instance1.instanceID + tc.upserter(name) - // request should fail - _, err = a.RegisterUsingToken(ctx, &request) - expectedErr := &trace.AccessDeniedError{} - require.ErrorAs(t, err, &expectedErr) + // request should fail + _, err = a.RegisterUsingToken(ctx, &request) + expectedErr := &trace.AccessDeniedError{} + require.ErrorAs(t, err, &expectedErr) + } }) } }