diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 0c5ca3074fb84..5eaf4f31ec81e 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -864,10 +864,6 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.BadParameter("bad backend state, please recreate the join token") case hasBoundPublicKey && hasBoundBotInstance && hasIncomingBotInstance: // Standard rejoin case, does not consume a rejoin. - if status.BoundBotInstanceID != req.JoinRequest.BotInstanceID { - return nil, trace.AccessDenied("bot instance mismatch") - } - if err := a.issueBoundKeypairChallenge( ctx, status.BoundPublicKey, @@ -892,16 +888,21 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.AccessDenied("join state verification failed") } + // Join state verification will check the instance IDs in the token and + // join state document, but as a sanity check, we'll also ensure it + // matches the value extracted from the certs. + // + // It should not be possible for this check to fail at this point, as + // any event that might have cycled bot instance IDs should have also + // modified the join state causing a failure above. In any case, we'll + // keep this as a sanity check. + if status.BoundBotInstanceID != req.JoinRequest.BotInstanceID { + return nil, trace.AccessDenied("bot instance mismatch") + } + // Nothing else to do, no key change, no additional audit event; regular // bot join event will be emitted later. case hasBoundPublicKey && hasBoundBotInstance && !hasIncomingBotInstance: - // Hard rejoin case, the client identity expired and a new bot instance - // is required. Consumes a rejoin. - if recoveryMode == boundkeypair.RecoveryModeStandard && !hasJoinsRemaining { - // Recovery limit only applies in "standard" mode. - return nil, trace.AccessDenied("no rejoins remaining") - } - if err := a.issueBoundKeypairChallenge( ctx, status.BoundPublicKey, @@ -911,6 +912,13 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.Wrap(err) } + // Hard rejoin case, the client identity expired and a new bot instance + // is required. Consumes a rejoin. + if recoveryMode == boundkeypair.RecoveryModeStandard && !hasJoinsRemaining { + // Recovery limit only applies in "standard" mode. + return nil, trace.AccessDenied("no rejoins remaining") + } + // Verify locks here now that we've verified private key ownership but // before we check join state. Otherwise, we could allow a lock creation // loop. diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go index 81ea87b82a7d5..2f10a06d93878 100644 --- a/lib/auth/join_bound_keypair_test.go +++ b/lib/auth/join_bound_keypair_test.go @@ -1148,6 +1148,9 @@ func TestServer_RegisterUsingBoundKeypairMethod_GenerationCounter(t *testing.T) } func TestServer_RegisterUsingBoundKeypairMethod_JoinStateFailure(t *testing.T) { + // This tests that join state verification will trigger a lock if the + // original client and a secondary client both attempt to recover in + // sequence. t.Parallel() ctx := context.Background() @@ -1298,3 +1301,179 @@ func TestServer_RegisterUsingBoundKeypairMethod_JoinStateFailure(t *testing.T) { require.ErrorContains(t, err, "a client failed to verify its join state") }, 5*time.Second, 100*time.Millisecond) } + +func TestServer_RegisterUsingBoundKeypairMethod_JoinStateFailureDuringRenewal(t *testing.T) { + // Similar to _JoinStateFailure above, this exercises the case where the + // original client still has valid certs and isn't attempting a recovery of + // its own. + t.Parallel() + + ctx := context.Background() + + sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + tlsPublicKey, err := authtest.PrivateKeyToPublicKeyTLS(sshPrivateKey) + require.NoError(t, err) + + _, correctPublicKey := testBoundKeypair(t) + + clock := clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()) + + srv := newTestTLSServer(t, withClock(clock)) + authServer := srv.Auth() + authServer.SetCreateBoundKeypairValidator(func(subject, clusterName string, publicKey crypto.PublicKey) (auth.BoundKeypairValidator, error) { + return &mockBoundKeypairValidator{ + subject: subject, + clusterName: clusterName, + publicKey: publicKey, + }, nil + }) + + _, err = authtest.CreateRole(ctx, authServer, "example", types.RoleSpecV6{}) + require.NoError(t, err) + + adminClient, err := srv.NewClient(authtest.TestAdmin()) + require.NoError(t, err) + + _, err = adminClient.BotServiceClient().CreateBot(ctx, &machineidv1pb.CreateBotRequest{ + Bot: &machineidv1pb.Bot{ + Kind: types.KindBot, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &machineidv1pb.BotSpec{ + Roles: []string{"example"}, + }, + }, + }) + require.NoError(t, err) + + token, err := types.NewProvisionTokenFromSpecAndStatus( + "bound-keypair-test", + time.Now().Add(2*time.Hour), + types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodBoundKeypair, + Roles: []types.SystemRole{types.RoleBot}, + BotName: "test", + BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: correctPublicKey, + }, + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + Limit: 3, + }, + }, + }, + &types.ProvisionTokenStatusV2{}, + ) + require.NoError(t, err) + require.NoError(t, authServer.CreateBoundKeypairToken(ctx, token)) + + makeInitReq := func(mutators ...func(r *proto.RegisterUsingBoundKeypairInitialRequest)) *proto.RegisterUsingBoundKeypairInitialRequest { + req := &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{ + HostID: "host-id", + Role: types.RoleBot, + PublicTLSKey: tlsPublicKey, + PublicSSHKey: sshPublicKey, + Token: "bound-keypair-test", + }, + } + for _, mutator := range mutators { + mutator(req) + } + return req + } + + withJoinState := func(state []byte) func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + return func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.PreviousJoinState = state + } + } + + withBotInstance := func(ident *tlsca.Identity) func(req *proto.RegisterUsingBoundKeypairInitialRequest) { + return func(req *proto.RegisterUsingBoundKeypairInitialRequest) { + req.JoinRequest.BotGeneration = int32(ident.Generation) + req.JoinRequest.BotInstanceID = ident.BotInstanceID + } + } + + // Perform the initial registration. + solver := newMockSolver(t, correctPublicKey) + firstResponse, err := authServer.RegisterUsingBoundKeypairMethod(ctx, makeInitReq(), solver.solver()) + require.NoError(t, err) + + // Parse the identity for subsequent use of the bot instance. + firstCert, err := tlsca.ParseCertificatePEM(firstResponse.Certs.TLS) + require.NoError(t, err) + firstIdent, err := tlsca.FromSubject(firstCert.Subject, firstCert.NotAfter) + require.NoError(t, err) + + // Perform a recovery, this time with a join state, simulating an attacker + // that has copied the certs. + secondResponse, err := authServer.RegisterUsingBoundKeypairMethod( + ctx, + makeInitReq(withJoinState(firstResponse.JoinState)), + solver.solver(), + ) + require.NotNil(t, secondResponse) + require.NoError(t, err) + + // Try an API call with these certs. + tlsCert, err := tls.X509KeyPair(secondResponse.Certs.TLS, sshPrivateKey) + require.NoError(t, err) + + client, err := srv.NewClientWithCert(tlsCert) + require.NoError(t, err) + _, err = client.Ping(ctx) + require.NoError(t, err) + + // Try once more, but this time with the first join state, simulating the + // original client authenticating again. + thirdResponse, err := authServer.RegisterUsingBoundKeypairMethod( + ctx, + makeInitReq( + withJoinState(firstResponse.JoinState), + + // Provide the previous identity to trigger the "standard rejoin" / + // renewal flow, rather than recovery. + withBotInstance(firstIdent), + ), + solver.solver(), + ) + require.Nil(t, thirdResponse) + + // Note: Exact error message depends on whether or not the lock is in + // effect, so we won't check it right now. + require.Error(t, err) + + // The token should now be locked - but only once. + locks, err := srv.Auth().GetLocks(ctx, true, types.LockTarget{ + JoinToken: "bound-keypair-test", + }) + require.NoError(t, err) + require.Len(t, locks, 1, "exactly one lock should be generated") + require.Contains(t, locks[0].Message(), "failed to verify its join state") + + // The previously working client should be locked. + require.Eventually(t, func() bool { + _, err = client.Ping(ctx) + return err != nil && strings.Contains(err.Error(), "access denied") + }, 5*time.Second, 100*time.Millisecond) + + // Repeat the above but with an Eventually() to consistently check the error + // message. Depending on exact timing / cache propagation / etc the lock may + // or may not be in force, but we also need to be absolutely certain to try + // to generate at least 2 locking events. + require.EventuallyWithT(t, func(t *assert.CollectT) { + nextResponse, err := authServer.RegisterUsingBoundKeypairMethod( + ctx, + makeInitReq(withJoinState(firstResponse.JoinState)), + solver.solver(), + ) + require.Nil(t, nextResponse) + require.Error(t, err) + require.ErrorContains(t, err, "a client failed to verify its join state") + }, 5*time.Second, 100*time.Millisecond) +}