Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions api/client/joinservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package client

import (
"context"
"errors"
"io"

"github.com/gravitational/trace"

Expand Down Expand Up @@ -60,6 +62,11 @@ type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredenti
// *proto.OracleSignedRequest for a given challenge, or an error.
type RegisterOracleChallengeResponseFunc func(challenge string) (*proto.OracleSignedRequest, error)

// RegisterUsingBoundKeypairChallengeResponseFunc is a function to be passed to
// RegisterUsingBoundKeypair. It must return a new follow-up request for the
// server response, or an error.
type RegisterUsingBoundKeypairChallengeResponseFunc func(challenge *proto.RegisterUsingBoundKeypairMethodResponse) (*proto.RegisterUsingBoundKeypairMethodRequest, error)

// RegisterUsingIAMMethod registers the caller using the IAM join method and
// returns signed certs to join the cluster.
//
Expand Down Expand Up @@ -262,6 +269,78 @@ func (c *JoinServiceClient) RegisterUsingOracleMethod(
return certs, nil
}

// RegisterUsingBoundKeypairMethod attempts to register the caller using
// bound-keypair join method. If successful, the public key registered with auth
// and a certificate bundle is returned, or an error. Clients must provide a
// callback to handle interactive challenges and keypair rotation requests.
func (c *JoinServiceClient) RegisterUsingBoundKeypairMethod(
ctx context.Context,
initReq *proto.RegisterUsingBoundKeypairInitialRequest,
challengeFunc RegisterUsingBoundKeypairChallengeResponseFunc,
) (*proto.Certs, string, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

stream, err := c.grpcClient.RegisterUsingBoundKeypairMethod(ctx)
if err != nil {
return nil, "", trace.Wrap(err)
}
defer stream.CloseSend()

err = stream.Send(&proto.RegisterUsingBoundKeypairMethodRequest{
Payload: &proto.RegisterUsingBoundKeypairMethodRequest_Init{
Init: initReq,
},
})
if err != nil {
return nil, "", trace.Wrap(err, "sending initial request")
}

// Unlike other methods, the server may send multiple challenges,
// particularly during keypair rotation. We'll iterate through all responses
// here instead to ensure we handle everything.
for {
res, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, "", trace.Wrap(err, "receiving intermediate bound keypair join response")
}

switch kind := res.GetResponse().(type) {
case *proto.RegisterUsingBoundKeypairMethodResponse_Certs:
// If we get certs, we're done, so just return the result.
certs := kind.Certs.GetCerts()
if certs == nil {
return nil, "", trace.BadParameter("expected Certs, got %T", kind.Certs.Certs)
}

// If we receive a cert bundle, we can return early. Even if we
// logically should have expected to receive a 2nd challenge if we
// e.g. requested keypair rotation, skipping it just means the new
// keypair won't be stored. That said, we'll rely on the server to
// raise an error if rotation fails or is otherwise skipped or not
// allowed.

return certs, kind.Certs.GetPublicKey(), nil
default:
// Forward all other responses to the challenge handler.
nextRequest, err := challengeFunc(res)
if err != nil {
return nil, "", trace.Wrap(err, "solving challenge")
}

if err := stream.Send(nextRequest); err != nil {
return nil, "", trace.Wrap(err, "sending solution")
}
}
}

// Ideally the server will emit a proper error instead of just hanging up on
// us.
return nil, "", trace.AccessDenied("server declined to send certs during bound-keypair join attempt")
}

// RegisterUsingToken registers the caller using a token and returns signed
// certs.
// This is used where a more specific RPC has not been introduced for the join
Expand Down
49 changes: 49 additions & 0 deletions api/types/provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ const (
// JoinMethodAzureDevops indicates that the node will join using the Azure
// Devops join method.
JoinMethodAzureDevops JoinMethod = "azure_devops"
// JoinMethodBoundKeypair indicates the node will join using the Bound
// Keypair join method. See lib/boundkeypair for more.
JoinMethodBoundKeypair JoinMethod = "bound_keypair"
Comment thread
timothyb89 marked this conversation as resolved.
)

