diff --git a/lib/auth/auth.go b/lib/auth/auth.go index ee37080245c19..506cf4b969fdd 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -3211,6 +3211,8 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. mfaVerified: req.mfaVerified, activeAccessRequests: req.activeRequests, deviceID: req.deviceExtensions.DeviceID, + botInstanceID: req.botInstanceID, + joinToken: req.joinToken, }); err != nil { return nil, trace.Wrap(err) } @@ -3675,6 +3677,10 @@ type verifyLocksForUserCertsReq struct { // deviceID is the trusted device ID. // Eg: tlsca.Identity.DeviceExtensions.DeviceID deviceID string + // botInstanceID is the bot instance UUID, set only for bots. + botInstanceID string + // joinMethod is the join token name, set only for non-token bots. + joinToken string } // verifyLocksForUserCerts verifies if any locks are in place before issuing new @@ -3694,6 +3700,12 @@ func (a *Server) verifyLocksForUserCerts(req verifyLocksForUserCertsReq) error { lockTargets = append(lockTargets, services.AccessRequestsToLockTargets(req.activeAccessRequests)..., ) + if req.botInstanceID != "" { + lockTargets = append(lockTargets, types.LockTarget{BotInstanceID: req.botInstanceID}) + } + if req.joinToken != "" { + lockTargets = append(lockTargets, types.LockTarget{JoinToken: req.joinToken}) + } return trace.Wrap(a.checkLockInForce(lockingMode, lockTargets)) } diff --git a/lib/auth/bot.go b/lib/auth/bot.go index 7b29c5cd61c78..c8c3a34b4423d 100644 --- a/lib/auth/bot.go +++ b/lib/auth/bot.go @@ -121,7 +121,13 @@ func (a *Server) legacyValidateGenerationLabel(ctx context.Context, username str // The current generations must match to continue: if currentIdentityGeneration != currentUserGeneration { - if err := a.tryLockBotDueToGenerationMismatch(ctx, user.GetName()); err != nil { + if err := a.tryLockBotDueToGenerationMismatch( + ctx, + certReq.botName, + certReq.botInstanceID, + certReq.joinToken, + certReq.renewable, + ); err != nil { a.logger.WarnContext(ctx, "Failed to lock bot when a generation mismatch was detected", "error", err, "bot", user.GetName(), @@ -223,19 +229,44 @@ func (a *Server) commitLegacyGenerationCounterToBotUser(ctx context.Context, use // tryLockBotDueToGenerationMismatch creates a lock for the given bot user and // emits a `RenewableCertificateGenerationMismatch` audit event. -func (a *Server) tryLockBotDueToGenerationMismatch(ctx context.Context, username string) error { - // TODO: In the future, consider only locking the current join method / token. +func (a *Server) tryLockBotDueToGenerationMismatch( + ctx context.Context, botName, botInstanceID, joinTokenName string, renewable bool, +) error { + var spec types.LockSpecV2 + if renewable { + // Renewable implies `token` joining. These are one-time use secrets + // and will not be embedded in the TLS identity, so we can't target + // the join token and should instead rely on the bot instance ID. As + // there is a 1:1 relationship between bot instance and "token"-type + // token, this should be functionally equivalent. + spec = types.LockSpecV2{ + Target: types.LockTarget{ + BotInstanceID: botInstanceID, + }, + Message: fmt.Sprintf( + "The bot instance %s/%s has been locked due to a certificate "+ + "generation mismatch, possibly indicating a stolen "+ + "certificate.", + botName, botInstanceID, + ), + CreatedAt: a.clock.Now(), + } + } else { + spec = types.LockSpecV2{ + Target: types.LockTarget{ + JoinToken: joinTokenName, + }, + Message: fmt.Sprintf( + "Bot joins via the token %q have been locked due to a "+ + "certificate generation mismatch by %s/%s, possibly "+ + "indicating a stolen certificate.", + joinTokenName, botName, botInstanceID, + ), + CreatedAt: a.clock.Now(), + } + } - // Lock the bot user indefinitely. - lock, err := types.NewLock(uuid.New().String(), types.LockSpecV2{ - Target: types.LockTarget{ - User: username, - }, - Message: fmt.Sprintf( - "The bot user %q has been locked due to a certificate generation mismatch, possibly indicating a stolen certificate.", - username, - ), - }) + lock, err := types.NewLock(uuid.New().String(), spec) if err != nil { return trace.Wrap(err) } @@ -360,7 +391,12 @@ func (a *Server) updateBotInstance( // If the incoming identity has a nonzero generation, validate it // using the legacy check. This will increment the counter on the // request automatically - if err := a.legacyValidateGenerationLabel(ctx, username, req, uint64(currentIdentityGeneration)); err != nil { + if err := a.legacyValidateGenerationLabel( + ctx, + username, + req, + uint64(currentIdentityGeneration), + ); err != nil { return trace.Wrap(err) } @@ -409,7 +445,7 @@ func (a *Server) updateBotInstance( // Generation counter enforcement depends on the type of cert and join // method (if any - token renewals technically have no join method.) if shouldEnforceGenerationCounter(req.renewable, authRecord.JoinMethod) { - if err := a.tryLockBotDueToGenerationMismatch(ctx, username); err != nil { + if err := a.tryLockBotDueToGenerationMismatch(ctx, botName, botInstanceID, req.joinToken, req.renewable); err != nil { log.WarnContext(ctx, "Failed to lock bot when a generation mismatch was detected", "error", err) } @@ -572,6 +608,14 @@ func (a *Server) generateInitialBotCerts( joinAttributes: joinAttrs, } + // Set the join token cert field for non-renewable identities. This is used + // for lock targeting; token name lock targets are particularly useful for + // token-joined bots and it's a secret value, so we don't bother setting it. + // (The renewable flag implies token joining.) + if !renewable { + certReq.joinToken = initialAuth.JoinToken + } + if existingInstanceID == "" { // If no existing instance ID is known, create a new one. uuid, err := uuid.NewRandom() @@ -614,14 +658,6 @@ func (a *Server) generateInitialBotCerts( } } - // Set the join token cert field for non-renewable identities. This is used - // for lock targeting; token name lock targets are particularly useful for - // token-joined bots and it's a secret value, so we don't bother setting it. - // (The renewable flag implies token joining.) - if !renewable { - certReq.joinToken = initialAuth.JoinToken - } - certs, err := a.generateUserCert(ctx, certReq) if err != nil { return nil, "", trace.Wrap(err) diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index a341cb4668083..4439969e58e6a 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -549,7 +549,11 @@ func TestRegisterBotCertificateGenerationStolen(t *testing.T) { require.NoError(t, err) // Renew the certs once (e.g. this is the actual bot process) - _, certsReal, err := renewBotCerts(ctx, srv, result.Certs.TLS, bot.Status.UserName, result.PrivateKey) + renewedClient, certsReal, err := renewBotCerts(ctx, srv, result.Certs.TLS, bot.Status.UserName, result.PrivateKey) + require.NoError(t, err) + + // This client should be able to ping. + _, err = renewedClient.Ping(ctx) require.NoError(t, err) // Check the generation, it should be 2. @@ -564,12 +568,16 @@ func TestRegisterBotCertificateGenerationStolen(t *testing.T) { require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) - // The user should now be locked. + // The bot instance should now be locked. locks, err := srv.Auth().GetLocks(ctx, true, types.LockTarget{ - User: "bot-test", + BotInstanceID: impersonatedIdent.BotInstanceID, }) require.NoError(t, err) require.NotEmpty(t, locks) + + // The original client should now be locked out. + _, err = renewedClient.Ping(ctx) + require.ErrorContains(t, err, "access denied") } // TestRegisterBotCertificateExtensions ensures bot cert extensions are present. diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index e12503daaf285..395c5c845f024 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -23,8 +23,11 @@ import ( "crypto" "crypto/subtle" "encoding/json" + "fmt" + "log/slog" "time" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client" @@ -350,7 +353,7 @@ type boundKeypairStatusMutator func(*types.ProvisionTokenSpecV2BoundKeypair, *ty // the recovery counter. This verifies that the backend recovery count has not // changed, and that total join count is at least the value when the mutator was // created. -func mutateStatusConsumeRecovery(mode boundkeypair.RecoveryMode, expectRecoveryCount uint32, expectMinRecoveryLimit uint32) boundKeypairStatusMutator { +func mutateStatusConsumeRecovery(expectRecoveryCount uint32, expectMinRecoveryLimit uint32) boundKeypairStatusMutator { now := time.Now() return func(spec *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error { @@ -528,8 +531,126 @@ func (a *Server) emitBoundKeypairRotationEvent( } } +func (a *Server) tryLockBotInvalidJoinState( + ctx context.Context, + ptv2 *types.ProvisionTokenV2, + req *proto.RegisterUsingBoundKeypairInitialRequest, + validationError error, +) { + log := a.logger.With("join_token", ptv2.GetName(), "validation_error", validationError) + + if auditErr := a.emitter.EmitAuditEvent(a.closeCtx, &apievents.BoundKeypairJoinStateVerificationFailed{ + Metadata: apievents.Metadata{ + Type: events.BoundKeypairJoinStateVerificationFailed, + Code: events.BoundKeypairJoinStateVerificationFailedCode, + }, + Status: apievents.Status{ + Success: false, + Error: validationError.Error(), + }, + ConnectionMetadata: apievents.ConnectionMetadata{ + RemoteAddr: req.JoinRequest.RemoteAddr, + }, + TokenName: ptv2.GetName(), + BotName: ptv2.GetBotName(), + }); auditErr != nil { + log.WarnContext(ctx, "Failed to emit failed join state verification event", "error", auditErr) + } + + // Create a lock against this token. + lock, err := types.NewLock(uuid.New().String(), types.LockSpecV2{ + Target: types.LockTarget{ + JoinToken: ptv2.GetName(), + }, + Message: fmt.Sprintf( + "The join token %q has been locked by bot %q after a client "+ + "failed to verify its join state, possibly indicating a "+ + "stolen keypair.", + ptv2.GetName(), ptv2.GetBotName(), + ), + CreatedAt: a.clock.Now(), + }) + if err != nil { + a.logger.ErrorContext(ctx, "Unable to create lock for bound keypair token") + return + } + if err := a.UpsertLock(ctx, lock); err != nil { + log.ErrorContext(ctx, "Unable to create lock for bound keypair token after join state verification failed") + } +} + +// verifyBoundKeypairJoinState verifies the client's provided join state +// document if the current state of the token indicates the join state must be +// verified. If verification is required and fails, this returns an error and +// locks the token until a cluster admin can ensure the token hasn't been +// compromised. If verification is not required, this is a no-op. Join state +// should be verified whenever a client rejoins, but only after they have proven +// ownership of their private key. +func (a *Server) verifyBoundKeypairJoinState( + ctx context.Context, + log *slog.Logger, + req *proto.RegisterUsingBoundKeypairInitialRequest, + ptv2 *types.ProvisionTokenV2, + ca types.CertAuthority, +) error { + recoveryMode, err := boundkeypair.ParseRecoveryMode(ptv2.Spec.BoundKeypair.Recovery.Mode) + if err != nil { + return trace.Wrap(err, "parsing recovery mode") + } + + // Join state is required after the initial join (first recovery), so long + // as the mode is not insecure. + // Note: we don't verify join state if it isn't expected. This is partly + // to ensure server-side recovery will work if join state desyncs - a + // cluster admin can change the recovery mode to insecure or reset the + // recovery counter to zero and start over with a fresh join state, with + // no client intervention. + joinStateRequired := ptv2.Status.BoundKeypair.RecoveryCount > 0 && recoveryMode != boundkeypair.RecoveryModeInsecure + if !joinStateRequired { + log.DebugContext( + ctx, + "skipping join state verification, not required due to token state", + "recovery_count", ptv2.Status.BoundKeypair.RecoveryCount, + "recovery_mode", ptv2.Spec.BoundKeypair.Recovery.Mode, + ) + return nil + } + + // If join state is required but missing, raise an error. + hasIncomingJoinState := len(req.PreviousJoinState) > 0 + if !hasIncomingJoinState { + return trace.AccessDenied("previous join state is required but was not provided") + } + + log.DebugContext(ctx, "join state verification required, verifying") + joinState, err := boundkeypair.VerifyJoinState( + ca, + string(req.PreviousJoinState), + &boundkeypair.JoinStateParams{ + Clock: a.clock, + ClusterName: ca.GetClusterName(), // equivalent to clusterName but saves a method param + Token: ptv2, + }, + ) + if err != nil { + log.ErrorContext(ctx, "bound keypair join state verification failed", "error", err) + a.tryLockBotInvalidJoinState(ctx, ptv2, req, err) + + return trace.AccessDenied("join state verification failed") + } + + // Now that we've verified it, make sure the previous bot instance ID is + // passed along to generateCerts. This will only be used if a new bot + // instance is generated. + req.JoinRequest.PreviousBotInstanceID = joinState.BotInstanceID + + log.DebugContext(ctx, "join state verified successfully", "join_state", joinState) + return nil +} + // RegisterUsingBoundKeypairMethod handles joining requests for the bound -// keypair join method. If successful, returns +// keypair join method. If successful, returns a certificate bundle and client +// joining parameters for use in subsequent join attempts. func (a *Server) RegisterUsingBoundKeypairMethod( ctx context.Context, req *proto.RegisterUsingBoundKeypairInitialRequest, @@ -596,11 +717,6 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.Wrap(err, "parsing recovery mode") } - // Join state is required after the initial join (first recovery), so long - // as the mode is not insecure. - joinStateRequired := status.RecoveryCount > 0 && recoveryMode != boundkeypair.RecoveryModeInsecure - hasIncomingJoinState := len(req.PreviousJoinState) > 0 - // if set, the bound bot instance will be updated in the backend expectNewBotInstance := false @@ -625,62 +741,6 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.Wrap(err) } - var joinState *boundkeypair.JoinState - if joinStateRequired { - // If join state is required but missing, raise an error. - if !hasIncomingJoinState { - return nil, trace.AccessDenied("previous join state is required but was not provided") - } - - log.DebugContext(ctx, "join state verification required, verifying") - joinState, err = boundkeypair.VerifyJoinState( - ca, - string(req.PreviousJoinState), - &boundkeypair.JoinStateParams{ - Clock: a.clock, - ClusterName: clusterName.GetClusterName(), - Token: ptv2, - }, - ) - if err != nil { - log.ErrorContext(ctx, "bound keypair join state verification failed", "error", err) - - // TODO: Once we have token-specific locking, generate a lock; this - // indicates the keypair may have been compromised. - if auditErr := a.emitter.EmitAuditEvent(a.closeCtx, &apievents.BoundKeypairJoinStateVerificationFailed{ - Metadata: apievents.Metadata{ - Type: events.BoundKeypairJoinStateVerificationFailed, - Code: events.BoundKeypairJoinStateVerificationFailedCode, - }, - Status: apievents.Status{ - Success: false, - Error: err.Error(), - }, - ConnectionMetadata: apievents.ConnectionMetadata{ - RemoteAddr: req.JoinRequest.RemoteAddr, - }, - TokenName: ptv2.GetName(), - BotName: ptv2.GetBotName(), - }); err != nil { - a.logger.WarnContext(ctx, "Failed to emit failed join state verification event", "error", auditErr) - } - return nil, trace.AccessDenied("join state verification failed") - } - - // Now that we've verified it, make sure the previous bot instance ID is - // passed along to generateCerts. This will only be used if a new bot - // instance is generated. - req.JoinRequest.PreviousBotInstanceID = joinState.BotInstanceID - - log.DebugContext(ctx, "join state verified successfully", "join_state", joinState) - - // Note: we don't verify join state if it isn't expected. This is partly - // to ensure server-side recovery will work if join state desyncs - a - // cluster admin can change the recovery mode to insecure or reset the - // recovery counter to zero and start over with a fresh join state, with - // no client intervention. - } - switch { case !hasBoundPublicKey && !hasIncomingBotInstance: // Normal initial join attempt. No bound key, and no incoming bot @@ -768,9 +828,11 @@ func (a *Server) RegisterUsingBoundKeypairMethod( // mutator. mutators = append( mutators, - mutateStatusConsumeRecovery(recoveryMode, status.RecoveryCount, spec.Recovery.Limit), + mutateStatusConsumeRecovery(status.RecoveryCount, spec.Recovery.Limit), ) + // Note: this is the initial join, so no join state to verify. + recoveryCount += 1 expectNewBotInstance = true a.emitBoundKeypairRecoveryEvent(ctx, req, ptv2, boundPublicKey, recoveryCount, nil) @@ -797,6 +859,15 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.Wrap(err) } + // Once we've verified the client has the matching private key, validate + // the join state. This must be done after a successful challenge to + // make sure an otherwise unauthorized client can't trigger a lockout. + // This also needs to be done before rotation to prevent an attacker + // from rotating the key. + if err := a.verifyBoundKeypairJoinState(ctx, log, req, ptv2, ca); err != nil { + return nil, trace.AccessDenied("join state verification failed") + } + // Nothing else to do, no key change, no additional audit event; regular // bot join event will be emitted later. case hasBoundPublicKey && hasBoundBotInstance && !hasIncomingBotInstance: @@ -816,9 +887,15 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.Wrap(err) } + // As in the standard case above, once we've verified the client has the + // matching private key, validate the join state. + if err := a.verifyBoundKeypairJoinState(ctx, log, req, ptv2, ca); err != nil { + return nil, trace.AccessDenied("join state verification failed") + } + mutators = append( mutators, - mutateStatusConsumeRecovery(recoveryMode, status.RecoveryCount, spec.Recovery.Limit), + mutateStatusConsumeRecovery(status.RecoveryCount, spec.Recovery.Limit), ) recoveryCount += 1 diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go index fb0e4419e1e84..08ccd494c335c 100644 --- a/lib/auth/join_bound_keypair_test.go +++ b/lib/auth/join_bound_keypair_test.go @@ -21,6 +21,7 @@ package auth import ( "context" "crypto" + "crypto/tls" "testing" "time" @@ -162,11 +163,8 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { Roles: []types.SystemRole{types.RoleBot}, BotName: "test", BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{ - Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ - InitialPublicKey: correctPublicKey, - }, + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{}, Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ - // Only insecure is supported for now. Mode: boundkeypair.RecoveryModeInsecure, }, }, @@ -190,6 +188,18 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { } } + withInitialKey := func(key string) func(*types.ProvisionTokenV2) { + return func(v2 *types.ProvisionTokenV2) { + v2.Spec.BoundKeypair.Onboarding.InitialPublicKey = key + } + } + + withBoundKey := func(key string) func(*types.ProvisionTokenV2) { + return func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = key + } + } + makeJoinState := func(signer crypto.Signer, mutators ...func(s *boundkeypair.JoinStateParams)) string { params := &boundkeypair.JoinStateParams{ Clock: srv.Clock(), @@ -319,11 +329,11 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { assertSolverState func(t *testing.T, s *wrappedSolver) }{ { - // no bound key, no bound bot instance, aka initial join without - // secret + // an initial key but no bound key, and no bound bot instance. aka, + // initial join with preregistered key name: "initial-join-success", - token: makeToken(), + token: makeToken(withInitialKey(correctPublicKey)), initReq: makeInitReq(), solver: makeSolver(correctPublicKey), @@ -340,7 +350,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { // secret name: "initial-join-with-wrong-key", - token: makeToken(), + token: makeToken(withInitialKey(correctPublicKey)), initReq: makeInitReq(), solver: makeSolver(incorrectPublicKey), @@ -353,8 +363,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { // bound key, valid bound bot instance, aka "soft join" name: "reauth-success", - token: makeToken(func(v2 *types.ProvisionTokenV2) { - v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + token: makeToken(withBoundKey(correctPublicKey), func(v2 *types.ProvisionTokenV2) { v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" }), initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) { @@ -373,8 +382,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { // (should be impossible, but should fail anyway) name: "reauth-with-wrong-key", - token: makeToken(func(v2 *types.ProvisionTokenV2) { - v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + token: makeToken(withBoundKey(correctPublicKey), func(v2 *types.ProvisionTokenV2) { v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" }), initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) { @@ -392,8 +400,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { // expired and triggered a hard rejoin name: "rejoin-success", - token: makeToken(func(v2 *types.ProvisionTokenV2) { - v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + token: makeToken(withBoundKey(correctPublicKey), func(v2 *types.ProvisionTokenV2) { v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" }), initReq: makeInitReq(), @@ -413,9 +420,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { // This should fail and prompt the user to recreate the token. name: "bound-key-no-instance", - token: makeToken(func(v2 *types.ProvisionTokenV2) { - v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey - }), + token: makeToken(withBoundKey(correctPublicKey)), initReq: makeInitReq(), solver: makeSolver(correctPublicKey), @@ -445,7 +450,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { }, { name: "standard-initial-recovery-success", - token: makeToken(withRecovery("standard", 0, 1, "")), + token: makeToken(withRecovery("standard", 0, 1, ""), withInitialKey(correctPublicKey)), initReq: makeInitReq(), solver: makeSolver(correctPublicKey), assertError: require.NoError, @@ -458,7 +463,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { }, { name: "standard-success-second-recovery", - token: makeToken(withRecovery("standard", 1, 2, "id")), + token: makeToken(withRecovery("standard", 1, 2, "id"), withInitialKey(correctPublicKey)), initReq: makeInitReq(withJoinState(jwtSigner, withToken(withRecovery("standard", 1, 2, "id")))), solver: makeSolver(correctPublicKey), assertError: require.NoError, @@ -471,11 +476,11 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { }, { name: "standard-failure-missing-join-state", - token: makeToken(withRecovery("standard", 1, 2, "id")), + token: makeToken(withRecovery("standard", 1, 2, "id"), withBoundKey(correctPublicKey)), initReq: makeInitReq(), solver: makeSolver(correctPublicKey), assertError: func(tt require.TestingT, err error, i ...any) { - require.ErrorContains(tt, err, "previous join state is required") + require.ErrorContains(tt, err, "join state verification failed") }, }, { @@ -490,7 +495,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { { // Attempts to join with an outdated join state document should fail. name: "standard-failure-recovery-count-mismatch", - token: makeToken(withRecovery("standard", 2, 3, "id")), + token: makeToken(withRecovery("standard", 2, 3, "id"), withBoundKey(correctPublicKey)), initReq: makeInitReq(withJoinState(jwtSigner, withToken(withRecovery("standard", 1, 3, "id")))), solver: makeSolver(correctPublicKey), assertError: func(tt require.TestingT, err error, i ...any) { @@ -499,7 +504,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { }, { name: "standard-failure-invalid-jwt", - token: makeToken(withRecovery("standard", 1, 2, "id")), + token: makeToken(withRecovery("standard", 1, 2, "id"), withBoundKey(correctPublicKey)), initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) { r.PreviousJoinState = []byte("asdf") }), @@ -510,7 +515,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { }, { name: "standard-failure-invalid-jwt-signature", - token: makeToken(withRecovery("standard", 1, 2, "id")), + token: makeToken(withRecovery("standard", 1, 2, "id"), withBoundKey(correctPublicKey)), initReq: makeInitReq(withJoinState(invalidJWTSigner, withToken(withRecovery("standard", 1, 2, "id")))), solver: makeSolver(correctPublicKey), assertError: func(tt require.TestingT, err error, i ...any) { @@ -519,7 +524,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { }, { name: "standard-failure-invalid-instance-id", - token: makeToken(withRecovery("standard", 1, 2, "foo")), + token: makeToken(withRecovery("standard", 1, 2, "foo"), withBoundKey(correctPublicKey)), initReq: makeInitReq(withJoinState(jwtSigner, withToken(withRecovery("standard", 1, 2, "id")))), solver: makeSolver(correctPublicKey), assertError: func(tt require.TestingT, err error, i ...any) { @@ -528,7 +533,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { }, { name: "standard-failure-invalid-cluster", - token: makeToken(withRecovery("standard", 1, 2, "foo")), + token: makeToken(withRecovery("standard", 1, 2, "foo"), withBoundKey(correctPublicKey)), initReq: makeInitReq(withJoinState(jwtSigner, withToken(withRecovery("standard", 1, 2, "id")), func(s *boundkeypair.JoinStateParams) { s.ClusterName = "wrong-cluster" })), @@ -539,7 +544,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { }, { name: "relaxed-success-count-over-limit", - token: makeToken(withRecovery("relaxed", 1, 0, "id")), + token: makeToken(withRecovery("relaxed", 1, 0, "id"), withBoundKey(correctPublicKey)), initReq: makeInitReq(withJoinState(jwtSigner, withToken(withRecovery("relaxed", 1, 0, "id")))), solver: makeSolver(correctPublicKey), assertError: require.NoError, @@ -1074,6 +1079,14 @@ func TestServer_RegisterUsingBoundKeypairMethod_GenerationCounter(t *testing.T) require.Equal(t, secondInstance, instance) } + // Try an API call with these certs. + tlsCert, err := tls.X509KeyPair(response.Certs.TLS, sshPrivateKey) + require.NoError(t, err) + + client := srv.NewClientWithCert(tlsCert) + _, err = client.Ping(ctx) + require.NoError(t, err) + // Provide an incorrect generation counter value. response, err = auth.RegisterUsingBoundKeypairMethod( ctx, @@ -1089,10 +1102,147 @@ func TestServer_RegisterUsingBoundKeypairMethod_GenerationCounter(t *testing.T) require.Nil(t, response) require.ErrorContains(t, err, "renewable cert generation mismatch") - // The bot user should now be locked. + // The token should now be locked. + locks, err := srv.Auth().GetLocks(ctx, true, types.LockTarget{ + JoinToken: "bound-keypair-test", + }) + require.NoError(t, err) + require.Len(t, locks, 1) + require.Contains(t, locks[0].Message(), "certificate generation mismatch") + + // Using the previously working client, make sure API calls no longer work. + _, err = client.Ping(ctx) + require.ErrorContains(t, err, "access denied") +} + +func TestServer_RegisterUsingBoundKeypairMethod_JoinStateFailure(t *testing.T) { + ctx := context.Background() + + // TODO: This prevents parallel execution; remove along with the experiment. + boundkeypairexperiment.SetEnabled(true) + + sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey) + require.NoError(t, err) + + _, correctPublicKey := testBoundKeypair(t) + + clock := clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()) + + srv := newTestTLSServer(t, withClock(clock)) + auth := srv.Auth() + auth.createBoundKeypairValidator = func(subject, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) { + return &mockBoundKeypairValidator{ + subject: subject, + clusterName: clusterName, + publicKey: publicKey, + }, nil + } + + _, err = CreateRole(ctx, auth, "example", types.RoleSpecV6{}) + require.NoError(t, err) + + adminClient, err := srv.NewClient(TestAdmin()) + require.NoError(t, err) + + _, err = adminClient.BotServiceClient().CreateBot(ctx, &machineidv1pb.CreateBotRequest{ + Bot: &machineidv1pb.Bot{ + Kind: types.KindBot, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &machineidv1pb.BotSpec{ + Roles: []string{"example"}, + }, + }, + }) + require.NoError(t, err) + + token, err := types.NewProvisionTokenFromSpecAndStatus( + "bound-keypair-test", + time.Now().Add(2*time.Hour), + types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodBoundKeypair, + Roles: []types.SystemRole{types.RoleBot}, + BotName: "test", + BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: correctPublicKey, + }, + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + Limit: 3, + }, + }, + }, + &types.ProvisionTokenStatusV2{}, + ) + require.NoError(t, err) + require.NoError(t, auth.CreateBoundKeypairToken(ctx, token)) + + makeInitReq := func(mutators ...func(r *proto.RegisterUsingBoundKeypairInitialRequest)) *proto.RegisterUsingBoundKeypairInitialRequest { + req := &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{ + HostID: "host-id", + Role: types.RoleBot, + PublicTLSKey: tlsPublicKey, + PublicSSHKey: sshPublicKey, + Token: "bound-keypair-test", + }, + } + for _, mutator := range mutators { + mutator(req) + } + return req + } + + withJoinState := func(state []byte) func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + return func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.PreviousJoinState = state + } + } + + // Perform the initial registration. + solver := newMockSolver(t, correctPublicKey) + firstResponse, err := auth.RegisterUsingBoundKeypairMethod(ctx, makeInitReq(), solver.solver()) + require.NoError(t, err) + + // Perform a recovery, this time with a join state. + secondResponse, err := auth.RegisterUsingBoundKeypairMethod( + ctx, + makeInitReq(withJoinState(firstResponse.JoinState)), + solver.solver(), + ) + require.NotNil(t, secondResponse) + require.NoError(t, err) + + // Try an API call with these certs. + tlsCert, err := tls.X509KeyPair(secondResponse.Certs.TLS, sshPrivateKey) + require.NoError(t, err) + + client := srv.NewClientWithCert(tlsCert) + _, err = client.Ping(ctx) + require.NoError(t, err) + + // Try once more, but this time with the first join state. + thirdResponse, err := auth.RegisterUsingBoundKeypairMethod( + ctx, + makeInitReq(withJoinState(firstResponse.JoinState)), + solver.solver(), + ) + require.Nil(t, thirdResponse) + require.ErrorContains(t, err, "join state verification failed") + + // The token should now be locked. locks, err := srv.Auth().GetLocks(ctx, true, types.LockTarget{ - User: "bot-test", + JoinToken: "bound-keypair-test", }) require.NoError(t, err) - require.NotEmpty(t, locks) + require.Len(t, locks, 1) + require.Contains(t, locks[0].Message(), "failed to verify its join state") + + // The previously working client should now be locked. + _, err = client.Ping(ctx) + require.ErrorContains(t, err, "access denied") }