From c8c7f653ee2c0bed03df830d0038300a0aad36ec Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Mon, 28 Apr 2025 21:05:36 -0600 Subject: [PATCH 01/17] MWI: Minimal bound-keypair joining implementation This includes a minimal implementation of bound-keypair joining. This first iteration requires preregistered public keys, and requires `unlimited` and `insecure` flags to be set on bound keypair tokens. Minimal client-side implementation will be in a follow up PR. RFD: #52546 Closes #53373 --- api/client/joinservice.go | 79 +++ api/types/provisioning.go | 29 ++ lib/auth/auth_with_roles.go | 12 + lib/auth/authclient/clt.go | 17 + lib/auth/join.go | 13 +- lib/auth/join/boundkeypair/boundkeypair.go | 213 ++++++++ lib/auth/join_azure.go | 2 +- lib/auth/join_bound_keypair.go | 464 ++++++++++++++++++ lib/auth/join_iam.go | 2 +- lib/auth/join_oracle.go | 2 +- lib/auth/join_tpm.go | 2 +- lib/auth/machineid/machineidv1/bot_service.go | 1 + lib/boundkeypair/bound_keypair.go | 119 +++++ lib/boundkeypair/experiment/experiment.go | 40 ++ lib/cryptosuites/suites.go | 14 +- lib/joinserver/joinserver.go | 98 ++++ lib/jwt/jwt.go | 5 +- lib/services/local/provisioning.go | 56 +++ lib/services/provisioning.go | 22 + 19 files changed, 1175 insertions(+), 15 deletions(-) create mode 100644 lib/auth/join/boundkeypair/boundkeypair.go create mode 100644 lib/auth/join_bound_keypair.go create mode 100644 lib/boundkeypair/bound_keypair.go create mode 100644 lib/boundkeypair/experiment/experiment.go diff --git a/api/client/joinservice.go b/api/client/joinservice.go index bc30c8a541e6c..dace70202bfeb 100644 --- a/api/client/joinservice.go +++ b/api/client/joinservice.go @@ -18,6 +18,7 @@ package client import ( "context" + "io" "github.com/gravitational/trace" @@ -60,6 +61,11 @@ type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredenti // *proto.OracleSignedRequest for a given challenge, or an error. type RegisterOracleChallengeResponseFunc func(challenge string) (*proto.OracleSignedRequest, error) +// RegisterUsingBoundKeypairChallengeResponseFunc is a function to be passed to +// RegisterUsingBoundKeypair. It must return a message containing a signed +// response for the given challenge, or an error. +type RegisterUsingBoundKeypairChallengeResponseFunc func(publicKey string, challenge string) (*proto.RegisterUsingBoundKeypairChallengeResponse, error) + // RegisterUsingIAMMethod registers the caller using the IAM join method and // returns signed certs to join the cluster. // @@ -262,6 +268,79 @@ func (c *JoinServiceClient) RegisterUsingOracleMethod( return certs, nil } +func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( + ctx context.Context, + initReq *proto.RegisterUsingBoundKeypairInitialRequest, + challengeFunc RegisterUsingBoundKeypairChallengeResponseFunc, +) (*proto.Certs, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + stream, err := c.grpcClient.RegisterUsingBoundKeypairMethod(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + defer stream.CloseSend() + + err = stream.Send(&proto.RegisterUsingBoundKeypairMethodRequest{ + Payload: &proto.RegisterUsingBoundKeypairMethodRequest_Init{ + Init: initReq, + }, + }) + if err != nil { + return nil, trace.Wrap(err, "sending initial request") + } + + // Unlike other methods, the server may send multiple challenges, + // particularly during keypair rotation. We'll iterate through all responses + // here instead to ensure we handle everything. + for { + res, err := stream.Recv() + if err == io.EOF { + break + } else if err != nil { + return nil, trace.Wrap(err, "receiving intermediate bound keypair join response") + } + + switch kind := res.GetResponse().(type) { + case *proto.RegisterUsingBoundKeypairMethodResponse_Challenge: + solution, err := challengeFunc(kind.Challenge.PublicKey, kind.Challenge.Challenge) + if err != nil { + return nil, trace.Wrap(err, "solving challenge") + } + + err = stream.Send(&proto.RegisterUsingBoundKeypairMethodRequest{ + Payload: &proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse{ + ChallengeResponse: solution, + }, + }) + if err != nil { + return nil, trace.Wrap(err, "sending solution") + } + case *proto.RegisterUsingBoundKeypairMethodResponse_Certs: + certs := kind.Certs.GetCerts() + if certs == nil { + return nil, trace.BadParameter("expected Certs, got %T", kind.Certs.Certs) + } + + // If we receive a cert bundle, we can return early. Even if we + // logically should have expected to receive a 2nd challenge if we + // e.g. requested keypair rotation, skipping it just means the new + // keypair won't be stored. That said, we'll rely on the server to + // raise an error if rotation fails or is otherwise skipped or not + // allowed. + + return certs, nil + default: + return nil, trace.BadParameter("received unexpected challenge response: %v", res.GetResponse()) + } + } + + // Ideally the server will emit a proper error instead of just hanging up on + // us. + return nil, trace.AccessDenied("server declined to send certs during bound-keypair join attempt") +} + // RegisterUsingToken registers the caller using a token and returns signed // certs. // This is used where a more specific RPC has not been introduced for the join diff --git a/api/types/provisioning.go b/api/types/provisioning.go index 94ef7b356d545..3c49fe11bb4f0 100644 --- a/api/types/provisioning.go +++ b/api/types/provisioning.go @@ -80,6 +80,9 @@ const ( // JoinMethodOracle indicates that the node will join using the Oracle join // method. JoinMethodOracle JoinMethod = "oracle" + // JoinMethodBoundKeypair indicates the node will join using the Bound + // Keypair join method. See lib/boundkeypair for more. + JoinMethodBoundKeypair JoinMethod = "bound_keypair" ) var JoinMethods = []JoinMethod{ @@ -97,6 +100,7 @@ var JoinMethods = []JoinMethod{ JoinMethodTPM, JoinMethodTerraformCloud, JoinMethodOracle, + JoinMethodBoundKeypair, } func ValidateJoinMethod(method JoinMethod) error { @@ -401,6 +405,18 @@ func (p *ProvisionTokenV2) CheckAndSetDefaults() error { if err := providerCfg.checkAndSetDefaults(); err != nil { return trace.Wrap(err, "spec.oracle: failed validation") } + case JoinMethodBoundKeypair: + providerCfg := p.Spec.BoundKeypair + if providerCfg == nil { + return trace.BadParameter( + "spec.bound_keypair: must be configured for the join method %q", + JoinMethodBoundKeypair, + ) + } + + if err := providerCfg.checkAndSetDefaults(); err != nil { + return trace.Wrap(err, "spec.bound_keypair: failed validation") + } default: return trace.BadParameter("unknown join method %q", p.Spec.JoinMethod) } @@ -951,3 +967,16 @@ func (a *ProvisionTokenSpecV2Oracle) checkAndSetDefaults() error { } return nil } + +func (a *ProvisionTokenSpecV2BoundKeypair) checkAndSetDefaults() error { + if a.Onboarding == nil { + return trace.BadParameter("spec.bound_keypair.onboarding is required") + } + + if a.Onboarding.InitialJoinSecret == "" && a.Onboarding.InitialPublicKey == "" { + return trace.BadParameter("at least one of [initial_join_secret, " + + "initial_public_key] is required in spec.bound_keypair.onboarding") + } + + return nil +} diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index a80f4914f21b5..129b8bc7ab20e 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -2434,6 +2434,12 @@ func (a *ServerWithRoles) UpsertToken(ctx context.Context, token types.Provision return trace.Wrap(err) } + // bound_keypair tokens have special creation/update logic and are handled + // separately + if token.GetJoinMethod() == types.JoinMethodBoundKeypair { + return trace.Wrap(a.authServer.UpsertBoundKeypairToken(ctx, token)) + } + if err := a.authServer.UpsertToken(ctx, token); err != nil { return trace.Wrap(err) } @@ -2459,6 +2465,12 @@ func (a *ServerWithRoles) CreateToken(ctx context.Context, token types.Provision return trace.Wrap(err) } + // bound_keypair tokens have special creation/update logic and are handled + // separately + if token.GetJoinMethod() == types.JoinMethodBoundKeypair { + return trace.Wrap(a.authServer.CreateBoundKeypairToken(ctx, token)) + } + if err := a.authServer.CreateToken(ctx, token); err != nil { return trace.Wrap(err) } diff --git a/lib/auth/authclient/clt.go b/lib/auth/authclient/clt.go index 7cebd85ce9895..5e5da68ee9ded 100644 --- a/lib/auth/authclient/clt.go +++ b/lib/auth/authclient/clt.go @@ -305,6 +305,15 @@ func (c *Client) DeleteAllTokens() error { return trace.NotImplemented(notImplementedMessage) } +// PatchToken not implemented: can only be called locally +func (c *Client) PatchToken( + ctx context.Context, + token string, + updateFn func(types.ProvisionToken) (types.ProvisionToken, error), +) (types.ProvisionToken, error) { + return nil, trace.NotImplemented(notImplementedMessage) +} + // AddUserLoginAttempt logs user login attempt func (c *Client) AddUserLoginAttempt(user string, attempt services.LoginAttempt, ttl time.Duration) error { panic("not implemented") @@ -1235,6 +1244,14 @@ type ProvisioningService interface { // CreateToken creates a new provision token for the auth server CreateToken(ctx context.Context, token types.ProvisionToken) error + // PatchToken performs a conditional update on the named token using + // `updateFn`, retrying internally if a comparison failure occurs. + PatchToken( + ctx context.Context, + token string, + updateFn func(types.ProvisionToken) (types.ProvisionToken, error), + ) (types.ProvisionToken, error) + // RegisterUsingToken calls the auth service API to register a new node via registration token // which has been previously issued via GenerateToken RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) diff --git a/lib/auth/join.go b/lib/auth/join.go index c96d96840034b..38acfc57a77e2 100644 --- a/lib/auth/join.go +++ b/lib/auth/join.go @@ -227,7 +227,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin if err := a.checkEC2JoinRequest(ctx, req); err != nil { return nil, trace.Wrap(err) } - case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM, types.JoinMethodOracle: + case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM, + types.JoinMethodOracle, types.JoinMethodBoundKeypair: // These join methods must use gRPC register methods return nil, trace.AccessDenied("this token is only valid for the %s "+ "join method but the node has connected to the wrong endpoint, make "+ @@ -322,7 +323,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin // With all elements of the token validated, we can now generate & return // certificates. if req.Role == types.RoleBot { - certs, err = a.generateCertsBot( + certs, _, err = a.generateCertsBot( ctx, provisionToken, req, @@ -341,7 +342,7 @@ func (a *Server) generateCertsBot( req *types.RegisterUsingTokenRequest, rawJoinClaims any, attrs *workloadidentityv1pb.JoinAttrs, -) (*proto.Certs, error) { +) (*proto.Certs, string, error) { // bots use this endpoint but get a user cert // botResourceName must be set, enforced in CheckAndSetDefaults botName := provisionToken.GetBotName() @@ -349,7 +350,7 @@ func (a *Server) generateCertsBot( // Check this is a join method for bots we support. if !slices.Contains(machineidv1.SupportedJoinMethods, joinMethod) { - return nil, trace.BadParameter( + return nil, "", trace.BadParameter( "unsupported join method %q for bot", joinMethod, ) } @@ -444,7 +445,7 @@ func (a *Server) generateCertsBot( attrs, ) if err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } joinEvent.BotInstanceID = botInstanceID @@ -466,7 +467,7 @@ func (a *Server) generateCertsBot( if err := a.emitter.EmitAuditEvent(ctx, joinEvent); err != nil { a.logger.WarnContext(ctx, "Failed to emit bot join event", "error", err) } - return certs, nil + return certs, botInstanceID, nil } func (a *Server) generateCerts( diff --git a/lib/auth/join/boundkeypair/boundkeypair.go b/lib/auth/join/boundkeypair/boundkeypair.go new file mode 100644 index 0000000000000..545c1dbef8f0e --- /dev/null +++ b/lib/auth/join/boundkeypair/boundkeypair.go @@ -0,0 +1,213 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package boundkeypair + +import ( + "context" + "os" + "path/filepath" + + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/auth/join" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" +) + +const ( + PrivateKeyPath = "id_bkp" + PublicKeyPath = PrivateKeyPath + ".pub" + JoinStatePath = "bkp_state" + + StandardFileWriteMode = 0600 +) + +// ClientState contains state parameters stored on disk needed to complete the +// bound keypair join process. +type ClientState struct { + // PrivateKey is the parsed private key. + PrivateKey *keys.PrivateKey + + // PrivateKeyBytes contains the private key bytes. This value should always + // be nonempty. + PrivateKeyBytes []byte + + // PublicKeyBytes contains the public key bytes. This value is not used at + // runtime, and is only set when a public key should be written to disk, + // like on first creation or during rotation. To consistently access the + // public key, use `.PrivateKey.Public()`. + PublicKeyBytes []byte + + // JoinStateBytes contains join state bytes. This value will be empty if + // this client has not yet joined. + JoinStateBytes []byte +} + +// ToJoinParams creates joining parameters for use with `join.Register()` from +// this client state. +func (c *ClientState) ToJoinParams(initialJoinSecret string) *join.BoundKeypairParams { + if len(c.JoinStateBytes) > 0 { + // This identity has been bound, so don't pass along the join secret (if + // any) + initialJoinSecret = "" + } + + return &join.BoundKeypairParams{ + // Note: pass the internal signer because go-jose does type assertions + // on the standard library types. + CurrentKey: c.PrivateKey.Signer, + PreviousJoinState: c.JoinStateBytes, + InitialJoinSecret: initialJoinSecret, + } +} + +// ToPublicKeyBytes returns the public key bytes in ssh authorized_keys format. +func (c *ClientState) ToPublicKeyBytes() ([]byte, error) { + sshPubKey, err := ssh.NewPublicKey(c.PrivateKey.Public()) + if err != nil { + return nil, trace.Wrap(err, "creating ssh public key") + } + + return ssh.MarshalAuthorizedKey(sshPubKey), nil +} + +type FS interface { + Read(ctx context.Context, name string) ([]byte, error) + Write(ctx context.Context, name string, data []byte) error +} + +type StandardFS struct { + parentDir string +} + +func (f *StandardFS) Read(ctx context.Context, name string) ([]byte, error) { + data, err := os.ReadFile(name) + if err != nil { + return nil, trace.Wrap(err) + } + + return data, nil +} + +func (f *StandardFS) Write(ctx context.Context, name string, data []byte) error { + path := filepath.Join(f.parentDir, name) + + return trace.Wrap(os.WriteFile(path, data, StandardFileWriteMode)) +} + +// NewStandardFS creates a new standard FS implementation. +func NewStandardFS(parentDir string) FS { + return &StandardFS{ + parentDir: parentDir, + } +} + +// LoadClientState attempts to load bound keypair client state from the given +// filesystem implementation. Callers should expect to handle NotFound errors +// returned here if a private key is not found; this indicates no prior client +// state exists and initial secret joining should be attempted if possible. If +// a keypair has been pregenerated, no prior join state will exist, and the +// join state will be empty; any corresponding errors while reading nonexistent +// join state documents will be ignored. +func LoadClientState(ctx context.Context, fs FS) (*ClientState, error) { + privateKeyBytes, err := fs.Read(ctx, PrivateKeyPath) + if err != nil { + return nil, trace.Wrap(err, "reading private key") + } + + joinStateBytes, err := fs.Read(ctx, JoinStatePath) + if trace.IsNotFound(err) { + // Join state doesn't exist, this is allowed. + } else if err != nil { + return nil, trace.Wrap(err, "reading previous join state") + } + + pk, err := keys.ParsePrivateKey(privateKeyBytes) + if err != nil { + return nil, trace.Wrap(err, "parsing private key") + } + + return &ClientState{ + PrivateKey: pk, + + PrivateKeyBytes: privateKeyBytes, + JoinStateBytes: joinStateBytes, + }, nil +} + +// StoreClientState writes bound keypair client state to the given filesystem +// wrapper. Public keys and join state will only be written if +func StoreClientState(ctx context.Context, fs FS, state *ClientState) error { + if err := fs.Write(ctx, PrivateKeyPath, state.PrivateKeyBytes); err != nil { + return trace.Wrap(err, "writing private key") + } + + // TODO: maybe consider just not writing the public key at all. End users + // aren't really meant to look in the internal storage, and we can just + // derive the public key whenever we want. + + // Only write the public key if it was explicitly provided. This helps save + // an unnecessary file write. + if len(state.PublicKeyBytes) > 0 { + if err := fs.Write(ctx, PublicKeyPath, state.PublicKeyBytes); err != nil { + return trace.Wrap(err, "writing public key") + } + } + + if len(state.JoinStateBytes) > 0 { + if err := fs.Write(ctx, JoinStatePath, state.JoinStateBytes); err != nil { + return trace.Wrap(err, "writing previous join state") + } + } + + return nil +} + +// NewUnboundClientState creates a new client state that has not yet been bound, +// i.e. a new keypair that has not been registered with Auth, and no prior join +// state. +func NewUnboundClientState(ctx context.Context, getSuite cryptosuites.GetSuiteFunc) (*ClientState, error) { + key, err := cryptosuites.GenerateKey(ctx, getSuite, cryptosuites.BoundKeypairJoining) + if err != nil { + return nil, trace.Wrap(err, "generating keypair") + } + + privateKeyBytes, err := keys.MarshalPrivateKey(key) + if err != nil { + return nil, trace.Wrap(err, "marshalling private key") + } + + sshPubKey, err := ssh.NewPublicKey(key.Public()) + if err != nil { + return nil, trace.Wrap(err, "creating ssh public key") + } + + publicKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey) + + pk, err := keys.NewPrivateKey(key, privateKeyBytes) + if err != nil { + return nil, trace.Wrap(err) + } + + return &ClientState{ + PrivateKeyBytes: privateKeyBytes, + PublicKeyBytes: publicKeyBytes, + PrivateKey: pk, + }, nil +} diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 9461a412ed7c5..02db610c710c5 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -489,7 +489,7 @@ func (a *Server) RegisterUsingAzureMethodWithOpts( } if req.RegisterUsingTokenRequest.Role == types.RoleBot { - certs, err := a.generateCertsBot( + certs, _, err := a.generateCertsBot( ctx, provisionToken, req.RegisterUsingTokenRequest, diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go new file mode 100644 index 0000000000000..a2044d65338cd --- /dev/null +++ b/lib/auth/join_bound_keypair.go @@ -0,0 +1,464 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package auth + +import ( + "context" + "encoding/json" + "time" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/boundkeypair" + "github.com/gravitational/teleport/lib/boundkeypair/experiment" + "github.com/gravitational/teleport/lib/jwt" + libsshutils "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/trace" +) + +// validateBoundKeypairTokenSpec performs some basic validation checks on a +// bound_keypair-type join token. +func validateBoundKeypairTokenSpec(spec *types.ProvisionTokenSpecV2BoundKeypair) error { + // Various constant checks, shared between creation and update. Many of + // these checks are temporary and will be removed alongside the experiment + // flag. + if !experiment.Enabled() { + return trace.BadParameter("bound keypair joining experiment is not enabled") + } + + if spec.RotateOnNextRenewal { + return trace.NotImplemented("spec.bound_keypair.rotate_on_next_renewal is not yet implemented") + } + + if spec.Onboarding.InitialJoinSecret != "" { + return trace.NotImplemented("spec.bound_keypair.initial_join_secret is not yet implemented") + } + + if spec.Onboarding.InitialPublicKey == "" { + return trace.NotImplemented("spec.bound_keypair.initial_public_key is currently required") + } + + if !spec.Joining.Unlimited { + return trace.NotImplemented("spec.bound_keypair.joining.unlimited cannot currently be `false`") + } + + if !spec.Joining.Insecure { + return trace.NotImplemented("spec.bound_keypair.joining.insecure cannot currently be `false`") + } + + return nil +} + +func (a *Server) initialBoundKeypairStatus(spec *types.ProvisionTokenSpecV2BoundKeypair) *types.ProvisionTokenStatusV2BoundKeypair { + return &types.ProvisionTokenStatusV2BoundKeypair{ + InitialJoinSecret: spec.Onboarding.InitialJoinSecret, + BoundPublicKey: spec.Onboarding.InitialPublicKey, + } +} + +func (a *Server) CreateBoundKeypairToken(ctx context.Context, token types.ProvisionToken) error { + if token.GetJoinMethod() != types.JoinMethodBoundKeypair { + return trace.BadParameter("must be called with a bound keypair token") + } + + tokenV2, ok := token.(*types.ProvisionTokenV2) + if !ok { + return trace.BadParameter("%v join method requires ProvisionTokenV2", types.JoinMethodOracle) + } + + spec := tokenV2.Spec.BoundKeypair + if spec == nil { + return trace.BadParameter("bound_keypair token requires non-nil spec.bound_keypair") + } + + if err := validateBoundKeypairTokenSpec(spec); err != nil { + return trace.Wrap(err) + } + + // TODO: Is this wise? End users _shouldn't_ modify this, but this could + // interfere with cluster backup/restore. Options seem to be: + // - Let users create/update with status fields. They can break things, but + // maybe that's okay. No backup/restore implications. + // - Ignore status fields during creation and update. Any set value will be + // discarded here, and during update. This would still have consequences + // during cluster restores, but wouldn't raise errors, and the status + // field would otherwise be protected from easy tampering. Users might be + // confused as no user-visible errors would be raised if they used + // `tctl edit`. + // - Raise an error if status fields are changed. Worst restore + // implications, but tampering won't be easy, and will have some UX. + if tokenV2.Status.BoundKeypair != nil { + return trace.BadParameter("cannot create a bound_keypair token with set status") + } + + // TODO: Populate initial_join_secret if needed. + + return trace.Wrap(a.CreateToken(ctx, tokenV2)) +} + +func (a *Server) UpsertBoundKeypairToken(ctx context.Context, token types.ProvisionToken) error { + if token.GetJoinMethod() != types.JoinMethodBoundKeypair { + return trace.BadParameter("must be called with a bound keypair token") + } + + tokenV2, ok := token.(*types.ProvisionTokenV2) + if !ok { + return trace.BadParameter("%v join method requires ProvisionTokenV2", types.JoinMethodOracle) + } + + spec := tokenV2.Spec.BoundKeypair + if spec == nil { + return trace.BadParameter("bound_keypair token requires non-nil spec.bound_keypair") + } + + if err := validateBoundKeypairTokenSpec(spec); err != nil { + return trace.Wrap(err) + } + + // TODO: Populate initial_join_secret if needed, but only if no previous + // resource exists. + + // TODO: Probably won't want to tweak status here; that's best done during + // the join ceremony. + // if tokenV2.Status == nil { + // tokenV2.Status = &types.ProvisionTokenStatusV2{} + // } + + // if tokenV2.Status.BoundKeypair == nil { + // tokenV2.Status.BoundKeypair = a.initialBoundKeypairStatus(spec) + // } + + // TODO: Follow up changes to include: + // - Compare and swap / conditional updates + // - Proper checking for previous resource + return trace.Wrap(a.UpsertToken(ctx, token)) +} + +func (a *Server) issueBoundKeypairChallenge( + ctx context.Context, + marshalledKey string, + challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, +) error { + key, err := libsshutils.CryptoPublicKey([]byte(marshalledKey)) + if err != nil { + return trace.Wrap(err, "parsing bound public key") + } + + // The particular subject value doesn't strictly need to be the name of the + // bot or node (which may not be known, yet). Instead, we'll use the key ID, + // which could at least be useful for the client to know which key the + // challenge should be signed with. + keyID, err := jwt.KeyID(key) + if err != nil { + return trace.Wrap(err, "determining the key ID") + } + + clusterName, err := a.GetClusterName(ctx) + if err != nil { + return trace.Wrap(err) + } + + a.logger.DebugContext(ctx, "Server.issueBoundKeypairChallenge(): preflight complete, issuing challenge", "pk", marshalledKey, "id", keyID) + + validator, err := boundkeypair.NewChallengeValidator(keyID, clusterName.GetClusterName(), key) + if err != nil { + return trace.Wrap(err) + } + + challenge, err := validator.IssueChallenge() + if err != nil { + return trace.Wrap(err, "generating a challenge document") + } + + a.logger.DebugContext(ctx, "Server.issueBoundKeypairChallenge(): issued new challenge", "challenge", challenge) + + marshalledChallenge, err := json.Marshal(challenge) + if err != nil { + return trace.Wrap(err) + } + + a.logger.InfoContext(ctx, "requesting signed bound keypair joining challenge") + + response, err := challengeResponse(marshalledKey, string(marshalledChallenge)) + if err != nil { + return trace.Wrap(err, "requesting a signed challenge") + } + + a.logger.DebugContext(ctx, "Server.issueBoundKeypairChallenge(): challenge complete, verifying", "response", response) + + if err := validator.ValidateChallengeResponse(challenge.Nonce, string(response.Solution)); err != nil { + // TODO: access denied instead? + return trace.Wrap(err, "validating challenge response") + } + + a.logger.InfoContext(ctx, "bound keypair challenge response verified successfully") + + return nil +} + +// boundKeypairStatusMutator is a function called to mutate a bound keypair +// status during a call to PatchProvisionToken(). These functions may be called +// repeatedly if e.g. revision checks fail. To ensure invariants remain in +// place, mutator functions may make assertions +type boundKeypairStatusMutator func(*types.ProvisionTokenStatusV2BoundKeypair) error + +func mutateStatusConsumeJoin(unlimited bool, expectRemainingJoins uint32) boundKeypairStatusMutator { + now := time.Now() + + return func(status *types.ProvisionTokenStatusV2BoundKeypair) error { + // Ensure we have the expected number of rejoins left to prevent going + // below zero. + // TODO: this could be >=? would avoid breaking if this happens to + // collide with a user incrementing TotalJoins. + if status.RemainingJoins != expectRemainingJoins { + return trace.AccessDenied("unexpected backend state") + } + + status.JoinCount += 1 + status.LastJoinedAt = &now + + if !unlimited { + // TODO: decrement remaining joins (not yet implemented.) + return trace.NotImplemented("only unlimited rejoining is currently supported") + } + + return nil + } +} + +func mutateStatusBoundPublicKey(newPublicKey, expectPreviousKey string) boundKeypairStatusMutator { + return func(status *types.ProvisionTokenStatusV2BoundKeypair) error { + if status.BoundPublicKey != expectPreviousKey { + return trace.AccessDenied("unexpected backend state") + } + + status.BoundPublicKey = newPublicKey + + return nil + } +} + +func mutateStatusBoundBotInstance(newBotInstance, expectPreviousBotInstance string) boundKeypairStatusMutator { + return func(status *types.ProvisionTokenStatusV2BoundKeypair) error { + if status.BoundBotInstanceID != expectPreviousBotInstance { + return trace.AccessDenied("unexpected backend state") + } + + status.BoundBotInstanceID = newBotInstance + + return nil + } +} + +func (a *Server) RegisterUsingBoundKeypairMethod( + ctx context.Context, + req *proto.RegisterUsingBoundKeypairInitialRequest, + challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, +) (_ *proto.Certs, err error) { + a.logger.DebugContext(ctx, "Server.RegisterUsingBoundKeypairMethod()", "req", req) + + var provisionToken types.ProvisionToken + var joinFailureMetadata any + defer func() { + // Emit a log message and audit event on join failure. + if err != nil { + a.handleJoinFailure( + ctx, err, provisionToken, joinFailureMetadata, req.JoinRequest, + ) + } + }() + + // First, check the specified token exists, and is a bound keypair-type join + // token. + if err := req.JoinRequest.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + // Only bot joining is supported at the moment - unique ID verification is + // required and this is currently only implemented for bots. + if req.JoinRequest.Role != types.RoleBot { + return nil, trace.BadParameter("bound keypair joining is only supported for bots") + } + + provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.JoinRequest) + if err != nil { + return nil, trace.Wrap(err) + } + ptv2, ok := provisionToken.(*types.ProvisionTokenV2) + if !ok { + return nil, trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken) + } + if ptv2.Spec.JoinMethod != types.JoinMethodBoundKeypair { + return nil, trace.BadParameter("specified join token is not for `%s` method", types.JoinMethodBoundKeypair) + } + + if ptv2.Status == nil { + ptv2.Status = &types.ProvisionTokenStatusV2{} + } + if ptv2.Status.BoundKeypair == nil { + ptv2.Status.BoundKeypair = &types.ProvisionTokenStatusV2BoundKeypair{} + } + + spec := ptv2.Spec.BoundKeypair + status := ptv2.Status.BoundKeypair + hasBoundPublicKey := status.BoundPublicKey != "" + hasBoundBotInstance := status.BoundBotInstanceID != "" + hasIncomingBotInstance := req.JoinRequest.BotInstanceID != "" + expectNewBotInstance := false + + // Mutators to use during the token resource status patch at the end. + var mutators []boundKeypairStatusMutator + + switch { + case !hasBoundPublicKey && !hasIncomingBotInstance: + // Normal initial join attempt. No bound key, and no incoming bot + // instance. Consumes a rejoin. + if spec.Onboarding.InitialJoinSecret != "" { + return nil, trace.NotImplemented("initial joining secrets are not yet supported") + } + + if spec.Onboarding.InitialPublicKey == "" { + return nil, trace.BadParameter("an initial public key is required") + } + + if !spec.Joining.Unlimited && status.RemainingJoins == 0 { + return nil, trace.AccessDenied("no rejoins remaining") + } + + if err := a.issueBoundKeypairChallenge( + ctx, + spec.Onboarding.InitialPublicKey, + challengeResponse, + ); err != nil { + return nil, trace.Wrap(err) + } + + // Now that we've confirmed the key, we can consider it bound. + mutators = append( + mutators, + mutateStatusBoundPublicKey(spec.Onboarding.InitialPublicKey, ""), + mutateStatusConsumeJoin(spec.Joining.Unlimited, status.RemainingJoins), + ) + + expectNewBotInstance = true + case !hasBoundPublicKey && hasIncomingBotInstance: + // Not allowed, at least at the moment. This would imply e.g. trying to + // change auth methods. + return nil, trace.BadParameter("cannot perform first bound keypair join with existing credentials") + case hasBoundPublicKey && !hasBoundBotInstance: + // TODO: Bad backend state, or maybe an incomplete previous join + // attempt. This shouldn't be possible state, but we should handle it + // sanely anyway. + return nil, trace.BadParameter("bad backend state, please recreate the join token") + case hasBoundPublicKey && hasBoundBotInstance && hasIncomingBotInstance: + // Standard rejoin case, does not consume a rejoin. + if status.BoundBotInstanceID != req.JoinRequest.BotInstanceID { + return nil, trace.AccessDenied("bot instance mismatch") + } + + if err := a.issueBoundKeypairChallenge( + ctx, + spec.Onboarding.InitialPublicKey, + challengeResponse, + ); err != nil { + return nil, trace.Wrap(err) + } + + // Nothing else to do, no key change + case hasBoundPublicKey && hasBoundBotInstance && !hasIncomingBotInstance: + // Hard rejoin case, the client identity expired and a new bot instance + // is required. Consumes a rejoin. + if !spec.Joining.Unlimited && status.RemainingJoins == 0 { + return nil, trace.AccessDenied("no rejoins remaining") + } + + if err := a.issueBoundKeypairChallenge( + ctx, + status.BoundPublicKey, + challengeResponse, + ); err != nil { + return nil, trace.Wrap(err) + } + + mutators = append( + mutators, + mutateStatusConsumeJoin(spec.Joining.Unlimited, status.RemainingJoins), + ) + + // TODO: decrement remaining joins + + expectNewBotInstance = true + default: + a.logger.ErrorContext( + ctx, "unexpected state", + "hasBoundPublicKey", hasBoundPublicKey, + "hasBoundBotInstance", hasBoundBotInstance, + "hasIncomingBotInstance", hasIncomingBotInstance, + "spec", spec, + "status", status, + ) + return nil, trace.BadParameter("unexpected state") + } + + if req.NewPublicKey != "" { + // TODO + return nil, trace.NotImplemented("key rotation not yet implemented") + } + + a.logger.DebugContext(ctx, "Server.RegisterUsingBoundKeypairMethod(): challenge verified, issuing certs") + + certs, botInstanceID, err := a.generateCertsBot( + ctx, + ptv2, + req.JoinRequest, + nil, // TODO: extended claims for this type? + nil, // TODO: workload id claims + ) + + if expectNewBotInstance { + mutators = append( + mutators, + mutateStatusBoundBotInstance(botInstanceID, status.BoundBotInstanceID), + ) + } + + if len(mutators) > 0 { + if _, err := a.PatchToken(ctx, ptv2.GetName(), func(token types.ProvisionToken) (types.ProvisionToken, error) { + ptv2, ok := provisionToken.(*types.ProvisionTokenV2) + if !ok { + return nil, trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken) + } + + // Apply all mutators. Individual mutators may make additional + // assertions to ensure invariants haven't changed. + for _, mutator := range mutators { + if err := mutator(ptv2.Status.BoundKeypair); err != nil { + return nil, trace.Wrap(err, "applying status mutator") + } + } + + return ptv2, nil + }); err != nil { + return nil, trace.Wrap(err, "commiting updated token state, please try again") + } + } + + return certs, trace.Wrap(err) +} diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index 0f0bddbb2899b..d20a3ddcb6b38 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -407,7 +407,7 @@ func (a *Server) RegisterUsingIAMMethodWithOpts( } if req.RegisterUsingTokenRequest.Role == types.RoleBot { - certs, err := a.generateCertsBot( + certs, _, err := a.generateCertsBot( ctx, provisionToken, req.RegisterUsingTokenRequest, diff --git a/lib/auth/join_oracle.go b/lib/auth/join_oracle.go index 457d238a437e7..95349ed8c7056 100644 --- a/lib/auth/join_oracle.go +++ b/lib/auth/join_oracle.go @@ -83,7 +83,7 @@ func (a *Server) registerUsingOracleMethod( } if tokenReq.Role == types.RoleBot { - certs, err := a.generateCertsBot( + certs, _, err := a.generateCertsBot( ctx, provisionToken, tokenReq, diff --git a/lib/auth/join_tpm.go b/lib/auth/join_tpm.go index df2e6b4e4cbcc..31334855ddd61 100644 --- a/lib/auth/join_tpm.go +++ b/lib/auth/join_tpm.go @@ -112,7 +112,7 @@ func (a *Server) RegisterUsingTPMMethod( } if initReq.JoinRequest.Role == types.RoleBot { - certs, err := a.generateCertsBot( + certs, _, err := a.generateCertsBot( ctx, ptv2, initReq.JoinRequest, diff --git a/lib/auth/machineid/machineidv1/bot_service.go b/lib/auth/machineid/machineidv1/bot_service.go index 139612f54b056..ca3aeeb0b451a 100644 --- a/lib/auth/machineid/machineidv1/bot_service.go +++ b/lib/auth/machineid/machineidv1/bot_service.go @@ -58,6 +58,7 @@ var SupportedJoinMethods = []types.JoinMethod{ types.JoinMethodTPM, types.JoinMethodTerraformCloud, types.JoinMethodBitbucket, + types.JoinMethodBoundKeypair, } // BotResourceName returns the default name for resources associated with the diff --git a/lib/boundkeypair/bound_keypair.go b/lib/boundkeypair/bound_keypair.go new file mode 100644 index 0000000000000..bffc4a7116ded --- /dev/null +++ b/lib/boundkeypair/bound_keypair.go @@ -0,0 +1,119 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package boundkeypair + +import ( + "crypto" + "crypto/subtle" + "time" + + "github.com/go-jose/go-jose/v3/jwt" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" +) + +const ( + challengeExpiration time.Duration = time.Minute +) + +type ChallengeDocument struct { + *jwt.Claims + + // Nonce is a secure random string, unique to a particular challenge + Nonce string `json:"nonce"` +} + +type ChallengeValidator struct { + clock clockwork.Clock + + subject string + clusterName string + publicKey crypto.PublicKey +} + +func NewChallengeValidator( + subject string, + clusterName string, + publicKey crypto.PublicKey, +) (*ChallengeValidator, error) { + return &ChallengeValidator{ + clock: clockwork.NewRealClock(), + + subject: subject, + clusterName: clusterName, + publicKey: publicKey, // TODO: API design issue, public key will change during rotation. should a new validator be created, or can we design this better? + }, nil +} + +func (v *ChallengeValidator) IssueChallenge() (*ChallengeDocument, error) { + // Implementation note: these challenges are only ever sent to a single + // client once, and we expect a valid reply as the next exchange in the + // join ceremony. There is not an opportunity for reuse, multiple attempts, + // or for clients to select their own nonce, so we won't bother storing it. + nonce, err := utils.CryptoRandomHex(defaults.TokenLenBytes) + if err != nil { + return nil, trace.Wrap(err, "generating nonce") + } + + return &ChallengeDocument{ + Claims: &jwt.Claims{ + Issuer: v.clusterName, + Audience: jwt.Audience{v.clusterName}, // the cluster is both the issuer and audience + NotBefore: jwt.NewNumericDate(v.clock.Now().Add(-10 * time.Second)), + IssuedAt: jwt.NewNumericDate(v.clock.Now()), + Expiry: jwt.NewNumericDate(v.clock.Now().Add(challengeExpiration)), + Subject: v.subject, + }, + Nonce: nonce, + }, nil +} + +func (v *ChallengeValidator) ValidateChallengeResponse(nonce string, compactResponse string) error { + token, err := jwt.ParseSigned(compactResponse) + if err != nil { + return trace.Wrap(err, "parsing signed response") + } + + var document ChallengeDocument + if err := token.Claims(v.publicKey, &document); err != nil { + return trace.Wrap(err) + } + + // TODO: this doesn't actually validate that the time-based fields are still + // what we assigned above; a hostile client could set their own values here. + // This may not be a realistic problem, but we might want to check it + // anyway. + const leeway time.Duration = time.Minute + if err := document.Claims.ValidateWithLeeway(jwt.Expected{ + Issuer: v.clusterName, + Subject: v.subject, + Audience: jwt.Audience{v.clusterName}, + Time: v.clock.Now(), + }, leeway); err != nil { + return trace.Wrap(err, "validating challenge claims") + } + + if subtle.ConstantTimeCompare([]byte(nonce), []byte(document.Nonce)) == 0 { + return trace.AccessDenied("invalid nonce") + } + + return nil +} diff --git a/lib/boundkeypair/experiment/experiment.go b/lib/boundkeypair/experiment/experiment.go new file mode 100644 index 0000000000000..739618e8bf46c --- /dev/null +++ b/lib/boundkeypair/experiment/experiment.go @@ -0,0 +1,40 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package experiment + +import ( + "os" + "sync" +) + +var mu sync.Mutex + +var experimentEnabled = os.Getenv("TELEPORT_BOUND_KEYPAIR_JOINING_EXPERIMENT") == "1" + +// Enabled returns true if the bound keypair joining experiment is enabled. +func Enabled() bool { + mu.Lock() + defer mu.Unlock() + return experimentEnabled +} + +// SetEnabled sets the experiment enabled flag. +func SetEnabled(enabled bool) { + mu.Lock() + defer mu.Unlock() + experimentEnabled = enabled +} diff --git a/lib/cryptosuites/suites.go b/lib/cryptosuites/suites.go index f82f2b93ab48e..17e1f56fdf834 100644 --- a/lib/cryptosuites/suites.go +++ b/lib/cryptosuites/suites.go @@ -121,6 +121,10 @@ const ( // AWSRACATLS represents the TLS key for the AWS IAM Roles Anywhere CA. AWSRACATLS + // BoundKeypairJoining represents a key used for the bound keypair joining + // identity. + BoundKeypairJoining + // keyPurposeMax is 1 greater than the last valid key purpose, used to test that all values less than this // are valid for each suite. keyPurposeMax @@ -194,9 +198,10 @@ var ( ProxyToDatabaseAgent: RSA2048, ProxyKubeClient: RSA2048, // EC2InstanceConnect has always used Ed25519 by default. - EC2InstanceConnect: Ed25519, - GitClient: Ed25519, - AWSRACATLS: ECDSAP256, + EC2InstanceConnect: Ed25519, + GitClient: Ed25519, + AWSRACATLS: ECDSAP256, + BoundKeypairJoining: Ed25519, } // balancedV1 strikes a balance between security, compatibility, and @@ -229,6 +234,7 @@ var ( EC2InstanceConnect: Ed25519, GitClient: Ed25519, AWSRACATLS: ECDSAP256, + BoundKeypairJoining: Ed25519, } // fipsv1 is an algorithm suite tailored for FIPS compliance. It is based on @@ -262,6 +268,7 @@ var ( EC2InstanceConnect: ECDSAP256, GitClient: ECDSAP256, AWSRACATLS: ECDSAP256, + BoundKeypairJoining: ECDSAP256, } // hsmv1 in an algorithm suite tailored for clusters using an HSM or KMS @@ -297,6 +304,7 @@ var ( EC2InstanceConnect: Ed25519, GitClient: Ed25519, AWSRACATLS: ECDSAP256, + BoundKeypairJoining: Ed25519, } allSuites = map[types.SignatureAlgorithmSuite]suite{ diff --git a/lib/joinserver/joinserver.go b/lib/joinserver/joinserver.go index b978bb8d61ef8..ab53c71469712 100644 --- a/lib/joinserver/joinserver.go +++ b/lib/joinserver/joinserver.go @@ -55,6 +55,11 @@ type joinServiceClient interface { tokenReq *types.RegisterUsingTokenRequest, challengeResponse client.RegisterOracleChallengeResponseFunc, ) (*proto.Certs, error) + RegisterUsingBoundKeypairMethod( + ctx context.Context, + req *proto.RegisterUsingBoundKeypairInitialRequest, + challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, + ) (*proto.Certs, error) RegisterUsingToken( ctx context.Context, req *types.RegisterUsingTokenRequest, @@ -369,6 +374,99 @@ func (s *JoinServiceGRPCServer) registerUsingOracleMethod(srv proto.JoinService_ })) } +// RegisterUsingBoundKeypairMethod registers the client using the bound-keypair +// join method, and if successful, returns a signed cert bundle for +// authenticated cluster access. +func (s *JoinServiceGRPCServer) RegisterUsingBoundKeypairMethod( + srv proto.JoinService_RegisterUsingBoundKeypairMethodServer, +) error { + slog.DebugContext(srv.Context(), "RegisterUsingBoundKeypairMethod()") + + return trace.Wrap(s.handleStreamingRegistration(srv.Context(), types.JoinMethodBoundKeypair, func() error { + return trace.Wrap(s.registerUsingBoundKeypair(srv)) + })) +} + +func (s *JoinServiceGRPCServer) registerUsingBoundKeypair(srv proto.JoinService_RegisterUsingBoundKeypairMethodServer) error { + ctx := srv.Context() + + slog.DebugContext(srv.Context(), "registerUsingBoundKeypair()") + + // Get initial payload from the client + req, err := srv.Recv() + if err != nil { + return trace.Wrap(err, "receiving initial payload") + } + initReq := req.GetInit() + if initReq == nil { + return trace.BadParameter("expected non-nil Init payload") + } + + if initReq.InitialJoinSecret != "" { + // TODO: not supported yet. + return trace.NotImplemented("initial join secrets are not yet supported") + } + + if initReq.JoinRequest == nil { + return trace.BadParameter( + "expected JoinRequest in RegisterUsingBoundKeypairInitialRequest, got nil", + ) + } + if err := setClientRemoteAddr(ctx, initReq.JoinRequest); err != nil { + return trace.Wrap(err, "setting client address") + } + + slog.DebugContext(srv.Context(), "registerUsingBoundKeypair(): preflight complete, attempting to relay challenges") + + setBotParameters(ctx, initReq.JoinRequest) + + certs, err := s.joinServiceClient.RegisterUsingBoundKeypairMethod(ctx, initReq, func(publicKey string, challenge string) (*proto.RegisterUsingBoundKeypairChallengeResponse, error) { + // First, forward the challenge from Auth to the client. + err := srv.Send(&proto.RegisterUsingBoundKeypairMethodResponse{ + Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{ + Challenge: &proto.RegisterUsingBoundKeypairChallenge{ + PublicKey: publicKey, + Challenge: challenge, + }, + }, + }) + if err != nil { + return nil, trace.Wrap( + err, "forwarding challenge to client", + ) + } + // Get response from Client + req, err := srv.Recv() + if err != nil { + return nil, trace.Wrap( + err, "receiving challenge solution from client", + ) + } + challengeResponse := req.GetChallengeResponse() + if challengeResponse == nil { + return nil, trace.BadParameter( + "expected non-nil ChallengeResponse payload", + ) + } + + return challengeResponse, nil + }) + if err != nil { + return trace.Wrap(err) + } + + slog.DebugContext(srv.Context(), "registerUsingBoundKeypair(): challenge ceremony complete, sending cert bundle") + + // finally, send the certs on the response stream + return trace.Wrap(srv.Send(&proto.RegisterUsingBoundKeypairMethodResponse{ + Response: &proto.RegisterUsingBoundKeypairMethodResponse_Certs{ + Certs: &proto.RegisterUsingBoundKeypairCertificates{ + Certs: certs, + }, + }, + })) +} + // RegisterUsingToken allows nodes and proxies to join the cluster using // legacy join methods which do not yet have their own RPC. // On the Auth server, this method will call the auth.Server's diff --git a/lib/jwt/jwt.go b/lib/jwt/jwt.go index 8bf50ba3b8f5e..e85354fed9766 100644 --- a/lib/jwt/jwt.go +++ b/lib/jwt/jwt.go @@ -169,7 +169,7 @@ func (k *Key) getSigner(opts *jose.SignerOptions) (jose.Signer, error) { default: signer = cryptosigner.Opaque(k.config.PrivateKey) } - algorithm, err := joseAlgorithm(k.config.PrivateKey.Public()) + algorithm, err := AlgorithmForPublicKey(k.config.PrivateKey.Public()) if err != nil { return nil, trace.Wrap(err) } @@ -189,7 +189,8 @@ func (k *Key) getSigner(opts *jose.SignerOptions) (jose.Signer, error) { return sig, nil } -func joseAlgorithm(pub crypto.PublicKey) (jose.SignatureAlgorithm, error) { +// AlgorithmForPublicKey returns a jose algorithm for the given public key. +func AlgorithmForPublicKey(pub crypto.PublicKey) (jose.SignatureAlgorithm, error) { switch pub.(type) { case *rsa.PublicKey: return jose.RS256, nil diff --git a/lib/services/local/provisioning.go b/lib/services/local/provisioning.go index 4d9eeac954a41..9a41ae091a162 100644 --- a/lib/services/local/provisioning.go +++ b/lib/services/local/provisioning.go @@ -52,6 +52,62 @@ func (s *ProvisioningService) UpsertToken(ctx context.Context, p types.Provision return nil } +// PatchToken uses the supplied function to attempt to patch a token resource. +// Up to 3 update attempts will be made if the conditional update fails due to +// a revision comparison failure. +func (s *ProvisioningService) PatchToken( + ctx context.Context, + tokenName string, + updateFn func(types.ProvisionToken) (types.ProvisionToken, error), +) (types.ProvisionToken, error) { + const iterLimit = 3 + + for i := 0; i < iterLimit; i++ { + existing, err := s.GetToken(ctx, tokenName) + if err != nil { + return nil, trace.Wrap(err) + } + + // Note: CloneProvisionToken only supports ProvisionTokenV2. + clone, err := services.CloneProvisionToken(existing) + if err != nil { + return nil, trace.Wrap(err) + } + + updated, err := updateFn(clone) + if err != nil { + return nil, trace.Wrap(err) + } + + updatedMetadata := updated.GetMetadata() + existingMetadata := existing.GetMetadata() + + switch { + case updatedMetadata.GetName() != existingMetadata.GetName(): + return nil, trace.BadParameter("metadata.name: cannot be patched") + case updatedMetadata.GetRevision() != existingMetadata.GetRevision(): + return nil, trace.BadParameter("metadata.revision: cannot be patched") + } + + item, err := s.tokenToItem(updated) + if err != nil { + return nil, trace.Wrap(err) + } + + lease, err := s.ConditionalUpdate(ctx, *item) + if trace.IsCompareFailed(err) { + continue + } else if err != nil { + return nil, trace.Wrap(err) + } + + updated.SetRevision(lease.Revision) + return updated, nil + } + + return nil, trace.CompareFailed("failed to update provision token within %v iterations", iterLimit) +} + // CreateToken creates a new token for the auth server func (s *ProvisioningService) CreateToken(ctx context.Context, p types.ProvisionToken) error { item, err := s.tokenToItem(p) diff --git a/lib/services/provisioning.go b/lib/services/provisioning.go index 8e17507fec536..185787a6c63ec 100644 --- a/lib/services/provisioning.go +++ b/lib/services/provisioning.go @@ -25,6 +25,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/utils" ) @@ -48,6 +49,14 @@ type Provisioner interface { // GetTokens returns all non-expired tokens GetTokens(ctx context.Context) ([]types.ProvisionToken, error) + + // PatchToken performs a conditional update on the named token using + // `updateFn`, retrying internally if a comparison failure occurs. + PatchToken( + ctx context.Context, + token string, + updateFn func(types.ProvisionToken) (types.ProvisionToken, error), + ) (types.ProvisionToken, error) } // MustCreateProvisionToken returns a new valid provision token @@ -127,3 +136,16 @@ func MarshalProvisionToken(provisionToken types.ProvisionToken, opts ...MarshalO return nil, trace.BadParameter("unrecognized provision token version %T", provisionToken) } } + +// CloneProvisionToken returns a deep copy of the given provision token, per +// `apiutils.CloneProtoMsg()`. Fields in the clone may be modified without +// affecting the original. Only V2 is supported. +func CloneProvisionToken(provisionToken types.ProvisionToken) (types.ProvisionToken, error) { + switch provisionToken := provisionToken.(type) { + case *types.ProvisionTokenV2: + clone := apiutils.CloneProtoMsg(provisionToken) + return clone, nil + default: + return nil, trace.BadParameter("cannot clone unsupported provision token version %T", provisionToken) + } +} From 75f4529ed260d04431ff4b66b538b22a4547689a Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Thu, 1 May 2025 21:20:51 -0600 Subject: [PATCH 02/17] Refactor challenge response function, rebase on updated protos branch This includes a number of changes: - Rebases on the latest protos branch. This includes removal of the new keypair field on initial join, and adds messages for interactive keypair rotation. - Per the rebase, remaining_joins is removed in favor of using join_count for all calculations. The registration method and validatity checks have been updated to reference that instead. - Refactors challenge response function to allow for keypair rotation. We still don't implement rotation but the handler now receives the full proto message and produces a full proto response, so that we can easily handle the rotation case in the future. - Challenge validation checks time fields explicitly to ensure the client didn't tamper with them. - Added some missing docstrings --- api/client/joinservice.go | 31 ++++---- lib/auth/join_bound_keypair.go | 118 +++++++++++++++++------------- lib/boundkeypair/bound_keypair.go | 29 ++++++-- lib/joinserver/joinserver.go | 28 ++----- 4 files changed, 109 insertions(+), 97 deletions(-) diff --git a/api/client/joinservice.go b/api/client/joinservice.go index dace70202bfeb..277c32ff2a311 100644 --- a/api/client/joinservice.go +++ b/api/client/joinservice.go @@ -62,9 +62,9 @@ type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredenti type RegisterOracleChallengeResponseFunc func(challenge string) (*proto.OracleSignedRequest, error) // RegisterUsingBoundKeypairChallengeResponseFunc is a function to be passed to -// RegisterUsingBoundKeypair. It must return a message containing a signed -// response for the given challenge, or an error. -type RegisterUsingBoundKeypairChallengeResponseFunc func(publicKey string, challenge string) (*proto.RegisterUsingBoundKeypairChallengeResponse, error) +// RegisterUsingBoundKeypair. It must return a new follow-up request for the +// server response, or an error. +type RegisterUsingBoundKeypairChallengeResponseFunc func(challenge *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) // RegisterUsingIAMMethod registers the caller using the IAM join method and // returns signed certs to join the cluster. @@ -303,21 +303,8 @@ func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( } switch kind := res.GetResponse().(type) { - case *proto.RegisterUsingBoundKeypairMethodResponse_Challenge: - solution, err := challengeFunc(kind.Challenge.PublicKey, kind.Challenge.Challenge) - if err != nil { - return nil, trace.Wrap(err, "solving challenge") - } - - err = stream.Send(&proto.RegisterUsingBoundKeypairMethodRequest{ - Payload: &proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse{ - ChallengeResponse: solution, - }, - }) - if err != nil { - return nil, trace.Wrap(err, "sending solution") - } case *proto.RegisterUsingBoundKeypairMethodResponse_Certs: + // If we get certs, we're done, so just return the result. certs := kind.Certs.GetCerts() if certs == nil { return nil, trace.BadParameter("expected Certs, got %T", kind.Certs.Certs) @@ -332,7 +319,15 @@ func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( return certs, nil default: - return nil, trace.BadParameter("received unexpected challenge response: %v", res.GetResponse()) + // Forward all other responses to the challenge handler. + nextRequest, err := challengeFunc(res) + if err != nil { + return nil, trace.Wrap(err, "solving challenge") + } + + if err := stream.Send(nextRequest); err != nil { + return nil, trace.Wrap(err, "sending solution") + } } } diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index a2044d65338cd..42e3cf9e7f920 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -1,3 +1,4 @@ +// join_bound_keypair.go /* * Teleport * Copyright (C) 2025 Gravitational, Inc. @@ -132,25 +133,16 @@ func (a *Server) UpsertBoundKeypairToken(ctx context.Context, token types.Provis return trace.Wrap(err) } - // TODO: Populate initial_join_secret if needed, but only if no previous - // resource exists. + // TODO: Follow up with proper checking for a preexisting resource so + // generated fields are handled properly, i.e. initial secret generation. - // TODO: Probably won't want to tweak status here; that's best done during - // the join ceremony. - // if tokenV2.Status == nil { - // tokenV2.Status = &types.ProvisionTokenStatusV2{} - // } - - // if tokenV2.Status.BoundKeypair == nil { - // tokenV2.Status.BoundKeypair = a.initialBoundKeypairStatus(spec) - // } - - // TODO: Follow up changes to include: - // - Compare and swap / conditional updates - // - Proper checking for previous resource return trace.Wrap(a.UpsertToken(ctx, token)) } +// issueBoundKeypairChallenge creates a new challenge for the given marshalled +// public key in ssh authorized_keys format, requests a solution from the +// client using the given `challengeResponse` function, and validates the +// response. func (a *Server) issueBoundKeypairChallenge( ctx context.Context, marshalledKey string, @@ -175,7 +167,7 @@ func (a *Server) issueBoundKeypairChallenge( return trace.Wrap(err) } - a.logger.DebugContext(ctx, "Server.issueBoundKeypairChallenge(): preflight complete, issuing challenge", "pk", marshalledKey, "id", keyID) + a.logger.DebugContext(ctx, "issuing bound keypair challenge", "keyID", keyID) validator, err := boundkeypair.NewChallengeValidator(keyID, clusterName.GetClusterName(), key) if err != nil { @@ -187,28 +179,37 @@ func (a *Server) issueBoundKeypairChallenge( return trace.Wrap(err, "generating a challenge document") } - a.logger.DebugContext(ctx, "Server.issueBoundKeypairChallenge(): issued new challenge", "challenge", challenge) - marshalledChallenge, err := json.Marshal(challenge) if err != nil { return trace.Wrap(err) } - a.logger.InfoContext(ctx, "requesting signed bound keypair joining challenge") - - response, err := challengeResponse(marshalledKey, string(marshalledChallenge)) + response, err := challengeResponse(&proto.RegisterUsingBoundKeypairMethodResponse{ + Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{ + Challenge: &proto.RegisterUsingBoundKeypairChallenge{ + PublicKey: marshalledKey, + Challenge: string(marshalledChallenge), + }, + }, + }) if err != nil { return trace.Wrap(err, "requesting a signed challenge") } - a.logger.DebugContext(ctx, "Server.issueBoundKeypairChallenge(): challenge complete, verifying", "response", response) + solutionResponse, ok := response.Payload.(*proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse) + if !ok { + return trace.BadParameter("client provided unexpected challenge response type %T", response.Payload) + } - if err := validator.ValidateChallengeResponse(challenge.Nonce, string(response.Solution)); err != nil { + if err := validator.ValidateChallengeResponse( + challenge, + string(solutionResponse.ChallengeResponse.Solution), + ); err != nil { // TODO: access denied instead? return trace.Wrap(err, "validating challenge response") } - a.logger.InfoContext(ctx, "bound keypair challenge response verified successfully") + a.logger.InfoContext(ctx, "bound keypair challenge response verified successfully", "keyID", keyID) return nil } @@ -216,35 +217,45 @@ func (a *Server) issueBoundKeypairChallenge( // boundKeypairStatusMutator is a function called to mutate a bound keypair // status during a call to PatchProvisionToken(). These functions may be called // repeatedly if e.g. revision checks fail. To ensure invariants remain in -// place, mutator functions may make assertions -type boundKeypairStatusMutator func(*types.ProvisionTokenStatusV2BoundKeypair) error - -func mutateStatusConsumeJoin(unlimited bool, expectRemainingJoins uint32) boundKeypairStatusMutator { +// place, mutator functions may make assertions to ensure the provided backend +// state is still sane before the update is committed. +type boundKeypairStatusMutator func(*types.ProvisionTokenSpecV2BoundKeypair, *types.ProvisionTokenStatusV2BoundKeypair) error + +// mutateStatusConsumeJoin consumes a "hard" join on the backend, incrementing +// the join counter. This verifies that the backend join count has not changed, +// and that total join count is at least the value when the mutator was created. +func mutateStatusConsumeJoin(unlimited bool, expectJoinCount uint32, expectMinTotalJoins uint32) boundKeypairStatusMutator { now := time.Now() - return func(status *types.ProvisionTokenStatusV2BoundKeypair) error { + return func(spec *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error { // Ensure we have the expected number of rejoins left to prevent going // below zero. - // TODO: this could be >=? would avoid breaking if this happens to - // collide with a user incrementing TotalJoins. - if status.RemainingJoins != expectRemainingJoins { + if status.JoinCount != expectJoinCount { return trace.AccessDenied("unexpected backend state") } - status.JoinCount += 1 - status.LastJoinedAt = &now + // Ensure the allowed join count has at least not decreased, but allow + // for collision with potentially increased values. + if spec.Joining.TotalJoins < expectMinTotalJoins { + return trace.AccessDenied("unexpected backend state") + } if !unlimited { - // TODO: decrement remaining joins (not yet implemented.) return trace.NotImplemented("only unlimited rejoining is currently supported") } + status.JoinCount += 1 + status.LastJoinedAt = &now + return nil } } +// mutateStatusBoundPublicKey is a mutator that updates the bound public key +// value. It ensures the backend public key is still the expected value before +// performing the update. func mutateStatusBoundPublicKey(newPublicKey, expectPreviousKey string) boundKeypairStatusMutator { - return func(status *types.ProvisionTokenStatusV2BoundKeypair) error { + return func(_ *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error { if status.BoundPublicKey != expectPreviousKey { return trace.AccessDenied("unexpected backend state") } @@ -255,8 +266,11 @@ func mutateStatusBoundPublicKey(newPublicKey, expectPreviousKey string) boundKey } } +// mutateStatusBoundBotInstance updates the bot instance ID currently bound to +// this token. It ensures the expected previous ID is still the bound value +// before performing the update. func mutateStatusBoundBotInstance(newBotInstance, expectPreviousBotInstance string) boundKeypairStatusMutator { - return func(status *types.ProvisionTokenStatusV2BoundKeypair) error { + return func(_ *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error { if status.BoundBotInstanceID != expectPreviousBotInstance { return trace.AccessDenied("unexpected backend state") } @@ -267,13 +281,13 @@ func mutateStatusBoundBotInstance(newBotInstance, expectPreviousBotInstance stri } } +// RegisterUsingBoundKeypairMethod handles joining requests for the bound +// keypair join method. func (a *Server) RegisterUsingBoundKeypairMethod( ctx context.Context, req *proto.RegisterUsingBoundKeypairInitialRequest, challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, ) (_ *proto.Certs, err error) { - a.logger.DebugContext(ctx, "Server.RegisterUsingBoundKeypairMethod()", "req", req) - var provisionToken types.ProvisionToken var joinFailureMetadata any defer func() { @@ -321,6 +335,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( hasBoundPublicKey := status.BoundPublicKey != "" hasBoundBotInstance := status.BoundBotInstanceID != "" hasIncomingBotInstance := req.JoinRequest.BotInstanceID != "" + hasJoinsRemaining := status.JoinCount < spec.Joining.TotalJoins expectNewBotInstance := false // Mutators to use during the token resource status patch at the end. @@ -338,8 +353,8 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.BadParameter("an initial public key is required") } - if !spec.Joining.Unlimited && status.RemainingJoins == 0 { - return nil, trace.AccessDenied("no rejoins remaining") + if !spec.Joining.Unlimited && !hasJoinsRemaining { + return nil, trace.AccessDenied("no joins remaining") } if err := a.issueBoundKeypairChallenge( @@ -354,7 +369,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( mutators = append( mutators, mutateStatusBoundPublicKey(spec.Onboarding.InitialPublicKey, ""), - mutateStatusConsumeJoin(spec.Joining.Unlimited, status.RemainingJoins), + mutateStatusConsumeJoin(spec.Joining.Unlimited, status.JoinCount, spec.Joining.TotalJoins), ) expectNewBotInstance = true @@ -385,7 +400,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( case hasBoundPublicKey && hasBoundBotInstance && !hasIncomingBotInstance: // Hard rejoin case, the client identity expired and a new bot instance // is required. Consumes a rejoin. - if !spec.Joining.Unlimited && status.RemainingJoins == 0 { + if !spec.Joining.Unlimited && !hasJoinsRemaining { return nil, trace.AccessDenied("no rejoins remaining") } @@ -399,11 +414,9 @@ func (a *Server) RegisterUsingBoundKeypairMethod( mutators = append( mutators, - mutateStatusConsumeJoin(spec.Joining.Unlimited, status.RemainingJoins), + mutateStatusConsumeJoin(spec.Joining.Unlimited, status.JoinCount, spec.Joining.TotalJoins), ) - // TODO: decrement remaining joins - expectNewBotInstance = true default: a.logger.ErrorContext( @@ -417,13 +430,18 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.BadParameter("unexpected state") } - if req.NewPublicKey != "" { - // TODO - return nil, trace.NotImplemented("key rotation not yet implemented") + if spec.RotateOnNextRenewal { + // TODO, to be implemented in a future PR + return nil, trace.NotImplemented("key rotation not yet supported") } a.logger.DebugContext(ctx, "Server.RegisterUsingBoundKeypairMethod(): challenge verified, issuing certs") + // TODO: We should pass along the previous bot instance ID - if any - based + // on the join state, once that is implemented. It will need to be passed + // either via extended claims, or by a new protected field in the join + // request like the current bot instance ID, i.e. cleared when set by an + // untrusted source. certs, botInstanceID, err := a.generateCertsBot( ctx, ptv2, @@ -449,7 +467,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( // Apply all mutators. Individual mutators may make additional // assertions to ensure invariants haven't changed. for _, mutator := range mutators { - if err := mutator(ptv2.Status.BoundKeypair); err != nil { + if err := mutator(ptv2.Spec.BoundKeypair, ptv2.Status.BoundKeypair); err != nil { return nil, trace.Wrap(err, "applying status mutator") } } diff --git a/lib/boundkeypair/bound_keypair.go b/lib/boundkeypair/bound_keypair.go index bffc4a7116ded..6f3317163597e 100644 --- a/lib/boundkeypair/bound_keypair.go +++ b/lib/boundkeypair/bound_keypair.go @@ -34,6 +34,9 @@ const ( challengeExpiration time.Duration = time.Minute ) +// ChallengeDocument is a bound keypair challenge document. These documents are +// sent in JSON form to clients attempting to authenticate, and are expected to +// be sent back signed with a known public key. type ChallengeDocument struct { *jwt.Claims @@ -41,6 +44,8 @@ type ChallengeDocument struct { Nonce string `json:"nonce"` } +// ChallengeValidator is used to issue and validate bound keypair challenges for +// a given public key. type ChallengeValidator struct { clock clockwork.Clock @@ -86,7 +91,10 @@ func (v *ChallengeValidator) IssueChallenge() (*ChallengeDocument, error) { }, nil } -func (v *ChallengeValidator) ValidateChallengeResponse(nonce string, compactResponse string) error { +// ValidateChallengeResponse validates a signed challenge document, ensuring the +// signature matches the requested public key, and that the claims pass JWT +// validation. +func (v *ChallengeValidator) ValidateChallengeResponse(issued *ChallengeDocument, compactResponse string) error { token, err := jwt.ParseSigned(compactResponse) if err != nil { return trace.Wrap(err, "parsing signed response") @@ -97,10 +105,7 @@ func (v *ChallengeValidator) ValidateChallengeResponse(nonce string, compactResp return trace.Wrap(err) } - // TODO: this doesn't actually validate that the time-based fields are still - // what we assigned above; a hostile client could set their own values here. - // This may not be a realistic problem, but we might want to check it - // anyway. + // Validate the challenge document claims per JWT rules. const leeway time.Duration = time.Minute if err := document.Claims.ValidateWithLeeway(jwt.Expected{ Issuer: v.clusterName, @@ -111,7 +116,19 @@ func (v *ChallengeValidator) ValidateChallengeResponse(nonce string, compactResp return trace.Wrap(err, "validating challenge claims") } - if subtle.ConstantTimeCompare([]byte(nonce), []byte(document.Nonce)) == 0 { + // JWT validation won't check equality on the time fields, so we'll manually + // check them. + if issued.Claims.IssuedAt != document.Claims.IssuedAt { + return trace.AccessDenied("invalid challenge document") + } + if issued.Claims.Expiry != document.Claims.Expiry { + return trace.AccessDenied("invalid challenge document") + } + if issued.Claims.NotBefore != document.Claims.NotBefore { + return trace.AccessDenied("invalid challenge document") + } + + if subtle.ConstantTimeCompare([]byte(issued.Nonce), []byte(document.Nonce)) == 0 { return trace.AccessDenied("invalid nonce") } diff --git a/lib/joinserver/joinserver.go b/lib/joinserver/joinserver.go index ab53c71469712..ea57b746b21fa 100644 --- a/lib/joinserver/joinserver.go +++ b/lib/joinserver/joinserver.go @@ -380,8 +380,6 @@ func (s *JoinServiceGRPCServer) registerUsingOracleMethod(srv proto.JoinService_ func (s *JoinServiceGRPCServer) RegisterUsingBoundKeypairMethod( srv proto.JoinService_RegisterUsingBoundKeypairMethodServer, ) error { - slog.DebugContext(srv.Context(), "RegisterUsingBoundKeypairMethod()") - return trace.Wrap(s.handleStreamingRegistration(srv.Context(), types.JoinMethodBoundKeypair, func() error { return trace.Wrap(s.registerUsingBoundKeypair(srv)) })) @@ -390,8 +388,6 @@ func (s *JoinServiceGRPCServer) RegisterUsingBoundKeypairMethod( func (s *JoinServiceGRPCServer) registerUsingBoundKeypair(srv proto.JoinService_RegisterUsingBoundKeypairMethodServer) error { ctx := srv.Context() - slog.DebugContext(srv.Context(), "registerUsingBoundKeypair()") - // Get initial payload from the client req, err := srv.Recv() if err != nil { @@ -416,25 +412,17 @@ func (s *JoinServiceGRPCServer) registerUsingBoundKeypair(srv proto.JoinService_ return trace.Wrap(err, "setting client address") } - slog.DebugContext(srv.Context(), "registerUsingBoundKeypair(): preflight complete, attempting to relay challenges") - setBotParameters(ctx, initReq.JoinRequest) - certs, err := s.joinServiceClient.RegisterUsingBoundKeypairMethod(ctx, initReq, func(publicKey string, challenge string) (*proto.RegisterUsingBoundKeypairChallengeResponse, error) { + certs, err := s.joinServiceClient.RegisterUsingBoundKeypairMethod(ctx, initReq, func(resp *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) { // First, forward the challenge from Auth to the client. - err := srv.Send(&proto.RegisterUsingBoundKeypairMethodResponse{ - Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{ - Challenge: &proto.RegisterUsingBoundKeypairChallenge{ - PublicKey: publicKey, - Challenge: challenge, - }, - }, - }) + err := srv.Send(resp) if err != nil { return nil, trace.Wrap( err, "forwarding challenge to client", ) } + // Get response from Client req, err := srv.Recv() if err != nil { @@ -442,20 +430,14 @@ func (s *JoinServiceGRPCServer) registerUsingBoundKeypair(srv proto.JoinService_ err, "receiving challenge solution from client", ) } - challengeResponse := req.GetChallengeResponse() - if challengeResponse == nil { - return nil, trace.BadParameter( - "expected non-nil ChallengeResponse payload", - ) - } - return challengeResponse, nil + return req, nil }) if err != nil { return trace.Wrap(err) } - slog.DebugContext(srv.Context(), "registerUsingBoundKeypair(): challenge ceremony complete, sending cert bundle") + slog.DebugContext(srv.Context(), "challenge ceremony complete, sending cert bundle") // finally, send the certs on the response stream return trace.Wrap(srv.Send(&proto.RegisterUsingBoundKeypairMethodResponse{ From 96cb98e4407456e83f8748fb106795864f981bf3 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Thu, 1 May 2025 21:21:58 -0600 Subject: [PATCH 03/17] Add joinserver test --- lib/joinserver/joinserver_test.go | 152 ++++++++++++++++++++++++++++-- 1 file changed, 144 insertions(+), 8 deletions(-) diff --git a/lib/joinserver/joinserver_test.go b/lib/joinserver/joinserver_test.go index f9bd3e6d8e7dc..2462d6d1fd36a 100644 --- a/lib/joinserver/joinserver_test.go +++ b/lib/joinserver/joinserver_test.go @@ -42,14 +42,17 @@ import ( ) type mockJoinServiceClient struct { - sendChallenge string - returnCerts *proto.Certs - returnError error - gotIAMChallengeResponse *proto.RegisterUsingIAMMethodRequest - gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest - gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse - gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest - gotRegisterUsingTokenReq *types.RegisterUsingTokenRequest + sendChallenge string + boundKeypairPublicKey string + returnCerts *proto.Certs + returnError error + gotIAMChallengeResponse *proto.RegisterUsingIAMMethodRequest + gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest + gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse + gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest + gotBoundKeypairInitReq *proto.RegisterUsingBoundKeypairInitialRequest + gotBoundKeypairChallengeResponse *proto.RegisterUsingBoundKeypairMethodRequest + gotRegisterUsingTokenReq *types.RegisterUsingTokenRequest } func (c *mockJoinServiceClient) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) { @@ -86,6 +89,29 @@ func (c *mockJoinServiceClient) RegisterUsingTPMMethod( return c.returnCerts, c.returnError } +func (c *mockJoinServiceClient) RegisterUsingBoundKeypairMethod( + ctx context.Context, + req *proto.RegisterUsingBoundKeypairInitialRequest, + challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, +) (*proto.Certs, error) { + c.gotBoundKeypairInitReq = req + resp, err := challengeResponse(&proto.RegisterUsingBoundKeypairMethodResponse{ + Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{ + Challenge: &proto.RegisterUsingBoundKeypairChallenge{ + PublicKey: c.boundKeypairPublicKey, + Challenge: c.sendChallenge, + }, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + c.gotBoundKeypairChallengeResponse = resp + + return c.returnCerts, c.returnError +} + func (c *mockJoinServiceClient) RegisterUsingOracleMethod( ctx context.Context, tokenReq *types.RegisterUsingTokenRequest, @@ -540,6 +566,116 @@ func TestJoinServiceGRPCServer_RegisterUsingTPMMethod(t *testing.T) { } } +// TestJoinServiceGRPCServer_RegisterUsingBoundKeypairMethodSimple tests the +// simplest bound keypair joining path, with no keypair registration or +// rotation. +func TestJoinServiceGRPCServer_RegisterUsingBoundKeypairMethodSimple(t *testing.T) { + t.Parallel() + testPack := newTestPack(t) + + standardResponse := &proto.RegisterUsingBoundKeypairMethodRequest{ + Payload: &proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse{ + ChallengeResponse: &proto.RegisterUsingBoundKeypairChallengeResponse{ + Solution: []byte("header.payload.signature"), + }, + }, + } + + testCases := []struct { + desc string + publicKey string + challenge string + req *proto.RegisterUsingBoundKeypairInitialRequest + challengeResponse *proto.RegisterUsingBoundKeypairMethodRequest + challengeResponseErr error + authErr error + certs *proto.Certs + }{ + { + desc: "success case", + challenge: "foo", + req: &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{}, + }, + challengeResponse: standardResponse, + certs: &proto.Certs{SSH: []byte("qux")}, + }, + { + desc: "auth error", + challenge: "foo", + req: &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{}, + }, + challengeResponse: standardResponse, + authErr: trace.AccessDenied("not allowed"), + }, + { + desc: "challenge response error", + challenge: "foo", + req: &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{}, + }, + challengeResponse: nil, + challengeResponseErr: trace.BadParameter("testing error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + testPack.mockAuthServer.sendChallenge = tc.challenge + testPack.mockAuthServer.returnCerts = tc.certs + testPack.mockAuthServer.returnError = tc.authErr + + challengeResponder := func( + challenge *proto.RegisterUsingBoundKeypairMethodResponse, + ) (*proto.RegisterUsingBoundKeypairMethodRequest, error) { + assert.Equal(t, &proto.RegisterUsingBoundKeypairMethodResponse{ + Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{ + Challenge: &proto.RegisterUsingBoundKeypairChallenge{ + PublicKey: tc.publicKey, + Challenge: tc.challenge, + }, + }, + }, challenge) + + return tc.challengeResponse, tc.challengeResponseErr + } + + for suffix, clt := range map[string]*client.JoinServiceClient{ + "_auth": testPack.authClient, + "_proxy": testPack.proxyClient, + } { + t.Run(tc.desc+suffix, func(t *testing.T) { + certs, err := clt.RegisterUsingBoundKeypairMethod( + context.Background(), tc.req, challengeResponder, + ) + if tc.challengeResponseErr != nil { + require.ErrorIs(t, err, tc.challengeResponseErr) + return + } + if tc.authErr != nil { + require.Error(t, err) + require.Contains(t, err.Error(), tc.authErr.Error()) + return + } + require.NoError(t, err) + require.Equal(t, tc.certs, certs) + + expectedInitReq := tc.req + expectedInitReq.JoinRequest.RemoteAddr = "bufconn" + assert.Equal(t, expectedInitReq, testPack.mockAuthServer.gotBoundKeypairInitReq) + + assert.Equal( + t, + tc.challengeResponse, + testPack.mockAuthServer.gotBoundKeypairChallengeResponse, + ) + }) + } + }) + } +} + func TestTimeout(t *testing.T) { t.Parallel() From a128fe62d00181d34c79261cd7aa0375d6f52f07 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Thu, 1 May 2025 21:27:58 -0600 Subject: [PATCH 04/17] Fix lint error and add docstring --- api/client/joinservice.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/client/joinservice.go b/api/client/joinservice.go index 277c32ff2a311..3452c27480b2a 100644 --- a/api/client/joinservice.go +++ b/api/client/joinservice.go @@ -18,6 +18,7 @@ package client import ( "context" + "errors" "io" "github.com/gravitational/trace" @@ -268,6 +269,10 @@ func (c *JoinServiceClient) RegisterUsingOracleMethod( return certs, nil } +// RegisterUsingBoundKeypairMethod attempts to register the caller using +// bound-keypair join method. If successful, a certificate bundle is returned, +// or an error. Clients must provide a callback to handle interactive challenges +// and keypair rotation requests. func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( ctx context.Context, initReq *proto.RegisterUsingBoundKeypairInitialRequest, @@ -296,7 +301,7 @@ func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( // here instead to ensure we handle everything. for { res, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { break } else if err != nil { return nil, trace.Wrap(err, "receiving intermediate bound keypair join response") From 645fab9f68412547b8be4ba575e3384d9539ed41 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Thu, 1 May 2025 22:30:03 -0600 Subject: [PATCH 05/17] Add tests for bound keypair challenge validation --- lib/auth/join_bound_keypair.go | 2 +- lib/boundkeypair/bound_keypair.go | 23 ++- lib/boundkeypair/bound_keypair_test.go | 265 +++++++++++++++++++++++++ 3 files changed, 284 insertions(+), 6 deletions(-) create mode 100644 lib/boundkeypair/bound_keypair_test.go diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 42e3cf9e7f920..fafdd007dcb31 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -205,7 +205,7 @@ func (a *Server) issueBoundKeypairChallenge( challenge, string(solutionResponse.ChallengeResponse.Solution), ); err != nil { - // TODO: access denied instead? + // TODO: Consider access denied instead? return trace.Wrap(err, "validating challenge response") } diff --git a/lib/boundkeypair/bound_keypair.go b/lib/boundkeypair/bound_keypair.go index 6f3317163597e..21d5f5c25c8f3 100644 --- a/lib/boundkeypair/bound_keypair.go +++ b/lib/boundkeypair/bound_keypair.go @@ -31,7 +31,13 @@ import ( ) const ( + // challengeExpiration is the TTL for a bound keypair challenge document + // once generated by the server, used to calculate the JWT `exp` field. challengeExpiration time.Duration = time.Minute + + // challengeNotBeforeOffset is the offset applied to the `nbf` field to + // allow for a small amount of clock drift. + challengeNotBeforeOffset = -10 * time.Second ) // ChallengeDocument is a bound keypair challenge document. These documents are @@ -54,6 +60,9 @@ type ChallengeValidator struct { publicKey crypto.PublicKey } +// NewChallengeValidator creates a new challenge validation helper using the +// given subject, cluster name, and public key. Subjects are arbitrary but a +// public key ID is recommended. func NewChallengeValidator( subject string, clusterName string, @@ -64,7 +73,11 @@ func NewChallengeValidator( subject: subject, clusterName: clusterName, - publicKey: publicKey, // TODO: API design issue, public key will change during rotation. should a new validator be created, or can we design this better? + + // TODO: API design issue to consider when implementing rotation: the + // public key will change during rotation. Should a new validator be + // created, or can we design this better? + publicKey: publicKey, }, nil } @@ -82,7 +95,7 @@ func (v *ChallengeValidator) IssueChallenge() (*ChallengeDocument, error) { Claims: &jwt.Claims{ Issuer: v.clusterName, Audience: jwt.Audience{v.clusterName}, // the cluster is both the issuer and audience - NotBefore: jwt.NewNumericDate(v.clock.Now().Add(-10 * time.Second)), + NotBefore: jwt.NewNumericDate(v.clock.Now().Add(challengeNotBeforeOffset)), IssuedAt: jwt.NewNumericDate(v.clock.Now()), Expiry: jwt.NewNumericDate(v.clock.Now().Add(challengeExpiration)), Subject: v.subject, @@ -118,13 +131,13 @@ func (v *ChallengeValidator) ValidateChallengeResponse(issued *ChallengeDocument // JWT validation won't check equality on the time fields, so we'll manually // check them. - if issued.Claims.IssuedAt != document.Claims.IssuedAt { + if !issued.Claims.IssuedAt.Time().Equal(document.Claims.IssuedAt.Time()) { return trace.AccessDenied("invalid challenge document") } - if issued.Claims.Expiry != document.Claims.Expiry { + if !issued.Claims.Expiry.Time().Equal(document.Claims.Expiry.Time()) { return trace.AccessDenied("invalid challenge document") } - if issued.Claims.NotBefore != document.Claims.NotBefore { + if !issued.Claims.NotBefore.Time().Equal(document.Claims.NotBefore.Time()) { return trace.AccessDenied("invalid challenge document") } diff --git a/lib/boundkeypair/bound_keypair_test.go b/lib/boundkeypair/bound_keypair_test.go new file mode 100644 index 0000000000000..81eea79d2e18f --- /dev/null +++ b/lib/boundkeypair/bound_keypair_test.go @@ -0,0 +1,265 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package boundkeypair + +import ( + "context" + "crypto" + "encoding/json" + "testing" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cryptosuites" + libjwt "github.com/gravitational/teleport/lib/jwt" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +func newTestKeypair(t *testing.T) crypto.Signer { + key, err := cryptosuites.GenerateKey(context.Background(), func(ctx context.Context) (types.SignatureAlgorithmSuite, error) { + return types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_BALANCED_V1, nil + }, cryptosuites.BoundKeypairJoining) + require.NoError(t, err) + + return key +} + +func TestChallengeValidator_IssueChallenge(t *testing.T) { + t.Parallel() + + key := newTestKeypair(t) + + clock := clockwork.NewFakeClock() + now := clock.Now() + + const clusterName = "example.teleport.sh" + validator, err := NewChallengeValidator("subject", clusterName, key.Public()) + require.NoError(t, err) + + validator.clock = clock + + challenge, err := validator.IssueChallenge() + require.NoError(t, err) + + require.NotEmpty(t, challenge.Nonce) + require.Equal(t, clusterName, challenge.Issuer) + require.Equal(t, jwt.Audience{clusterName}, challenge.Audience) + require.Equal(t, "subject", challenge.Subject) + require.Equal(t, jwt.NewNumericDate(now), challenge.IssuedAt) + require.Equal(t, jwt.NewNumericDate(now.Add(challengeNotBeforeOffset)), challenge.NotBefore) + require.Equal(t, jwt.NewNumericDate(now.Add(challengeExpiration)), challenge.Expiry) + + newChallenge, err := validator.IssueChallenge() + require.NoError(t, err) + + require.NotEqual(t, challenge.Nonce, newChallenge.Nonce, "nonces must be random") +} + +func signChallenge(t *testing.T, challenge string, signer crypto.Signer) string { + alg, err := libjwt.AlgorithmForPublicKey(signer.Public()) + require.NoError(t, err) + + opts := (&jose.SignerOptions{}).WithType("JWT") + key := jose.SigningKey{ + Algorithm: alg, + Key: signer, + } + + joseSigner, err := jose.NewSigner(key, opts) + require.NoError(t, err) + + jws, err := joseSigner.Sign([]byte(challenge)) + require.NoError(t, err) + + serialized, err := jws.CompactSerialize() + require.NoError(t, err) + + return serialized +} + +func jsonClone(t *testing.T, c *ChallengeDocument) *ChallengeDocument { + bytes, err := json.Marshal(c) + require.NoError(t, err) + + var cloned ChallengeDocument + require.NoError(t, json.Unmarshal(bytes, &cloned)) + + return &cloned +} + +func TestChallengeValidator_ValidateChallengeResponse(t *testing.T) { + t.Parallel() + + correctKey := newTestKeypair(t) + incorrectKey := newTestKeypair(t) + + const clusterName = "example.teleport.sh" + + tests := []struct { + name string + key crypto.Signer + assert require.ErrorAssertionFunc + clockFn func(clock *clockwork.FakeClock) + manipulateFn func(doc *ChallengeDocument, now time.Time) + }{ + { + name: "success", + key: correctKey, + assert: require.NoError, + }, + { + name: "wrong key", + key: incorrectKey, + assert: require.Error, + }, + { + name: "waited too long", + key: correctKey, + clockFn: func(clock *clockwork.FakeClock) { + clock.Advance(challengeExpiration * 10) + }, + assert: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "token is expired") + }, + }, + { + name: "too early", + key: correctKey, + clockFn: func(clock *clockwork.FakeClock) { + clock.Advance(challengeNotBeforeOffset * 10) + }, + assert: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "token not valid yet") + }, + }, + { + name: "tampered with iat", + key: correctKey, + manipulateFn: func(doc *ChallengeDocument, now time.Time) { + doc.IssuedAt = jwt.NewNumericDate(now.Add(time.Minute)) + }, + assert: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "invalid challenge document") + }, + }, + { + name: "tampered with exp", + key: correctKey, + manipulateFn: func(doc *ChallengeDocument, now time.Time) { + doc.Expiry = jwt.NewNumericDate(now.Add(time.Hour)) + }, + assert: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "invalid challenge document") + }, + }, + { + name: "tampered with nbf", + key: correctKey, + manipulateFn: func(doc *ChallengeDocument, now time.Time) { + doc.NotBefore = jwt.NewNumericDate(now.Add(time.Minute)) + }, + assert: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "invalid challenge document") + }, + }, + { + name: "tampered with nonce", + key: correctKey, + manipulateFn: func(doc *ChallengeDocument, now time.Time) { + doc.Nonce = "abcd" + }, + assert: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "invalid nonce") + }, + }, + { + name: "tampered with subject", + key: correctKey, + manipulateFn: func(doc *ChallengeDocument, now time.Time) { + doc.Subject = "abcd" + }, + assert: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "invalid subject claim") + }, + }, + { + name: "tampered with issuer", + key: correctKey, + manipulateFn: func(doc *ChallengeDocument, now time.Time) { + doc.Issuer = "abcd" + }, + assert: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "invalid issuer claim") + }, + }, + { + name: "tampered with audience", + key: correctKey, + manipulateFn: func(doc *ChallengeDocument, now time.Time) { + doc.Audience = jwt.Audience{"abcd"} + }, + assert: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "invalid audience claim") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clock := clockwork.NewFakeClock() + + validator, err := NewChallengeValidator("subject", clusterName, correctKey.Public()) + require.NoError(t, err) + + validator.clock = clock + + challenge, err := validator.IssueChallenge() + require.NoError(t, err) + + cloned := jsonClone(t, challenge) + if tt.manipulateFn != nil { + tt.manipulateFn(cloned, clock.Now()) + } + + challengeString, err := json.Marshal(cloned) + require.NoError(t, err) + + signed := signChallenge(t, string(challengeString), tt.key) + + if tt.clockFn != nil { + tt.clockFn(clock) + } + + err = validator.ValidateChallengeResponse(challenge, signed) + tt.assert(t, err) + }) + } +} From 1073c5be5083bb56a5603dcf459bdc042272d414 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Fri, 2 May 2025 18:06:45 -0600 Subject: [PATCH 06/17] Remove client side package intended for other PR --- lib/auth/join/boundkeypair/boundkeypair.go | 213 --------------------- 1 file changed, 213 deletions(-) delete mode 100644 lib/auth/join/boundkeypair/boundkeypair.go diff --git a/lib/auth/join/boundkeypair/boundkeypair.go b/lib/auth/join/boundkeypair/boundkeypair.go deleted file mode 100644 index 545c1dbef8f0e..0000000000000 --- a/lib/auth/join/boundkeypair/boundkeypair.go +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Teleport - * Copyright (C) 2025 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package boundkeypair - -import ( - "context" - "os" - "path/filepath" - - "github.com/gravitational/teleport/api/utils/keys" - "github.com/gravitational/teleport/lib/auth/join" - "github.com/gravitational/teleport/lib/cryptosuites" - "github.com/gravitational/trace" - "golang.org/x/crypto/ssh" -) - -const ( - PrivateKeyPath = "id_bkp" - PublicKeyPath = PrivateKeyPath + ".pub" - JoinStatePath = "bkp_state" - - StandardFileWriteMode = 0600 -) - -// ClientState contains state parameters stored on disk needed to complete the -// bound keypair join process. -type ClientState struct { - // PrivateKey is the parsed private key. - PrivateKey *keys.PrivateKey - - // PrivateKeyBytes contains the private key bytes. This value should always - // be nonempty. - PrivateKeyBytes []byte - - // PublicKeyBytes contains the public key bytes. This value is not used at - // runtime, and is only set when a public key should be written to disk, - // like on first creation or during rotation. To consistently access the - // public key, use `.PrivateKey.Public()`. - PublicKeyBytes []byte - - // JoinStateBytes contains join state bytes. This value will be empty if - // this client has not yet joined. - JoinStateBytes []byte -} - -// ToJoinParams creates joining parameters for use with `join.Register()` from -// this client state. -func (c *ClientState) ToJoinParams(initialJoinSecret string) *join.BoundKeypairParams { - if len(c.JoinStateBytes) > 0 { - // This identity has been bound, so don't pass along the join secret (if - // any) - initialJoinSecret = "" - } - - return &join.BoundKeypairParams{ - // Note: pass the internal signer because go-jose does type assertions - // on the standard library types. - CurrentKey: c.PrivateKey.Signer, - PreviousJoinState: c.JoinStateBytes, - InitialJoinSecret: initialJoinSecret, - } -} - -// ToPublicKeyBytes returns the public key bytes in ssh authorized_keys format. -func (c *ClientState) ToPublicKeyBytes() ([]byte, error) { - sshPubKey, err := ssh.NewPublicKey(c.PrivateKey.Public()) - if err != nil { - return nil, trace.Wrap(err, "creating ssh public key") - } - - return ssh.MarshalAuthorizedKey(sshPubKey), nil -} - -type FS interface { - Read(ctx context.Context, name string) ([]byte, error) - Write(ctx context.Context, name string, data []byte) error -} - -type StandardFS struct { - parentDir string -} - -func (f *StandardFS) Read(ctx context.Context, name string) ([]byte, error) { - data, err := os.ReadFile(name) - if err != nil { - return nil, trace.Wrap(err) - } - - return data, nil -} - -func (f *StandardFS) Write(ctx context.Context, name string, data []byte) error { - path := filepath.Join(f.parentDir, name) - - return trace.Wrap(os.WriteFile(path, data, StandardFileWriteMode)) -} - -// NewStandardFS creates a new standard FS implementation. -func NewStandardFS(parentDir string) FS { - return &StandardFS{ - parentDir: parentDir, - } -} - -// LoadClientState attempts to load bound keypair client state from the given -// filesystem implementation. Callers should expect to handle NotFound errors -// returned here if a private key is not found; this indicates no prior client -// state exists and initial secret joining should be attempted if possible. If -// a keypair has been pregenerated, no prior join state will exist, and the -// join state will be empty; any corresponding errors while reading nonexistent -// join state documents will be ignored. -func LoadClientState(ctx context.Context, fs FS) (*ClientState, error) { - privateKeyBytes, err := fs.Read(ctx, PrivateKeyPath) - if err != nil { - return nil, trace.Wrap(err, "reading private key") - } - - joinStateBytes, err := fs.Read(ctx, JoinStatePath) - if trace.IsNotFound(err) { - // Join state doesn't exist, this is allowed. - } else if err != nil { - return nil, trace.Wrap(err, "reading previous join state") - } - - pk, err := keys.ParsePrivateKey(privateKeyBytes) - if err != nil { - return nil, trace.Wrap(err, "parsing private key") - } - - return &ClientState{ - PrivateKey: pk, - - PrivateKeyBytes: privateKeyBytes, - JoinStateBytes: joinStateBytes, - }, nil -} - -// StoreClientState writes bound keypair client state to the given filesystem -// wrapper. Public keys and join state will only be written if -func StoreClientState(ctx context.Context, fs FS, state *ClientState) error { - if err := fs.Write(ctx, PrivateKeyPath, state.PrivateKeyBytes); err != nil { - return trace.Wrap(err, "writing private key") - } - - // TODO: maybe consider just not writing the public key at all. End users - // aren't really meant to look in the internal storage, and we can just - // derive the public key whenever we want. - - // Only write the public key if it was explicitly provided. This helps save - // an unnecessary file write. - if len(state.PublicKeyBytes) > 0 { - if err := fs.Write(ctx, PublicKeyPath, state.PublicKeyBytes); err != nil { - return trace.Wrap(err, "writing public key") - } - } - - if len(state.JoinStateBytes) > 0 { - if err := fs.Write(ctx, JoinStatePath, state.JoinStateBytes); err != nil { - return trace.Wrap(err, "writing previous join state") - } - } - - return nil -} - -// NewUnboundClientState creates a new client state that has not yet been bound, -// i.e. a new keypair that has not been registered with Auth, and no prior join -// state. -func NewUnboundClientState(ctx context.Context, getSuite cryptosuites.GetSuiteFunc) (*ClientState, error) { - key, err := cryptosuites.GenerateKey(ctx, getSuite, cryptosuites.BoundKeypairJoining) - if err != nil { - return nil, trace.Wrap(err, "generating keypair") - } - - privateKeyBytes, err := keys.MarshalPrivateKey(key) - if err != nil { - return nil, trace.Wrap(err, "marshalling private key") - } - - sshPubKey, err := ssh.NewPublicKey(key.Public()) - if err != nil { - return nil, trace.Wrap(err, "creating ssh public key") - } - - publicKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey) - - pk, err := keys.NewPrivateKey(key, privateKeyBytes) - if err != nil { - return nil, trace.Wrap(err) - } - - return &ClientState{ - PrivateKeyBytes: privateKeyBytes, - PublicKeyBytes: publicKeyBytes, - PrivateKey: pk, - }, nil -} From 9bfd290358527ed2fac8da5121750a42a727dd52 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Fri, 2 May 2025 18:20:14 -0600 Subject: [PATCH 07/17] Fix various lints --- lib/auth/join_bound_keypair.go | 27 +++++++------------ lib/auth/join_bound_keypair_test.go | 37 ++++++++++++++++++++++++++ lib/boundkeypair/bound_keypair.go | 5 ++-- lib/boundkeypair/bound_keypair_test.go | 5 ++-- 4 files changed, 52 insertions(+), 22 deletions(-) create mode 100644 lib/auth/join_bound_keypair_test.go diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index fafdd007dcb31..60d3aa0588aae 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -1,4 +1,3 @@ -// join_bound_keypair.go /* * Teleport * Copyright (C) 2025 Gravitational, Inc. @@ -24,6 +23,8 @@ import ( "encoding/json" "time" + "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" @@ -31,7 +32,6 @@ import ( "github.com/gravitational/teleport/lib/boundkeypair/experiment" "github.com/gravitational/teleport/lib/jwt" libsshutils "github.com/gravitational/teleport/lib/sshutils" - "github.com/gravitational/trace" ) // validateBoundKeypairTokenSpec performs some basic validation checks on a @@ -67,13 +67,6 @@ func validateBoundKeypairTokenSpec(spec *types.ProvisionTokenSpecV2BoundKeypair) return nil } -func (a *Server) initialBoundKeypairStatus(spec *types.ProvisionTokenSpecV2BoundKeypair) *types.ProvisionTokenStatusV2BoundKeypair { - return &types.ProvisionTokenStatusV2BoundKeypair{ - InitialJoinSecret: spec.Onboarding.InitialJoinSecret, - BoundPublicKey: spec.Onboarding.InitialPublicKey, - } -} - func (a *Server) CreateBoundKeypairToken(ctx context.Context, token types.ProvisionToken) error { if token.GetJoinMethod() != types.JoinMethodBoundKeypair { return trace.BadParameter("must be called with a bound keypair token") @@ -139,7 +132,7 @@ func (a *Server) UpsertBoundKeypairToken(ctx context.Context, token types.Provis return trace.Wrap(a.UpsertToken(ctx, token)) } -// issueBoundKeypairChallenge creates a new challenge for the given marshalled +// issueBoundKeypairChallenge creates a new challenge for the given marshaled // public key in ssh authorized_keys format, requests a solution from the // client using the given `challengeResponse` function, and validates the // response. @@ -167,7 +160,7 @@ func (a *Server) issueBoundKeypairChallenge( return trace.Wrap(err) } - a.logger.DebugContext(ctx, "issuing bound keypair challenge", "keyID", keyID) + a.logger.DebugContext(ctx, "issuing bound keypair challenge", "key_id", keyID) validator, err := boundkeypair.NewChallengeValidator(keyID, clusterName.GetClusterName(), key) if err != nil { @@ -209,7 +202,7 @@ func (a *Server) issueBoundKeypairChallenge( return trace.Wrap(err, "validating challenge response") } - a.logger.InfoContext(ctx, "bound keypair challenge response verified successfully", "keyID", keyID) + a.logger.InfoContext(ctx, "bound keypair challenge response verified successfully", "key_id", keyID) return nil } @@ -421,9 +414,9 @@ func (a *Server) RegisterUsingBoundKeypairMethod( default: a.logger.ErrorContext( ctx, "unexpected state", - "hasBoundPublicKey", hasBoundPublicKey, - "hasBoundBotInstance", hasBoundBotInstance, - "hasIncomingBotInstance", hasIncomingBotInstance, + "has_bound_public_key", hasBoundPublicKey, + "has_bound_bot_instance", hasBoundBotInstance, + "has_incoming_bot_instance", hasIncomingBotInstance, "spec", spec, "status", status, ) @@ -435,8 +428,6 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.NotImplemented("key rotation not yet supported") } - a.logger.DebugContext(ctx, "Server.RegisterUsingBoundKeypairMethod(): challenge verified, issuing certs") - // TODO: We should pass along the previous bot instance ID - if any - based // on the join state, once that is implemented. It will need to be passed // either via extended claims, or by a new protected field in the join @@ -474,7 +465,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return ptv2, nil }); err != nil { - return nil, trace.Wrap(err, "commiting updated token state, please try again") + return nil, trace.Wrap(err, "committing updated token state, please try again") } } diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go new file mode 100644 index 0000000000000..3f4613e0a3c3b --- /dev/null +++ b/lib/auth/join_bound_keypair_test.go @@ -0,0 +1,37 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package auth + +import "testing" + +func TestAuth_RegisterUsingToken_BoundKeypair(t *testing.T) { + tests := []struct { + name string + }{ + { + name: "success", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + }) + } +} diff --git a/lib/boundkeypair/bound_keypair.go b/lib/boundkeypair/bound_keypair.go index 21d5f5c25c8f3..d6e634e79d7b4 100644 --- a/lib/boundkeypair/bound_keypair.go +++ b/lib/boundkeypair/bound_keypair.go @@ -24,10 +24,11 @@ import ( "time" "github.com/go-jose/go-jose/v3/jwt" - "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils" ) const ( diff --git a/lib/boundkeypair/bound_keypair_test.go b/lib/boundkeypair/bound_keypair_test.go index 81eea79d2e18f..7dd58c63943d1 100644 --- a/lib/boundkeypair/bound_keypair_test.go +++ b/lib/boundkeypair/bound_keypair_test.go @@ -27,11 +27,12 @@ import ( "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cryptosuites" libjwt "github.com/gravitational/teleport/lib/jwt" - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" ) func newTestKeypair(t *testing.T) crypto.Signer { From 0289b79c7eef5a5a6fdc48f5a7d0e33ea94fafd9 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Fri, 2 May 2025 21:32:21 -0600 Subject: [PATCH 08/17] Add tests for RegisterUsingBoundKeypairMethod() --- api/types/provisioning.go | 20 ++ lib/auth/auth.go | 11 + lib/auth/join_bound_keypair.go | 14 +- lib/auth/join_bound_keypair_test.go | 336 +++++++++++++++++++++++++++- 4 files changed, 375 insertions(+), 6 deletions(-) diff --git a/api/types/provisioning.go b/api/types/provisioning.go index 3c49fe11bb4f0..d1b0dfabfaa85 100644 --- a/api/types/provisioning.go +++ b/api/types/provisioning.go @@ -193,6 +193,26 @@ func NewProvisionTokenFromSpec(token string, expires time.Time, spec ProvisionTo return t, nil } +// NewProvisionTokenFromSpecAndStatus returns a new provision token with the given spec. +func NewProvisionTokenFromSpecAndStatus( + token string, expires time.Time, + spec ProvisionTokenSpecV2, + status *ProvisionTokenStatusV2, +) (ProvisionToken, error) { + t := &ProvisionTokenV2{ + Metadata: Metadata{ + Name: token, + Expires: &expires, + }, + Spec: spec, + Status: status, + } + if err := t.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return t, nil +} + // MustCreateProvisionToken returns a new valid provision token // or panics, used in tests func MustCreateProvisionToken(token string, roles SystemRoles, expires time.Time) ProvisionToken { diff --git a/lib/auth/auth.go b/lib/auth/auth.go index d33d8365fa08a..da71c0d4cb0f6 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -96,6 +96,7 @@ import ( "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/bitbucket" + "github.com/gravitational/teleport/lib/boundkeypair" "github.com/gravitational/teleport/lib/cache" "github.com/gravitational/teleport/lib/circleci" "github.com/gravitational/teleport/lib/cryptosuites" @@ -699,6 +700,12 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { as.bitbucketIDTokenValidator = bitbucket.NewIDTokenValidator(as.clock) } + if as.createBoundKeypairValidator == nil { + as.createBoundKeypairValidator = func(subject, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) { + return boundkeypair.NewChallengeValidator(subject, clusterName, publicKey) + } + } + // Add in a login hook for generating state during user login. as.ulsGenerator, err = userloginstate.NewGenerator(userloginstate.GeneratorConfig{ Log: as.logger, @@ -1137,6 +1144,10 @@ type Server struct { bitbucketIDTokenValidator bitbucketIDTokenValidator + // createBoundKeypairValidator is a helper to create new bound keypair + // challenge validators. Used to override the implementation used in tests. + createBoundKeypairValidator createBoundKeypairValidator + // loadAllCAs tells tsh to load the host CAs for all clusters when trying to ssh into a node. loadAllCAs bool diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 60d3aa0588aae..242e18eca00fd 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -20,6 +20,7 @@ package auth import ( "context" + "crypto" "encoding/json" "time" @@ -34,6 +35,13 @@ import ( libsshutils "github.com/gravitational/teleport/lib/sshutils" ) +type boundKeypairValidator interface { + IssueChallenge() (*boundkeypair.ChallengeDocument, error) + ValidateChallengeResponse(issued *boundkeypair.ChallengeDocument, compactResponse string) error +} + +type createBoundKeypairValidator func(subject string, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) + // validateBoundKeypairTokenSpec performs some basic validation checks on a // bound_keypair-type join token. func validateBoundKeypairTokenSpec(spec *types.ProvisionTokenSpecV2BoundKeypair) error { @@ -162,7 +170,7 @@ func (a *Server) issueBoundKeypairChallenge( a.logger.DebugContext(ctx, "issuing bound keypair challenge", "key_id", keyID) - validator, err := boundkeypair.NewChallengeValidator(keyID, clusterName.GetClusterName(), key) + validator, err := a.createBoundKeypairValidator(keyID, clusterName.GetClusterName(), key) if err != nil { return trace.Wrap(err) } @@ -372,7 +380,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, trace.BadParameter("cannot perform first bound keypair join with existing credentials") case hasBoundPublicKey && !hasBoundBotInstance: // TODO: Bad backend state, or maybe an incomplete previous join - // attempt. This shouldn't be possible state, but we should handle it + // attempt. This shouldn't be a possible state, but we should handle it // sanely anyway. return nil, trace.BadParameter("bad backend state, please recreate the join token") case hasBoundPublicKey && hasBoundBotInstance && hasIncomingBotInstance: @@ -469,5 +477,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( } } + // TODO: need to return public key here for inclusion in the certs response. + return certs, trace.Wrap(err) } diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go index 3f4613e0a3c3b..8f2bb61f34c41 100644 --- a/lib/auth/join_bound_keypair_test.go +++ b/lib/auth/join_bound_keypair_test.go @@ -18,20 +18,348 @@ package auth -import "testing" +import ( + "context" + "crypto" + "testing" + "time" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/boundkeypair" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +type mockBoundKeypairValidator struct { + subject string + clusterName string + publicKey crypto.PublicKey +} + +func (v *mockBoundKeypairValidator) IssueChallenge() (*boundkeypair.ChallengeDocument, error) { + return &boundkeypair.ChallengeDocument{ + Nonce: "fake", + }, nil +} + +func (v *mockBoundKeypairValidator) ValidateChallengeResponse(issued *boundkeypair.ChallengeDocument, compactResponse string) error { + // For testing, the solver will just reply with the marshaled public key, so + // we'll parse and compare it. + key, err := sshutils.CryptoPublicKey([]byte(compactResponse)) + if err != nil { + return trace.Wrap(err, "parsing bound public key") + } + + equal, ok := v.publicKey.(interface { + Equal(x crypto.PublicKey) bool + }) + if !ok { + return trace.BadParameter("unsupported public key type %T", key) + } + + if !equal.Equal(key) { + return trace.AccessDenied("incorrect public key") + } + + return nil +} + +func testBoundKeypair(t *testing.T) (crypto.Signer, string) { + key, err := cryptosuites.GeneratePrivateKeyWithAlgorithm(cryptosuites.ECDSAP256) + require.NoError(t, err) + + return key.Signer, string(key.MarshalSSHPublicKey()) +} + +func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { + ctx := context.Background() + + _, correctPublicKey := testBoundKeypair(t) + _, incorrectPublicKey := testBoundKeypair(t) + + srv := newTestTLSServer(t) + auth := srv.Auth() + auth.createBoundKeypairValidator = func(subject, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) { + return &mockBoundKeypairValidator{ + //correctPublicKey: correctSigner.Public(), + + subject: subject, + clusterName: clusterName, + publicKey: publicKey, + }, nil + } + + _, err := CreateRole(ctx, auth, "example", types.RoleSpecV6{}) + require.NoError(t, err) + + adminClient, err := srv.NewClient(TestAdmin()) + _, err = adminClient.BotServiceClient().CreateBot(ctx, &machineidv1pb.CreateBotRequest{ + Bot: &machineidv1pb.Bot{ + Kind: types.KindBot, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &machineidv1pb.BotSpec{ + Roles: []string{"example"}, + }, + }, + }) + require.NoError(t, err) + + sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey) + require.NoError(t, err) + + makeToken := func(mutators ...func(v2 *types.ProvisionTokenV2)) types.ProvisionTokenV2 { + token := types.ProvisionTokenV2{ + Spec: types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodBoundKeypair, + Roles: []types.SystemRole{types.RoleBot}, + BotName: "test", + BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: correctPublicKey, + }, + Joining: &types.ProvisionTokenSpecV2BoundKeypair_JoiningSpec{ + // Only unlimited and insecure is supported for now, so + // we'll hard code it. + Unlimited: true, + Insecure: true, + }, + }, + }, + Status: &types.ProvisionTokenStatusV2{ + BoundKeypair: &types.ProvisionTokenStatusV2BoundKeypair{}, + }, + } + for _, mutator := range mutators { + mutator(&token) + } + return token + } + + makeInitReq := func(mutators ...func(r *proto.RegisterUsingBoundKeypairInitialRequest)) *proto.RegisterUsingBoundKeypairInitialRequest { + req := &proto.RegisterUsingBoundKeypairInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{ + HostID: "host-id", + Role: types.RoleBot, + PublicTLSKey: tlsPublicKey, + PublicSSHKey: sshPublicKey, + }, + } + for _, mutator := range mutators { + mutator(req) + } + return req + } + + makeSolver := func(publicKey string) client.RegisterUsingBoundKeypairChallengeResponseFunc { + return func(challenge *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) { + switch r := challenge.Response.(type) { + case *proto.RegisterUsingBoundKeypairMethodResponse_Challenge: + if r.Challenge.PublicKey != publicKey { + return nil, trace.BadParameter("wrong public key") + } + + return &proto.RegisterUsingBoundKeypairMethodRequest{ + Payload: &proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse{ + ChallengeResponse: &proto.RegisterUsingBoundKeypairChallengeResponse{ + // For testing purposes, we'll just reply with the + // public key, to avoid needing to parse the JWT. + Solution: []byte(publicKey), + }, + }, + }, nil + default: + return nil, trace.BadParameter("invalid response type") + } + } + } -func TestAuth_RegisterUsingToken_BoundKeypair(t *testing.T) { tests := []struct { name string + + token types.ProvisionTokenV2 + initReq *proto.RegisterUsingBoundKeypairInitialRequest + solver client.RegisterUsingBoundKeypairChallengeResponseFunc + + assertError require.ErrorAssertionFunc + assertSuccess func(t *testing.T, v2 *types.ProvisionTokenV2) }{ { - name: "success", + // no bound key, no bound bot instance, aka initial join without + // secret + name: "initial-join-success", + + token: makeToken(), + initReq: makeInitReq(), + solver: makeSolver(correctPublicKey), + + assertError: require.NoError, + assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) { + // join count should be incremented + require.Equal(t, uint32(1), v2.Status.BoundKeypair.JoinCount) + require.NotEmpty(t, v2.Status.BoundKeypair.BoundBotInstanceID) + require.NotEmpty(t, v2.Status.BoundKeypair.BoundPublicKey) + }, + }, + { + // no bound key, no bound bot instance, aka initial join without + // secret + name: "initial-join-with-wrong-key", + + token: makeToken(), + initReq: makeInitReq(), + solver: makeSolver(incorrectPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "wrong public key") + }, + }, + { + // bound key, valid bound bot instance, aka "soft join" + name: "reauth-success", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" + }), + initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.JoinRequest.BotInstanceID = "asdf" + }), + solver: makeSolver(correctPublicKey), + + assertError: require.NoError, + assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) { + // join count should not be incremented + require.Equal(t, uint32(0), v2.Status.BoundKeypair.JoinCount) + }, + }, + { + // bound key, seemingly valid bot instance, but wrong key + // (should be impossible, but should fail anyway) + name: "reauth-with-wrong-key", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" + }), + initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.JoinRequest.BotInstanceID = "asdf" + }), + solver: makeSolver(incorrectPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "wrong public key") + }, + }, + { + // bound key but no valid incoming bot instance, i.e. the certs + // expired and triggered a hard rejoin + name: "rejoin-success", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" + }), + initReq: makeInitReq(), + solver: makeSolver(correctPublicKey), + + assertError: require.NoError, + assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) { + require.Equal(t, uint32(1), v2.Status.BoundKeypair.JoinCount) + + // Should generate a new bot instance + require.NotEmpty(t, v2.Status.BoundKeypair.BoundBotInstanceID) + require.NotEqual(t, "asdf", v2.Status.BoundKeypair.BoundBotInstanceID) + }, + }, + { + // Bad state: somehow a key was registered without a bot instance. + // This should fail and prompt the user to recreate the token. + name: "bound-key-no-instance", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + }), + initReq: makeInitReq(), + solver: makeSolver(correctPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "bad backend state") + }, + }, + { + // The client somehow presents certs that refer to a different + // instance, maybe tried switching auth methods. + name: "bound-key-wrong-instance", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "qwerty" + }), + initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) { + r.JoinRequest.BotInstanceID = "asdf" + }), + solver: makeSolver(correctPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "bot instance mismatch") + }, + }, + { + // TODO: rotation is not yet implemented. + name: "rotation-requested", + + token: makeToken(func(v2 *types.ProvisionTokenV2) { + v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey + v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" + v2.Spec.BoundKeypair.RotateOnNextRenewal = true + }), + initReq: makeInitReq(), + solver: makeSolver(correctPublicKey), + + assertError: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(tt, err) + require.ErrorContains(tt, err, "key rotation not yet supported") + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - + token, err := types.NewProvisionTokenFromSpecAndStatus( + tt.name, time.Now().Add(time.Minute), tt.token.Spec, tt.token.Status, + ) + require.NoError(t, err) + require.NoError(t, auth.CreateToken(ctx, token)) + tt.initReq.JoinRequest.Token = tt.name + + _, err = auth.RegisterUsingBoundKeypairMethod(ctx, tt.initReq, tt.solver) + tt.assertError(t, err) + + if tt.assertSuccess != nil { + pt, err := auth.GetToken(ctx, tt.name) + require.NoError(t, err) + + ptv2, ok := pt.(*types.ProvisionTokenV2) + require.True(t, ok) + + tt.assertSuccess(t, ptv2) + } }) } } From b9727488fba48a7f4f699d3b3ffa9aceb5a79970 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Mon, 5 May 2025 19:10:01 -0600 Subject: [PATCH 09/17] Fix lints --- lib/auth/join_bound_keypair_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go index 8f2bb61f34c41..b01ccab45ed7d 100644 --- a/lib/auth/join_bound_keypair_test.go +++ b/lib/auth/join_bound_keypair_test.go @@ -24,6 +24,9 @@ import ( "testing" "time" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" @@ -33,8 +36,6 @@ import ( "github.com/gravitational/teleport/lib/boundkeypair" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/sshutils" - "github.com/gravitational/trace" - "github.com/stretchr/testify/require" ) type mockBoundKeypairValidator struct { @@ -100,6 +101,8 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { require.NoError(t, err) adminClient, err := srv.NewClient(TestAdmin()) + require.NoError(t, err) + _, err = adminClient.BotServiceClient().CreateBot(ctx, &machineidv1pb.CreateBotRequest{ Bot: &machineidv1pb.Bot{ Kind: types.KindBot, From 1628093db0e462c7baaaec5e80e8c3546e72d451 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Mon, 5 May 2025 19:41:14 -0600 Subject: [PATCH 10/17] Add basic provisioning token CheckAndSetDefaults() tests --- api/types/provisioning_test.go | 48 ++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/api/types/provisioning_test.go b/api/types/provisioning_test.go index d4679432fa334..155a8834298b7 100644 --- a/api/types/provisioning_test.go +++ b/api/types/provisioning_test.go @@ -1277,6 +1277,54 @@ func TestProvisionTokenV2_CheckAndSetDefaults(t *testing.T) { }, wantErr: true, }, + { + desc: "minimal bound keypair with pregenerated key", + token: &ProvisionTokenV2{ + Metadata: Metadata{ + Name: "test", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodBoundKeypair, + BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: "asdf", + }, + }, + }, + }, + expected: &ProvisionTokenV2{ + Kind: "token", + Version: "v2", + Metadata: Metadata{ + Name: "test", + Namespace: "default", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodBoundKeypair, + BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: "asdf", + }, + }, + }, + }, + }, + { + desc: "bound keypair missing onboarding config", + token: &ProvisionTokenV2{ + Metadata: Metadata{ + Name: "test", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodBoundKeypair, + BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{}, + }, + }, + wantErr: true, + }, } for _, tc := range testcases { From d937c663a8dd0c1e158aeb6700ed4bab1a717e9b Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Mon, 5 May 2025 20:47:26 -0600 Subject: [PATCH 11/17] Include bound public key in RegisterUsingBoundKeypairMethod return This is passed back to clients as part of the proto certs message as confirmation that rotation succeeded, so the value needed to be plumbed through. --- api/client/joinservice.go | 24 ++++++------- lib/auth/join_bound_keypair.go | 52 ++++++++++++++++------------- lib/auth/join_bound_keypair_test.go | 2 +- lib/joinserver/joinserver.go | 7 ++-- lib/joinserver/joinserver_test.go | 10 +++--- 5 files changed, 52 insertions(+), 43 deletions(-) diff --git a/api/client/joinservice.go b/api/client/joinservice.go index 3452c27480b2a..32b0f60dfe5db 100644 --- a/api/client/joinservice.go +++ b/api/client/joinservice.go @@ -270,20 +270,20 @@ func (c *JoinServiceClient) RegisterUsingOracleMethod( } // RegisterUsingBoundKeypairMethod attempts to register the caller using -// bound-keypair join method. If successful, a certificate bundle is returned, -// or an error. Clients must provide a callback to handle interactive challenges -// and keypair rotation requests. +// bound-keypair join method. If successful, the public key registered with auth +// and a certificate bundle is returned, or an error. Clients must provide a +// callback to handle interactive challenges and keypair rotation requests. func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( ctx context.Context, initReq *proto.RegisterUsingBoundKeypairInitialRequest, challengeFunc RegisterUsingBoundKeypairChallengeResponseFunc, -) (*proto.Certs, error) { +) (*proto.Certs, string, error) { ctx, cancel := context.WithCancel(ctx) defer cancel() stream, err := c.grpcClient.RegisterUsingBoundKeypairMethod(ctx) if err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } defer stream.CloseSend() @@ -293,7 +293,7 @@ func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( }, }) if err != nil { - return nil, trace.Wrap(err, "sending initial request") + return nil, "", trace.Wrap(err, "sending initial request") } // Unlike other methods, the server may send multiple challenges, @@ -304,7 +304,7 @@ func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( if errors.Is(err, io.EOF) { break } else if err != nil { - return nil, trace.Wrap(err, "receiving intermediate bound keypair join response") + return nil, "", trace.Wrap(err, "receiving intermediate bound keypair join response") } switch kind := res.GetResponse().(type) { @@ -312,7 +312,7 @@ func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( // If we get certs, we're done, so just return the result. certs := kind.Certs.GetCerts() if certs == nil { - return nil, trace.BadParameter("expected Certs, got %T", kind.Certs.Certs) + return nil, "", trace.BadParameter("expected Certs, got %T", kind.Certs.Certs) } // If we receive a cert bundle, we can return early. Even if we @@ -322,23 +322,23 @@ func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod( // raise an error if rotation fails or is otherwise skipped or not // allowed. - return certs, nil + return certs, kind.Certs.GetPublicKey(), nil default: // Forward all other responses to the challenge handler. nextRequest, err := challengeFunc(res) if err != nil { - return nil, trace.Wrap(err, "solving challenge") + return nil, "", trace.Wrap(err, "solving challenge") } if err := stream.Send(nextRequest); err != nil { - return nil, trace.Wrap(err, "sending solution") + return nil, "", trace.Wrap(err, "sending solution") } } } // Ideally the server will emit a proper error instead of just hanging up on // us. - return nil, trace.AccessDenied("server declined to send certs during bound-keypair join attempt") + return nil, "", trace.AccessDenied("server declined to send certs during bound-keypair join attempt") } // RegisterUsingToken registers the caller using a token and returns signed diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 242e18eca00fd..276a48e861d1a 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -288,7 +288,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( ctx context.Context, req *proto.RegisterUsingBoundKeypairInitialRequest, challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, -) (_ *proto.Certs, err error) { +) (_ *proto.Certs, _ string, err error) { var provisionToken types.ProvisionToken var joinFailureMetadata any defer func() { @@ -303,25 +303,25 @@ func (a *Server) RegisterUsingBoundKeypairMethod( // First, check the specified token exists, and is a bound keypair-type join // token. if err := req.JoinRequest.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } // Only bot joining is supported at the moment - unique ID verification is // required and this is currently only implemented for bots. if req.JoinRequest.Role != types.RoleBot { - return nil, trace.BadParameter("bound keypair joining is only supported for bots") + return nil, "", trace.BadParameter("bound keypair joining is only supported for bots") } provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.JoinRequest) if err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } ptv2, ok := provisionToken.(*types.ProvisionTokenV2) if !ok { - return nil, trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken) + return nil, "", trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken) } if ptv2.Spec.JoinMethod != types.JoinMethodBoundKeypair { - return nil, trace.BadParameter("specified join token is not for `%s` method", types.JoinMethodBoundKeypair) + return nil, "", trace.BadParameter("specified join token is not for `%s` method", types.JoinMethodBoundKeypair) } if ptv2.Status == nil { @@ -337,8 +337,14 @@ func (a *Server) RegisterUsingBoundKeypairMethod( hasBoundBotInstance := status.BoundBotInstanceID != "" hasIncomingBotInstance := req.JoinRequest.BotInstanceID != "" hasJoinsRemaining := status.JoinCount < spec.Joining.TotalJoins + + // if set, the bound bot instance will be updated in the backend expectNewBotInstance := false + // the bound public key; may change during initial join or rotation. used to + // inform the returned public key value. + boundPublicKey := status.BoundPublicKey + // Mutators to use during the token resource status patch at the end. var mutators []boundKeypairStatusMutator @@ -347,15 +353,15 @@ func (a *Server) RegisterUsingBoundKeypairMethod( // Normal initial join attempt. No bound key, and no incoming bot // instance. Consumes a rejoin. if spec.Onboarding.InitialJoinSecret != "" { - return nil, trace.NotImplemented("initial joining secrets are not yet supported") + return nil, "", trace.NotImplemented("initial joining secrets are not yet supported") } if spec.Onboarding.InitialPublicKey == "" { - return nil, trace.BadParameter("an initial public key is required") + return nil, "", trace.BadParameter("an initial public key is required") } if !spec.Joining.Unlimited && !hasJoinsRemaining { - return nil, trace.AccessDenied("no joins remaining") + return nil, "", trace.AccessDenied("no joins remaining") } if err := a.issueBoundKeypairChallenge( @@ -363,7 +369,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( spec.Onboarding.InitialPublicKey, challengeResponse, ); err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } // Now that we've confirmed the key, we can consider it bound. @@ -374,19 +380,20 @@ func (a *Server) RegisterUsingBoundKeypairMethod( ) expectNewBotInstance = true + boundPublicKey = spec.Onboarding.InitialPublicKey case !hasBoundPublicKey && hasIncomingBotInstance: // Not allowed, at least at the moment. This would imply e.g. trying to // change auth methods. - return nil, trace.BadParameter("cannot perform first bound keypair join with existing credentials") + return nil, "", trace.BadParameter("cannot perform first bound keypair join with existing credentials") case hasBoundPublicKey && !hasBoundBotInstance: // TODO: Bad backend state, or maybe an incomplete previous join // attempt. This shouldn't be a possible state, but we should handle it // sanely anyway. - return nil, trace.BadParameter("bad backend state, please recreate the join token") + return nil, "", trace.BadParameter("bad backend state, please recreate the join token") case hasBoundPublicKey && hasBoundBotInstance && hasIncomingBotInstance: // Standard rejoin case, does not consume a rejoin. if status.BoundBotInstanceID != req.JoinRequest.BotInstanceID { - return nil, trace.AccessDenied("bot instance mismatch") + return nil, "", trace.AccessDenied("bot instance mismatch") } if err := a.issueBoundKeypairChallenge( @@ -394,7 +401,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( spec.Onboarding.InitialPublicKey, challengeResponse, ); err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } // Nothing else to do, no key change @@ -402,7 +409,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( // Hard rejoin case, the client identity expired and a new bot instance // is required. Consumes a rejoin. if !spec.Joining.Unlimited && !hasJoinsRemaining { - return nil, trace.AccessDenied("no rejoins remaining") + return nil, "", trace.AccessDenied("no rejoins remaining") } if err := a.issueBoundKeypairChallenge( @@ -410,7 +417,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( status.BoundPublicKey, challengeResponse, ); err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } mutators = append( @@ -428,12 +435,13 @@ func (a *Server) RegisterUsingBoundKeypairMethod( "spec", spec, "status", status, ) - return nil, trace.BadParameter("unexpected state") + return nil, "", trace.BadParameter("unexpected state") } if spec.RotateOnNextRenewal { - // TODO, to be implemented in a future PR - return nil, trace.NotImplemented("key rotation not yet supported") + // TODO, to be implemented in a future PR. `boundPublicKey` will need to + // be updated. + return nil, "", trace.NotImplemented("key rotation not yet supported") } // TODO: We should pass along the previous bot instance ID - if any - based @@ -473,11 +481,9 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return ptv2, nil }); err != nil { - return nil, trace.Wrap(err, "committing updated token state, please try again") + return nil, "", trace.Wrap(err, "committing updated token state, please try again") } } - // TODO: need to return public key here for inclusion in the certs response. - - return certs, trace.Wrap(err) + return certs, boundPublicKey, trace.Wrap(err) } diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go index b01ccab45ed7d..ee8eae7204906 100644 --- a/lib/auth/join_bound_keypair_test.go +++ b/lib/auth/join_bound_keypair_test.go @@ -351,7 +351,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { require.NoError(t, auth.CreateToken(ctx, token)) tt.initReq.JoinRequest.Token = tt.name - _, err = auth.RegisterUsingBoundKeypairMethod(ctx, tt.initReq, tt.solver) + _, _, err = auth.RegisterUsingBoundKeypairMethod(ctx, tt.initReq, tt.solver) tt.assertError(t, err) if tt.assertSuccess != nil { diff --git a/lib/joinserver/joinserver.go b/lib/joinserver/joinserver.go index ea57b746b21fa..1d596995defad 100644 --- a/lib/joinserver/joinserver.go +++ b/lib/joinserver/joinserver.go @@ -59,7 +59,7 @@ type joinServiceClient interface { ctx context.Context, req *proto.RegisterUsingBoundKeypairInitialRequest, challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, - ) (*proto.Certs, error) + ) (*proto.Certs, string, error) RegisterUsingToken( ctx context.Context, req *types.RegisterUsingTokenRequest, @@ -414,7 +414,7 @@ func (s *JoinServiceGRPCServer) registerUsingBoundKeypair(srv proto.JoinService_ setBotParameters(ctx, initReq.JoinRequest) - certs, err := s.joinServiceClient.RegisterUsingBoundKeypairMethod(ctx, initReq, func(resp *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) { + certs, pubKey, err := s.joinServiceClient.RegisterUsingBoundKeypairMethod(ctx, initReq, func(resp *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) { // First, forward the challenge from Auth to the client. err := srv.Send(resp) if err != nil { @@ -443,7 +443,8 @@ func (s *JoinServiceGRPCServer) registerUsingBoundKeypair(srv proto.JoinService_ return trace.Wrap(srv.Send(&proto.RegisterUsingBoundKeypairMethodResponse{ Response: &proto.RegisterUsingBoundKeypairMethodResponse_Certs{ Certs: &proto.RegisterUsingBoundKeypairCertificates{ - Certs: certs, + Certs: certs, + PublicKey: pubKey, }, }, })) diff --git a/lib/joinserver/joinserver_test.go b/lib/joinserver/joinserver_test.go index 2462d6d1fd36a..dcd954f8d815c 100644 --- a/lib/joinserver/joinserver_test.go +++ b/lib/joinserver/joinserver_test.go @@ -93,7 +93,7 @@ func (c *mockJoinServiceClient) RegisterUsingBoundKeypairMethod( ctx context.Context, req *proto.RegisterUsingBoundKeypairInitialRequest, challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc, -) (*proto.Certs, error) { +) (*proto.Certs, string, error) { c.gotBoundKeypairInitReq = req resp, err := challengeResponse(&proto.RegisterUsingBoundKeypairMethodResponse{ Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{ @@ -104,12 +104,12 @@ func (c *mockJoinServiceClient) RegisterUsingBoundKeypairMethod( }, }) if err != nil { - return nil, trace.Wrap(err) + return nil, "", trace.Wrap(err) } c.gotBoundKeypairChallengeResponse = resp - return c.returnCerts, c.returnError + return c.returnCerts, c.boundKeypairPublicKey, c.returnError } func (c *mockJoinServiceClient) RegisterUsingOracleMethod( @@ -646,7 +646,7 @@ func TestJoinServiceGRPCServer_RegisterUsingBoundKeypairMethodSimple(t *testing. "_proxy": testPack.proxyClient, } { t.Run(tc.desc+suffix, func(t *testing.T) { - certs, err := clt.RegisterUsingBoundKeypairMethod( + certs, pubKey, err := clt.RegisterUsingBoundKeypairMethod( context.Background(), tc.req, challengeResponder, ) if tc.challengeResponseErr != nil { @@ -670,6 +670,8 @@ func TestJoinServiceGRPCServer_RegisterUsingBoundKeypairMethodSimple(t *testing. tc.challengeResponse, testPack.mockAuthServer.gotBoundKeypairChallengeResponse, ) + + require.Equal(t, tc.publicKey, pubKey) }) } }) From a2ca7e144577dc93910b1a365a93c1387a028c17 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Thu, 8 May 2025 21:06:37 -0600 Subject: [PATCH 12/17] Fixes after upstream proto change We renamed and tweaked a number of proto fields, so this updates field references. --- api/types/provisioning.go | 2 +- lib/auth/join_bound_keypair.go | 46 ++++++++++++++--------------- lib/auth/join_bound_keypair_test.go | 18 +++++------ lib/boundkeypair/bound_keypair.go | 26 ++++++++++++++++ 4 files changed, 58 insertions(+), 34 deletions(-) diff --git a/api/types/provisioning.go b/api/types/provisioning.go index d1b0dfabfaa85..6f7a76f501952 100644 --- a/api/types/provisioning.go +++ b/api/types/provisioning.go @@ -993,7 +993,7 @@ func (a *ProvisionTokenSpecV2BoundKeypair) checkAndSetDefaults() error { return trace.BadParameter("spec.bound_keypair.onboarding is required") } - if a.Onboarding.InitialJoinSecret == "" && a.Onboarding.InitialPublicKey == "" { + if a.Onboarding.RegistrationSecret == "" && a.Onboarding.InitialPublicKey == "" { return trace.BadParameter("at least one of [initial_join_secret, " + "initial_public_key] is required in spec.bound_keypair.onboarding") } diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 276a48e861d1a..75dc190b1b591 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -52,24 +52,20 @@ func validateBoundKeypairTokenSpec(spec *types.ProvisionTokenSpecV2BoundKeypair) return trace.BadParameter("bound keypair joining experiment is not enabled") } - if spec.RotateOnNextRenewal { - return trace.NotImplemented("spec.bound_keypair.rotate_on_next_renewal is not yet implemented") + if spec.RotateAfter != nil { + return trace.NotImplemented("spec.bound_keypair.rotate_after is not yet implemented") } - if spec.Onboarding.InitialJoinSecret != "" { - return trace.NotImplemented("spec.bound_keypair.initial_join_secret is not yet implemented") + if spec.Onboarding.RegistrationSecret != "" { + return trace.NotImplemented("spec.bound_keypair.onboarding.registration_secret is not yet implemented") } if spec.Onboarding.InitialPublicKey == "" { - return trace.NotImplemented("spec.bound_keypair.initial_public_key is currently required") + return trace.NotImplemented("spec.bound_keypair.onboarding.initial_public_key is currently required") } - if !spec.Joining.Unlimited { - return trace.NotImplemented("spec.bound_keypair.joining.unlimited cannot currently be `false`") - } - - if !spec.Joining.Insecure { - return trace.NotImplemented("spec.bound_keypair.joining.insecure cannot currently be `false`") + if spec.Recovery.Mode != boundkeypair.RecoveryModeInsecure { + return trace.NotImplemented("spec.bound_keypair.recovery.mode currently must be %s", boundkeypair.RecoveryModeInsecure) } return nil @@ -110,7 +106,7 @@ func (a *Server) CreateBoundKeypairToken(ctx context.Context, token types.Provis return trace.BadParameter("cannot create a bound_keypair token with set status") } - // TODO: Populate initial_join_secret if needed. + // TODO (follow up PR): Populate initial_join_secret if needed. return trace.Wrap(a.CreateToken(ctx, tokenV2)) } @@ -225,27 +221,28 @@ type boundKeypairStatusMutator func(*types.ProvisionTokenSpecV2BoundKeypair, *ty // mutateStatusConsumeJoin consumes a "hard" join on the backend, incrementing // the join counter. This verifies that the backend join count has not changed, // and that total join count is at least the value when the mutator was created. -func mutateStatusConsumeJoin(unlimited bool, expectJoinCount uint32, expectMinTotalJoins uint32) boundKeypairStatusMutator { +func mutateStatusConsumeJoin(mode boundkeypair.RecoveryMode, expectRecoveryCount uint32, expectMinRecoveryLimit uint32) boundKeypairStatusMutator { now := time.Now() return func(spec *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error { // Ensure we have the expected number of rejoins left to prevent going // below zero. - if status.JoinCount != expectJoinCount { + if status.RecoveryCount != expectRecoveryCount { return trace.AccessDenied("unexpected backend state") } // Ensure the allowed join count has at least not decreased, but allow // for collision with potentially increased values. - if spec.Joining.TotalJoins < expectMinTotalJoins { + if spec.Recovery.Limit < expectMinRecoveryLimit { return trace.AccessDenied("unexpected backend state") } - if !unlimited { + if mode == boundkeypair.RecoveryModeStandard { + // TODO: to be removed in a future PR return trace.NotImplemented("only unlimited rejoining is currently supported") } - status.JoinCount += 1 + status.RecoveryCount += 1 status.LastJoinedAt = &now return nil @@ -336,7 +333,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( hasBoundPublicKey := status.BoundPublicKey != "" hasBoundBotInstance := status.BoundBotInstanceID != "" hasIncomingBotInstance := req.JoinRequest.BotInstanceID != "" - hasJoinsRemaining := status.JoinCount < spec.Joining.TotalJoins + hasJoinsRemaining := status.RecoveryCount < spec.Recovery.Limit // if set, the bound bot instance will be updated in the backend expectNewBotInstance := false @@ -352,7 +349,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( case !hasBoundPublicKey && !hasIncomingBotInstance: // Normal initial join attempt. No bound key, and no incoming bot // instance. Consumes a rejoin. - if spec.Onboarding.InitialJoinSecret != "" { + if spec.Onboarding.RegistrationSecret != "" { return nil, "", trace.NotImplemented("initial joining secrets are not yet supported") } @@ -360,7 +357,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, "", trace.BadParameter("an initial public key is required") } - if !spec.Joining.Unlimited && !hasJoinsRemaining { + if spec.Recovery.Mode == string(boundkeypair.RecoveryModeStandard) && !hasJoinsRemaining { return nil, "", trace.AccessDenied("no joins remaining") } @@ -376,7 +373,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( mutators = append( mutators, mutateStatusBoundPublicKey(spec.Onboarding.InitialPublicKey, ""), - mutateStatusConsumeJoin(spec.Joining.Unlimited, status.JoinCount, spec.Joining.TotalJoins), + mutateStatusConsumeJoin(boundkeypair.RecoveryMode(spec.Recovery.Mode), status.RecoveryCount, spec.Recovery.Limit), ) expectNewBotInstance = true @@ -408,7 +405,8 @@ func (a *Server) RegisterUsingBoundKeypairMethod( case hasBoundPublicKey && hasBoundBotInstance && !hasIncomingBotInstance: // Hard rejoin case, the client identity expired and a new bot instance // is required. Consumes a rejoin. - if !spec.Joining.Unlimited && !hasJoinsRemaining { + if spec.Recovery.Mode == string(boundkeypair.RecoveryModeStandard) && !hasJoinsRemaining { + // Recovery limit only applies in "standard" mode. return nil, "", trace.AccessDenied("no rejoins remaining") } @@ -422,7 +420,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( mutators = append( mutators, - mutateStatusConsumeJoin(spec.Joining.Unlimited, status.JoinCount, spec.Joining.TotalJoins), + mutateStatusConsumeJoin(boundkeypair.RecoveryMode(spec.Recovery.Mode), status.RecoveryCount, spec.Recovery.Limit), ) expectNewBotInstance = true @@ -438,7 +436,7 @@ func (a *Server) RegisterUsingBoundKeypairMethod( return nil, "", trace.BadParameter("unexpected state") } - if spec.RotateOnNextRenewal { + if spec.RotateAfter != nil { // TODO, to be implemented in a future PR. `boundPublicKey` will need to // be updated. return nil, "", trace.NotImplemented("key rotation not yet supported") diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go index ee8eae7204906..837648f43c8c7 100644 --- a/lib/auth/join_bound_keypair_test.go +++ b/lib/auth/join_bound_keypair_test.go @@ -132,11 +132,9 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ InitialPublicKey: correctPublicKey, }, - Joining: &types.ProvisionTokenSpecV2BoundKeypair_JoiningSpec{ - // Only unlimited and insecure is supported for now, so - // we'll hard code it. - Unlimited: true, - Insecure: true, + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + // Only insecure is supported for now. + Mode: boundkeypair.RecoveryModeInsecure, }, }, }, @@ -210,7 +208,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { assertError: require.NoError, assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) { // join count should be incremented - require.Equal(t, uint32(1), v2.Status.BoundKeypair.JoinCount) + require.Equal(t, uint32(1), v2.Status.BoundKeypair.RecoveryCount) require.NotEmpty(t, v2.Status.BoundKeypair.BoundBotInstanceID) require.NotEmpty(t, v2.Status.BoundKeypair.BoundPublicKey) }, @@ -245,7 +243,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { assertError: require.NoError, assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) { // join count should not be incremented - require.Equal(t, uint32(0), v2.Status.BoundKeypair.JoinCount) + require.Equal(t, uint32(0), v2.Status.BoundKeypair.RecoveryCount) }, }, { @@ -281,7 +279,7 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { assertError: require.NoError, assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) { - require.Equal(t, uint32(1), v2.Status.BoundKeypair.JoinCount) + require.Equal(t, uint32(1), v2.Status.BoundKeypair.RecoveryCount) // Should generate a new bot instance require.NotEmpty(t, v2.Status.BoundKeypair.BoundBotInstanceID) @@ -328,9 +326,11 @@ func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) { name: "rotation-requested", token: makeToken(func(v2 *types.ProvisionTokenV2) { + t := time.Now() v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey v2.Status.BoundKeypair.BoundBotInstanceID = "asdf" - v2.Spec.BoundKeypair.RotateOnNextRenewal = true + v2.Spec.BoundKeypair.RotateAfter = &t + // TODO: test clock? }), initReq: makeInitReq(), solver: makeSolver(correctPublicKey), diff --git a/lib/boundkeypair/bound_keypair.go b/lib/boundkeypair/bound_keypair.go index d6e634e79d7b4..e032ace67336d 100644 --- a/lib/boundkeypair/bound_keypair.go +++ b/lib/boundkeypair/bound_keypair.go @@ -41,6 +41,23 @@ const ( challengeNotBeforeOffset = -10 * time.Second ) +// RecoveryMode is a recovery configuration mode +type RecoveryMode string + +const ( + // RecoveryModeStandard is the standard recovery mode, and enforces the + // recovery count limit and verifies client state. + RecoveryModeStandard RecoveryMode = "standard" + + // RecoveryModeRelaxed does not enforce the recovery count limit, but still + // verifies client state. + RecoveryModeRelaxed = "relaxed" + + // RecoveryModeInsecure does enforces neither the recovery count limit nor + // the client state. + RecoveryModeInsecure = "insecure" +) + // ChallengeDocument is a bound keypair challenge document. These documents are // sent in JSON form to clients attempting to authenticate, and are expected to // be sent back signed with a known public key. @@ -148,3 +165,12 @@ func (v *ChallengeValidator) ValidateChallengeResponse(issued *ChallengeDocument return nil } + +// RecoveryModes returns a list of all supported recovery modes +func RecoveryModes() []RecoveryMode { + return []RecoveryMode{ + RecoveryModeStandard, + RecoveryModeRelaxed, + RecoveryModeInsecure, + } +} From 8b94d7acdf5a3b0da7409633dd38b3f904be2f34 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Thu, 8 May 2025 21:12:49 -0600 Subject: [PATCH 13/17] Apply suggestions from code review Co-authored-by: Dan Upton --- lib/boundkeypair/bound_keypair.go | 2 +- lib/services/local/provisioning.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/boundkeypair/bound_keypair.go b/lib/boundkeypair/bound_keypair.go index e032ace67336d..bb4202ebb34cc 100644 --- a/lib/boundkeypair/bound_keypair.go +++ b/lib/boundkeypair/bound_keypair.go @@ -60,7 +60,7 @@ const ( // ChallengeDocument is a bound keypair challenge document. These documents are // sent in JSON form to clients attempting to authenticate, and are expected to -// be sent back signed with a known public key. +// be sent back signed with the private counterpart of a known public key. type ChallengeDocument struct { *jwt.Claims diff --git a/lib/services/local/provisioning.go b/lib/services/local/provisioning.go index 9a41ae091a162..e2a018a988f8f 100644 --- a/lib/services/local/provisioning.go +++ b/lib/services/local/provisioning.go @@ -86,7 +86,7 @@ func (s *ProvisioningService) PatchToken( case updatedMetadata.GetName() != existingMetadata.GetName(): return nil, trace.BadParameter("metadata.name: cannot be patched") case updatedMetadata.GetRevision() != existingMetadata.GetRevision(): - return nil, trace.BadParameter("metadata.revision: cannot be patched") + return nil, trace.BadParameter("metadata.revision: cannot be patched") } item, err := s.tokenToItem(updated) From c17de88f6fceaf1611e25a1795532e99adfb6e9c Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Thu, 8 May 2025 21:15:13 -0600 Subject: [PATCH 14/17] Remove TODO --- lib/auth/join_bound_keypair.go | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 75dc190b1b591..18ef45013913c 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -90,21 +90,10 @@ func (a *Server) CreateBoundKeypairToken(ctx context.Context, token types.Provis return trace.Wrap(err) } - // TODO: Is this wise? End users _shouldn't_ modify this, but this could - // interfere with cluster backup/restore. Options seem to be: - // - Let users create/update with status fields. They can break things, but - // maybe that's okay. No backup/restore implications. - // - Ignore status fields during creation and update. Any set value will be - // discarded here, and during update. This would still have consequences - // during cluster restores, but wouldn't raise errors, and the status - // field would otherwise be protected from easy tampering. Users might be - // confused as no user-visible errors would be raised if they used - // `tctl edit`. - // - Raise an error if status fields are changed. Worst restore - // implications, but tampering won't be easy, and will have some UX. - if tokenV2.Status.BoundKeypair != nil { - return trace.BadParameter("cannot create a bound_keypair token with set status") - } + // Not as much to do here - ideally we'd like to prevent users from + // tampering with the status field, but we don't have a good mechanism to + // stop that that wouldn't also break backup and restore. For now, it's + // simpler and easier to just tell users not to edit those fields. // TODO (follow up PR): Populate initial_join_secret if needed. From 082b1b26710c5ecfa119a085164d3a0c600c8e3b Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Fri, 9 May 2025 20:33:13 -0600 Subject: [PATCH 15/17] Fix missed field rename --- lib/auth/join_bound_keypair.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 18ef45013913c..930a36a6eff5a 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -232,7 +232,7 @@ func mutateStatusConsumeJoin(mode boundkeypair.RecoveryMode, expectRecoveryCount } status.RecoveryCount += 1 - status.LastJoinedAt = &now + status.LastRecoveredAt = &now return nil } From e663aa0f4f34a31efc046911a66bb84974a890ce Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Mon, 12 May 2025 20:10:32 -0600 Subject: [PATCH 16/17] Fix broken test --- lib/web/join_tokens_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lib/web/join_tokens_test.go b/lib/web/join_tokens_test.go index 5ac264b99a369..1ceec4724d18c 100644 --- a/lib/web/join_tokens_test.go +++ b/lib/web/join_tokens_test.go @@ -47,6 +47,8 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/automaticupgrades" + "github.com/gravitational/teleport/lib/boundkeypair" + "github.com/gravitational/teleport/lib/boundkeypair/experiment" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/modules" @@ -479,6 +481,9 @@ func TestCreateTokenExpiry(t *testing.T) { }, }) + // TODO: Remove this once bound keypair experiment flag is removed. + experiment.SetEnabled(true) + ctx := context.Background() username := "test-user@example.com" env := newWebPack(t, 1) @@ -614,6 +619,15 @@ func setMinimalConfigForMethod(spec *types.ProvisionTokenSpecV2, method types.Jo }, }, } + case types.JoinMethodBoundKeypair: + spec.BoundKeypair = &types.ProvisionTokenSpecV2BoundKeypair{ + Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{ + InitialPublicKey: "abcd", + }, + Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{ + Mode: boundkeypair.RecoveryModeInsecure, + }, + } } } From 21455de51d6f28d74ea4c82d66b50dc440cfb860 Mon Sep 17 00:00:00 2001 From: Tim Buckley Date: Tue, 13 May 2025 20:36:27 -0600 Subject: [PATCH 17/17] Fix lurking nil pointer deref after field rename --- lib/auth/join_bound_keypair.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go index 273a687215119..462053aadbbdc 100644 --- a/lib/auth/join_bound_keypair.go +++ b/lib/auth/join_bound_keypair.go @@ -64,6 +64,10 @@ func validateBoundKeypairTokenSpec(spec *types.ProvisionTokenSpecV2BoundKeypair) return trace.NotImplemented("spec.bound_keypair.onboarding.initial_public_key is currently required") } + if spec.Recovery == nil { + return trace.BadParameter("spec.recovery: field is required") + } + if spec.Recovery.Mode != boundkeypair.RecoveryModeInsecure { return trace.NotImplemented("spec.bound_keypair.recovery.mode currently must be %s", boundkeypair.RecoveryModeInsecure) }