diff --git a/lib/auth/bot.go b/lib/auth/bot.go index 1f431c3fa4834..a32fdcff66550 100644 --- a/lib/auth/bot.go +++ b/lib/auth/bot.go @@ -258,6 +258,26 @@ func (a *Server) tryLockBotDueToGenerationMismatch(ctx context.Context, username return nil } +// shouldEnforceGenerationCounter decides if generation counter checks should be +// enforced for a given join method. Note that in certain situations the counter +// may still not technically be enforced, for example, when onboarding a new bot +// or recovering a bound keypair bot. +func shouldEnforceGenerationCounter(renewable bool, joinMethod string) bool { + if renewable { + return true + } + + // Note: token renewals are handled by the `renewable` check above, since + // those certs are issued via `ServerWithRoles.generateUserCerts()` and do + // not have an associated join method. + switch joinMethod { + case string(types.JoinMethodBoundKeypair): + return true + default: + return false + } +} + // updateBotInstance updates the bot instance associated with the context // identity, if any. If the optional `templateAuthRecord` is provided, various // metadata fields will be copied into the newly generated auth record. @@ -386,9 +406,9 @@ func (a *Server) updateBotInstance( return trace.AccessDenied("a current identity generation must be provided") } else if currentIdentityGeneration > 0 && currentIdentityGeneration != instanceGeneration { - // For now, continue to only enforce generation counter checks on - // renewable (i.e. token) identities. - if req.renewable { + // Generation counter enforcement depends on the type of cert and join + // method (if any - token renewals technically have no join method.) + if shouldEnforceGenerationCounter(req.renewable, authRecord.JoinMethod) { if err := a.tryLockBotDueToGenerationMismatch(ctx, username); err != nil { log.WarnContext(ctx, "Failed to lock bot when a generation mismatch was detected", "error", err) } @@ -422,6 +442,8 @@ func (a *Server) updateBotInstance( // compatibility, but only if this is a renewable identity. Previous // versions only expect a nonzero generation counter for token joins, so // setting this for other methods will break compatibility. + // Note: new join methods that enforce generation counter checks will not + // write a generation counter to user labels (e.g. bound keypair). if req.renewable { if err := a.commitLegacyGenerationCounterToBotUser(ctx, username, uint64(newGeneration)); err != nil { log.WarnContext(ctx, "unable to commit legacy generation counter to bot user", "error", err) diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 70dfcb78cb07d..1c307a925a3b5 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -643,6 +643,9 @@ func (a *Server) RegisterUsingBoundKeypairMethod( nil, // TODO: extended claims for this type? nil, // TODO: workload id claims ) + if err != nil { + return nil, trace.Wrap(err) + } if expectNewBotInstance { mutators = append( diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go index 8c9795ab29c6a..c8c7fa353bc7a 100644 --- a/lib/auth/join_bound_keypair_test.go +++ b/lib/auth/join_bound_keypair_test.go @@ -36,8 +36,10 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/boundkeypair" + "github.com/gravitational/teleport/lib/boundkeypair/boundkeypairexperiment" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/tlsca" ) type mockBoundKeypairValidator struct { @@ -736,3 +738,221 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { }) } } + +type mockSolver struct { + publicKey string +} + +func (m *mockSolver) solver() client.RegisterUsingBoundKeypairChallengeResponseFunc { + return func(challenge *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) { + switch r := challenge.Response.(type) { + case *proto.RegisterUsingBoundKeypairMethodResponse_Rotation: + return &proto.RegisterUsingBoundKeypairMethodRequest{ + Payload: &proto.RegisterUsingBoundKeypairMethodRequest_RotationResponse{ + RotationResponse: &proto.RegisterUsingBoundKeypairRotationResponse{ + PublicKey: m.publicKey, + }, + }, + }, nil + case *proto.RegisterUsingBoundKeypairMethodResponse_Challenge: + return &proto.RegisterUsingBoundKeypairMethodRequest{ + Payload: &proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse{ + ChallengeResponse: &proto.RegisterUsingBoundKeypairChallengeResponse{ + // For testing purposes, we'll just reply with the + // public key, to avoid needing to parse the JWT. + Solution: []byte(r.Challenge.PublicKey), + }, + }, + }, nil + default: + return nil, trace.BadParameter("not supported") + + } + } +} + +func newMockSolver(t *testing.T, pubKey string) *mockSolver { + t.Helper() + + return &mockSolver{ + publicKey: pubKey, + } +} + +func testExtractBotParamsFromCerts(t *testing.T, certs *proto.Certs) (string, uint64) { + t.Helper() + + parsed, err := tlsca.ParseCertificatePEM(certs.TLS) + require.NoError(t, err) + ident, err := tlsca.FromSubject(parsed.Subject, parsed.NotAfter) + require.NoError(t, err) + + return ident.BotInstanceID, ident.Generation +} + +func TestServer_RegisterUsingBoundKeypairMethod_GenerationCounter(t *testing.T) { + ctx := context.Background() + + // TODO: This prevents parallel execution; remove along with the experiment. + boundkeypairexperiment.SetEnabled(true) + + sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey) + require.NoError(t, err) + + _, correctPublicKey := testBoundKeypair(t) + + clock := clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()) + + srv := newTestTLSServer(t, withClock(clock)) + auth := srv.Auth() + auth.createBoundKeypairValidator = func(subject, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) { + return &mockBoundKeypairValidator{ + subject: subject, + clusterName: clusterName, + publicKey: publicKey, + }, nil + } + + _, err = CreateRole(ctx, auth, "example", types.RoleSpecV6{}) + require.NoError(t, err) + + adminClient, err := srv.NewClient(TestAdmin()) + require.NoError(t, err) + + _, err = adminClient.BotServiceClient().CreateBot(ctx, &machineidv1pb.CreateBotRequest{ + Bot: &machineidv1pb.Bot{ + Kind: types.KindBot, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &machineidv1pb.BotSpec{ + Roles: []string{"example"}, + }, + }, + }) + require.NoError(t, err) + + token, err := types.NewProvisionTokenFromSpecAndStatus( + "bound-keypair-test", + time.Now().Add(2*time.Hour), + types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodBoundKeypair, + Roles: []types.SystemRole{types.RoleBot}, + BotName: "test", + BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: correctPublicKey, + }, + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + Limit: 2, + }, + }, + }, + &types.ProvisionTokenStatusV2{}, + ) + require.NoError(t, err) + require.NoError(t, auth.CreateBoundKeypairToken(ctx, token)) + + makeInitReq := func(mutators ...func(r *proto.RegisterUsingBoundKeypairInitialRequest)) *proto.RegisterUsingBoundKeypairInitialRequest { + req := &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{ + HostID: "host-id", + Role: types.RoleBot, + PublicTLSKey: tlsPublicKey, + PublicSSHKey: sshPublicKey, + Token: "bound-keypair-test", + }, + } + for _, mutator := range mutators { + mutator(req) + } + return req + } + + withJoinState := func(state []byte) func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + return func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.PreviousJoinState = state + } + } + + withBotParamsFromIdent := func(t *testing.T, certs *proto.Certs) func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + id, gen := testExtractBotParamsFromCerts(t, certs) + + return func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.JoinRequest.BotInstanceID = id + r.JoinRequest.BotGeneration = int32(gen) + } + } + + solver := newMockSolver(t, correctPublicKey) + response, err := auth.RegisterUsingBoundKeypairMethod(ctx, makeInitReq(), solver.solver()) + require.NoError(t, err) + + instance, generation := testExtractBotParamsFromCerts(t, response.Certs) + require.Equal(t, uint64(1), generation) + + firstInstance := instance + + // Register several times. + for i := range 10 { + response, err = auth.RegisterUsingBoundKeypairMethod( + ctx, + makeInitReq(withJoinState(response.JoinState), withBotParamsFromIdent(t, response.Certs)), + solver.solver(), + ) + require.NoError(t, err) + + instance, generation := testExtractBotParamsFromCerts(t, response.Certs) + require.Equal(t, uint64(i+2), generation) + require.Equal(t, firstInstance, instance) + } + + // Perform a recovery to get a new instance and reset the counter. + response, err = auth.RegisterUsingBoundKeypairMethod(ctx, makeInitReq(withJoinState(response.JoinState)), solver.solver()) + require.NoError(t, err) + + instance, generation = testExtractBotParamsFromCerts(t, response.Certs) + require.Equal(t, uint64(1), generation, "generation counter should reset") + require.NotEqual(t, instance, firstInstance) + + secondInstance := instance + + // Register several more times. + for i := range 10 { + response, err = auth.RegisterUsingBoundKeypairMethod( + ctx, + makeInitReq(withJoinState(response.JoinState), withBotParamsFromIdent(t, response.Certs)), + solver.solver(), + ) + require.NoError(t, err) + + instance, generation := testExtractBotParamsFromCerts(t, response.Certs) + require.Equal(t, uint64(i+2), generation) + require.Equal(t, secondInstance, instance) + } + + // Provide an incorrect generation counter value. + response, err = auth.RegisterUsingBoundKeypairMethod( + ctx, + makeInitReq( + withJoinState(response.JoinState), + withBotParamsFromIdent(t, response.Certs), + func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.JoinRequest.BotGeneration = 1 + }, + ), + solver.solver(), + ) + require.Nil(t, response) + require.ErrorContains(t, err, "renewable cert generation mismatch") + + // The bot user should now be locked. + locks, err := srv.Auth().GetLocks(ctx, true, types.LockTarget{ + User: "bot-test", + }) + require.NoError(t, err) + require.NotEmpty(t, locks) +}