var JoinMethods = []JoinMethod{
Expand All @@ -101,6 +104,7 @@ var JoinMethods = []JoinMethod{
JoinMethodTPM,
JoinMethodTerraformCloud,
JoinMethodOracle,
JoinMethodBoundKeypair,
}

func ValidateJoinMethod(method JoinMethod) error {
Expand Down Expand Up @@ -193,6 +197,26 @@ func NewProvisionTokenFromSpec(token string, expires time.Time, spec ProvisionTo
return t, nil
}

// NewProvisionTokenFromSpecAndStatus returns a new provision token with the given spec.
func NewProvisionTokenFromSpecAndStatus(
token string, expires time.Time,
spec ProvisionTokenSpecV2,
status *ProvisionTokenStatusV2,
) (ProvisionToken, error) {
t := &ProvisionTokenV2{
Metadata: Metadata{
Name: token,
Expires: &expires,
},
Spec: spec,
Status: status,
}
if err := t.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
return t, nil
}

// MustCreateProvisionToken returns a new valid provision token
// or panics, used in tests
func MustCreateProvisionToken(token string, roles SystemRoles, expires time.Time) ProvisionToken {
Expand Down Expand Up @@ -416,6 +440,18 @@ func (p *ProvisionTokenV2) CheckAndSetDefaults() error {
if err := providerCfg.checkAndSetDefaults(); err != nil {
return trace.Wrap(err, "spec.azure_devops: failed validation")
}
case JoinMethodBoundKeypair:
providerCfg := p.Spec.BoundKeypair
if providerCfg == nil {
return trace.BadParameter(
"spec.bound_keypair: must be configured for the join method %q",
JoinMethodBoundKeypair,
)
}

if err := providerCfg.checkAndSetDefaults(); err != nil {
return trace.Wrap(err, "spec.bound_keypair: failed validation")
}
default:
return trace.BadParameter("unknown join method %q", p.Spec.JoinMethod)
}
Expand Down Expand Up @@ -994,3 +1030,16 @@ func (a *ProvisionTokenSpecV2AzureDevops) checkAndSetDefaults() error {
}
return nil
}

func (a *ProvisionTokenSpecV2BoundKeypair) checkAndSetDefaults() error {
if a.Onboarding == nil {
return trace.BadParameter("spec.bound_keypair.onboarding is required")
}

if a.Onboarding.RegistrationSecret == "" && a.Onboarding.InitialPublicKey == "" {
return trace.BadParameter("at least one of [initial_join_secret, " +
"initial_public_key] is required in spec.bound_keypair.onboarding")
}

return nil
}
48 changes: 48 additions & 0 deletions api/types/provisioning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,54 @@ func TestProvisionTokenV2_CheckAndSetDefaults(t *testing.T) {
},
wantErr: true,
},
{
desc: "minimal bound keypair with pregenerated key",
token: &ProvisionTokenV2{
Metadata: Metadata{
Name: "test",
},
Spec: ProvisionTokenSpecV2{
Roles: []SystemRole{RoleNode},
JoinMethod: JoinMethodBoundKeypair,
BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{
Onboarding: &ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{
InitialPublicKey: "asdf",
},
},
},
},
expected: &ProvisionTokenV2{
Kind: "token",
Version: "v2",
Metadata: Metadata{
Name: "test",
Namespace: "default",
},
Spec: ProvisionTokenSpecV2{
Roles: []SystemRole{RoleNode},
JoinMethod: JoinMethodBoundKeypair,
BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{
Onboarding: &ProvisionTokenSpecV2BoundKeypair_OnboardingSpec{
InitialPublicKey: "asdf",
},
},
},
},
},
{
desc: "bound keypair missing onboarding config",
token: &ProvisionTokenV2{
Metadata: Metadata{
Name: "test",
},
Spec: ProvisionTokenSpecV2{
Roles: []SystemRole{RoleNode},
JoinMethod: JoinMethodBoundKeypair,
BoundKeypair: &ProvisionTokenSpecV2BoundKeypair{},
},
},
wantErr: true,
},
}

