diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 1ec2042eb1bca..8f24e70ed5bf9 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -71,6 +71,9 @@ func populateRegistrationSecret(v2 *types.ProvisionTokenV2) error { if v2.Spec.BoundKeypair == nil { v2.Spec.BoundKeypair = &types.ProvisionTokenSpecV2BoundKeypair{} } + if v2.Spec.BoundKeypair.Onboarding == nil { + v2.Spec.BoundKeypair.Onboarding = &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{} + } if v2.Status == nil { v2.Status = &types.ProvisionTokenStatusV2{} diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go index 9c7033b5bf371..d2020fa90f01d 100644 --- a/lib/auth/join_bound_keypair_test.go +++ b/lib/auth/join_bound_keypair_test.go @@ -1477,3 +1477,111 @@ func TestServer_RegisterUsingBoundKeypairMethod_JoinStateFailureDuringRenewal(t require.ErrorContains(t, err, "a client failed to verify its join state") }, 5*time.Second, 100*time.Millisecond) } + +func TestServer_CreateBoundKeypairToken(t *testing.T) { + t.Parallel() + // Most creation/validation functionality is tested in api/ as part of + // CheckAndSetDefaults() or in lib/services, but there's some specific logic + // at this layer to generate the default registration secret if needed we + // should test. + clock := clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()) + srv := newTestTLSServer(t, withClock(clock)) + authServer := srv.Auth() + + tests := []struct { + name string + token *types.ProvisionTokenV2 + wantErr require.ErrorAssertionFunc + assertion func(t require.TestingT, token *types.ProvisionTokenV2) + }{ + { + name: "nil onboarding spec", + token: &types.ProvisionTokenV2{ + Kind: types.KindToken, + Version: types.V2, + Metadata: types.Metadata{ + Name: "empty-onboarding", + }, + Spec: types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodBoundKeypair, + Roles: []types.SystemRole{types.RoleBot}, + BotName: "test", + BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{ + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + Mode: "insecure", + }, + }, + }, + }, + wantErr: require.NoError, + assertion: func(t require.TestingT, token *types.ProvisionTokenV2) { + require.NotEmpty(t, token.Status.BoundKeypair.RegistrationSecret) + }, + }, + { + name: "set onboarding spec with secret", + token: &types.ProvisionTokenV2{ + Kind: types.KindToken, + Version: types.V2, + Metadata: types.Metadata{ + Name: "set-onboarding-with-secret", + }, + Spec: types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodBoundKeypair, + Roles: []types.SystemRole{types.RoleBot}, + BotName: "test", + BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + RegistrationSecret: "my-initial-secret", + }, + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + Mode: "insecure", + }, + }, + }, + }, + wantErr: require.NoError, + assertion: func(t require.TestingT, token *types.ProvisionTokenV2) { + require.Equal(t, "my-initial-secret", token.Status.BoundKeypair.RegistrationSecret) + }, + }, + { + name: "set onboarding spec with no secret", + token: &types.ProvisionTokenV2{ + Kind: types.KindToken, + Version: types.V2, + Metadata: types.Metadata{ + Name: "set-onboarding-with-no-secret", + }, + Spec: types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodBoundKeypair, + Roles: []types.SystemRole{types.RoleBot}, + BotName: "test", + BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{}, + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + Mode: "insecure", + }, + }, + }, + }, + wantErr: require.NoError, + assertion: func(t require.TestingT, token *types.ProvisionTokenV2) { + require.NotEmpty(t, token.Status.BoundKeypair.RegistrationSecret) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := authServer.CreateBoundKeypairToken(t.Context(), tt.token) + tt.wantErr(t, err) + + if tt.assertion != nil { + got, err := authServer.GetToken(t.Context(), tt.token.GetName()) + require.NoError(t, err) + tt.assertion(t, got.(*types.ProvisionTokenV2)) + } + }) + } +}