diff --git a/lib/auth/auth.go b/lib/auth/auth.go
index b28cea169933c..d766ae1d5f2f6 100644
--- a/lib/auth/auth.go
+++ b/lib/auth/auth.go
@@ -35,6 +35,7 @@ import (
"errors"
"fmt"
"io"
+ "log/slog"
"math"
"math/big"
insecurerand "math/rand"
@@ -113,6 +114,7 @@ import (
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/tpm"
usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/interval"
@@ -497,6 +499,9 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
)
}
}
+ if as.tpmValidator == nil {
+ as.tpmValidator = tpm.Validate
+ }
if as.k8sTokenReviewValidator == nil {
as.k8sTokenReviewValidator = &kubernetestoken.TokenReviewValidator{}
}
@@ -891,6 +896,12 @@ type Server struct {
// the auth server. It can be overridden for the purpose of tests.
gitlabIDTokenValidator gitlabIDTokenValidator
+ // tpmValidator allows TPMs to be validated by the auth server. It can be
+ // overridden for the purpose of tests.
+ tpmValidator func(
+ ctx context.Context, log *slog.Logger, params tpm.ValidateParams,
+ ) (*tpm.ValidatedTPM, error)
+
// circleCITokenValidate allows ID tokens from CircleCI to be validated by
// the auth server. It can be overridden for the purpose of tests.
circleCITokenValidate func(ctx context.Context, organizationID, token string) (*circleci.IDTokenClaims, error)
diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go
index b86ca0f21ab39..68fd7ebe55bea 100644
--- a/lib/auth/auth_with_roles.go
+++ b/lib/auth/auth_with_roles.go
@@ -631,6 +631,24 @@ func (a *ServerWithRoles) RegisterUsingAzureMethod(ctx context.Context, challeng
return certs, trace.Wrap(err)
}
+// RegisterUsingTPMMethod registers the caller using the TPM join method and
+// returns signed certs to join the cluster.
+//
+// See (*Server).RegisterUsingTPMMethod for further documentation.
+//
+// This wrapper does not do any extra authz checks, as the register method has
+// its own authz mechanism.
+func (a *ServerWithRoles) RegisterUsingTPMMethod(
+ ctx context.Context,
+ initReq *proto.RegisterUsingTPMMethodInitialRequest,
+ solveChallenge client.RegisterTPMChallengeResponseFunc,
+) (*proto.Certs, error) {
+ certs, err := a.authServer.registerUsingTPMMethod(
+ ctx, initReq, solveChallenge,
+ )
+ return certs, trace.Wrap(err)
+}
+
// GenerateHostCerts generates new host certificates (signed
// by the host certificate authority) for a node.
func (a *ServerWithRoles) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequest) (*proto.Certs, error) {
@@ -2079,15 +2097,20 @@ func enforceEnterpriseJoinMethodCreation(token types.ProvisionToken) error {
switch v.Spec.JoinMethod {
case types.JoinMethodGitHub:
if v.Spec.GitHub != nil && v.Spec.GitHub.EnterpriseServerHost != "" {
- return fmt.Errorf(
- "github enterprise server joining: %w",
+ return trace.Wrap(
ErrRequiresEnterprise,
+ "github enterprise server joining",
)
}
case types.JoinMethodSpacelift:
- return fmt.Errorf(
- "spacelift joining: %w",
+ return trace.Wrap(
+ ErrRequiresEnterprise,
+ "spacelift joining",
+ )
+ case types.JoinMethodTPM:
+ return trace.Wrap(
ErrRequiresEnterprise,
+ "tpm joining",
)
}
diff --git a/lib/auth/join.go b/lib/auth/join.go
index 7ed55f055a6bf..7507c7bb6b244 100644
--- a/lib/auth/join.go
+++ b/lib/auth/join.go
@@ -227,7 +227,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin
if err := a.checkEC2JoinRequest(ctx, req); err != nil {
return nil, trace.Wrap(err)
}
- case types.JoinMethodIAM, types.JoinMethodAzure:
+ case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM:
// IAM and Azure join methods must use gRPC register methods
return nil, trace.AccessDenied("this token is only valid for the %s "+
"join method but the node has connected to the wrong endpoint, make "+
diff --git a/lib/auth/join_tpm.go b/lib/auth/join_tpm.go
new file mode 100644
index 0000000000000..4304628b9d07c
--- /dev/null
+++ b/lib/auth/join_tpm.go
@@ -0,0 +1,137 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package auth
+
+import (
+ "context"
+ "crypto/x509"
+ "log/slog"
+
+ "github.com/google/go-attestation/attest"
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/modules"
+ "github.com/gravitational/teleport/lib/tpm"
+)
+
+func (a *Server) registerUsingTPMMethod(
+ ctx context.Context,
+ initReq *proto.RegisterUsingTPMMethodInitialRequest,
+ solveChallenge client.RegisterTPMChallengeResponseFunc,
+) (_ *proto.Certs, err error) {
+ var provisionToken types.ProvisionToken
+ var attributeSrc joinAttributeSourcer
+ defer func() {
+ // Emit a log message and audit event on join failure.
+ if err != nil {
+ a.handleJoinFailure(
+ err, provisionToken, attributeSrc, initReq.JoinRequest,
+ )
+ }
+ }()
+
+ // First, check the specified token exists, and is a TPM-type join token.
+ if err := initReq.JoinRequest.CheckAndSetDefaults(); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ provisionToken, err = a.checkTokenJoinRequestCommon(ctx, initReq.JoinRequest)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ ptv2, ok := provisionToken.(*types.ProvisionTokenV2)
+ if !ok {
+ return nil, trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken)
+ }
+ if ptv2.Spec.JoinMethod != types.JoinMethodTPM {
+ return nil, trace.BadParameter("specified join token is not for `tpm` method")
+ }
+
+ if modules.GetModules().BuildType() != modules.BuildEnterprise {
+ return nil, trace.Wrap(
+ ErrRequiresEnterprise,
+ "tpm joining",
+ )
+ }
+
+ // Convert configured CAs to a CAPool
+ var certPool *x509.CertPool
+ if len(ptv2.Spec.TPM.EKCertAllowedCAs) > 0 {
+ certPool = x509.NewCertPool()
+ for i, ca := range ptv2.Spec.TPM.EKCertAllowedCAs {
+ if ok := certPool.AppendCertsFromPEM([]byte(ca)); !ok {
+ return nil, trace.BadParameter(
+ "ekcert_allowed_cas[%d] has an invalid or malformed PEM", i,
+ )
+ }
+ }
+ }
+
+ // TODO(noah): Use logger from TeleportProcess.
+ validatedEK, err := a.tpmValidator(ctx, slog.Default(), tpm.ValidateParams{
+ EKCert: initReq.GetEkCert(),
+ EKKey: initReq.GetEkKey(),
+ AttestParams: tpm.AttestationParametersFromProto(initReq.AttestationParams),
+ AllowedCAs: certPool,
+ Solve: func(ec *attest.EncryptedCredential) ([]byte, error) {
+ solution, err := solveChallenge(tpm.EncryptedCredentialToProto(ec))
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return solution.Solution, nil
+ },
+ })
+ if err != nil {
+ return nil, trace.Wrap(err, "validating TPM EK")
+ }
+ attributeSrc = validatedEK
+
+ if err := checkTPMAllowRules(validatedEK, ptv2.Spec.TPM.Allow); err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ if initReq.JoinRequest.Role == types.RoleBot {
+ certs, err := a.generateCertsBot(
+ ctx, ptv2, initReq.JoinRequest, validatedEK,
+ )
+ return certs, trace.Wrap(err, "generating certs for bot")
+ }
+ certs, err := a.generateCerts(
+ ctx, ptv2, initReq.JoinRequest, validatedEK,
+ )
+ return certs, trace.Wrap(err, "generating certs for host")
+}
+
+func checkTPMAllowRules(tpm *tpm.ValidatedTPM, rules []*types.ProvisionTokenSpecV2TPM_Rule) error {
+ // If a single rule passes, accept the TPM
+ for _, rule := range rules {
+ if rule.EKPublicHash != "" && tpm.EKPubHash != rule.EKPublicHash {
+ continue
+ }
+ if rule.EKCertificateSerial != "" && tpm.EKCertSerial != rule.EKCertificateSerial {
+ continue
+ }
+
+ // All rules met.
+ return nil
+ }
+ return trace.AccessDenied("validated tpm attributes did not match any allow rules")
+}
diff --git a/lib/auth/join_tpm_test.go b/lib/auth/join_tpm_test.go
new file mode 100644
index 0000000000000..8210a2f975185
--- /dev/null
+++ b/lib/auth/join_tpm_test.go
@@ -0,0 +1,358 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package auth
+
+import (
+ "bytes"
+ "context"
+ "crypto/x509"
+ "errors"
+ "log/slog"
+ "testing"
+ "time"
+
+ "github.com/google/go-attestation/attest"
+ gocmp "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/client/proto"
+ apifixtures "github.com/gravitational/teleport/api/fixtures"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/auth/testauthority"
+ "github.com/gravitational/teleport/lib/modules"
+ "github.com/gravitational/teleport/lib/tpm"
+)
+
+type mockTPMValidator struct {
+ lastCalledParams *tpm.ValidateParams
+ returnErr error
+ returnValidatedTPM *tpm.ValidatedTPM
+}
+
+func (m *mockTPMValidator) setup(returns *tpm.ValidatedTPM, err error) {
+ m.lastCalledParams = nil
+ m.returnErr = err
+ m.returnValidatedTPM = returns
+}
+
+func (m *mockTPMValidator) validate(
+ _ context.Context, _ *slog.Logger, params tpm.ValidateParams,
+) (*tpm.ValidatedTPM, error) {
+ m.lastCalledParams = ¶ms
+
+ solution, err := params.Solve(&attest.EncryptedCredential{
+ Secret: []byte("mock-secret"),
+ Credential: []byte("mock-credential"),
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ if !bytes.Equal(solution, []byte("mock-solution")) {
+ return nil, trace.AccessDenied("invalid solution")
+ }
+
+ return m.returnValidatedTPM, m.returnErr
+}
+
+func TestServer_RegisterUsingTPMMethod(t *testing.T) {
+ ctx := context.Background()
+ mockValidator := &mockTPMValidator{}
+ p, err := newTestPack(ctx, t.TempDir(), func(server *Server) error {
+ server.tpmValidator = mockValidator.validate
+ return nil
+ })
+ require.NoError(t, err)
+ auth := p.a
+
+ sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair()
+ require.NoError(t, err)
+ tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey)
+ require.NoError(t, err)
+
+ attParams := &proto.TPMAttestationParameters{
+ Public: []byte("mock-public"),
+ }
+
+ const (
+ goodEKPubHash = "mock-ekpub-hashed"
+ goodEKCertSerial = "mock-ekcert-serial"
+ goodEKPubHashAlt = "mock-ekpub-hashed-alt"
+ goodEKCertSerialAlt = "mock-ekcert-serial-alt"
+ )
+ tokenSpec := func(mutate func(v2 *types.ProvisionTokenSpecV2)) types.ProvisionTokenSpecV2 {
+ spec := types.ProvisionTokenSpecV2{
+ JoinMethod: types.JoinMethodTPM,
+ Roles: []types.SystemRole{types.RoleNode},
+ TPM: &types.ProvisionTokenSpecV2TPM{
+ Allow: []*types.ProvisionTokenSpecV2TPM_Rule{
+ {
+ Description: "ekpub only",
+ EKPublicHash: goodEKPubHash,
+ },
+ {
+ Description: "ekcert only",
+ EKCertificateSerial: goodEKCertSerial,
+ },
+ {
+ Description: "both",
+ EKPublicHash: goodEKPubHashAlt,
+ EKCertificateSerial: goodEKCertSerialAlt,
+ },
+ },
+ },
+ }
+ if mutate != nil {
+ mutate(&spec)
+ }
+ return spec
+ }
+ joinRequest := func() *types.RegisterUsingTokenRequest {
+ return &types.RegisterUsingTokenRequest{
+ HostID: "host-id",
+ Role: types.RoleNode,
+ PublicTLSKey: tlsPublicKey,
+ PublicSSHKey: sshPublicKey,
+ }
+ }
+
+ caPool := x509.NewCertPool()
+ require.True(t, caPool.AppendCertsFromPEM([]byte(apifixtures.TLSCACertPEM)))
+
+ allowRulesNotMatched := require.ErrorAssertionFunc(func(t require.TestingT, err error, i ...interface{}) {
+ require.ErrorContains(t, err, "validated tpm attributes did not match any allow rules")
+ require.True(t, trace.IsAccessDenied(err))
+ })
+ tests := []struct {
+ name string
+ setOSS bool
+
+ tokenSpec types.ProvisionTokenSpecV2
+
+ validateReturnTPM *tpm.ValidatedTPM
+ validateReturnErr error
+
+ initReq *proto.RegisterUsingTPMMethodInitialRequest
+ wantParams *tpm.ValidateParams
+
+ assertError require.ErrorAssertionFunc
+ }{
+ {
+ name: "success, ekpub",
+ assertError: require.NoError,
+
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ JoinRequest: joinRequest(),
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{
+ EkKey: []byte("mock-ekpub"),
+ },
+ AttestationParams: attParams,
+ },
+ wantParams: &tpm.ValidateParams{
+ EKKey: []byte("mock-ekpub"),
+ AttestParams: tpm.AttestationParametersFromProto(attParams),
+ },
+
+ tokenSpec: tokenSpec(nil),
+ validateReturnTPM: &tpm.ValidatedTPM{
+ EKPubHash: goodEKPubHash,
+ },
+ },
+ {
+ name: "success, ekcert",
+ assertError: require.NoError,
+
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ JoinRequest: joinRequest(),
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{
+ EkCert: []byte("mock-ekcert"),
+ },
+ AttestationParams: attParams,
+ },
+ wantParams: &tpm.ValidateParams{
+ EKCert: []byte("mock-ekcert"),
+ AttestParams: tpm.AttestationParametersFromProto(attParams),
+ AllowedCAs: caPool,
+ },
+
+ tokenSpec: tokenSpec(func(v2 *types.ProvisionTokenSpecV2) {
+ v2.TPM.EKCertAllowedCAs = []string{apifixtures.TLSCACertPEM}
+ }),
+ validateReturnTPM: &tpm.ValidatedTPM{
+ EKCertSerial: goodEKCertSerial,
+ EKCertVerified: true,
+ },
+ },
+ {
+ name: "success, both ek cert serial and ek pub hash match",
+ assertError: require.NoError,
+
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ JoinRequest: joinRequest(),
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{
+ EkCert: []byte("mock-ekcert"),
+ },
+ AttestationParams: attParams,
+ },
+ wantParams: &tpm.ValidateParams{
+ EKCert: []byte("mock-ekcert"),
+ AttestParams: tpm.AttestationParametersFromProto(attParams),
+ },
+
+ tokenSpec: tokenSpec(nil),
+ validateReturnTPM: &tpm.ValidatedTPM{
+ EKCertSerial: goodEKCertSerialAlt,
+ EKPubHash: goodEKPubHashAlt,
+ EKCertVerified: true,
+ },
+ },
+ {
+ name: "failure, mismatched ekpub",
+ assertError: allowRulesNotMatched,
+
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ JoinRequest: joinRequest(),
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{
+ EkKey: []byte("mock-ekpub"),
+ },
+ AttestationParams: attParams,
+ },
+ wantParams: &tpm.ValidateParams{
+ EKKey: []byte("mock-ekpub"),
+ AttestParams: tpm.AttestationParametersFromProto(attParams),
+ },
+
+ tokenSpec: tokenSpec(nil),
+ validateReturnTPM: &tpm.ValidatedTPM{
+ EKPubHash: "mock-ekpub-hashed-mismatched!",
+ },
+ },
+ {
+ name: "failure, mismatched ekcert",
+ assertError: allowRulesNotMatched,
+
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ JoinRequest: joinRequest(),
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{
+ EkCert: []byte("mock-ekcert"),
+ },
+ AttestationParams: attParams,
+ },
+ wantParams: &tpm.ValidateParams{
+ EKCert: []byte("mock-ekcert"),
+ AttestParams: tpm.AttestationParametersFromProto(attParams),
+ },
+
+ tokenSpec: tokenSpec(nil),
+ validateReturnTPM: &tpm.ValidatedTPM{
+ EKCertSerial: "mock-ekcert-serial-mismatched!",
+ },
+ },
+ {
+ name: "failure, verification",
+ assertError: func(t require.TestingT, err error, i ...interface{}) {
+ assert.ErrorContains(t, err, "capacitor overcharged")
+ },
+
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ JoinRequest: joinRequest(),
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{
+ EkCert: []byte("mock-ekcert"),
+ },
+ AttestationParams: attParams,
+ },
+ wantParams: &tpm.ValidateParams{
+ EKCert: []byte("mock-ekcert"),
+ AttestParams: tpm.AttestationParametersFromProto(attParams),
+ },
+
+ tokenSpec: tokenSpec(nil),
+ validateReturnTPM: &tpm.ValidatedTPM{
+ EKCertSerial: goodEKCertSerial,
+ },
+ validateReturnErr: errors.New("capacitor overcharged"),
+ },
+ {
+ name: "failure, no enterprise",
+ setOSS: true,
+ assertError: func(t require.TestingT, err error, i ...interface{}) {
+ assert.ErrorIs(t, err, ErrRequiresEnterprise)
+ },
+
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ JoinRequest: joinRequest(),
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{
+ EkCert: []byte("mock-ekcert"),
+ },
+ AttestationParams: attParams,
+ },
+
+ tokenSpec: tokenSpec(nil),
+ },
+ }
+
+ solver := func(t *testing.T) func(ec *proto.TPMEncryptedCredential) (
+ *proto.RegisterUsingTPMMethodChallengeResponse, error,
+ ) {
+ return func(ec *proto.TPMEncryptedCredential) (
+ *proto.RegisterUsingTPMMethodChallengeResponse, error,
+ ) {
+ assert.Equal(t, []byte("mock-secret"), ec.Secret)
+ assert.Equal(t, []byte("mock-credential"), ec.CredentialBlob)
+ return &proto.RegisterUsingTPMMethodChallengeResponse{
+ Solution: []byte("mock-solution"),
+ }, nil
+ }
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockValidator.setup(tt.validateReturnTPM, tt.validateReturnErr)
+ if !tt.setOSS {
+ modules.SetTestModules(
+ t,
+ &modules.TestModules{TestBuildType: modules.BuildEnterprise},
+ )
+ }
+
+ token, err := types.NewProvisionTokenFromSpec(
+ tt.name, time.Now().Add(time.Minute), tt.tokenSpec,
+ )
+ require.NoError(t, err)
+ require.NoError(t, auth.CreateToken(ctx, token))
+ tt.initReq.JoinRequest.Token = tt.name
+
+ _, err = auth.registerUsingTPMMethod(
+ ctx,
+ tt.initReq,
+ solver(t))
+ tt.assertError(t, err)
+
+ assert.Empty(t,
+ gocmp.Diff(
+ tt.wantParams,
+ mockValidator.lastCalledParams,
+ cmpopts.IgnoreFields(tpm.ValidateParams{}, "Solve"),
+ ),
+ )
+ })
+ }
+}
diff --git a/lib/auth/machineid/machineidv1/bot_service.go b/lib/auth/machineid/machineidv1/bot_service.go
index 893c2e291449c..ee898d6614cf0 100644
--- a/lib/auth/machineid/machineidv1/bot_service.go
+++ b/lib/auth/machineid/machineidv1/bot_service.go
@@ -53,6 +53,7 @@ var SupportedJoinMethods = []types.JoinMethod{
types.JoinMethodKubernetes,
types.JoinMethodSpacelift,
types.JoinMethodToken,
+ types.JoinMethodTPM,
}
// BotResourceName returns the default name for resources associated with the
diff --git a/lib/joinserver/joinserver.go b/lib/joinserver/joinserver.go
index 09fcb270e9522..82aebfb65968a 100644
--- a/lib/joinserver/joinserver.go
+++ b/lib/joinserver/joinserver.go
@@ -22,6 +22,7 @@ package joinserver
import (
"context"
+ "log/slog"
"net"
"slices"
"time"
@@ -41,15 +42,25 @@ import (
const (
iamJoinRequestTimeout = time.Minute
azureJoinRequestTimeout = time.Minute
+ tpmJoinRequestTimeout = time.Minute
)
type joinServiceClient interface {
RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error)
RegisterUsingAzureMethod(ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc) (*proto.Certs, error)
+ RegisterUsingTPMMethod(
+ ctx context.Context,
+ initReq *proto.RegisterUsingTPMMethodInitialRequest,
+ solveChallenge client.RegisterTPMChallengeResponseFunc,
+ ) (*proto.Certs, error)
}
// JoinServiceGRPCServer implements proto.JoinServiceServer and is designed
// to run on both the Teleport Proxy and Auth servers.
+//
+// On the Proxy, this uses a gRPC client to forward the request to the Auth
+// server. On the Auth Server, this is passed to auth.ServerWithRoles and
+// through to auth.Server to be handled.
type JoinServiceGRPCServer struct {
*proto.UnimplementedJoinServiceServer
@@ -221,3 +232,111 @@ func (s *JoinServiceGRPCServer) registerUsingAzureMethod(srv proto.JoinService_R
Certs: certs,
}))
}
+
+// RegisterUsingTPMMethod allows nodes and bots to join the cluster using the
+// TPM join method.
+//
+// When running on the Auth server, this method will call the
+// auth.ServerWithRoles's RegisterUsingTPMMethod method. When running on the
+// Proxy, this method will forward the request to the Auth server.
+func (s *JoinServiceGRPCServer) RegisterUsingTPMMethod(srv proto.JoinService_RegisterUsingTPMMethodServer) error {
+ ctx := srv.Context()
+
+ // Enforce a timeout on the entire RPC so that misbehaving clients cannot
+ // hold connections open indefinitely.
+ timeout := s.clock.After(tpmJoinRequestTimeout)
+
+ // The only way to cancel a blocked Send or Recv on the server side without
+ // adding an interceptor to the entire gRPC service is to return from the
+ // handler https://github.com/grpc/grpc-go/issues/465#issuecomment-179414474
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- s.registerUsingTPMMethod(ctx, srv)
+ }()
+ select {
+ case err := <-errCh:
+ // Completed before the deadline, return the error (may be nil).
+ return trace.Wrap(err)
+ case <-timeout:
+ nodeAddr := ""
+ if peerInfo, ok := peer.FromContext(ctx); ok {
+ nodeAddr = peerInfo.Addr.String()
+ }
+ slog.WarnContext(
+ srv.Context(),
+ "TPM join attempt timed out, node is misbehaving or did not close the connection after encountering an error",
+ "node_addr", nodeAddr,
+ )
+ // Returning here should cancel any blocked Send or Recv operations.
+ return trace.LimitExceeded(
+ "RegisterUsingTPMMethod timed out after %s, terminating the stream on the server",
+ tpmJoinRequestTimeout,
+ )
+ case <-ctx.Done():
+ return trace.Wrap(ctx.Err())
+ }
+}
+
+func (s *JoinServiceGRPCServer) registerUsingTPMMethod(
+ ctx context.Context, srv proto.JoinService_RegisterUsingTPMMethodServer,
+) error {
+ // Get initial payload from the client
+ req, err := srv.Recv()
+ if err != nil {
+ return trace.Wrap(err, "receiving initial payload")
+ }
+ initReq := req.GetInit()
+ if initReq == nil {
+ return trace.BadParameter("expected non-nil Init payload")
+ }
+ if initReq.JoinRequest == nil {
+ return trace.BadParameter(
+ "expected JoinRequest in RegisterUsingTPMMethodRequest_Init, got nil",
+ )
+ }
+ if err := setClientRemoteAddr(ctx, initReq.JoinRequest); err != nil {
+ return trace.Wrap(err, "setting client address")
+ }
+
+ certs, err := s.joinServiceClient.RegisterUsingTPMMethod(
+ ctx,
+ initReq,
+ func(challenge *proto.TPMEncryptedCredential,
+ ) (*proto.RegisterUsingTPMMethodChallengeResponse, error) {
+ // First, forward the challenge from Auth to the client.
+ err := srv.Send(&proto.RegisterUsingTPMMethodResponse{
+ Payload: &proto.RegisterUsingTPMMethodResponse_ChallengeRequest{
+ ChallengeRequest: challenge,
+ },
+ })
+ if err != nil {
+ return nil, trace.Wrap(
+ err, "forwarding challenge to client",
+ )
+ }
+ // Get response from Client
+ req, err := srv.Recv()
+ if err != nil {
+ return nil, trace.Wrap(
+ err, "receiving challenge solution from client",
+ )
+ }
+ challengeResponse := req.GetChallengeResponse()
+ if challengeResponse == nil {
+ return nil, trace.BadParameter(
+ "expected non-nil ChallengeResponse payload",
+ )
+ }
+ return challengeResponse, nil
+ })
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // finally, send the certs on the response stream
+ return trace.Wrap(srv.Send(&proto.RegisterUsingTPMMethodResponse{
+ Payload: &proto.RegisterUsingTPMMethodResponse_Certs{
+ Certs: certs,
+ },
+ }))
+}
diff --git a/lib/joinserver/joinserver_test.go b/lib/joinserver/joinserver_test.go
index 82028b1b6fc67..0189c5db0945a 100644
--- a/lib/joinserver/joinserver_test.go
+++ b/lib/joinserver/joinserver_test.go
@@ -20,6 +20,7 @@ package joinserver
import (
"context"
+ "fmt"
"net"
"sync"
"sync/atomic"
@@ -28,6 +29,7 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
@@ -45,6 +47,8 @@ type mockJoinServiceClient struct {
returnError error
gotIAMChallengeResponse *proto.RegisterUsingIAMMethodRequest
gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest
+ gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse
+ gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest
}
func (c *mockJoinServiceClient) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) {
@@ -65,6 +69,22 @@ func (c *mockJoinServiceClient) RegisterUsingAzureMethod(ctx context.Context, ch
return c.returnCerts, c.returnError
}
+func (c *mockJoinServiceClient) RegisterUsingTPMMethod(
+ ctx context.Context,
+ initReq *proto.RegisterUsingTPMMethodInitialRequest,
+ challengeResponse client.RegisterTPMChallengeResponseFunc,
+) (*proto.Certs, error) {
+ c.gotTPMInitReq = initReq
+ resp, err := challengeResponse(&proto.TPMEncryptedCredential{
+ Secret: []byte(c.sendChallenge),
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ c.gotTPMChallengeResponse = resp
+ return c.returnCerts, c.returnError
+}
+
func ConnectionCountingStreamInterceptor(count *atomic.Int32) grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
count.Add(1)
@@ -303,6 +323,124 @@ func TestJoinServiceGRPCServer_RegisterUsingAzureMethod(t *testing.T) {
}
}
+func TestJoinServiceGRPCServer_RegisterUsingTPMMethod(t *testing.T) {
+ t.Parallel()
+ testPack := newTestPack(t)
+
+ testCases := []struct {
+ desc string
+ challenge string
+ initReq *proto.RegisterUsingTPMMethodInitialRequest
+ challengeResponse *proto.RegisterUsingTPMMethodChallengeResponse
+ challengeResponseErr error
+ authErr string
+ certs *proto.Certs
+ }{
+ {
+ desc: "pass case",
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{
+ EkKey: []byte("llama"),
+ },
+ JoinRequest: &types.RegisterUsingTokenRequest{
+ Token: "xyzzy",
+ },
+ },
+ challenge: "foo",
+ challengeResponse: &proto.RegisterUsingTPMMethodChallengeResponse{
+ Solution: []byte("bar"),
+ },
+ certs: &proto.Certs{SSH: []byte("qux")},
+ },
+ {
+ desc: "auth error",
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{
+ EkKey: []byte("llama"),
+ },
+ JoinRequest: &types.RegisterUsingTokenRequest{
+ Token: "xyzzy",
+ },
+ },
+ challenge: "foo",
+ challengeResponse: &proto.RegisterUsingTPMMethodChallengeResponse{
+ Solution: []byte("bar"),
+ },
+ authErr: "test auth error",
+ },
+ {
+ desc: "challenge response error",
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{
+ EkKey: []byte("llama"),
+ },
+ JoinRequest: &types.RegisterUsingTokenRequest{
+ Token: "xyzzy",
+ },
+ },
+ challenge: "foo",
+ challengeResponseErr: trace.BadParameter("test challenge error"),
+ },
+ {
+ desc: "missing join request",
+ initReq: &proto.RegisterUsingTPMMethodInitialRequest{
+ Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{
+ EkKey: []byte("llama"),
+ },
+ JoinRequest: nil,
+ },
+ challenge: "foo",
+ authErr: "expected JoinRequest in RegisterUsingTPMMethodRequest_Init, got nil",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ testPack.mockAuthServer.sendChallenge = tc.challenge
+ testPack.mockAuthServer.returnCerts = tc.certs
+ if tc.authErr != "" {
+ testPack.mockAuthServer.returnError = fmt.Errorf(tc.authErr)
+ }
+ challengeResponder := func(
+ challenge *proto.TPMEncryptedCredential,
+ ) (*proto.RegisterUsingTPMMethodChallengeResponse, error) {
+ assert.Equal(t, &proto.TPMEncryptedCredential{
+ Secret: []byte(tc.challenge),
+ }, challenge)
+ return tc.challengeResponse, tc.challengeResponseErr
+ }
+
+ for suffix, clt := range map[string]*client.JoinServiceClient{
+ "_auth": testPack.authClient,
+ "_proxy": testPack.proxyClient,
+ } {
+ t.Run(tc.desc+suffix, func(t *testing.T) {
+ certs, err := clt.RegisterUsingTPMMethod(
+ context.Background(), tc.initReq, challengeResponder,
+ )
+ if tc.challengeResponseErr != nil {
+ require.ErrorIs(t, err, tc.challengeResponseErr)
+ return
+ }
+ if tc.authErr != "" {
+ require.ErrorContains(t, err, tc.authErr, "authErr mismatch")
+ return
+ }
+ if assert.NoError(t, err) {
+ assert.Equal(t, tc.certs, certs)
+ }
+ expectedInitReq := tc.initReq
+ expectedInitReq.JoinRequest.RemoteAddr = "bufconn"
+ assert.Equal(
+ t,
+ expectedInitReq,
+ testPack.mockAuthServer.gotTPMInitReq,
+ )
+ })
+ }
+ })
+ }
+}
+
func TestTimeout(t *testing.T) {
t.Parallel()
diff --git a/lib/tpm/validate.go b/lib/tpm/validate.go
index d0c9b74736342..268857d35e4ff 100644
--- a/lib/tpm/validate.go
+++ b/lib/tpm/validate.go
@@ -184,6 +184,8 @@ func verifyEKCert(
return trace.BadParameter("tpm did not provide an EKCert to validate against allowed CAs")
}
+ StripSANExtensionOIDs(ekCert)
+
// Validate EKCert against CA pool
_, err := ekCert.Verify(x509.VerifyOptions{
Roots: allowedCAs,
@@ -199,3 +201,28 @@ func verifyEKCert(
}
return nil
}
+
+var sanExtensionOID = []int{2, 5, 29, 17}
+
+// StripSANExtensionOIDs removes the SAN Extension OID from the specified
+// cert. This method may re-assign the remaining extensions out of order.
+//
+// This is necessary because the EKCert may contain additional data
+// bundled within the SAN extension. This ext is also sometimes marked
+// critical. This causes the Verify() to reject the cert because not all data
+// within a critical extension has been handled. We mark this as OK here by
+// stripping the SAN Extension OID out of UnhandledCriticalExtensions.
+func StripSANExtensionOIDs(cert *x509.Certificate) {
+ for i := 0; i < len(cert.UnhandledCriticalExtensions); i++ {
+ ext := cert.UnhandledCriticalExtensions[i]
+ if !ext.Equal(sanExtensionOID) {
+ continue
+ }
+ // Swap ext with the last index and remove it.
+ last := len(cert.UnhandledCriticalExtensions) - 1
+ cert.UnhandledCriticalExtensions[i] = cert.UnhandledCriticalExtensions[last]
+ cert.UnhandledCriticalExtensions[last] = nil // "Release" extension
+ cert.UnhandledCriticalExtensions = cert.UnhandledCriticalExtensions[:last]
+ i--
+ }
+}
diff --git a/lib/tpm/validate_test.go b/lib/tpm/validate_test.go
new file mode 100644
index 0000000000000..3821574fefb9d
--- /dev/null
+++ b/lib/tpm/validate_test.go
@@ -0,0 +1,68 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+package tpm
+
+import (
+ "crypto/x509"
+ "encoding/pem"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// tpmEKCertPEM is the real RSA 2048 EK certificate. This was captured from
+// Noah's Infineon SLB9665 TPM.
+const tpmEKCertPEM = `-----BEGIN CERTIFICATE-----
+MIIElTCCA32gAwIBAgIEXs1fjjANBgkqhkiG9w0BAQsFADCBgzELMAkGA1UEBhMC
+REUxITAfBgNVBAoMGEluZmluZW9uIFRlY2hub2xvZ2llcyBBRzEaMBgGA1UECwwR
+T1BUSUdBKFRNKSBUUE0yLjAxNTAzBgNVBAMMLEluZmluZW9uIE9QVElHQShUTSkg
+UlNBIE1hbnVmYWN0dXJpbmcgQ0EgMDM2MB4XDTIzMDExNzA4MDY0MVoXDTM3MDgy
+MDIzNTk1OVowADCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAJ0qMMu+
+SRyCrhKcEQvKXk+Md2rdZC317Nqmhjf7rXJ527DX051XMTKfy+SXfalkqdT8IQkd
+aUPYC/m8XWz7/J/9781dVt7rOw1CJsEk9DFoaInQmL2E5dUDgsA8Em942o2r1x7K
+NdigHrLRQetn/CJkODYeBnHmmQUpU9syZ86Dhxl5tK1Sq2ddCm5Z/RCy+LIRBrpl
+qstrTsY3Wyj0aqt/Opikq3geSkW+viG9ipk/D5J3i/qbdHQHSWZqD6ImixTmqIZf
+I3u9QCVftoeWuJxrXgUPHdyyO6lMXDUgW918912Ihr6ZRBY6jUZT2Y8II+T1IRGT
+/ymr26W7Mf7qfYkCAwEAAaOCAZEwggGNMFsGCCsGAQUFBwEBBE8wTTBLBggrBgEF
+BQcwAoY/aHR0cDovL3BraS5pbmZpbmVvbi5jb20vT3B0aWdhUnNhTWZyQ0EwMzYv
+T3B0aWdhUnNhTWZyQ0EwMzYuY3J0MA4GA1UdDwEB/wQEAwIAIDBRBgNVHREBAf8E
+RzBFpEMwQTEWMBQGBWeBBQIBDAtpZDo0OTQ2NTgwMDETMBEGBWeBBQICDAhTTEIg
+OTY2NTESMBAGBWeBBQIDDAdpZDowNTNmMAwGA1UdEwEB/wQCMAAwUAYDVR0fBEkw
+RzBFoEOgQYY/aHR0cDovL3BraS5pbmZpbmVvbi5jb20vT3B0aWdhUnNhTWZyQ0Ew
+MzYvT3B0aWdhUnNhTWZyQ0EwMzYuY3JsMBUGA1UdIAQOMAwwCgYIKoIUAEQBFAEw
+HwYDVR0jBBgwFoAUfLS3jmiGFL5EIcWFjxW5bV6rUe4wEAYDVR0lBAkwBwYFZ4EF
+CAEwIQYDVR0JBBowGDAWBgVngQUCEDENMAsMAzIuMAIBAAIBdDANBgkqhkiG9w0B
+AQsFAAOCAQEACSSM+6o4INqV7mJ+aD5kPH6BkbEPhJBsYRA6vka+911Th7JfGZA7
+4C1ig4EjD1qUaRvkwNoDbGr3MRiNPHan3PLJkBy+WSERWglnBlXooJnkncWsNGwm
+lzCTAYPOKZSTLiZiijvzW1XO+VqaTCMTkTpegO3MnE6xXZhXSyXQs8ro7qY6cTBd
+whcEtufTT4khxMhjRTUBocqlLlN8PifG6xL2GD6xAW/PplL0uQFLUgnY0U4WQ/oP
+nK/NX7N02p23JzlhDgcOdRrF3hYi8huKuoe3YfVJCTJLfSwxHytgXqJdpiRSHvut
+Upw0EsfcY7cSbnlMWx2n4c7ptVvqiLXTkg==
+-----END CERTIFICATE-----`
+
+func TestStripSANExtensionsOIDs(t *testing.T) {
+ b, _ := pem.Decode([]byte(tpmEKCertPEM))
+ require.NotNil(t, b)
+ c, err := x509.ParseCertificate(b.Bytes)
+ require.NoError(t, err)
+
+ assert.Len(t, c.UnhandledCriticalExtensions, 1)
+ StripSANExtensionOIDs(c)
+ assert.Empty(t, c.UnhandledCriticalExtensions)
+}