for _, tc := range testcases {
Expand Down
11 changes: 11 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ import (
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/bitbucket"
"github.com/gravitational/teleport/lib/boundkeypair"
"github.com/gravitational/teleport/lib/cache"
"github.com/gravitational/teleport/lib/circleci"
"github.com/gravitational/teleport/lib/cryptosuites"
Expand Down Expand Up @@ -706,6 +707,12 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
as.bitbucketIDTokenValidator = bitbucket.NewIDTokenValidator(as.clock)
}

if as.createBoundKeypairValidator == nil {
as.createBoundKeypairValidator = func(subject, clusterName string, publicKey crypto.PublicKey) (boundKeypairValidator, error) {
return boundkeypair.NewChallengeValidator(subject, clusterName, publicKey)
}
}

// Add in a login hook for generating state during user login.
as.ulsGenerator, err = userloginstate.NewGenerator(userloginstate.GeneratorConfig{
Log: as.logger,
Expand Down Expand Up @@ -1145,6 +1152,10 @@ type Server struct {

bitbucketIDTokenValidator bitbucketIDTokenValidator

// createBoundKeypairValidator is a helper to create new bound keypair
// challenge validators. Used to override the implementation used in tests.
createBoundKeypairValidator createBoundKeypairValidator

// loadAllCAs tells tsh to load the host CAs for all clusters when trying to ssh into a node.
loadAllCAs bool

Expand Down
12 changes: 12 additions & 0 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -2440,6 +2440,12 @@ func (a *ServerWithRoles) UpsertToken(ctx context.Context, token types.Provision
return trace.Wrap(err)
}

// bound_keypair tokens have special creation/update logic and are handled
// separately
if token.GetJoinMethod() == types.JoinMethodBoundKeypair {
return trace.Wrap(a.authServer.UpsertBoundKeypairToken(ctx, token))
}

if err := a.authServer.UpsertToken(ctx, token); err != nil {
return trace.Wrap(err)
}
Expand All @@ -2465,6 +2471,12 @@ func (a *ServerWithRoles) CreateToken(ctx context.Context, token types.Provision
return trace.Wrap(err)
}

// bound_keypair tokens have special creation/update logic and are handled
// separately
if token.GetJoinMethod() == types.JoinMethodBoundKeypair {
return trace.Wrap(a.authServer.CreateBoundKeypairToken(ctx, token))
}

if err := a.authServer.CreateToken(ctx, token); err != nil {
return trace.Wrap(err)
}
Expand Down
17 changes: 17 additions & 0 deletions lib/auth/authclient/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,15 @@ func (c *Client) DeleteAllTokens() error {
return trace.NotImplemented(notImplementedMessage)
}

// PatchToken not implemented: can only be called locally
func (c *Client) PatchToken(
ctx context.Context,
token string,
updateFn func(types.ProvisionToken) (types.ProvisionToken, error),
) (types.ProvisionToken, error) {
return nil, trace.NotImplemented(notImplementedMessage)
}

// AddUserLoginAttempt logs user login attempt
func (c *Client) AddUserLoginAttempt(user string, attempt services.LoginAttempt, ttl time.Duration) error {
panic("not implemented")
Expand Down Expand Up @@ -1235,6 +1244,14 @@ type ProvisioningService interface {
// CreateToken creates a new provision token for the auth server
CreateToken(ctx context.Context, token types.ProvisionToken) error

// PatchToken performs a conditional update on the named token using
// `updateFn`, retrying internally if a comparison failure occurs.
PatchToken(
ctx context.Context,
token string,
updateFn func(types.ProvisionToken) (types.ProvisionToken, error),
) (types.ProvisionToken, error)

// RegisterUsingToken calls the auth service API to register a new node via registration token
// which has been previously issued via GenerateToken
RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error)
Expand Down
13 changes: 7 additions & 6 deletions lib/auth/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin
if err := a.checkEC2JoinRequest(ctx, req); err != nil {
return nil, trace.Wrap(err)
}
case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM, types.JoinMethodOracle:
case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM,
types.JoinMethodOracle, types.JoinMethodBoundKeypair:
// Some join methods require use of a specific RPC - reject those here.
// This would generally be a developer error - but can be triggered if
// the user has configured the wrong join method on the client-side.
Expand Down Expand Up @@ -317,7 +318,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin
// With all elements of the token validated, we can now generate & return
// certificates.
if req.Role == types.RoleBot {
certs, err = a.generateCertsBot(
certs, _, err = a.generateCertsBot(
ctx,
provisionToken,
req,
Expand All @@ -336,15 +337,15 @@ func (a *Server) generateCertsBot(
req *types.RegisterUsingTokenRequest,
rawJoinClaims any,
attrs *workloadidentityv1pb.JoinAttrs,
) (*proto.Certs, error) {
) (*proto.Certs, string, error) {
// bots use this endpoint but get a user cert
// botResourceName must be set, enforced in CheckAndSetDefaults
botName := provisionToken.GetBotName()
joinMethod := provisionToken.GetJoinMethod()

// Check this is a join method for bots we support.
if !slices.Contains(machineidv1.SupportedJoinMethods, joinMethod) {
return nil, trace.BadParameter(
return nil, "", trace.BadParameter(
"unsupported join method %q for bot", joinMethod,
)
}
Expand Down Expand Up @@ -439,7 +440,7 @@ func (a *Server) generateCertsBot(
attrs,
)
if err != nil {
return nil, trace.Wrap(err)
return nil, "", trace.Wrap(err)
}
joinEvent.BotInstanceID = botInstanceID

Expand All @@ -461,7 +462,7 @@ func (a *Server) generateCertsBot(
if err := a.emitter.EmitAuditEvent(ctx, joinEvent); err != nil {
a.logger.WarnContext(ctx, "Failed to emit bot join event", "error", err)
}
return certs, nil
return certs, botInstanceID, nil
}

func (a *Server) generateCerts(
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ func (a *Server) RegisterUsingAzureMethodWithOpts(
}

if req.RegisterUsingTokenRequest.Role == types.RoleBot {
certs, err := a.generateCertsBot(
certs, _, err := a.generateCertsBot(
ctx,
provisionToken,
req.RegisterUsingTokenRequest,
Expand Down
Loading
Loading