diff --git a/api/client/joinservice.go b/api/client/joinservice.go
index bc30c8a541e6c..32b0f60dfe5db 100644
--- a/api/client/joinservice.go
+++ b/api/client/joinservice.go
@@ -18,6 +18,8 @@ package client
import (
"context"
+ "errors"
+ "io"
"github.com/gravitational/trace"
@@ -60,6 +62,11 @@ type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredenti
// *proto.OracleSignedRequest for a given challenge, or an error.
type RegisterOracleChallengeResponseFunc func(challenge string) (*proto.OracleSignedRequest, error)
+// RegisterUsingBoundKeypairChallengeResponseFunc is a function to be passed to
+// RegisterUsingBoundKeypair. It must return a new follow-up request for the
+// server response, or an error.
+type RegisterUsingBoundKeypairChallengeResponseFunc func(challenge *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error)
+
// RegisterUsingIAMMethod registers the caller using the IAM join method and
// returns signed certs to join the cluster.
//
@@ -262,6 +269,78 @@ func (c *JoinServiceClient) RegisterUsingOracleMethod(
return certs, nil
}
+// RegisterUsingBoundKeypairMethod attempts to register the caller using
+// bound-keypair join method. If successful, the public key registered with auth
+// and a certificate bundle is returned, or an error. Clients must provide a
+// callback to handle interactive challenges and keypair rotation requests.
+func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod(
+ ctx context.Context,
+ initReq *proto.RegisterUsingBoundKeypairInitialRequest,
+ challengeFunc RegisterUsingBoundKeypairChallengeResponseFunc,
+) (*proto.Certs, string, error) {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ stream, err := c.grpcClient.RegisterUsingBoundKeypairMethod(ctx)
+ if err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+ defer stream.CloseSend()
+
+ err = stream.Send(&proto.RegisterUsingBoundKeypairMethodRequest{
+ Payload: &proto.RegisterUsingBoundKeypairMethodRequest_Init{
+ Init: initReq,
+ },
+ })
+ if err != nil {
+ return nil, "", trace.Wrap(err, "sending initial request")
+ }
+
+ // Unlike other methods, the server may send multiple challenges,
+ // particularly during keypair rotation. We'll iterate through all responses
+ // here instead to ensure we handle everything.
+ for {
+ res, err := stream.Recv()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return nil, "", trace.Wrap(err, "receiving intermediate bound keypair join response")
+ }
+
+ switch kind := res.GetResponse().(type) {
+ case *proto.RegisterUsingBoundKeypairMethodResponse_Certs:
+ // If we get certs, we're done, so just return the result.
+ certs := kind.Certs.GetCerts()
+ if certs == nil {
+ return nil, "", trace.BadParameter("expected Certs, got %T", kind.Certs.Certs)
+ }
+
+ // If we receive a cert bundle, we can return early. Even if we
+ // logically should have expected to receive a 2nd challenge if we
+ // e.g. requested keypair rotation, skipping it just means the new
+ // keypair won't be stored. That said, we'll rely on the server to
+ // raise an error if rotation fails or is otherwise skipped or not
+ // allowed.
+
+ return certs, kind.Certs.GetPublicKey(), nil
+ default:
+ // Forward all other responses to the challenge handler.
+ nextRequest, err := challengeFunc(res)
+ if err != nil {
+ return nil, "", trace.Wrap(err, "solving challenge")
+ }
+
+ if err := stream.Send(nextRequest); err != nil {
+ return nil, "", trace.Wrap(err, "sending solution")
+ }
+ }
+ }
+
+ // Ideally the server will emit a proper error instead of just hanging up on
+ // us.
+ return nil, "", trace.AccessDenied("server declined to send certs during bound-keypair join attempt")
+}
+
// RegisterUsingToken registers the caller using a token and returns signed
// certs.
// This is used where a more specific RPC has not been introduced for the join
diff --git a/api/types/provisioning.go b/api/types/provisioning.go
index 0ebe818037906..82b2dd9452442 100644
--- a/api/types/provisioning.go
+++ b/api/types/provisioning.go
@@ -83,6 +83,9 @@ const (
// JoinMethodAzureDevops indicates that the node will join using the Azure
// Devops join method.
JoinMethodAzureDevops JoinMethod = "azure_devops"
+ // JoinMethodBoundKeypair indicates the node will join using the Bound
+ // Keypair join method. See lib/boundkeypair for more.
+ JoinMethodBoundKeypair JoinMethod = "bound_keypair"
)
var JoinMethods = []JoinMethod{
@@ -101,6 +104,7 @@ var JoinMethods = []JoinMethod{
JoinMethodTPM,
JoinMethodTerraformCloud,
JoinMethodOracle,
+ JoinMethodBoundKeypair,
}
func ValidateJoinMethod(method JoinMethod) error {
@@ -193,6 +197,26 @@ func NewProvisionTokenFromSpec(token string, expires time.Time, spec ProvisionTo
return t, nil
}
+// NewProvisionTokenFromSpecAndStatus returns a new provision token with the given spec.
+func NewProvisionTokenFromSpecAndStatus(
+ token string, expires time.Time,
+ spec ProvisionTokenSpecV2,
+ status *ProvisionTokenStatusV2,
+) (ProvisionToken, error) {
+ t := &ProvisionTokenV2{
+ Metadata: Metadata{
+ Name: token,
+ Expires: &expires,
+ },
+ Spec: spec,
+ Status: status,
+ }
+ if err := t.CheckAndSetDefaults(); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return t, nil
+}
+
// MustCreateProvisionToken returns a new valid provision token
// or panics, used in tests
func MustCreateProvisionToken(token string, roles SystemRoles, expires time.Time) ProvisionToken {
@@ -416,6 +440,18 @@ func (p *ProvisionTokenV2) CheckAndSetDefaults() error {
if err := providerCfg.checkAndSetDefaults(); err != nil {
return trace.Wrap(err, "spec.azure_devops: failed validation")
}
+ case JoinMethodBoundKeypair:
+ providerCfg := p.Spec.BoundKeypair
+ if providerCfg == nil {
+ return trace.BadParameter(
+ "spec.bound_keypair: must be configured for the join method %q",
+ JoinMethodBoundKeypair,
+ )
+ }
+
+ if err := providerCfg.checkAndSetDefaults(); err != nil {
+ return trace.Wrap(err, "spec.bound_keypair: failed validation")
+ }
default:
return trace.BadParameter("unknown join method %q", p.Spec.JoinMethod)
}
@@ -994,3 +1030,16 @@ func (a *ProvisionTokenSpecV2AzureDevops) checkAndSetDefaults() error {
}
return nil
}
+
+func (a *ProvisionTokenSpecV2BoundKeypair) checkAndSetDefaults() error {
+ if a.Onboarding == nil {
+ return trace.BadParameter("spec.bound_keypair.onboarding is required")
+ }
+
+ if a.Onboarding.RegistrationSecret == "" && a.Onboarding.InitialPublicKey == "" {
+ return trace.BadParameter("at least one of [initial_join_secret, " +
+ "initial_public_key] is required in spec.bound_keypair.onboarding")
+ }
+
+ return nil
+}
diff --git a/api/types/provisioning_test.go b/api/types/provisioning_test.go
index dae665bb67e1e..cc62a40cab80f 100644
--- a/api/types/provisioning_test.go
+++ b/api/types/provisioning_test.go
@@ -1369,6 +1369,54 @@ func TestProvisionTokenV2_CheckAndSetDefaults(t *testing.T) {
},
wantErr: true,
},
+ {
+ desc: "minimal bound keypair with pregenerated key",
+ token: &ProvisionTokenV2{
+ Metadata: Metadata{
+ Name: "test",
+ },
+ Spec: ProvisionTokenSpecV2{
+ Roles: []SystemRole{RoleNode},
+ JoinMethod: JoinMethodBoundKeypair,
+ BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{
+ Onboarding: &ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{
+ InitialPublicKey: "asdf",
+ },
+ },
+ },
+ },
+ expected: &ProvisionTokenV2{
+ Kind: "token",
+ Version: "v2",
+ Metadata: Metadata{
+ Name: "test",
+ Namespace: "default",
+ },
+ Spec: ProvisionTokenSpecV2{
+ Roles: []SystemRole{RoleNode},
+ JoinMethod: JoinMethodBoundKeypair,
+ BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{
+ Onboarding: &ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{
+ InitialPublicKey: "asdf",
+ },
+ },
+ },
+ },
+ },
+ {
+ desc: "bound keypair missing onboarding config",
+ token: &ProvisionTokenV2{
+ Metadata: Metadata{
+ Name: "test",
+ },
+ Spec: ProvisionTokenSpecV2{
+ Roles: []SystemRole{RoleNode},
+ JoinMethod: JoinMethodBoundKeypair,
+ BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{},
+ },
+ },
+ wantErr: true,
+ },
}
for _, tc := range testcases {
diff --git a/lib/auth/auth.go b/lib/auth/auth.go
index cfcb7507ce800..66953ba1804c9 100644
--- a/lib/auth/auth.go
+++ b/lib/auth/auth.go
@@ -96,6 +96,7 @@ import (
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/bitbucket"
+ "github.com/gravitational/teleport/lib/boundkeypair"
"github.com/gravitational/teleport/lib/cache"
"github.com/gravitational/teleport/lib/circleci"
"github.com/gravitational/teleport/lib/cryptosuites"
@@ -706,6 +707,12 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
as.bitbucketIDTokenValidator = bitbucket.NewIDTokenValidator(as.clock)
}
+ if as.createBoundKeypairValidator == nil {
+ as.createBoundKeypairValidator = func(subject, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) {
+ return boundkeypair.NewChallengeValidator(subject, clusterName, publicKey)
+ }
+ }
+
// Add in a login hook for generating state during user login.
as.ulsGenerator, err = userloginstate.NewGenerator(userloginstate.GeneratorConfig{
Log: as.logger,
@@ -1145,6 +1152,10 @@ type Server struct {
bitbucketIDTokenValidator bitbucketIDTokenValidator
+ // createBoundKeypairValidator is a helper to create new bound keypair
+ // challenge validators. Used to override the implementation used in tests.
+ createBoundKeypairValidator createBoundKeypairValidator
+
// loadAllCAs tells tsh to load the host CAs for all clusters when trying to ssh into a node.
loadAllCAs bool
diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go
index 72acbfbf1db80..00a417708c614 100644
--- a/lib/auth/auth_with_roles.go
+++ b/lib/auth/auth_with_roles.go
@@ -2440,6 +2440,12 @@ func (a *ServerWithRoles) UpsertToken(ctx context.Context, token types.Provision
return trace.Wrap(err)
}
+ // bound_keypair tokens have special creation/update logic and are handled
+ // separately
+ if token.GetJoinMethod() == types.JoinMethodBoundKeypair {
+ return trace.Wrap(a.authServer.UpsertBoundKeypairToken(ctx, token))
+ }
+
if err := a.authServer.UpsertToken(ctx, token); err != nil {
return trace.Wrap(err)
}
@@ -2465,6 +2471,12 @@ func (a *ServerWithRoles) CreateToken(ctx context.Context, token types.Provision
return trace.Wrap(err)
}
+ // bound_keypair tokens have special creation/update logic and are handled
+ // separately
+ if token.GetJoinMethod() == types.JoinMethodBoundKeypair {
+ return trace.Wrap(a.authServer.CreateBoundKeypairToken(ctx, token))
+ }
+
if err := a.authServer.CreateToken(ctx, token); err != nil {
return trace.Wrap(err)
}
diff --git a/lib/auth/authclient/clt.go b/lib/auth/authclient/clt.go
index 7cebd85ce9895..5e5da68ee9ded 100644
--- a/lib/auth/authclient/clt.go
+++ b/lib/auth/authclient/clt.go
@@ -305,6 +305,15 @@ func (c *Client) DeleteAllTokens() error {
return trace.NotImplemented(notImplementedMessage)
}
+// PatchToken not implemented: can only be called locally
+func (c *Client) PatchToken(
+ ctx context.Context,
+ token string,
+ updateFn func(types.ProvisionToken) (types.ProvisionToken, error),
+) (types.ProvisionToken, error) {
+ return nil, trace.NotImplemented(notImplementedMessage)
+}
+
// AddUserLoginAttempt logs user login attempt
func (c *Client) AddUserLoginAttempt(user string, attempt services.LoginAttempt, ttl time.Duration) error {
panic("not implemented")
@@ -1235,6 +1244,14 @@ type ProvisioningService interface {
// CreateToken creates a new provision token for the auth server
CreateToken(ctx context.Context, token types.ProvisionToken) error
+ // PatchToken performs a conditional update on the named token using
+ // `updateFn`, retrying internally if a comparison failure occurs.
+ PatchToken(
+ ctx context.Context,
+ token string,
+ updateFn func(types.ProvisionToken) (types.ProvisionToken, error),
+ ) (types.ProvisionToken, error)
+
// RegisterUsingToken calls the auth service API to register a new node via registration token
// which has been previously issued via GenerateToken
RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error)
diff --git a/lib/auth/join.go b/lib/auth/join.go
index be71948c6d1ce..853da5688acc5 100644
--- a/lib/auth/join.go
+++ b/lib/auth/join.go
@@ -226,7 +226,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin
if err := a.checkEC2JoinRequest(ctx, req); err != nil {
return nil, trace.Wrap(err)
}
- case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM, types.JoinMethodOracle:
+ case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM,
+ types.JoinMethodOracle, types.JoinMethodBoundKeypair:
// Some join methods require use of a specific RPC - reject those here.
// This would generally be a developer error - but can be triggered if
// the user has configured the wrong join method on the client-side.
@@ -317,7 +318,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin
// With all elements of the token validated, we can now generate & return
// certificates.
if req.Role == types.RoleBot {
- certs, err = a.generateCertsBot(
+ certs, _, err = a.generateCertsBot(
ctx,
provisionToken,
req,
@@ -336,7 +337,7 @@ func (a *Server) generateCertsBot(
req *types.RegisterUsingTokenRequest,
rawJoinClaims any,
attrs *workloadidentityv1pb.JoinAttrs,
-) (*proto.Certs, error) {
+) (*proto.Certs, string, error) {
// bots use this endpoint but get a user cert
// botResourceName must be set, enforced in CheckAndSetDefaults
botName := provisionToken.GetBotName()
@@ -344,7 +345,7 @@ func (a *Server) generateCertsBot(
// Check this is a join method for bots we support.
if !slices.Contains(machineidv1.SupportedJoinMethods, joinMethod) {
- return nil, trace.BadParameter(
+ return nil, "", trace.BadParameter(
"unsupported join method %q for bot", joinMethod,
)
}
@@ -439,7 +440,7 @@ func (a *Server) generateCertsBot(
attrs,
)
if err != nil {
- return nil, trace.Wrap(err)
+ return nil, "", trace.Wrap(err)
}
joinEvent.BotInstanceID = botInstanceID
@@ -461,7 +462,7 @@ func (a *Server) generateCertsBot(
if err := a.emitter.EmitAuditEvent(ctx, joinEvent); err != nil {
a.logger.WarnContext(ctx, "Failed to emit bot join event", "error", err)
}
- return certs, nil
+ return certs, botInstanceID, nil
}
func (a *Server) generateCerts(
diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go
index 9461a412ed7c5..02db610c710c5 100644
--- a/lib/auth/join_azure.go
+++ b/lib/auth/join_azure.go
@@ -489,7 +489,7 @@ func (a *Server) RegisterUsingAzureMethodWithOpts(
}
if req.RegisterUsingTokenRequest.Role == types.RoleBot {
- certs, err := a.generateCertsBot(
+ certs, _, err := a.generateCertsBot(
ctx,
provisionToken,
req.RegisterUsingTokenRequest,
diff --git a/lib/auth/join_bound_keypair.go b/lib/auth/join_bound_keypair.go
new file mode 100644
index 0000000000000..462053aadbbdc
--- /dev/null
+++ b/lib/auth/join_bound_keypair.go
@@ -0,0 +1,480 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package auth
+
+import (
+ "context"
+ "crypto"
+ "encoding/json"
+ "time"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/boundkeypair"
+ "github.com/gravitational/teleport/lib/boundkeypair/boundkeypairexperiment"
+ "github.com/gravitational/teleport/lib/jwt"
+ libsshutils "github.com/gravitational/teleport/lib/sshutils"
+)
+
+type boundKeypairValidator interface {
+ IssueChallenge() (*boundkeypair.ChallengeDocument, error)
+ ValidateChallengeResponse(issued *boundkeypair.ChallengeDocument, compactResponse string) error
+}
+
+type createBoundKeypairValidator func(subject string, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error)
+
+// validateBoundKeypairTokenSpec performs some basic validation checks on a
+// bound_keypair-type join token.
+func validateBoundKeypairTokenSpec(spec *types.ProvisionTokenSpecV2BoundKeypair) error {
+ // Various constant checks, shared between creation and update. Many of
+ // these checks are temporary and will be removed alongside the experiment
+ // flag.
+ if !boundkeypairexperiment.Enabled() {
+ return trace.BadParameter("bound keypair joining experiment is not enabled")
+ }
+
+ if spec.RotateAfter != nil {
+ return trace.NotImplemented("spec.bound_keypair.rotate_after is not yet implemented")
+ }
+
+ if spec.Onboarding.RegistrationSecret != "" {
+ return trace.NotImplemented("spec.bound_keypair.onboarding.registration_secret is not yet implemented")
+ }
+
+ if spec.Onboarding.InitialPublicKey == "" {
+ return trace.NotImplemented("spec.bound_keypair.onboarding.initial_public_key is currently required")
+ }
+
+ if spec.Recovery == nil {
+ return trace.BadParameter("spec.recovery: field is required")
+ }
+
+ if spec.Recovery.Mode != boundkeypair.RecoveryModeInsecure {
+ return trace.NotImplemented("spec.bound_keypair.recovery.mode currently must be %s", boundkeypair.RecoveryModeInsecure)
+ }
+
+ return nil
+}
+
+func (a *Server) CreateBoundKeypairToken(ctx context.Context, token types.ProvisionToken) error {
+ if token.GetJoinMethod() != types.JoinMethodBoundKeypair {
+ return trace.BadParameter("must be called with a bound keypair token")
+ }
+
+ tokenV2, ok := token.(*types.ProvisionTokenV2)
+ if !ok {
+ return trace.BadParameter("%v join method requires ProvisionTokenV2", types.JoinMethodOracle)
+ }
+
+ spec := tokenV2.Spec.BoundKeypair
+ if spec == nil {
+ return trace.BadParameter("bound_keypair token requires non-nil spec.bound_keypair")
+ }
+
+ if err := validateBoundKeypairTokenSpec(spec); err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Not as much to do here - ideally we'd like to prevent users from
+ // tampering with the status field, but we don't have a good mechanism to
+ // stop that that wouldn't also break backup and restore. For now, it's
+ // simpler and easier to just tell users not to edit those fields.
+
+ // TODO (follow up PR): Populate initial_join_secret if needed.
+
+ return trace.Wrap(a.CreateToken(ctx, tokenV2))
+}
+
+func (a *Server) UpsertBoundKeypairToken(ctx context.Context, token types.ProvisionToken) error {
+ if token.GetJoinMethod() != types.JoinMethodBoundKeypair {
+ return trace.BadParameter("must be called with a bound keypair token")
+ }
+
+ tokenV2, ok := token.(*types.ProvisionTokenV2)
+ if !ok {
+ return trace.BadParameter("%v join method requires ProvisionTokenV2", types.JoinMethodOracle)
+ }
+
+ spec := tokenV2.Spec.BoundKeypair
+ if spec == nil {
+ return trace.BadParameter("bound_keypair token requires non-nil spec.bound_keypair")
+ }
+
+ if err := validateBoundKeypairTokenSpec(spec); err != nil {
+ return trace.Wrap(err)
+ }
+
+ // TODO: Follow up with proper checking for a preexisting resource so
+ // generated fields are handled properly, i.e. initial secret generation.
+
+ return trace.Wrap(a.UpsertToken(ctx, token))
+}
+
+// issueBoundKeypairChallenge creates a new challenge for the given marshaled
+// public key in ssh authorized_keys format, requests a solution from the
+// client using the given `challengeResponse` function, and validates the
+// response.
+func (a *Server) issueBoundKeypairChallenge(
+ ctx context.Context,
+ marshalledKey string,
+ challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc,
+) error {
+ key, err := libsshutils.CryptoPublicKey([]byte(marshalledKey))
+ if err != nil {
+ return trace.Wrap(err, "parsing bound public key")
+ }
+
+ // The particular subject value doesn't strictly need to be the name of the
+ // bot or node (which may not be known, yet). Instead, we'll use the key ID,
+ // which could at least be useful for the client to know which key the
+ // challenge should be signed with.
+ keyID, err := jwt.KeyID(key)
+ if err != nil {
+ return trace.Wrap(err, "determining the key ID")
+ }
+
+ clusterName, err := a.GetClusterName(ctx)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ a.logger.DebugContext(ctx, "issuing bound keypair challenge", "key_id", keyID)
+
+ validator, err := a.createBoundKeypairValidator(keyID, clusterName.GetClusterName(), key)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ challenge, err := validator.IssueChallenge()
+ if err != nil {
+ return trace.Wrap(err, "generating a challenge document")
+ }
+
+ marshalledChallenge, err := json.Marshal(challenge)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ response, err := challengeResponse(&proto.RegisterUsingBoundKeypairMethodResponse{
+ Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{
+ Challenge: &proto.RegisterUsingBoundKeypairChallenge{
+ PublicKey: marshalledKey,
+ Challenge: string(marshalledChallenge),
+ },
+ },
+ })
+ if err != nil {
+ return trace.Wrap(err, "requesting a signed challenge")
+ }
+
+ solutionResponse, ok := response.Payload.(*proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse)
+ if !ok {
+ return trace.BadParameter("client provided unexpected challenge response type %T", response.Payload)
+ }
+
+ if err := validator.ValidateChallengeResponse(
+ challenge,
+ string(solutionResponse.ChallengeResponse.Solution),
+ ); err != nil {
+ // TODO: Consider access denied instead?
+ return trace.Wrap(err, "validating challenge response")
+ }
+
+ a.logger.InfoContext(ctx, "bound keypair challenge response verified successfully", "key_id", keyID)
+
+ return nil
+}
+
+// boundKeypairStatusMutator is a function called to mutate a bound keypair
+// status during a call to PatchProvisionToken(). These functions may be called
+// repeatedly if e.g. revision checks fail. To ensure invariants remain in
+// place, mutator functions may make assertions to ensure the provided backend
+// state is still sane before the update is committed.
+type boundKeypairStatusMutator func(*types.ProvisionTokenSpecV2BoundKeypair, *types.ProvisionTokenStatusV2BoundKeypair) error
+
+// mutateStatusConsumeJoin consumes a "hard" join on the backend, incrementing
+// the join counter. This verifies that the backend join count has not changed,
+// and that total join count is at least the value when the mutator was created.
+func mutateStatusConsumeJoin(mode boundkeypair.RecoveryMode, expectRecoveryCount uint32, expectMinRecoveryLimit uint32) boundKeypairStatusMutator {
+ now := time.Now()
+
+ return func(spec *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error {
+ // Ensure we have the expected number of rejoins left to prevent going
+ // below zero.
+ if status.RecoveryCount != expectRecoveryCount {
+ return trace.AccessDenied("unexpected backend state")
+ }
+
+ // Ensure the allowed join count has at least not decreased, but allow
+ // for collision with potentially increased values.
+ if spec.Recovery.Limit < expectMinRecoveryLimit {
+ return trace.AccessDenied("unexpected backend state")
+ }
+
+ if mode == boundkeypair.RecoveryModeStandard {
+ // TODO: to be removed in a future PR
+ return trace.NotImplemented("only unlimited rejoining is currently supported")
+ }
+
+ status.RecoveryCount += 1
+ status.LastRecoveredAt = &now
+
+ return nil
+ }
+}
+
+// mutateStatusBoundPublicKey is a mutator that updates the bound public key
+// value. It ensures the backend public key is still the expected value before
+// performing the update.
+func mutateStatusBoundPublicKey(newPublicKey, expectPreviousKey string) boundKeypairStatusMutator {
+ return func(_ *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error {
+ if status.BoundPublicKey != expectPreviousKey {
+ return trace.AccessDenied("unexpected backend state")
+ }
+
+ status.BoundPublicKey = newPublicKey
+
+ return nil
+ }
+}
+
+// mutateStatusBoundBotInstance updates the bot instance ID currently bound to
+// this token. It ensures the expected previous ID is still the bound value
+// before performing the update.
+func mutateStatusBoundBotInstance(newBotInstance, expectPreviousBotInstance string) boundKeypairStatusMutator {
+ return func(_ *types.ProvisionTokenSpecV2BoundKeypair, status *types.ProvisionTokenStatusV2BoundKeypair) error {
+ if status.BoundBotInstanceID != expectPreviousBotInstance {
+ return trace.AccessDenied("unexpected backend state")
+ }
+
+ status.BoundBotInstanceID = newBotInstance
+
+ return nil
+ }
+}
+
+// RegisterUsingBoundKeypairMethod handles joining requests for the bound
+// keypair join method.
+func (a *Server) RegisterUsingBoundKeypairMethod(
+ ctx context.Context,
+ req *proto.RegisterUsingBoundKeypairInitialRequest,
+ challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc,
+) (_ *proto.Certs, _ string, err error) {
+ var provisionToken types.ProvisionToken
+ var joinFailureMetadata any
+ defer func() {
+ // Emit a log message and audit event on join failure.
+ if err != nil {
+ a.handleJoinFailure(
+ ctx, err, provisionToken, joinFailureMetadata, req.JoinRequest,
+ )
+ }
+ }()
+
+ // First, check the specified token exists, and is a bound keypair-type join
+ // token.
+ if err := req.JoinRequest.CheckAndSetDefaults(); err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+
+ // Only bot joining is supported at the moment - unique ID verification is
+ // required and this is currently only implemented for bots.
+ if req.JoinRequest.Role != types.RoleBot {
+ return nil, "", trace.BadParameter("bound keypair joining is only supported for bots")
+ }
+
+ provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.JoinRequest)
+ if err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+ ptv2, ok := provisionToken.(*types.ProvisionTokenV2)
+ if !ok {
+ return nil, "", trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken)
+ }
+ if ptv2.Spec.JoinMethod != types.JoinMethodBoundKeypair {
+ return nil, "", trace.BadParameter("specified join token is not for `%s` method", types.JoinMethodBoundKeypair)
+ }
+
+ if ptv2.Status == nil {
+ ptv2.Status = &types.ProvisionTokenStatusV2{}
+ }
+ if ptv2.Status.BoundKeypair == nil {
+ ptv2.Status.BoundKeypair = &types.ProvisionTokenStatusV2BoundKeypair{}
+ }
+
+ spec := ptv2.Spec.BoundKeypair
+ status := ptv2.Status.BoundKeypair
+ hasBoundPublicKey := status.BoundPublicKey != ""
+ hasBoundBotInstance := status.BoundBotInstanceID != ""
+ hasIncomingBotInstance := req.JoinRequest.BotInstanceID != ""
+ hasJoinsRemaining := status.RecoveryCount < spec.Recovery.Limit
+
+ // if set, the bound bot instance will be updated in the backend
+ expectNewBotInstance := false
+
+ // the bound public key; may change during initial join or rotation. used to
+ // inform the returned public key value.
+ boundPublicKey := status.BoundPublicKey
+
+ // Mutators to use during the token resource status patch at the end.
+ var mutators []boundKeypairStatusMutator
+
+ switch {
+ case !hasBoundPublicKey && !hasIncomingBotInstance:
+ // Normal initial join attempt. No bound key, and no incoming bot
+ // instance. Consumes a rejoin.
+ if spec.Onboarding.RegistrationSecret != "" {
+ return nil, "", trace.NotImplemented("initial joining secrets are not yet supported")
+ }
+
+ if spec.Onboarding.InitialPublicKey == "" {
+ return nil, "", trace.BadParameter("an initial public key is required")
+ }
+
+ if spec.Recovery.Mode == string(boundkeypair.RecoveryModeStandard) && !hasJoinsRemaining {
+ return nil, "", trace.AccessDenied("no joins remaining")
+ }
+
+ if err := a.issueBoundKeypairChallenge(
+ ctx,
+ spec.Onboarding.InitialPublicKey,
+ challengeResponse,
+ ); err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+
+ // Now that we've confirmed the key, we can consider it bound.
+ mutators = append(
+ mutators,
+ mutateStatusBoundPublicKey(spec.Onboarding.InitialPublicKey, ""),
+ mutateStatusConsumeJoin(boundkeypair.RecoveryMode(spec.Recovery.Mode), status.RecoveryCount, spec.Recovery.Limit),
+ )
+
+ expectNewBotInstance = true
+ boundPublicKey = spec.Onboarding.InitialPublicKey
+ case !hasBoundPublicKey && hasIncomingBotInstance:
+ // Not allowed, at least at the moment. This would imply e.g. trying to
+ // change auth methods.
+ return nil, "", trace.BadParameter("cannot perform first bound keypair join with existing credentials")
+ case hasBoundPublicKey && !hasBoundBotInstance:
+ // TODO: Bad backend state, or maybe an incomplete previous join
+ // attempt. This shouldn't be a possible state, but we should handle it
+ // sanely anyway.
+ return nil, "", trace.BadParameter("bad backend state, please recreate the join token")
+ case hasBoundPublicKey && hasBoundBotInstance && hasIncomingBotInstance:
+ // Standard rejoin case, does not consume a rejoin.
+ if status.BoundBotInstanceID != req.JoinRequest.BotInstanceID {
+ return nil, "", trace.AccessDenied("bot instance mismatch")
+ }
+
+ if err := a.issueBoundKeypairChallenge(
+ ctx,
+ spec.Onboarding.InitialPublicKey,
+ challengeResponse,
+ ); err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+
+ // Nothing else to do, no key change
+ case hasBoundPublicKey && hasBoundBotInstance && !hasIncomingBotInstance:
+ // Hard rejoin case, the client identity expired and a new bot instance
+ // is required. Consumes a rejoin.
+ if spec.Recovery.Mode == string(boundkeypair.RecoveryModeStandard) && !hasJoinsRemaining {
+ // Recovery limit only applies in "standard" mode.
+ return nil, "", trace.AccessDenied("no rejoins remaining")
+ }
+
+ if err := a.issueBoundKeypairChallenge(
+ ctx,
+ status.BoundPublicKey,
+ challengeResponse,
+ ); err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+
+ mutators = append(
+ mutators,
+ mutateStatusConsumeJoin(boundkeypair.RecoveryMode(spec.Recovery.Mode), status.RecoveryCount, spec.Recovery.Limit),
+ )
+
+ expectNewBotInstance = true
+ default:
+ a.logger.ErrorContext(
+ ctx, "unexpected state",
+ "has_bound_public_key", hasBoundPublicKey,
+ "has_bound_bot_instance", hasBoundBotInstance,
+ "has_incoming_bot_instance", hasIncomingBotInstance,
+ "spec", spec,
+ "status", status,
+ )
+ return nil, "", trace.BadParameter("unexpected state")
+ }
+
+ if spec.RotateAfter != nil {
+ // TODO, to be implemented in a future PR. `boundPublicKey` will need to
+ // be updated.
+ return nil, "", trace.NotImplemented("key rotation not yet supported")
+ }
+
+ // TODO: We should pass along the previous bot instance ID - if any - based
+ // on the join state, once that is implemented. It will need to be passed
+ // either via extended claims, or by a new protected field in the join
+ // request like the current bot instance ID, i.e. cleared when set by an
+ // untrusted source.
+ certs, botInstanceID, err := a.generateCertsBot(
+ ctx,
+ ptv2,
+ req.JoinRequest,
+ nil, // TODO: extended claims for this type?
+ nil, // TODO: workload id claims
+ )
+
+ if expectNewBotInstance {
+ mutators = append(
+ mutators,
+ mutateStatusBoundBotInstance(botInstanceID, status.BoundBotInstanceID),
+ )
+ }
+
+ if len(mutators) > 0 {
+ if _, err := a.PatchToken(ctx, ptv2.GetName(), func(token types.ProvisionToken) (types.ProvisionToken, error) {
+ ptv2, ok := provisionToken.(*types.ProvisionTokenV2)
+ if !ok {
+ return nil, trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken)
+ }
+
+ // Apply all mutators. Individual mutators may make additional
+ // assertions to ensure invariants haven't changed.
+ for _, mutator := range mutators {
+ if err := mutator(ptv2.Spec.BoundKeypair, ptv2.Status.BoundKeypair); err != nil {
+ return nil, trace.Wrap(err, "applying status mutator")
+ }
+ }
+
+ return ptv2, nil
+ }); err != nil {
+ return nil, "", trace.Wrap(err, "committing updated token state, please try again")
+ }
+ }
+
+ return certs, boundPublicKey, trace.Wrap(err)
+}
diff --git a/lib/auth/join_bound_keypair_test.go b/lib/auth/join_bound_keypair_test.go
new file mode 100644
index 0000000000000..837648f43c8c7
--- /dev/null
+++ b/lib/auth/join_bound_keypair_test.go
@@ -0,0 +1,368 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package auth
+
+import (
+ "context"
+ "crypto"
+ "testing"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/api/client/proto"
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/auth/testauthority"
+ "github.com/gravitational/teleport/lib/boundkeypair"
+ "github.com/gravitational/teleport/lib/cryptosuites"
+ "github.com/gravitational/teleport/lib/sshutils"
+)
+
+type mockBoundKeypairValidator struct {
+ subject string
+ clusterName string
+ publicKey crypto.PublicKey
+}
+
+func (v *mockBoundKeypairValidator) IssueChallenge() (*boundkeypair.ChallengeDocument, error) {
+ return &boundkeypair.ChallengeDocument{
+ Nonce: "fake",
+ }, nil
+}
+
+func (v *mockBoundKeypairValidator) ValidateChallengeResponse(issued *boundkeypair.ChallengeDocument, compactResponse string) error {
+ // For testing, the solver will just reply with the marshaled public key, so
+ // we'll parse and compare it.
+ key, err := sshutils.CryptoPublicKey([]byte(compactResponse))
+ if err != nil {
+ return trace.Wrap(err, "parsing bound public key")
+ }
+
+ equal, ok := v.publicKey.(interface {
+ Equal(x crypto.PublicKey) bool
+ })
+ if !ok {
+ return trace.BadParameter("unsupported public key type %T", key)
+ }
+
+ if !equal.Equal(key) {
+ return trace.AccessDenied("incorrect public key")
+ }
+
+ return nil
+}
+
+func testBoundKeypair(t *testing.T) (crypto.Signer, string) {
+ key, err := cryptosuites.GeneratePrivateKeyWithAlgorithm(cryptosuites.ECDSAP256)
+ require.NoError(t, err)
+
+ return key.Signer, string(key.MarshalSSHPublicKey())
+}
+
+func TestServer_RegisterUsingBoundKeypairMethod(t *testing.T) {
+ ctx := context.Background()
+
+ _, correctPublicKey := testBoundKeypair(t)
+ _, incorrectPublicKey := testBoundKeypair(t)
+
+ srv := newTestTLSServer(t)
+ auth := srv.Auth()
+ auth.createBoundKeypairValidator = func(subject, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) {
+ return &mockBoundKeypairValidator{
+ //correctPublicKey: correctSigner.Public(),
+
+ subject: subject,
+ clusterName: clusterName,
+ publicKey: publicKey,
+ }, nil
+ }
+
+ _, err := CreateRole(ctx, auth, "example", types.RoleSpecV6{})
+ require.NoError(t, err)
+
+ adminClient, err := srv.NewClient(TestAdmin())
+ require.NoError(t, err)
+
+ _, err = adminClient.BotServiceClient().CreateBot(ctx, &machineidv1pb.CreateBotRequest{
+ Bot: &machineidv1pb.Bot{
+ Kind: types.KindBot,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &machineidv1pb.BotSpec{
+ Roles: []string{"example"},
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair()
+ require.NoError(t, err)
+ tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey)
+ require.NoError(t, err)
+
+ makeToken := func(mutators ...func(v2 *types.ProvisionTokenV2)) types.ProvisionTokenV2 {
+ token := types.ProvisionTokenV2{
+ Spec: types.ProvisionTokenSpecV2{
+ JoinMethod: types.JoinMethodBoundKeypair,
+ Roles: []types.SystemRole{types.RoleBot},
+ BotName: "test",
+ BoundKeypair: &types.ProvisionTokenSpecV2BoundKeypair{
+ Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{
+ InitialPublicKey: correctPublicKey,
+ },
+ Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{
+ // Only insecure is supported for now.
+ Mode: boundkeypair.RecoveryModeInsecure,
+ },
+ },
+ },
+ Status: &types.ProvisionTokenStatusV2{
+ BoundKeypair: &types.ProvisionTokenStatusV2BoundKeypair{},
+ },
+ }
+ for _, mutator := range mutators {
+ mutator(&token)
+ }
+ return token
+ }
+
+ makeInitReq := func(mutators ...func(r *proto.RegisterUsingBoundKeypairInitialRequest)) *proto.RegisterUsingBoundKeypairInitialRequest {
+ req := &proto.RegisterUsingBoundKeypairInitialRequest{
+ JoinRequest: &types.RegisterUsingTokenRequest{
+ HostID: "host-id",
+ Role: types.RoleBot,
+ PublicTLSKey: tlsPublicKey,
+ PublicSSHKey: sshPublicKey,
+ },
+ }
+ for _, mutator := range mutators {
+ mutator(req)
+ }
+ return req
+ }
+
+ makeSolver := func(publicKey string) client.RegisterUsingBoundKeypairChallengeResponseFunc {
+ return func(challenge *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) {
+ switch r := challenge.Response.(type) {
+ case *proto.RegisterUsingBoundKeypairMethodResponse_Challenge:
+ if r.Challenge.PublicKey != publicKey {
+ return nil, trace.BadParameter("wrong public key")
+ }
+
+ return &proto.RegisterUsingBoundKeypairMethodRequest{
+ Payload: &proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse{
+ ChallengeResponse: &proto.RegisterUsingBoundKeypairChallengeResponse{
+ // For testing purposes, we'll just reply with the
+ // public key, to avoid needing to parse the JWT.
+ Solution: []byte(publicKey),
+ },
+ },
+ }, nil
+ default:
+ return nil, trace.BadParameter("invalid response type")
+ }
+ }
+ }
+
+ tests := []struct {
+ name string
+
+ token types.ProvisionTokenV2
+ initReq *proto.RegisterUsingBoundKeypairInitialRequest
+ solver client.RegisterUsingBoundKeypairChallengeResponseFunc
+
+ assertError require.ErrorAssertionFunc
+ assertSuccess func(t *testing.T, v2 *types.ProvisionTokenV2)
+ }{
+ {
+ // no bound key, no bound bot instance, aka initial join without
+ // secret
+ name: "initial-join-success",
+
+ token: makeToken(),
+ initReq: makeInitReq(),
+ solver: makeSolver(correctPublicKey),
+
+ assertError: require.NoError,
+ assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) {
+ // join count should be incremented
+ require.Equal(t, uint32(1), v2.Status.BoundKeypair.RecoveryCount)
+ require.NotEmpty(t, v2.Status.BoundKeypair.BoundBotInstanceID)
+ require.NotEmpty(t, v2.Status.BoundKeypair.BoundPublicKey)
+ },
+ },
+ {
+ // no bound key, no bound bot instance, aka initial join without
+ // secret
+ name: "initial-join-with-wrong-key",
+
+ token: makeToken(),
+ initReq: makeInitReq(),
+ solver: makeSolver(incorrectPublicKey),
+
+ assertError: func(tt require.TestingT, err error, i ...interface{}) {
+ require.Error(tt, err)
+ require.ErrorContains(tt, err, "wrong public key")
+ },
+ },
+ {
+ // bound key, valid bound bot instance, aka "soft join"
+ name: "reauth-success",
+
+ token: makeToken(func(v2 *types.ProvisionTokenV2) {
+ v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey
+ v2.Status.BoundKeypair.BoundBotInstanceID = "asdf"
+ }),
+ initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) {
+ r.JoinRequest.BotInstanceID = "asdf"
+ }),
+ solver: makeSolver(correctPublicKey),
+
+ assertError: require.NoError,
+ assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) {
+ // join count should not be incremented
+ require.Equal(t, uint32(0), v2.Status.BoundKeypair.RecoveryCount)
+ },
+ },
+ {
+ // bound key, seemingly valid bot instance, but wrong key
+ // (should be impossible, but should fail anyway)
+ name: "reauth-with-wrong-key",
+
+ token: makeToken(func(v2 *types.ProvisionTokenV2) {
+ v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey
+ v2.Status.BoundKeypair.BoundBotInstanceID = "asdf"
+ }),
+ initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) {
+ r.JoinRequest.BotInstanceID = "asdf"
+ }),
+ solver: makeSolver(incorrectPublicKey),
+
+ assertError: func(tt require.TestingT, err error, i ...interface{}) {
+ require.Error(tt, err)
+ require.ErrorContains(tt, err, "wrong public key")
+ },
+ },
+ {
+ // bound key but no valid incoming bot instance, i.e. the certs
+ // expired and triggered a hard rejoin
+ name: "rejoin-success",
+
+ token: makeToken(func(v2 *types.ProvisionTokenV2) {
+ v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey
+ v2.Status.BoundKeypair.BoundBotInstanceID = "asdf"
+ }),
+ initReq: makeInitReq(),
+ solver: makeSolver(correctPublicKey),
+
+ assertError: require.NoError,
+ assertSuccess: func(t *testing.T, v2 *types.ProvisionTokenV2) {
+ require.Equal(t, uint32(1), v2.Status.BoundKeypair.RecoveryCount)
+
+ // Should generate a new bot instance
+ require.NotEmpty(t, v2.Status.BoundKeypair.BoundBotInstanceID)
+ require.NotEqual(t, "asdf", v2.Status.BoundKeypair.BoundBotInstanceID)
+ },
+ },
+ {
+ // Bad state: somehow a key was registered without a bot instance.
+ // This should fail and prompt the user to recreate the token.
+ name: "bound-key-no-instance",
+
+ token: makeToken(func(v2 *types.ProvisionTokenV2) {
+ v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey
+ }),
+ initReq: makeInitReq(),
+ solver: makeSolver(correctPublicKey),
+
+ assertError: func(tt require.TestingT, err error, i ...interface{}) {
+ require.Error(tt, err)
+ require.ErrorContains(tt, err, "bad backend state")
+ },
+ },
+ {
+ // The client somehow presents certs that refer to a different
+ // instance, maybe tried switching auth methods.
+ name: "bound-key-wrong-instance",
+
+ token: makeToken(func(v2 *types.ProvisionTokenV2) {
+ v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey
+ v2.Status.BoundKeypair.BoundBotInstanceID = "qwerty"
+ }),
+ initReq: makeInitReq(func(r *proto.RegisterUsingBoundKeypairInitialRequest) {
+ r.JoinRequest.BotInstanceID = "asdf"
+ }),
+ solver: makeSolver(correctPublicKey),
+
+ assertError: func(tt require.TestingT, err error, i ...interface{}) {
+ require.Error(tt, err)
+ require.ErrorContains(tt, err, "bot instance mismatch")
+ },
+ },
+ {
+ // TODO: rotation is not yet implemented.
+ name: "rotation-requested",
+
+ token: makeToken(func(v2 *types.ProvisionTokenV2) {
+ t := time.Now()
+ v2.Status.BoundKeypair.BoundPublicKey = correctPublicKey
+ v2.Status.BoundKeypair.BoundBotInstanceID = "asdf"
+ v2.Spec.BoundKeypair.RotateAfter = &t
+ // TODO: test clock?
+ }),
+ initReq: makeInitReq(),
+ solver: makeSolver(correctPublicKey),
+
+ assertError: func(tt require.TestingT, err error, i ...interface{}) {
+ require.Error(tt, err)
+ require.ErrorContains(tt, err, "key rotation not yet supported")
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ token, err := types.NewProvisionTokenFromSpecAndStatus(
+ tt.name, time.Now().Add(time.Minute), tt.token.Spec, tt.token.Status,
+ )
+ require.NoError(t, err)
+ require.NoError(t, auth.CreateToken(ctx, token))
+ tt.initReq.JoinRequest.Token = tt.name
+
+ _, _, err = auth.RegisterUsingBoundKeypairMethod(ctx, tt.initReq, tt.solver)
+ tt.assertError(t, err)
+
+ if tt.assertSuccess != nil {
+ pt, err := auth.GetToken(ctx, tt.name)
+ require.NoError(t, err)
+
+ ptv2, ok := pt.(*types.ProvisionTokenV2)
+ require.True(t, ok)
+
+ tt.assertSuccess(t, ptv2)
+ }
+ })
+ }
+}
diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go
index 0f0bddbb2899b..d20a3ddcb6b38 100644
--- a/lib/auth/join_iam.go
+++ b/lib/auth/join_iam.go
@@ -407,7 +407,7 @@ func (a *Server) RegisterUsingIAMMethodWithOpts(
}
if req.RegisterUsingTokenRequest.Role == types.RoleBot {
- certs, err := a.generateCertsBot(
+ certs, _, err := a.generateCertsBot(
ctx,
provisionToken,
req.RegisterUsingTokenRequest,
diff --git a/lib/auth/join_oracle.go b/lib/auth/join_oracle.go
index 457d238a437e7..95349ed8c7056 100644
--- a/lib/auth/join_oracle.go
+++ b/lib/auth/join_oracle.go
@@ -83,7 +83,7 @@ func (a *Server) registerUsingOracleMethod(
}
if tokenReq.Role == types.RoleBot {
- certs, err := a.generateCertsBot(
+ certs, _, err := a.generateCertsBot(
ctx,
provisionToken,
tokenReq,
diff --git a/lib/auth/join_tpm.go b/lib/auth/join_tpm.go
index df2e6b4e4cbcc..31334855ddd61 100644
--- a/lib/auth/join_tpm.go
+++ b/lib/auth/join_tpm.go
@@ -112,7 +112,7 @@ func (a *Server) RegisterUsingTPMMethod(
}
if initReq.JoinRequest.Role == types.RoleBot {
- certs, err := a.generateCertsBot(
+ certs, _, err := a.generateCertsBot(
ctx,
ptv2,
initReq.JoinRequest,
diff --git a/lib/auth/machineid/machineidv1/bot_service.go b/lib/auth/machineid/machineidv1/bot_service.go
index 139612f54b056..ca3aeeb0b451a 100644
--- a/lib/auth/machineid/machineidv1/bot_service.go
+++ b/lib/auth/machineid/machineidv1/bot_service.go
@@ -58,6 +58,7 @@ var SupportedJoinMethods = []types.JoinMethod{
types.JoinMethodTPM,
types.JoinMethodTerraformCloud,
types.JoinMethodBitbucket,
+ types.JoinMethodBoundKeypair,
}
// BotResourceName returns the default name for resources associated with the
diff --git a/lib/joinserver/joinserver.go b/lib/joinserver/joinserver.go
index b978bb8d61ef8..1d596995defad 100644
--- a/lib/joinserver/joinserver.go
+++ b/lib/joinserver/joinserver.go
@@ -55,6 +55,11 @@ type joinServiceClient interface {
tokenReq *types.RegisterUsingTokenRequest,
challengeResponse client.RegisterOracleChallengeResponseFunc,
) (*proto.Certs, error)
+ RegisterUsingBoundKeypairMethod(
+ ctx context.Context,
+ req *proto.RegisterUsingBoundKeypairInitialRequest,
+ challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc,
+ ) (*proto.Certs, string, error)
RegisterUsingToken(
ctx context.Context,
req *types.RegisterUsingTokenRequest,
@@ -369,6 +374,82 @@ func (s *JoinServiceGRPCServer) registerUsingOracleMethod(srv proto.JoinService_
}))
}
+// RegisterUsingBoundKeypairMethod registers the client using the bound-keypair
+// join method, and if successful, returns a signed cert bundle for
+// authenticated cluster access.
+func (s *JoinServiceGRPCServer) RegisterUsingBoundKeypairMethod(
+ srv proto.JoinService_RegisterUsingBoundKeypairMethodServer,
+) error {
+ return trace.Wrap(s.handleStreamingRegistration(srv.Context(), types.JoinMethodBoundKeypair, func() error {
+ return trace.Wrap(s.registerUsingBoundKeypair(srv))
+ }))
+}
+
+func (s *JoinServiceGRPCServer) registerUsingBoundKeypair(srv proto.JoinService_RegisterUsingBoundKeypairMethodServer) error {
+ ctx := srv.Context()
+
+ // Get initial payload from the client
+ req, err := srv.Recv()
+ if err != nil {
+ return trace.Wrap(err, "receiving initial payload")
+ }
+ initReq := req.GetInit()
+ if initReq == nil {
+ return trace.BadParameter("expected non-nil Init payload")
+ }
+
+ if initReq.InitialJoinSecret != "" {
+ // TODO: not supported yet.
+ return trace.NotImplemented("initial join secrets are not yet supported")
+ }
+
+ if initReq.JoinRequest == nil {
+ return trace.BadParameter(
+ "expected JoinRequest in RegisterUsingBoundKeypairInitialRequest, got nil",
+ )
+ }
+ if err := setClientRemoteAddr(ctx, initReq.JoinRequest); err != nil {
+ return trace.Wrap(err, "setting client address")
+ }
+
+ setBotParameters(ctx, initReq.JoinRequest)
+
+ certs, pubKey, err := s.joinServiceClient.RegisterUsingBoundKeypairMethod(ctx, initReq, func(resp *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error) {
+ // First, forward the challenge from Auth to the client.
+ err := srv.Send(resp)
+ if err != nil {
+ return nil, trace.Wrap(
+ err, "forwarding challenge to client",
+ )
+ }
+
+ // Get response from Client
+ req, err := srv.Recv()
+ if err != nil {
+ return nil, trace.Wrap(
+ err, "receiving challenge solution from client",
+ )
+ }
+
+ return req, nil
+ })
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ slog.DebugContext(srv.Context(), "challenge ceremony complete, sending cert bundle")
+
+ // finally, send the certs on the response stream
+ return trace.Wrap(srv.Send(&proto.RegisterUsingBoundKeypairMethodResponse{
+ Response: &proto.RegisterUsingBoundKeypairMethodResponse_Certs{
+ Certs: &proto.RegisterUsingBoundKeypairCertificates{
+ Certs: certs,
+ PublicKey: pubKey,
+ },
+ },
+ }))
+}
+
// RegisterUsingToken allows nodes and proxies to join the cluster using
// legacy join methods which do not yet have their own RPC.
// On the Auth server, this method will call the auth.Server's
diff --git a/lib/joinserver/joinserver_test.go b/lib/joinserver/joinserver_test.go
index f9bd3e6d8e7dc..dcd954f8d815c 100644
--- a/lib/joinserver/joinserver_test.go
+++ b/lib/joinserver/joinserver_test.go
@@ -42,14 +42,17 @@ import (
)
type mockJoinServiceClient struct {
- sendChallenge string
- returnCerts *proto.Certs
- returnError error
- gotIAMChallengeResponse *proto.RegisterUsingIAMMethodRequest
- gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest
- gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse
- gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest
- gotRegisterUsingTokenReq *types.RegisterUsingTokenRequest
+ sendChallenge string
+ boundKeypairPublicKey string
+ returnCerts *proto.Certs
+ returnError error
+ gotIAMChallengeResponse *proto.RegisterUsingIAMMethodRequest
+ gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest
+ gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse
+ gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest
+ gotBoundKeypairInitReq *proto.RegisterUsingBoundKeypairInitialRequest
+ gotBoundKeypairChallengeResponse *proto.RegisterUsingBoundKeypairMethodRequest
+ gotRegisterUsingTokenReq *types.RegisterUsingTokenRequest
}
func (c *mockJoinServiceClient) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) {
@@ -86,6 +89,29 @@ func (c *mockJoinServiceClient) RegisterUsingTPMMethod(
return c.returnCerts, c.returnError
}
+func (c *mockJoinServiceClient) RegisterUsingBoundKeypairMethod(
+ ctx context.Context,
+ req *proto.RegisterUsingBoundKeypairInitialRequest,
+ challengeResponse client.RegisterUsingBoundKeypairChallengeResponseFunc,
+) (*proto.Certs, string, error) {
+ c.gotBoundKeypairInitReq = req
+ resp, err := challengeResponse(&proto.RegisterUsingBoundKeypairMethodResponse{
+ Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{
+ Challenge: &proto.RegisterUsingBoundKeypairChallenge{
+ PublicKey: c.boundKeypairPublicKey,
+ Challenge: c.sendChallenge,
+ },
+ },
+ })
+ if err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+
+ c.gotBoundKeypairChallengeResponse = resp
+
+ return c.returnCerts, c.boundKeypairPublicKey, c.returnError
+}
+
func (c *mockJoinServiceClient) RegisterUsingOracleMethod(
ctx context.Context,
tokenReq *types.RegisterUsingTokenRequest,
@@ -540,6 +566,118 @@ func TestJoinServiceGRPCServer_RegisterUsingTPMMethod(t *testing.T) {
}
}
+// TestJoinServiceGRPCServer_RegisterUsingBoundKeypairMethodSimple tests the
+// simplest bound keypair joining path, with no keypair registration or
+// rotation.
+func TestJoinServiceGRPCServer_RegisterUsingBoundKeypairMethodSimple(t *testing.T) {
+ t.Parallel()
+ testPack := newTestPack(t)
+
+ standardResponse := &proto.RegisterUsingBoundKeypairMethodRequest{
+ Payload: &proto.RegisterUsingBoundKeypairMethodRequest_ChallengeResponse{
+ ChallengeResponse: &proto.RegisterUsingBoundKeypairChallengeResponse{
+ Solution: []byte("header.payload.signature"),
+ },
+ },
+ }
+
+ testCases := []struct {
+ desc string
+ publicKey string
+ challenge string
+ req *proto.RegisterUsingBoundKeypairInitialRequest
+ challengeResponse *proto.RegisterUsingBoundKeypairMethodRequest
+ challengeResponseErr error
+ authErr error
+ certs *proto.Certs
+ }{
+ {
+ desc: "success case",
+ challenge: "foo",
+ req: &proto.RegisterUsingBoundKeypairInitialRequest{
+ JoinRequest: &types.RegisterUsingTokenRequest{},
+ },
+ challengeResponse: standardResponse,
+ certs: &proto.Certs{SSH: []byte("qux")},
+ },
+ {
+ desc: "auth error",
+ challenge: "foo",
+ req: &proto.RegisterUsingBoundKeypairInitialRequest{
+ JoinRequest: &types.RegisterUsingTokenRequest{},
+ },
+ challengeResponse: standardResponse,
+ authErr: trace.AccessDenied("not allowed"),
+ },
+ {
+ desc: "challenge response error",
+ challenge: "foo",
+ req: &proto.RegisterUsingBoundKeypairInitialRequest{
+ JoinRequest: &types.RegisterUsingTokenRequest{},
+ },
+ challengeResponse: nil,
+ challengeResponseErr: trace.BadParameter("testing error"),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ testPack.mockAuthServer.sendChallenge = tc.challenge
+ testPack.mockAuthServer.returnCerts = tc.certs
+ testPack.mockAuthServer.returnError = tc.authErr
+
+ challengeResponder := func(
+ challenge *proto.RegisterUsingBoundKeypairMethodResponse,
+ ) (*proto.RegisterUsingBoundKeypairMethodRequest, error) {
+ assert.Equal(t, &proto.RegisterUsingBoundKeypairMethodResponse{
+ Response: &proto.RegisterUsingBoundKeypairMethodResponse_Challenge{
+ Challenge: &proto.RegisterUsingBoundKeypairChallenge{
+ PublicKey: tc.publicKey,
+ Challenge: tc.challenge,
+ },
+ },
+ }, challenge)
+
+ return tc.challengeResponse, tc.challengeResponseErr
+ }
+
+ for suffix, clt := range map[string]*client.JoinServiceClient{
+ "_auth": testPack.authClient,
+ "_proxy": testPack.proxyClient,
+ } {
+ t.Run(tc.desc+suffix, func(t *testing.T) {
+ certs, pubKey, err := clt.RegisterUsingBoundKeypairMethod(
+ context.Background(), tc.req, challengeResponder,
+ )
+ if tc.challengeResponseErr != nil {
+ require.ErrorIs(t, err, tc.challengeResponseErr)
+ return
+ }
+ if tc.authErr != nil {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tc.authErr.Error())
+ return
+ }
+ require.NoError(t, err)
+ require.Equal(t, tc.certs, certs)
+
+ expectedInitReq := tc.req
+ expectedInitReq.JoinRequest.RemoteAddr = "bufconn"
+ assert.Equal(t, expectedInitReq, testPack.mockAuthServer.gotBoundKeypairInitReq)
+
+ assert.Equal(
+ t,
+ tc.challengeResponse,
+ testPack.mockAuthServer.gotBoundKeypairChallengeResponse,
+ )
+
+ require.Equal(t, tc.publicKey, pubKey)
+ })
+ }
+ })
+ }
+}
+
func TestTimeout(t *testing.T) {
t.Parallel()
diff --git a/lib/services/local/provisioning.go b/lib/services/local/provisioning.go
index 4d9eeac954a41..e2a018a988f8f 100644
--- a/lib/services/local/provisioning.go
+++ b/lib/services/local/provisioning.go
@@ -52,6 +52,62 @@ func (s *ProvisioningService) UpsertToken(ctx context.Context, p types.Provision
return nil
}
+// PatchToken uses the supplied function to attempt to patch a token resource.
+// Up to 3 update attempts will be made if the conditional update fails due to
+// a revision comparison failure.
+func (s *ProvisioningService) PatchToken(
+ ctx context.Context,
+ tokenName string,
+ updateFn func(types.ProvisionToken) (types.ProvisionToken, error),
+) (types.ProvisionToken, error) {
+ const iterLimit = 3
+
+ for i := 0; i < iterLimit; i++ {
+ existing, err := s.GetToken(ctx, tokenName)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ // Note: CloneProvisionToken only supports ProvisionTokenV2.
+ clone, err := services.CloneProvisionToken(existing)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ updated, err := updateFn(clone)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ updatedMetadata := updated.GetMetadata()
+ existingMetadata := existing.GetMetadata()
+
+ switch {
+ case updatedMetadata.GetName() != existingMetadata.GetName():
+ return nil, trace.BadParameter("metadata.name: cannot be patched")
+ case updatedMetadata.GetRevision() != existingMetadata.GetRevision():
+ return nil, trace.BadParameter("metadata.revision: cannot be patched")
+ }
+
+ item, err := s.tokenToItem(updated)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ lease, err := s.ConditionalUpdate(ctx, *item)
+ if trace.IsCompareFailed(err) {
+ continue
+ } else if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ updated.SetRevision(lease.Revision)
+ return updated, nil
+ }
+
+ return nil, trace.CompareFailed("failed to update provision token within %v iterations", iterLimit)
+}
+
// CreateToken creates a new token for the auth server
func (s *ProvisioningService) CreateToken(ctx context.Context, p types.ProvisionToken) error {
item, err := s.tokenToItem(p)
diff --git a/lib/services/provisioning.go b/lib/services/provisioning.go
index 8e17507fec536..185787a6c63ec 100644
--- a/lib/services/provisioning.go
+++ b/lib/services/provisioning.go
@@ -25,6 +25,7 @@ import (
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/types"
+ apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/utils"
)
@@ -48,6 +49,14 @@ type Provisioner interface {
// GetTokens returns all non-expired tokens
GetTokens(ctx context.Context) ([]types.ProvisionToken, error)
+
+ // PatchToken performs a conditional update on the named token using
+ // `updateFn`, retrying internally if a comparison failure occurs.
+ PatchToken(
+ ctx context.Context,
+ token string,
+ updateFn func(types.ProvisionToken) (types.ProvisionToken, error),
+ ) (types.ProvisionToken, error)
}
// MustCreateProvisionToken returns a new valid provision token
@@ -127,3 +136,16 @@ func MarshalProvisionToken(provisionToken types.ProvisionToken, opts ...MarshalO
return nil, trace.BadParameter("unrecognized provision token version %T", provisionToken)
}
}
+
+// CloneProvisionToken returns a deep copy of the given provision token, per
+// `apiutils.CloneProtoMsg()`. Fields in the clone may be modified without
+// affecting the original. Only V2 is supported.
+func CloneProvisionToken(provisionToken types.ProvisionToken) (types.ProvisionToken, error) {
+ switch provisionToken := provisionToken.(type) {
+ case *types.ProvisionTokenV2:
+ clone := apiutils.CloneProtoMsg(provisionToken)
+ return clone, nil
+ default:
+ return nil, trace.BadParameter("cannot clone unsupported provision token version %T", provisionToken)
+ }
+}
diff --git a/lib/web/join_tokens_test.go b/lib/web/join_tokens_test.go
index 39aa9511abef7..d605984443c07 100644
--- a/lib/web/join_tokens_test.go
+++ b/lib/web/join_tokens_test.go
@@ -47,6 +47,8 @@ import (
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/automaticupgrades"
+ "github.com/gravitational/teleport/lib/boundkeypair"
+ "github.com/gravitational/teleport/lib/boundkeypair/boundkeypairexperiment"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/modules"
@@ -479,6 +481,9 @@ func TestCreateTokenExpiry(t *testing.T) {
},
})
+ // TODO: Remove this once bound keypair experiment flag is removed.
+ boundkeypairexperiment.SetEnabled(true)
+
ctx := context.Background()
username := "test-user@example.com"
env := newWebPack(t, 1)
@@ -623,6 +628,15 @@ func setMinimalConfigForMethod(spec *types.ProvisionTokenSpecV2, method types.Jo
},
},
}
+ case types.JoinMethodBoundKeypair:
+ spec.BoundKeypair = &types.ProvisionTokenSpecV2BoundKeypair{
+ Onboarding: &types.ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{
+ InitialPublicKey: "abcd",
+ },
+ Recovery: &types.ProvisionTokenSpecV2BoundKeypair_RecoverySpec{
+ Mode: boundkeypair.RecoveryModeInsecure,
+ },
+ }
}
}