Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"math"
"math/big"
insecurerand "math/rand"
Expand Down Expand Up @@ -108,6 +109,7 @@ import (
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/tpm"
usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/interval"
Expand Down Expand Up @@ -451,6 +453,9 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
)
}
}
if as.tpmValidator == nil {
as.tpmValidator = tpm.Validate
}
if as.k8sTokenReviewValidator == nil {
as.k8sTokenReviewValidator = &kubernetestoken.TokenReviewValidator{}
}
Expand Down Expand Up @@ -813,6 +818,12 @@ type Server struct {
// the auth server. It can be overridden for the purpose of tests.
gitlabIDTokenValidator gitlabIDTokenValidator

// tpmValidator allows TPMs to be validated by the auth server. It can be
// overridden for the purpose of tests.
tpmValidator func(
ctx context.Context, log *slog.Logger, params tpm.ValidateParams,
) (*tpm.ValidatedTPM, error)

// circleCITokenValidate allows ID tokens from CircleCI to be validated by
// the auth server. It can be overridden for the purpose of tests.
circleCITokenValidate func(ctx context.Context, organizationID, token string) (*circleci.IDTokenClaims, error)
Expand Down
31 changes: 27 additions & 4 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,24 @@ func (a *ServerWithRoles) RegisterUsingAzureMethod(ctx context.Context, challeng
return certs, trace.Wrap(err)
}

// RegisterUsingTPMMethod registers the caller using the TPM join method and
// returns signed certs to join the cluster.
//
// See (*Server).RegisterUsingTPMMethod for further documentation.
//
// This wrapper does not do any extra authz checks, as the register method has
// its own authz mechanism.
func (a *ServerWithRoles) RegisterUsingTPMMethod(
ctx context.Context,
initReq *proto.RegisterUsingTPMMethodInitialRequest,
solveChallenge client.RegisterTPMChallengeResponseFunc,
) (*proto.Certs, error) {
certs, err := a.authServer.registerUsingTPMMethod(
ctx, initReq, solveChallenge,
)
return certs, trace.Wrap(err)
}

// GenerateHostCerts generates new host certificates (signed
// by the host certificate authority) for a node.
func (a *ServerWithRoles) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequest) (*proto.Certs, error) {
Expand Down Expand Up @@ -2217,15 +2235,20 @@ func enforceEnterpriseJoinMethodCreation(token types.ProvisionToken) error {
switch v.Spec.JoinMethod {
case types.JoinMethodGitHub:
if v.Spec.GitHub != nil && v.Spec.GitHub.EnterpriseServerHost != "" {
return fmt.Errorf(
"github enterprise server joining: %w",
return trace.Wrap(
ErrRequiresEnterprise,
"github enterprise server joining",
)
}
case types.JoinMethodSpacelift:
return fmt.Errorf(
"spacelift joining: %w",
return trace.Wrap(
ErrRequiresEnterprise,
"spacelift joining",
)
case types.JoinMethodTPM:
return trace.Wrap(
ErrRequiresEnterprise,
"tpm joining",
)
}

Expand Down
1 change: 1 addition & 0 deletions lib/auth/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ var supportedBotJoinMethods = []types.JoinMethod{
types.JoinMethodKubernetes,
types.JoinMethodSpacelift,
types.JoinMethodToken,
types.JoinMethodTPM,
}

// checkOrCreateBotToken checks the existing token if given, or creates a new
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin
if err := a.checkEC2JoinRequest(ctx, req); err != nil {
return nil, trace.Wrap(err)
}
case types.JoinMethodIAM, types.JoinMethodAzure:
case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM:
// IAM and Azure join methods must use gRPC register methods
return nil, trace.AccessDenied("this token is only valid for the %s "+
"join method but the node has connected to the wrong endpoint, make "+
Expand Down
137 changes: 137 additions & 0 deletions lib/auth/join_tpm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package auth

import (
"context"
"crypto/x509"
"log/slog"

"github.com/google/go-attestation/attest"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/tpm"
)

func (a *Server) registerUsingTPMMethod(
ctx context.Context,
initReq *proto.RegisterUsingTPMMethodInitialRequest,
solveChallenge client.RegisterTPMChallengeResponseFunc,
) (_ *proto.Certs, err error) {
var provisionToken types.ProvisionToken
var attributeSrc joinAttributeSourcer
defer func() {
// Emit a log message and audit event on join failure.
if err != nil {
a.handleJoinFailure(
err, provisionToken, attributeSrc, initReq.JoinRequest,
)
}
}()

// First, check the specified token exists, and is a TPM-type join token.
if err := initReq.JoinRequest.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
provisionToken, err = a.checkTokenJoinRequestCommon(ctx, initReq.JoinRequest)
if err != nil {
return nil, trace.Wrap(err)
}
ptv2, ok := provisionToken.(*types.ProvisionTokenV2)
if !ok {
return nil, trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken)
}
if ptv2.Spec.JoinMethod != types.JoinMethodTPM {
return nil, trace.BadParameter("specified join token is not for `tpm` method")
}

if modules.GetModules().BuildType() != modules.BuildEnterprise {
return nil, trace.Wrap(
ErrRequiresEnterprise,
"tpm joining",
)
}

// Convert configured CAs to a CAPool
var certPool *x509.CertPool
if len(ptv2.Spec.TPM.EKCertAllowedCAs) > 0 {
certPool = x509.NewCertPool()
for i, ca := range ptv2.Spec.TPM.EKCertAllowedCAs {
if ok := certPool.AppendCertsFromPEM([]byte(ca)); !ok {
return nil, trace.BadParameter(
"ekcert_allowed_cas[%d] has an invalid or malformed PEM", i,
)
}
}
}

// TODO(noah): Use logger from TeleportProcess.
validatedEK, err := a.tpmValidator(ctx, slog.Default(), tpm.ValidateParams{
EKCert: initReq.GetEkCert(),
EKKey: initReq.GetEkKey(),
AttestParams: tpm.AttestationParametersFromProto(initReq.AttestationParams),
AllowedCAs: certPool,
Solve: func(ec *attest.EncryptedCredential) ([]byte, error) {
solution, err := solveChallenge(tpm.EncryptedCredentialToProto(ec))
if err != nil {
return nil, trace.Wrap(err)
}
return solution.Solution, nil
},
})
if err != nil {
return nil, trace.Wrap(err, "validating TPM EK")
}
attributeSrc = validatedEK

if err := checkTPMAllowRules(validatedEK, ptv2.Spec.TPM.Allow); err != nil {
return nil, trace.Wrap(err)
}

if initReq.JoinRequest.Role == types.RoleBot {
certs, err := a.generateCertsBot(
ctx, ptv2, initReq.JoinRequest, validatedEK,
)
return certs, trace.Wrap(err, "generating certs for bot")
}
certs, err := a.generateCerts(
ctx, ptv2, initReq.JoinRequest, validatedEK,
)
return certs, trace.Wrap(err, "generating certs for host")
}

func checkTPMAllowRules(tpm *tpm.ValidatedTPM, rules []*types.ProvisionTokenSpecV2TPM_Rule) error {
// If a single rule passes, accept the TPM
for _, rule := range rules {
if rule.EKPublicHash != "" && tpm.EKPubHash != rule.EKPublicHash {
continue
}
if rule.EKCertificateSerial != "" && tpm.EKCertSerial != rule.EKCertificateSerial {
continue
}

// All rules met.
return nil
}
return trace.AccessDenied("validated tpm attributes did not match any allow rules")
}
Loading