diff --git a/api/client/joinservice.go b/api/client/joinservice.go index bc30c8a541e6c..32b0f60dfe5db 100644 --- a/api/client/joinservice.go +++ b/api/client/joinservice.go @@ -18,6 +18,8 @@ package client import ( "context" + "errors" + "io" "github.com/gravitational/trace" @@ -60,6 +62,11 @@ type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredenti // *proto.OracleSignedRequest for a given challenge, or an error. type RegisterOracleChallengeResponseFunc func(challenge string) (*proto.OracleSignedRequest, error) +// RegisterUsingBoundKeypairChallengeResponseFunc is a function to be passed to +// RegisterUsingBoundKeypair. It must return a new follow-up request for the +// server response, or an error. +type RegisterUsingBoundKeypairChallengeResponseFunc func(challenge *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) + // RegisterUsingIAMMethod registers the caller using the IAM join method and // returns signed certs to join the cluster. // @@ -262,6 +269,78 @@ func (c *JoinServiceClient) RegisterUsingOracleMethod( return certs, nil } +// RegisterUsingBoundKeypairMethod attempts to register the caller using +// bound-keypair join method. If successful, the public key registered with auth +// and a certificate bundle is returned, or an error. Clients must provide a +// callback to handle interactive challenges and keypair rotation requests. +func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( + ctx context.Context, + initReq *proto.RegisterUsingBoundKeypairInitialRequest, + challengeFunc RegisterUsingBoundKeypairChallengeResponseFunc, +) (*proto.Certs, string, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + stream, err := c.grpcClient.RegisterUsingBoundKeypairMethod(ctx) + if err != nil { + return nil, "", trace.Wrap(err) + } + defer stream.CloseSend() + + err = stream.Send(&proto.RegisterUsingBoundKeypairMethodRequest{ + Payload: &proto.RegisterUsingBoundKeypairMethodRequest_Init{ + Init: initReq, + }, + }) + if err != nil { + return nil, "", trace.Wrap(err, "sending initial request") + } + + // Unlike other methods, the server may send multiple challenges, + // particularly during keypair rotation. We'll iterate through all responses + // here instead to ensure we handle everything. + for { + res, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return nil, "", trace.Wrap(err, "receiving intermediate bound keypair join response") + } + + switch kind := res.GetResponse().(type) { + case *proto.RegisterUsingBoundKeypairMethodResponse_Certs: + // If we get certs, we're done, so just return the result. + certs := kind.Certs.GetCerts() + if certs == nil { + return nil, "", trace.BadParameter("expected Certs, got %T", kind.Certs.Certs) + } + + // If we receive a cert bundle, we can return early. Even if we + // logically should have expected to receive a 2nd challenge if we + // e.g. requested keypair rotation, skipping it just means the new + // keypair won't be stored. That said, we'll rely on the server to + // raise an error if rotation fails or is otherwise skipped or not + // allowed. + + return certs, kind.Certs.GetPublicKey(), nil + default: + // Forward all other responses to the challenge handler. + nextRequest, err := challengeFunc(res) + if err != nil { + return nil, "", trace.Wrap(err, "solving challenge") + } + + if err := stream.Send(nextRequest); err != nil { + return nil, "", trace.Wrap(err, "sending solution") + } + } + } + + // Ideally the server will emit a proper error instead of just hanging up on + // us. + return nil, "", trace.AccessDenied("server declined to send certs during bound-keypair join attempt") +} + // RegisterUsingToken registers the caller using a token and returns signed // certs. // This is used where a more specific RPC has not been introduced for the join diff --git a/api/types/provisioning.go b/api/types/provisioning.go index 0ebe818037906..82b2dd9452442 100644 --- a/api/types/provisioning.go +++ b/api/types/provisioning.go @@ -83,6 +83,9 @@ const ( // JoinMethodAzureDevops indicates that the node will join using the Azure // Devops join method. JoinMethodAzureDevops JoinMethod = "azure_devops" + // JoinMethodBoundKeypair indicates the node will join using the Bound + // Keypair join method. See lib/boundkeypair for more. + JoinMethodBoundKeypair JoinMethod = "bound_keypair" ) var JoinMethods = []JoinMethod{ @@ -101,6 +104,7 @@ var JoinMethods = []JoinMethod{ JoinMethodTPM, JoinMethodTerraformCloud, JoinMethodOracle, + JoinMethodBoundKeypair, } func ValidateJoinMethod(method JoinMethod) error { @@ -193,6 +197,26 @@ func NewProvisionTokenFromSpec(token string, expires time.Time, spec ProvisionTo return t, nil } +// NewProvisionTokenFromSpecAndStatus returns a new provision token with the given spec. +func NewProvisionTokenFromSpecAndStatus( + token string, expires time.Time, + spec ProvisionTokenSpecV2, + status *ProvisionTokenStatusV2, +) (ProvisionToken, error) { + t := &ProvisionTokenV2{ + Metadata: Metadata{ + Name: token, + Expires: &expires, + }, + Spec: spec, + Status: status, + } + if err := t.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return t, nil +} + // MustCreateProvisionToken returns a new valid provision token // or panics, used in tests func MustCreateProvisionToken(token string, roles SystemRoles, expires time.Time) ProvisionToken { @@ -416,6 +440,18 @@ func (p *ProvisionTokenV2) CheckAndSetDefaults() error { if err := providerCfg.checkAndSetDefaults(); err != nil { return trace.Wrap(err, "spec.azure_devops: failed validation") } + case JoinMethodBoundKeypair: + providerCfg := p.Spec.BoundKeypair + if providerCfg == nil { + return trace.BadParameter( + "spec.bound_keypair: must be configured for the join method %q", + JoinMethodBoundKeypair, + ) + } + + if err := providerCfg.checkAndSetDefaults(); err != nil { + return trace.Wrap(err, "spec.bound_keypair: failed validation") + } default: return trace.BadParameter("unknown join method %q", p.Spec.JoinMethod) } @@ -994,3 +1030,16 @@ func (a *ProvisionTokenSpecV2AzureDevops) checkAndSetDefaults() error { } return nil } + +func (a *ProvisionTokenSpecV2BoundKeypair) checkAndSetDefaults() error { + if a.Onboarding == nil { + return trace.BadParameter("spec.bound_keypair.onboarding is required") + } + + if a.Onboarding.RegistrationSecret == "" && a.Onboarding.InitialPublicKey == "" { + return trace.BadParameter("at least one of [initial_join_secret, " + + "initial_public_key] is required in spec.bound_keypair.onboarding") + } + + return nil +} diff --git a/api/types/provisioning_test.go b/api/types/provisioning_test.go index dae665bb67e1e..cc62a40cab80f 100644 --- a/api/types/provisioning_test.go +++ b/api/types/provisioning_test.go @@ -1369,6 +1369,54 @@ func TestProvisionTokenV2_CheckAndSetDefaults(t *testing.T) { }, wantErr: true, }, + { + desc: "minimal bound keypair with pregenerated key", + token: &ProvisionTokenV2{ + Metadata: Metadata{ + Name: "test", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodBoundKeypair, + BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: "asdf", + }, + }, + }, + }, + expected: &ProvisionTokenV2{ + Kind: "token", + Version: "v2", + Metadata: Metadata{ + Name: "test", + Namespace: "default", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodBoundKeypair, + BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: "asdf", + }, + }, + }, + }, + }, + { + desc: "bound keypair missing onboarding config", + token: &ProvisionTokenV2{ + Metadata: Metadata{ + Name: "test", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodBoundKeypair, + BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{}, + }, + }, + wantErr: true, + }, } for _, tc := range testcases { diff --git a/lib/auth/auth.go b/lib/auth/auth.go index cfcb7507ce800..66953ba1804c9 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -96,6 +96,7 @@ import ( "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/bitbucket" + "github.com/gravitational/teleport/lib/boundkeypair" "github.com/gravitational/teleport/lib/cache" "github.com/gravitational/teleport/lib/circleci" "github.com/gravitational/teleport/lib/cryptosuites" @@ -706,6 +707,12 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { as.bitbucketIDTokenValidator = bitbucket.NewIDTokenValidator(as.clock) } + if as.createBoundKeypairValidator == nil { + as.createBoundKeypairValidator = func(subject, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) { + return boundkeypair.NewChallengeValidator(subject, clusterName, publicKey) + } + } + // Add in a login hook for generating state during user login. as.ulsGenerator, err = userloginstate.NewGenerator(userloginstate.GeneratorConfig{ Log: as.logger, @@ -1145,6 +1152,10 @@ type Server struct { bitbucketIDTokenValidator bitbucketIDTokenValidator + // createBoundKeypairValidator is a helper to create new bound keypair + // challenge validators. Used to override the implementation used in tests. + createBoundKeypairValidator createBoundKeypairValidator + // loadAllCAs tells tsh to load the host CAs for all clusters when trying to ssh into a node. loadAllCAs bool diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 72acbfbf1db80..00a417708c614 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -2440,6 +2440,12 @@ func (a *ServerWithRoles) UpsertToken(ctx context.Context, token types.Provision return trace.Wrap(err) } + // bound_keypair tokens have special creation/update logic and are handled + // separately + if token.GetJoinMethod() == types.JoinMethodBoundKeypair { + return trace.Wrap(a.authServer.UpsertBoundKeypairToken(ctx, token)) + } + if err := a.authServer.UpsertToken(ctx, token); err != nil { return trace.Wrap(err) } @@ -2465,6 +2471,12 @@ func (a *ServerWithRoles) CreateToken(ctx context.Context, token types.Provision return trace.Wrap(err) } + // bound_keypair tokens have special creation/update logic and are handled + // separately + if token.GetJoinMethod() == types.JoinMethodBoundKeypair { + return trace.Wrap(a.authServer.CreateBoundKeypairToken(ctx, token)) + } + if err := a.authServer.CreateToken(ctx, token); err != nil { return trace.Wrap(err) } diff --git a/lib/auth/authclient/clt.go b/lib/auth/authclient/clt.go index 7cebd85ce9895..5e5da68ee9ded 100644 --- a/lib/auth/authclient/clt.go +++ b/lib/auth/authclient/clt.go @@ -305,6 +305,15 @@ func (c *Client) DeleteAllTokens() error { return trace.NotImplemented(notImplementedMessage) } +// PatchToken not implemented: can only be called locally +func (c *Client) PatchToken( + ctx context.Context, + token string, + updateFn func(types.ProvisionToken) (types.ProvisionToken, error), +) (types.ProvisionToken, error) { + return nil, trace.NotImplemented(notImplementedMessage) +} + // AddUserLoginAttempt logs user login attempt func (c *Client) AddUserLoginAttempt(user string, attempt services.LoginAttempt, ttl time.Duration) error { panic("not implemented") @@ -1235,6 +1244,14 @@ type ProvisioningService interface { // CreateToken creates a new provision token for the auth server CreateToken(ctx context.Context, token types.ProvisionToken) error + // PatchToken performs a conditional update on the named token using + // `updateFn`, retrying internally if a comparison failure occurs. + PatchToken( + ctx context.Context, + token string, + updateFn func(types.ProvisionToken) (types.ProvisionToken, error), + ) (types.ProvisionToken, error) + // RegisterUsingToken calls the auth service API to register a new node via registration token // which has been previously issued via GenerateToken RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) diff --git a/lib/auth/join.go b/lib/auth/join.go index be71948c6d1ce..853da5688acc5 100644 --- a/lib/auth/join.go +++ b/lib/auth/join.go @@ -226,7 +226,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin if err := a.checkEC2JoinRequest(ctx, req); err != nil { return nil, trace.Wrap(err) } - case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM, types.JoinMethodOracle: + case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM, + types.JoinMethodOracle, types.JoinMethodBoundKeypair: // Some join methods require use of a specific RPC - reject those here. // This would generally be a developer error - but can be triggered if // the user has configured the wrong join method on the client-side. @@ -317,7 +318,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin // With all elements of the token validated, we can now generate & return // certificates. if req.Role == types.RoleBot { - certs, err = a.generateCertsBot( + certs, _, err = a.generateCertsBot( ctx, provisionToken, req, @@ -336,7 +337,7 @@ func (a *Server) generateCertsBot( req *types.RegisterUsingTokenRequest, rawJoinClaims any, attrs *workloadidentityv1pb.JoinAttrs, -) (*proto.Certs, error) { +) (*proto.Certs, string, error) { // bots use this endpoint but get a user cert // botResourceName must be set, enforced in CheckAndSetDefaults botName := provisionToken.GetBotName() @@ -344,7 +345,7 @@ func (a *Server) generateCertsBot( // Check this is a join method for bots we support. if !slices.Contains(machineidv1.SupportedJoinMethods, joinMethod) { - return nil, trace.BadParameter( + return nil, "", trace.BadParameter( "unsupported join method %q for bot", joinMethod, ) } @@ -439,7 +440,7 @@ func (a *Server) generateCertsBot( attrs, ) if err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } joinEvent.BotInstanceID = botInstanceID @@ -461,7 +462,7 @@ func (a *Server) generateCertsBot( if err := a.emitter.EmitAuditEvent(ctx, joinEvent); err != nil { a.logger.WarnContext(ctx, "Failed to emit bot join event", "error", err) } - return certs, nil + return certs, botInstanceID, nil } func (a *Server) generateCerts( diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 9461a412ed7c5..02db610c710c5 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -489,7 +489,7 @@ func (a *Server) RegisterUsingAzureMethodWithOpts( } if req.RegisterUsingTokenRequest.Role == types.RoleBot { - certs, err := a.generateCertsBot( + certs, _, err := a.generateCertsBot( ctx, provisionToken, req.RegisterUsingTokenRequest, diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go new file mode 100644 index 0000000000000..462053aadbbdc --- /dev/null +++ b/lib/auth/join_bound_keypair.go @@ -0,0 +1,480 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package auth + +import ( + "context" + "crypto" + "encoding/json" + "time" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/boundkeypair" + "github.com/gravitational/teleport/lib/boundkeypair/boundkeypairexperiment" + "github.com/gravitational/teleport/lib/jwt" + libsshutils "github.com/gravitational/teleport/lib/sshutils" +) + +type boundKeypairValidator interface { + IssueChallenge() (*boundkeypair.ChallengeDocument, error) + ValidateChallengeResponse(issued *boundkeypair.ChallengeDocument, compactResponse string) error +} + +type createBoundKeypairValidator func(subject string, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) + +// validateBoundKeypairTokenSpec performs some basic validation checks on a +// bound_keypair-type join token. +func validateBoundKeypairTokenSpec(spec *types.ProvisionTokenSpecV2BoundKeypair) error { + // Various constant checks, shared between creation and update. Many of + // these checks are temporary and will be removed alongside the experiment + // flag. + if !boundkeypairexperiment.Enabled() { + return trace.BadParameter("bound keypair joining experiment is not enabled") + } + + if spec.RotateAfter != nil { + return trace.NotImplemented("spec.bound_keypair.rotate_after is not yet implemented") + } + + if spec.Onboarding.RegistrationSecret != "" { + return trace.NotImplemented("spec.bound_keypair.onboarding.registration_secret is not yet implemented") + } + + if spec.Onboarding.InitialPublicKey == "" { + return trace.NotImplemented("spec.bound_keypair.onboarding.initial_public_key is currently required") + } + + if spec.Recovery == nil { + return trace.BadParameter("spec.recovery: field is required") + } + + if spec.Recovery.Mode != boundkeypair.RecoveryModeInsecure { + return trace.NotImplemented("spec.bound_keypair.recovery.mode currently must be %s", boundkeypair.RecoveryModeInsecure) + } + + return nil +} + +func (a *Server) CreateBoundKeypairToken(ctx context.Context, token types.ProvisionToken) error { + if token.GetJoinMethod() != types.JoinMethodBoundKeypair { + return trace.BadParameter("must be called with a bound keypair token") + } + + tokenV2, ok := token.(*types.ProvisionTokenV2) + if !ok { + return trace.BadParameter("%v join method requires ProvisionTokenV2", types.JoinMethodOracle) + } + + spec := tokenV2.Spec.BoundKeypair + if spec == nil { + return trace.BadParameter("bound_keypair token requires non-nil spec.bound_keypair") + } + + if err := validateBoundKeypairTokenSpec(spec); err != nil { + return trace.Wrap(err) + } + + // Not as much to do here - ideally we'd like to prevent users from + // tampering with the status field, but we don't have a good mechanism to + // stop that that wouldn't also break backup and restore. For now, it's + // simpler and easier to just tell users not to edit those fields. + + // TODO (follow up PR): Populate initial_join_secret if needed. + + return trace.Wrap(a.CreateToken(ctx, tokenV2)) +} + +func (a *Server) UpsertBoundKeypairToken(ctx context.Context, token types.ProvisionToken) error { + if token.GetJoinMethod() != types.JoinMethodBoundKeypair { + return trace.BadParameter("must be called with a bound keypair token") + } + + tokenV2, ok := token.(*types.ProvisionTokenV2) + if !ok { + return trace.BadParameter("%v join method requires ProvisionTokenV2", types.JoinMethodOracle) + } + + spec := tokenV2.Spec.BoundKeypair + if spec == nil { + return trace.BadParameter("bound_keypair token requires non-nil spec.bound_keypair") + } + + if err := validateBoundKeypairTokenSpec(spec); err != nil { + return trace.Wrap(err) + } + + // TODO: Follow up with proper checking for a preexisting resource so + // generated fields are handled properly, i.e. initial secret generation. + + return trace.Wrap(a.UpsertToken(ctx, token)) +} + +// issueBoundKeypairChallenge creates a new challenge for the given marshaled +// public key in ssh authorized_keys format, requests a solution from the +// client using the given `challengeResponse` function, and validates the +// response. +func (a *Server) issueBoundKeypairChallenge( + ctx context.Context, + marshalledKey string, + challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, +) error { + key, err := libsshutils.CryptoPublicKey([]byte(marshalledKey)) + if err != nil { + return trace.Wrap(err, "parsing bound public key") + } + + // The particular subject value doesn't strictly need to be the name of the + // bot or node (which may not be known, yet). Instead, we'll use the key ID, + // which could at least be useful for the client to know which key the + // challenge should be signed with. + keyID, err := jwt.KeyID(key) + if err != nil { + return trace.Wrap(err, "determining the key ID") + } + + clusterName, err := a.GetClusterName(ctx) + if err != nil { + return trace.Wrap(err) + } + + a.logger.DebugContext(ctx, "issuing bound keypair challenge", "key_id", keyID) + + validator, err := a.createBoundKeypairValidator(keyID, clusterName.GetClusterName(), key) + if err != nil { + return trace.Wrap(err) + } + + challenge, err := validator.IssueChallenge() + if err != nil { + return trace.Wrap(err, "generating a challenge document") + } + + marshalledChallenge, err := json.Marshal(challenge) + if err != nil { + return trace.Wrap(err) + } + + response, err := challengeResponse(&proto.RegisterUsingBoundKeypairMethodResponse{ + Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{ + Challenge: &proto.RegisterUsingBoundKeypairChallenge{ + PublicKey: marshalledKey, + Challenge: string(marshalledChallenge), + }, + }, + }) + if err != nil { + return trace.Wrap(err, "requesting a signed challenge") + } + + solutionResponse, ok := response.Payload.(*proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse) + if !ok { + return trace.BadParameter("client provided unexpected challenge response type %T", response.Payload) + } + + if err := validator.ValidateChallengeResponse( + challenge, + string(solutionResponse.ChallengeResponse.Solution), + ); err != nil { + // TODO: Consider access denied instead? + return trace.Wrap(err, "validating challenge response") + } + + a.logger.InfoContext(ctx, "bound keypair challenge response verified successfully", "key_id", keyID) + + return nil +} + +// boundKeypairStatusMutator is a function called to mutate a bound keypair +// status during a call to PatchProvisionToken(). These functions may be called +// repeatedly if e.g. revision checks fail. To ensure invariants remain in +// place, mutator functions may make assertions to ensure the provided backend +// state is still sane before the update is committed. +type boundKeypairStatusMutator func(*types.ProvisionTokenSpecV2BoundKeypair, *types.ProvisionTokenStatusV2BoundKeypair) error + +// mutateStatusConsumeJoin consumes a "hard" join on the backend, incrementing +// the join counter. This verifies that the backend join count has not changed, +// and that total join count is at least the value when the mutator was created. +func mutateStatusConsumeJoin(mode boundkeypair.RecoveryMode, expectRecoveryCount uint32, expectMinRecoveryLimit uint32) boundKeypairStatusMutator { + now := time.Now() + + return func(spec *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error { + // Ensure we have the expected number of rejoins left to prevent going + // below zero. + if status.RecoveryCount != expectRecoveryCount { + return trace.AccessDenied("unexpected backend state") + } + + // Ensure the allowed join count has at least not decreased, but allow + // for collision with potentially increased values. + if spec.Recovery.Limit < expectMinRecoveryLimit { + return trace.AccessDenied("unexpected backend state") + } + + if mode == boundkeypair.RecoveryModeStandard { + // TODO: to be removed in a future PR + return trace.NotImplemented("only unlimited rejoining is currently supported") + } + + status.RecoveryCount += 1 + status.LastRecoveredAt = &now + + return nil + } +} + +// mutateStatusBoundPublicKey is a mutator that updates the bound public key +// value. It ensures the backend public key is still the expected value before +// performing the update. +func mutateStatusBoundPublicKey(newPublicKey, expectPreviousKey string) boundKeypairStatusMutator { + return func(_ *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error { + if status.BoundPublicKey != expectPreviousKey { + return trace.AccessDenied("unexpected backend state") + } + + status.BoundPublicKey = newPublicKey + + return nil + } +} + +// mutateStatusBoundBotInstance updates the bot instance ID currently bound to +// this token. It ensures the expected previous ID is still the bound value +// before performing the update. +func mutateStatusBoundBotInstance(newBotInstance, expectPreviousBotInstance string) boundKeypairStatusMutator { + return func(_ *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error { + if status.BoundBotInstanceID != expectPreviousBotInstance { + return trace.AccessDenied("unexpected backend state") + } + + status.BoundBotInstanceID = newBotInstance + + return nil + } +} + +// RegisterUsingBoundKeypairMethod handles joining requests for the bound +// keypair join method. +func (a *Server) RegisterUsingBoundKeypairMethod( + ctx context.Context, + req *proto.RegisterUsingBoundKeypairInitialRequest, + challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, +) (_ *proto.Certs, _ string, err error) { + var provisionToken types.ProvisionToken + var joinFailureMetadata any + defer func() { + // Emit a log message and audit event on join failure. + if err != nil { + a.handleJoinFailure( + ctx, err, provisionToken, joinFailureMetadata, req.JoinRequest, + ) + } + }() + + // First, check the specified token exists, and is a bound keypair-type join + // token. + if err := req.JoinRequest.CheckAndSetDefaults(); err != nil { + return nil, "", trace.Wrap(err) + } + + // Only bot joining is supported at the moment - unique ID verification is + // required and this is currently only implemented for bots. + if req.JoinRequest.Role != types.RoleBot { + return nil, "", trace.BadParameter("bound keypair joining is only supported for bots") + } + + provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.JoinRequest) + if err != nil { + return nil, "", trace.Wrap(err) + } + ptv2, ok := provisionToken.(*types.ProvisionTokenV2) + if !ok { + return nil, "", trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken) + } + if ptv2.Spec.JoinMethod != types.JoinMethodBoundKeypair { + return nil, "", trace.BadParameter("specified join token is not for `%s` method", types.JoinMethodBoundKeypair) + } + + if ptv2.Status == nil { + ptv2.Status = &types.ProvisionTokenStatusV2{} + } + if ptv2.Status.BoundKeypair == nil { + ptv2.Status.BoundKeypair = &types.ProvisionTokenStatusV2BoundKeypair{} + } + + spec := ptv2.Spec.BoundKeypair + status := ptv2.Status.BoundKeypair + hasBoundPublicKey := status.BoundPublicKey != "" + hasBoundBotInstance := status.BoundBotInstanceID != "" + hasIncomingBotInstance := req.JoinRequest.BotInstanceID != "" + hasJoinsRemaining := status.RecoveryCount < spec.Recovery.Limit + + // if set, the bound bot instance will be updated in the backend + expectNewBotInstance := false + + // the bound public key; may change during initial join or rotation. used to + // inform the returned public key value. + boundPublicKey := status.BoundPublicKey + + // Mutators to use during the token resource status patch at the end. + var mutators []boundKeypairStatusMutator + + switch { + case !hasBoundPublicKey && !hasIncomingBotInstance: + // Normal initial join attempt. No bound key, and no incoming bot + // instance. Consumes a rejoin. + if spec.Onboarding.RegistrationSecret != "" { + return nil, "", trace.NotImplemented("initial joining secrets are not yet supported") + } + + if spec.Onboarding.InitialPublicKey == "" { + return nil, "", trace.BadParameter("an initial public key is required") + } + + if spec.Recovery.Mode == string(boundkeypair.RecoveryModeStandard) && !hasJoinsRemaining { + return nil, "", trace.AccessDenied("no joins remaining") + } + + if err := a.issueBoundKeypairChallenge( + ctx, + spec.Onboarding.InitialPublicKey, + challengeResponse, + ); err != nil { + return nil, "", trace.Wrap(err) + } + + // Now that we've confirmed the key, we can consider it bound. + mutators = append( + mutators, + mutateStatusBoundPublicKey(spec.Onboarding.InitialPublicKey, ""), + mutateStatusConsumeJoin(boundkeypair.RecoveryMode(spec.Recovery.Mode), status.RecoveryCount, spec.Recovery.Limit), + ) + + expectNewBotInstance = true + boundPublicKey = spec.Onboarding.InitialPublicKey + case !hasBoundPublicKey && hasIncomingBotInstance: + // Not allowed, at least at the moment. This would imply e.g. trying to + // change auth methods. + return nil, "", trace.BadParameter("cannot perform first bound keypair join with existing credentials") + case hasBoundPublicKey && !hasBoundBotInstance: + // TODO: Bad backend state, or maybe an incomplete previous join + // attempt. This shouldn't be a possible state, but we should handle it + // sanely anyway. + return nil, "", trace.BadParameter("bad backend state, please recreate the join token") + case hasBoundPublicKey && hasBoundBotInstance && hasIncomingBotInstance: + // Standard rejoin case, does not consume a rejoin. + if status.BoundBotInstanceID != req.JoinRequest.BotInstanceID { + return nil, "", trace.AccessDenied("bot instance mismatch") + } + + if err := a.issueBoundKeypairChallenge( + ctx, + spec.Onboarding.InitialPublicKey, + challengeResponse, + ); err != nil { + return nil, "", trace.Wrap(err) + } + + // Nothing else to do, no key change + case hasBoundPublicKey && hasBoundBotInstance && !hasIncomingBotInstance: + // Hard rejoin case, the client identity expired and a new bot instance + // is required. Consumes a rejoin. + if spec.Recovery.Mode == string(boundkeypair.RecoveryModeStandard) && !hasJoinsRemaining { + // Recovery limit only applies in "standard" mode. + return nil, "", trace.AccessDenied("no rejoins remaining") + } + + if err := a.issueBoundKeypairChallenge( + ctx, + status.BoundPublicKey, + challengeResponse, + ); err != nil { + return nil, "", trace.Wrap(err) + } + + mutators = append( + mutators, + mutateStatusConsumeJoin(boundkeypair.RecoveryMode(spec.Recovery.Mode), status.RecoveryCount, spec.Recovery.Limit), + ) + + expectNewBotInstance = true + default: + a.logger.ErrorContext( + ctx, "unexpected state", + "has_bound_public_key", hasBoundPublicKey, + "has_bound_bot_instance", hasBoundBotInstance, + "has_incoming_bot_instance", hasIncomingBotInstance, + "spec", spec, + "status", status, + ) + return nil, "", trace.BadParameter("unexpected state") + } + + if spec.RotateAfter != nil { + // TODO, to be implemented in a future PR. `boundPublicKey` will need to + // be updated. + return nil, "", trace.NotImplemented("key rotation not yet supported") + } + + // TODO: We should pass along the previous bot instance ID - if any - based + // on the join state, once that is implemented. It will need to be passed + // either via extended claims, or by a new protected field in the join + // request like the current bot instance ID, i.e. cleared when set by an + // untrusted source. + certs, botInstanceID, err := a.generateCertsBot( + ctx, + ptv2, + req.JoinRequest, + nil, // TODO: extended claims for this type? + nil, // TODO: workload id claims + ) + + if expectNewBotInstance { + mutators = append( + mutators, + mutateStatusBoundBotInstance(botInstanceID, status.BoundBotInstanceID), + ) + } + + if len(mutators) > 0 { + if _, err := a.PatchToken(ctx, ptv2.GetName(), func(token types.ProvisionToken) (types.ProvisionToken, error) { + ptv2, ok := provisionToken.(*types.ProvisionTokenV2) + if !ok { + return nil, trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken) + } + + // Apply all mutators. Individual mutators may make additional + // assertions to ensure invariants haven't changed. + for _, mutator := range mutators { + if err := mutator(ptv2.Spec.BoundKeypair, ptv2.Status.BoundKeypair); err != nil { + return nil, trace.Wrap(err, "applying status mutator") + } + } + + return ptv2, nil + }); err != nil { + return nil, "", trace.Wrap(err, "committing updated token state, please try again") + } + } + + return certs, boundPublicKey, trace.Wrap(err) +} diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go new file mode 100644 index 0000000000000..837648f43c8c7 --- /dev/null +++ b/lib/auth/join_bound_keypair_test.go @@ -0,0 +1,368 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package auth + +import ( + "context" + "crypto" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/boundkeypair" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/sshutils" +) + +type mockBoundKeypairValidator struct { + subject string + clusterName string + publicKey crypto.PublicKey +} + +func (v *mockBoundKeypairValidator) IssueChallenge() (*boundkeypair.ChallengeDocument, error) { + return &boundkeypair.ChallengeDocument{ + Nonce: "fake", + }, nil +} + +func (v *mockBoundKeypairValidator) ValidateChallengeResponse(issued *boundkeypair.ChallengeDocument, compactResponse string) error { + // For testing, the solver will just reply with the marshaled public key, so + // we'll parse and compare it. + key, err := sshutils.CryptoPublicKey([]byte(compactResponse)) + if err != nil { + return trace.Wrap(err, "parsing bound public key") + } + + equal, ok := v.publicKey.(interface { + Equal(x crypto.PublicKey) bool + }) + if !ok { + return trace.BadParameter("unsupported public key type %T", key) + } + + if !equal.Equal(key) { + return trace.AccessDenied("incorrect public key") + } + + return nil +} + +func testBoundKeypair(t *testing.T) (crypto.Signer, string) { + key, err := cryptosuites.GeneratePrivateKeyWithAlgorithm(cryptosuites.ECDSAP256) + require.NoError(t, err) + + return key.Signer, string(key.MarshalSSHPublicKey()) +} + +func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { + ctx := context.Background() + + _, correctPublicKey := testBoundKeypair(t) + _, incorrectPublicKey := testBoundKeypair(t) + + srv := newTestTLSServer(t) + auth := srv.Auth() + auth.createBoundKeypairValidator = func(subject, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) { + return &mockBoundKeypairValidator{ + //correctPublicKey: correctSigner.Public(), + + subject: subject, + clusterName: clusterName, + publicKey: publicKey, + }, nil + } + + _, err := CreateRole(ctx, auth, "example", types.RoleSpecV6{}) + require.NoError(t, err) + + adminClient, err := srv.NewClient(TestAdmin()) + require.NoError(t, err) + + _, err = adminClient.BotServiceClient().CreateBot(ctx, &machineidv1pb.CreateBotRequest{ + Bot: &machineidv1pb.Bot{ + Kind: types.KindBot, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &machineidv1pb.BotSpec{ + Roles: []string{"example"}, + }, + }, + }) + require.NoError(t, err) + + sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey) + require.NoError(t, err) + + makeToken := func(mutators ...func(v2 *types.ProvisionTokenV2)) types.ProvisionTokenV2 { + token := types.ProvisionTokenV2{ + Spec: types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodBoundKeypair, + Roles: []types.SystemRole{types.RoleBot}, + BotName: "test", + BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: correctPublicKey, + }, + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + // Only insecure is supported for now. + Mode: boundkeypair.RecoveryModeInsecure, + }, + }, + }, + Status: &types.ProvisionTokenStatusV2{ + BoundKeypair: &types.ProvisionTokenStatusV2BoundKeypair{}, + }, + } + for _, mutator := range mutators { + mutator(&token) + } + return token + } + + makeInitReq := func(mutators ...func(r *proto.RegisterUsingBoundKeypairInitialRequest)) *proto.RegisterUsingBoundKeypairInitialRequest { + req := &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{ + HostID: "host-id", + Role: types.RoleBot, + PublicTLSKey: tlsPublicKey, + PublicSSHKey: sshPublicKey, + }, + } + for _, mutator := range mutators { + mutator(req) + } + return req + } + + makeSolver := func(publicKey string) client.RegisterUsingBoundKeypairChallengeResponseFunc { + return func(challenge *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) { + switch r := challenge.Response.(type) { + case *proto.RegisterUsingBoundKeypairMethodResponse_Challenge: + if r.Challenge.PublicKey != publicKey { + return nil, trace.BadParameter("wrong public key") + } + + return &proto.RegisterUsingBoundKeypairMethodRequest{ + Payload: &proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse{ + ChallengeResponse: &proto.RegisterUsingBoundKeypairChallengeResponse{ + // For testing purposes, we'll just reply with the + // public key, to avoid needing to parse the JWT. + Solution: []byte(publicKey), + }, + }, + }, nil + default: + return nil, trace.BadParameter("invalid response type") + } + } + } + + tests := []struct { + name string + + token types.ProvisionTokenV2 + initReq *proto.RegisterUsingBoundKeypairInitialRequest + solver client.RegisterUsingBoundKeypairChallengeResponseFunc + + assertError require.ErrorAssertionFunc + assertSuccess func(t *testing.T, v2 *types.ProvisionTokenV2) + }{ + { + // no bound key, no bound bot instance, aka initial join without + // secret + name: "initial-join-success", + + token: makeToken(), + initReq: makeInitReq(), + solver: makeSolver(correctPublicKey), + + assertError: require.NoError, + assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) { + // join count should be incremented + require.Equal(t, uint32(1), v2.Status.BoundKeypair.RecoveryCount) + require.NotEmpty(t, v2.Status.BoundKeypair.BoundBotInstanceID) + require.NotEmpty(t, v2.Status.BoundKeypair.BoundPublicKey) + }, + }, + { + // no bound key, no bound bot instance, aka initial join without + // secret + name: "initial-join-with-wrong-key", + + token: makeToken(), + initReq: makeInitReq(), + solver: makeSolver(incorrectPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "wrong public key") + }, + }, + { + // bound key, valid bound bot instance, aka "soft join" + name: "reauth-success", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" + }), + initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.JoinRequest.BotInstanceID = "asdf" + }), + solver: makeSolver(correctPublicKey), + + assertError: require.NoError, + assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) { + // join count should not be incremented + require.Equal(t, uint32(0), v2.Status.BoundKeypair.RecoveryCount) + }, + }, + { + // bound key, seemingly valid bot instance, but wrong key + // (should be impossible, but should fail anyway) + name: "reauth-with-wrong-key", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" + }), + initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.JoinRequest.BotInstanceID = "asdf" + }), + solver: makeSolver(incorrectPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "wrong public key") + }, + }, + { + // bound key but no valid incoming bot instance, i.e. the certs + // expired and triggered a hard rejoin + name: "rejoin-success", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" + }), + initReq: makeInitReq(), + solver: makeSolver(correctPublicKey), + + assertError: require.NoError, + assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) { + require.Equal(t, uint32(1), v2.Status.BoundKeypair.RecoveryCount) + + // Should generate a new bot instance + require.NotEmpty(t, v2.Status.BoundKeypair.BoundBotInstanceID) + require.NotEqual(t, "asdf", v2.Status.BoundKeypair.BoundBotInstanceID) + }, + }, + { + // Bad state: somehow a key was registered without a bot instance. + // This should fail and prompt the user to recreate the token. + name: "bound-key-no-instance", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + }), + initReq: makeInitReq(), + solver: makeSolver(correctPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "bad backend state") + }, + }, + { + // The client somehow presents certs that refer to a different + // instance, maybe tried switching auth methods. + name: "bound-key-wrong-instance", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "qwerty" + }), + initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.JoinRequest.BotInstanceID = "asdf" + }), + solver: makeSolver(correctPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "bot instance mismatch") + }, + }, + { + // TODO: rotation is not yet implemented. + name: "rotation-requested", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + t := time.Now() + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" + v2.Spec.BoundKeypair.RotateAfter = &t + // TODO: test clock? + }), + initReq: makeInitReq(), + solver: makeSolver(correctPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "key rotation not yet supported") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, err := types.NewProvisionTokenFromSpecAndStatus( + tt.name, time.Now().Add(time.Minute), tt.token.Spec, tt.token.Status, + ) + require.NoError(t, err) + require.NoError(t, auth.CreateToken(ctx, token)) + tt.initReq.JoinRequest.Token = tt.name + + _, _, err = auth.RegisterUsingBoundKeypairMethod(ctx, tt.initReq, tt.solver) + tt.assertError(t, err) + + if tt.assertSuccess != nil { + pt, err := auth.GetToken(ctx, tt.name) + require.NoError(t, err) + + ptv2, ok := pt.(*types.ProvisionTokenV2) + require.True(t, ok) + + tt.assertSuccess(t, ptv2) + } + }) + } +} diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index 0f0bddbb2899b..d20a3ddcb6b38 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -407,7 +407,7 @@ func (a *Server) RegisterUsingIAMMethodWithOpts( } if req.RegisterUsingTokenRequest.Role == types.RoleBot { - certs, err := a.generateCertsBot( + certs, _, err := a.generateCertsBot( ctx, provisionToken, req.RegisterUsingTokenRequest, diff --git a/lib/auth/join_oracle.go b/lib/auth/join_oracle.go index 457d238a437e7..95349ed8c7056 100644 --- a/lib/auth/join_oracle.go +++ b/lib/auth/join_oracle.go @@ -83,7 +83,7 @@ func (a *Server) registerUsingOracleMethod( } if tokenReq.Role == types.RoleBot { - certs, err := a.generateCertsBot( + certs, _, err := a.generateCertsBot( ctx, provisionToken, tokenReq, diff --git a/lib/auth/join_tpm.go b/lib/auth/join_tpm.go index df2e6b4e4cbcc..31334855ddd61 100644 --- a/lib/auth/join_tpm.go +++ b/lib/auth/join_tpm.go @@ -112,7 +112,7 @@ func (a *Server) RegisterUsingTPMMethod( } if initReq.JoinRequest.Role == types.RoleBot { - certs, err := a.generateCertsBot( + certs, _, err := a.generateCertsBot( ctx, ptv2, initReq.JoinRequest, diff --git a/lib/auth/machineid/machineidv1/bot_service.go b/lib/auth/machineid/machineidv1/bot_service.go index 139612f54b056..ca3aeeb0b451a 100644 --- a/lib/auth/machineid/machineidv1/bot_service.go +++ b/lib/auth/machineid/machineidv1/bot_service.go @@ -58,6 +58,7 @@ var SupportedJoinMethods = []types.JoinMethod{ types.JoinMethodTPM, types.JoinMethodTerraformCloud, types.JoinMethodBitbucket, + types.JoinMethodBoundKeypair, } // BotResourceName returns the default name for resources associated with the diff --git a/lib/joinserver/joinserver.go b/lib/joinserver/joinserver.go index b978bb8d61ef8..1d596995defad 100644 --- a/lib/joinserver/joinserver.go +++ b/lib/joinserver/joinserver.go @@ -55,6 +55,11 @@ type joinServiceClient interface { tokenReq *types.RegisterUsingTokenRequest, challengeResponse client.RegisterOracleChallengeResponseFunc, ) (*proto.Certs, error) + RegisterUsingBoundKeypairMethod( + ctx context.Context, + req *proto.RegisterUsingBoundKeypairInitialRequest, + challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, + ) (*proto.Certs, string, error) RegisterUsingToken( ctx context.Context, req *types.RegisterUsingTokenRequest, @@ -369,6 +374,82 @@ func (s *JoinServiceGRPCServer) registerUsingOracleMethod(srv proto.JoinService_ })) } +// RegisterUsingBoundKeypairMethod registers the client using the bound-keypair +// join method, and if successful, returns a signed cert bundle for +// authenticated cluster access. +func (s *JoinServiceGRPCServer) RegisterUsingBoundKeypairMethod( + srv proto.JoinService_RegisterUsingBoundKeypairMethodServer, +) error { + return trace.Wrap(s.handleStreamingRegistration(srv.Context(), types.JoinMethodBoundKeypair, func() error { + return trace.Wrap(s.registerUsingBoundKeypair(srv)) + })) +} + +func (s *JoinServiceGRPCServer) registerUsingBoundKeypair(srv proto.JoinService_RegisterUsingBoundKeypairMethodServer) error { + ctx := srv.Context() + + // Get initial payload from the client + req, err := srv.Recv() + if err != nil { + return trace.Wrap(err, "receiving initial payload") + } + initReq := req.GetInit() + if initReq == nil { + return trace.BadParameter("expected non-nil Init payload") + } + + if initReq.InitialJoinSecret != "" { + // TODO: not supported yet. + return trace.NotImplemented("initial join secrets are not yet supported") + } + + if initReq.JoinRequest == nil { + return trace.BadParameter( + "expected JoinRequest in RegisterUsingBoundKeypairInitialRequest, got nil", + ) + } + if err := setClientRemoteAddr(ctx, initReq.JoinRequest); err != nil { + return trace.Wrap(err, "setting client address") + } + + setBotParameters(ctx, initReq.JoinRequest) + + certs, pubKey, err := s.joinServiceClient.RegisterUsingBoundKeypairMethod(ctx, initReq, func(resp *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) { + // First, forward the challenge from Auth to the client. + err := srv.Send(resp) + if err != nil { + return nil, trace.Wrap( + err, "forwarding challenge to client", + ) + } + + // Get response from Client + req, err := srv.Recv() + if err != nil { + return nil, trace.Wrap( + err, "receiving challenge solution from client", + ) + } + + return req, nil + }) + if err != nil { + return trace.Wrap(err) + } + + slog.DebugContext(srv.Context(), "challenge ceremony complete, sending cert bundle") + + // finally, send the certs on the response stream + return trace.Wrap(srv.Send(&proto.RegisterUsingBoundKeypairMethodResponse{ + Response: &proto.RegisterUsingBoundKeypairMethodResponse_Certs{ + Certs: &proto.RegisterUsingBoundKeypairCertificates{ + Certs: certs, + PublicKey: pubKey, + }, + }, + })) +} + // RegisterUsingToken allows nodes and proxies to join the cluster using // legacy join methods which do not yet have their own RPC. // On the Auth server, this method will call the auth.Server's diff --git a/lib/joinserver/joinserver_test.go b/lib/joinserver/joinserver_test.go index f9bd3e6d8e7dc..dcd954f8d815c 100644 --- a/lib/joinserver/joinserver_test.go +++ b/lib/joinserver/joinserver_test.go @@ -42,14 +42,17 @@ import ( ) type mockJoinServiceClient struct { - sendChallenge string - returnCerts *proto.Certs - returnError error - gotIAMChallengeResponse *proto.RegisterUsingIAMMethodRequest - gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest - gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse - gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest - gotRegisterUsingTokenReq *types.RegisterUsingTokenRequest + sendChallenge string + boundKeypairPublicKey string + returnCerts *proto.Certs + returnError error + gotIAMChallengeResponse *proto.RegisterUsingIAMMethodRequest + gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest + gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse + gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest + gotBoundKeypairInitReq *proto.RegisterUsingBoundKeypairInitialRequest + gotBoundKeypairChallengeResponse *proto.RegisterUsingBoundKeypairMethodRequest + gotRegisterUsingTokenReq *types.RegisterUsingTokenRequest } func (c *mockJoinServiceClient) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) { @@ -86,6 +89,29 @@ func (c *mockJoinServiceClient) RegisterUsingTPMMethod( return c.returnCerts, c.returnError } +func (c *mockJoinServiceClient) RegisterUsingBoundKeypairMethod( + ctx context.Context, + req *proto.RegisterUsingBoundKeypairInitialRequest, + challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, +) (*proto.Certs, string, error) { + c.gotBoundKeypairInitReq = req + resp, err := challengeResponse(&proto.RegisterUsingBoundKeypairMethodResponse{ + Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{ + Challenge: &proto.RegisterUsingBoundKeypairChallenge{ + PublicKey: c.boundKeypairPublicKey, + Challenge: c.sendChallenge, + }, + }, + }) + if err != nil { + return nil, "", trace.Wrap(err) + } + + c.gotBoundKeypairChallengeResponse = resp + + return c.returnCerts, c.boundKeypairPublicKey, c.returnError +} + func (c *mockJoinServiceClient) RegisterUsingOracleMethod( ctx context.Context, tokenReq *types.RegisterUsingTokenRequest, @@ -540,6 +566,118 @@ func TestJoinServiceGRPCServer_RegisterUsingTPMMethod(t *testing.T) { } } +// TestJoinServiceGRPCServer_RegisterUsingBoundKeypairMethodSimple tests the +// simplest bound keypair joining path, with no keypair registration or +// rotation. +func TestJoinServiceGRPCServer_RegisterUsingBoundKeypairMethodSimple(t *testing.T) { + t.Parallel() + testPack := newTestPack(t) + + standardResponse := &proto.RegisterUsingBoundKeypairMethodRequest{ + Payload: &proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse{ + ChallengeResponse: &proto.RegisterUsingBoundKeypairChallengeResponse{ + Solution: []byte("header.payload.signature"), + }, + }, + } + + testCases := []struct { + desc string + publicKey string + challenge string + req *proto.RegisterUsingBoundKeypairInitialRequest + challengeResponse *proto.RegisterUsingBoundKeypairMethodRequest + challengeResponseErr error + authErr error + certs *proto.Certs + }{ + { + desc: "success case", + challenge: "foo", + req: &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{}, + }, + challengeResponse: standardResponse, + certs: &proto.Certs{SSH: []byte("qux")}, + }, + { + desc: "auth error", + challenge: "foo", + req: &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{}, + }, + challengeResponse: standardResponse, + authErr: trace.AccessDenied("not allowed"), + }, + { + desc: "challenge response error", + challenge: "foo", + req: &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{}, + }, + challengeResponse: nil, + challengeResponseErr: trace.BadParameter("testing error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + testPack.mockAuthServer.sendChallenge = tc.challenge + testPack.mockAuthServer.returnCerts = tc.certs + testPack.mockAuthServer.returnError = tc.authErr + + challengeResponder := func( + challenge *proto.RegisterUsingBoundKeypairMethodResponse, + ) (*proto.RegisterUsingBoundKeypairMethodRequest, error) { + assert.Equal(t, &proto.RegisterUsingBoundKeypairMethodResponse{ + Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{ + Challenge: &proto.RegisterUsingBoundKeypairChallenge{ + PublicKey: tc.publicKey, + Challenge: tc.challenge, + }, + }, + }, challenge) + + return tc.challengeResponse, tc.challengeResponseErr + } + + for suffix, clt := range map[string]*client.JoinServiceClient{ + "_auth": testPack.authClient, + "_proxy": testPack.proxyClient, + } { + t.Run(tc.desc+suffix, func(t *testing.T) { + certs, pubKey, err := clt.RegisterUsingBoundKeypairMethod( + context.Background(), tc.req, challengeResponder, + ) + if tc.challengeResponseErr != nil { + require.ErrorIs(t, err, tc.challengeResponseErr) + return + } + if tc.authErr != nil { + require.Error(t, err) + require.Contains(t, err.Error(), tc.authErr.Error()) + return + } + require.NoError(t, err) + require.Equal(t, tc.certs, certs) + + expectedInitReq := tc.req + expectedInitReq.JoinRequest.RemoteAddr = "bufconn" + assert.Equal(t, expectedInitReq, testPack.mockAuthServer.gotBoundKeypairInitReq) + + assert.Equal( + t, + tc.challengeResponse, + testPack.mockAuthServer.gotBoundKeypairChallengeResponse, + ) + + require.Equal(t, tc.publicKey, pubKey) + }) + } + }) + } +} + func TestTimeout(t *testing.T) { t.Parallel() diff --git a/lib/services/local/provisioning.go b/lib/services/local/provisioning.go index 4d9eeac954a41..e2a018a988f8f 100644 --- a/lib/services/local/provisioning.go +++ b/lib/services/local/provisioning.go @@ -52,6 +52,62 @@ func (s *ProvisioningService) UpsertToken(ctx context.Context, p types.Provision return nil } +// PatchToken uses the supplied function to attempt to patch a token resource. +// Up to 3 update attempts will be made if the conditional update fails due to +// a revision comparison failure. +func (s *ProvisioningService) PatchToken( + ctx context.Context, + tokenName string, + updateFn func(types.ProvisionToken) (types.ProvisionToken, error), +) (types.ProvisionToken, error) { + const iterLimit = 3 + + for i := 0; i < iterLimit; i++ { + existing, err := s.GetToken(ctx, tokenName) + if err != nil { + return nil, trace.Wrap(err) + } + + // Note: CloneProvisionToken only supports ProvisionTokenV2. + clone, err := services.CloneProvisionToken(existing) + if err != nil { + return nil, trace.Wrap(err) + } + + updated, err := updateFn(clone) + if err != nil { + return nil, trace.Wrap(err) + } + + updatedMetadata := updated.GetMetadata() + existingMetadata := existing.GetMetadata() + + switch { + case updatedMetadata.GetName() != existingMetadata.GetName(): + return nil, trace.BadParameter("metadata.name: cannot be patched") + case updatedMetadata.GetRevision() != existingMetadata.GetRevision(): + return nil, trace.BadParameter("metadata.revision: cannot be patched") + } + + item, err := s.tokenToItem(updated) + if err != nil { + return nil, trace.Wrap(err) + } + + lease, err := s.ConditionalUpdate(ctx, *item) + if trace.IsCompareFailed(err) { + continue + } else if err != nil { + return nil, trace.Wrap(err) + } + + updated.SetRevision(lease.Revision) + return updated, nil + } + + return nil, trace.CompareFailed("failed to update provision token within %v iterations", iterLimit) +} + // CreateToken creates a new token for the auth server func (s *ProvisioningService) CreateToken(ctx context.Context, p types.ProvisionToken) error { item, err := s.tokenToItem(p) diff --git a/lib/services/provisioning.go b/lib/services/provisioning.go index 8e17507fec536..185787a6c63ec 100644 --- a/lib/services/provisioning.go +++ b/lib/services/provisioning.go @@ -25,6 +25,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/utils" ) @@ -48,6 +49,14 @@ type Provisioner interface { // GetTokens returns all non-expired tokens GetTokens(ctx context.Context) ([]types.ProvisionToken, error) + + // PatchToken performs a conditional update on the named token using + // `updateFn`, retrying internally if a comparison failure occurs. + PatchToken( + ctx context.Context, + token string, + updateFn func(types.ProvisionToken) (types.ProvisionToken, error), + ) (types.ProvisionToken, error) } // MustCreateProvisionToken returns a new valid provision token @@ -127,3 +136,16 @@ func MarshalProvisionToken(provisionToken types.ProvisionToken, opts ...MarshalO return nil, trace.BadParameter("unrecognized provision token version %T", provisionToken) } } + +// CloneProvisionToken returns a deep copy of the given provision token, per +// `apiutils.CloneProtoMsg()`. Fields in the clone may be modified without +// affecting the original. Only V2 is supported. +func CloneProvisionToken(provisionToken types.ProvisionToken) (types.ProvisionToken, error) { + switch provisionToken := provisionToken.(type) { + case *types.ProvisionTokenV2: + clone := apiutils.CloneProtoMsg(provisionToken) + return clone, nil + default: + return nil, trace.BadParameter("cannot clone unsupported provision token version %T", provisionToken) + } +} diff --git a/lib/web/join_tokens_test.go b/lib/web/join_tokens_test.go index 39aa9511abef7..d605984443c07 100644 --- a/lib/web/join_tokens_test.go +++ b/lib/web/join_tokens_test.go @@ -47,6 +47,8 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/automaticupgrades" + "github.com/gravitational/teleport/lib/boundkeypair" + "github.com/gravitational/teleport/lib/boundkeypair/boundkeypairexperiment" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/modules" @@ -479,6 +481,9 @@ func TestCreateTokenExpiry(t *testing.T) { }, }) + // TODO: Remove this once bound keypair experiment flag is removed. + boundkeypairexperiment.SetEnabled(true) + ctx := context.Background() username := "test-user@example.com" env := newWebPack(t, 1) @@ -623,6 +628,15 @@ func setMinimalConfigForMethod(spec *types.ProvisionTokenSpecV2, method types.Jo }, }, } + case types.JoinMethodBoundKeypair: + spec.BoundKeypair = &types.ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: "abcd", + }, + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + Mode: boundkeypair.RecoveryModeInsecure, + }, + } } }