diff --git a/lib/auth/auth.go b/lib/auth/auth.go index b28cea169933c..d766ae1d5f2f6 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -35,6 +35,7 @@ import ( "errors" "fmt" "io" + "log/slog" "math" "math/big" insecurerand "math/rand" @@ -113,6 +114,7 @@ import ( "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/tpm" usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/interval" @@ -497,6 +499,9 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { ) } } + if as.tpmValidator == nil { + as.tpmValidator = tpm.Validate + } if as.k8sTokenReviewValidator == nil { as.k8sTokenReviewValidator = &kubernetestoken.TokenReviewValidator{} } @@ -891,6 +896,12 @@ type Server struct { // the auth server. It can be overridden for the purpose of tests. gitlabIDTokenValidator gitlabIDTokenValidator + // tpmValidator allows TPMs to be validated by the auth server. It can be + // overridden for the purpose of tests. + tpmValidator func( + ctx context.Context, log *slog.Logger, params tpm.ValidateParams, + ) (*tpm.ValidatedTPM, error) + // circleCITokenValidate allows ID tokens from CircleCI to be validated by // the auth server. It can be overridden for the purpose of tests. circleCITokenValidate func(ctx context.Context, organizationID, token string) (*circleci.IDTokenClaims, error) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index b86ca0f21ab39..68fd7ebe55bea 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -631,6 +631,24 @@ func (a *ServerWithRoles) RegisterUsingAzureMethod(ctx context.Context, challeng return certs, trace.Wrap(err) } +// RegisterUsingTPMMethod registers the caller using the TPM join method and +// returns signed certs to join the cluster. +// +// See (*Server).RegisterUsingTPMMethod for further documentation. +// +// This wrapper does not do any extra authz checks, as the register method has +// its own authz mechanism. +func (a *ServerWithRoles) RegisterUsingTPMMethod( + ctx context.Context, + initReq *proto.RegisterUsingTPMMethodInitialRequest, + solveChallenge client.RegisterTPMChallengeResponseFunc, +) (*proto.Certs, error) { + certs, err := a.authServer.registerUsingTPMMethod( + ctx, initReq, solveChallenge, + ) + return certs, trace.Wrap(err) +} + // GenerateHostCerts generates new host certificates (signed // by the host certificate authority) for a node. func (a *ServerWithRoles) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequest) (*proto.Certs, error) { @@ -2079,15 +2097,20 @@ func enforceEnterpriseJoinMethodCreation(token types.ProvisionToken) error { switch v.Spec.JoinMethod { case types.JoinMethodGitHub: if v.Spec.GitHub != nil && v.Spec.GitHub.EnterpriseServerHost != "" { - return fmt.Errorf( - "github enterprise server joining: %w", + return trace.Wrap( ErrRequiresEnterprise, + "github enterprise server joining", ) } case types.JoinMethodSpacelift: - return fmt.Errorf( - "spacelift joining: %w", + return trace.Wrap( + ErrRequiresEnterprise, + "spacelift joining", + ) + case types.JoinMethodTPM: + return trace.Wrap( ErrRequiresEnterprise, + "tpm joining", ) } diff --git a/lib/auth/join.go b/lib/auth/join.go index 7ed55f055a6bf..7507c7bb6b244 100644 --- a/lib/auth/join.go +++ b/lib/auth/join.go @@ -227,7 +227,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin if err := a.checkEC2JoinRequest(ctx, req); err != nil { return nil, trace.Wrap(err) } - case types.JoinMethodIAM, types.JoinMethodAzure: + case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM: // IAM and Azure join methods must use gRPC register methods return nil, trace.AccessDenied("this token is only valid for the %s "+ "join method but the node has connected to the wrong endpoint, make "+ diff --git a/lib/auth/join_tpm.go b/lib/auth/join_tpm.go new file mode 100644 index 0000000000000..4304628b9d07c --- /dev/null +++ b/lib/auth/join_tpm.go @@ -0,0 +1,137 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package auth + +import ( + "context" + "crypto/x509" + "log/slog" + + "github.com/google/go-attestation/attest" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/tpm" +) + +func (a *Server) registerUsingTPMMethod( + ctx context.Context, + initReq *proto.RegisterUsingTPMMethodInitialRequest, + solveChallenge client.RegisterTPMChallengeResponseFunc, +) (_ *proto.Certs, err error) { + var provisionToken types.ProvisionToken + var attributeSrc joinAttributeSourcer + defer func() { + // Emit a log message and audit event on join failure. + if err != nil { + a.handleJoinFailure( + err, provisionToken, attributeSrc, initReq.JoinRequest, + ) + } + }() + + // First, check the specified token exists, and is a TPM-type join token. + if err := initReq.JoinRequest.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + provisionToken, err = a.checkTokenJoinRequestCommon(ctx, initReq.JoinRequest) + if err != nil { + return nil, trace.Wrap(err) + } + ptv2, ok := provisionToken.(*types.ProvisionTokenV2) + if !ok { + return nil, trace.BadParameter("expected *types.ProvisionTokenV2, got %T", provisionToken) + } + if ptv2.Spec.JoinMethod != types.JoinMethodTPM { + return nil, trace.BadParameter("specified join token is not for `tpm` method") + } + + if modules.GetModules().BuildType() != modules.BuildEnterprise { + return nil, trace.Wrap( + ErrRequiresEnterprise, + "tpm joining", + ) + } + + // Convert configured CAs to a CAPool + var certPool *x509.CertPool + if len(ptv2.Spec.TPM.EKCertAllowedCAs) > 0 { + certPool = x509.NewCertPool() + for i, ca := range ptv2.Spec.TPM.EKCertAllowedCAs { + if ok := certPool.AppendCertsFromPEM([]byte(ca)); !ok { + return nil, trace.BadParameter( + "ekcert_allowed_cas[%d] has an invalid or malformed PEM", i, + ) + } + } + } + + // TODO(noah): Use logger from TeleportProcess. + validatedEK, err := a.tpmValidator(ctx, slog.Default(), tpm.ValidateParams{ + EKCert: initReq.GetEkCert(), + EKKey: initReq.GetEkKey(), + AttestParams: tpm.AttestationParametersFromProto(initReq.AttestationParams), + AllowedCAs: certPool, + Solve: func(ec *attest.EncryptedCredential) ([]byte, error) { + solution, err := solveChallenge(tpm.EncryptedCredentialToProto(ec)) + if err != nil { + return nil, trace.Wrap(err) + } + return solution.Solution, nil + }, + }) + if err != nil { + return nil, trace.Wrap(err, "validating TPM EK") + } + attributeSrc = validatedEK + + if err := checkTPMAllowRules(validatedEK, ptv2.Spec.TPM.Allow); err != nil { + return nil, trace.Wrap(err) + } + + if initReq.JoinRequest.Role == types.RoleBot { + certs, err := a.generateCertsBot( + ctx, ptv2, initReq.JoinRequest, validatedEK, + ) + return certs, trace.Wrap(err, "generating certs for bot") + } + certs, err := a.generateCerts( + ctx, ptv2, initReq.JoinRequest, validatedEK, + ) + return certs, trace.Wrap(err, "generating certs for host") +} + +func checkTPMAllowRules(tpm *tpm.ValidatedTPM, rules []*types.ProvisionTokenSpecV2TPM_Rule) error { + // If a single rule passes, accept the TPM + for _, rule := range rules { + if rule.EKPublicHash != "" && tpm.EKPubHash != rule.EKPublicHash { + continue + } + if rule.EKCertificateSerial != "" && tpm.EKCertSerial != rule.EKCertificateSerial { + continue + } + + // All rules met. + return nil + } + return trace.AccessDenied("validated tpm attributes did not match any allow rules") +} diff --git a/lib/auth/join_tpm_test.go b/lib/auth/join_tpm_test.go new file mode 100644 index 0000000000000..8210a2f975185 --- /dev/null +++ b/lib/auth/join_tpm_test.go @@ -0,0 +1,358 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package auth + +import ( + "bytes" + "context" + "crypto/x509" + "errors" + "log/slog" + "testing" + "time" + + "github.com/google/go-attestation/attest" + gocmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/client/proto" + apifixtures "github.com/gravitational/teleport/api/fixtures" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/tpm" +) + +type mockTPMValidator struct { + lastCalledParams *tpm.ValidateParams + returnErr error + returnValidatedTPM *tpm.ValidatedTPM +} + +func (m *mockTPMValidator) setup(returns *tpm.ValidatedTPM, err error) { + m.lastCalledParams = nil + m.returnErr = err + m.returnValidatedTPM = returns +} + +func (m *mockTPMValidator) validate( + _ context.Context, _ *slog.Logger, params tpm.ValidateParams, +) (*tpm.ValidatedTPM, error) { + m.lastCalledParams = ¶ms + + solution, err := params.Solve(&attest.EncryptedCredential{ + Secret: []byte("mock-secret"), + Credential: []byte("mock-credential"), + }) + if err != nil { + return nil, trace.Wrap(err) + } + if !bytes.Equal(solution, []byte("mock-solution")) { + return nil, trace.AccessDenied("invalid solution") + } + + return m.returnValidatedTPM, m.returnErr +} + +func TestServer_RegisterUsingTPMMethod(t *testing.T) { + ctx := context.Background() + mockValidator := &mockTPMValidator{} + p, err := newTestPack(ctx, t.TempDir(), func(server *Server) error { + server.tpmValidator = mockValidator.validate + return nil + }) + require.NoError(t, err) + auth := p.a + + sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey) + require.NoError(t, err) + + attParams := &proto.TPMAttestationParameters{ + Public: []byte("mock-public"), + } + + const ( + goodEKPubHash = "mock-ekpub-hashed" + goodEKCertSerial = "mock-ekcert-serial" + goodEKPubHashAlt = "mock-ekpub-hashed-alt" + goodEKCertSerialAlt = "mock-ekcert-serial-alt" + ) + tokenSpec := func(mutate func(v2 *types.ProvisionTokenSpecV2)) types.ProvisionTokenSpecV2 { + spec := types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodTPM, + Roles: []types.SystemRole{types.RoleNode}, + TPM: &types.ProvisionTokenSpecV2TPM{ + Allow: []*types.ProvisionTokenSpecV2TPM_Rule{ + { + Description: "ekpub only", + EKPublicHash: goodEKPubHash, + }, + { + Description: "ekcert only", + EKCertificateSerial: goodEKCertSerial, + }, + { + Description: "both", + EKPublicHash: goodEKPubHashAlt, + EKCertificateSerial: goodEKCertSerialAlt, + }, + }, + }, + } + if mutate != nil { + mutate(&spec) + } + return spec + } + joinRequest := func() *types.RegisterUsingTokenRequest { + return &types.RegisterUsingTokenRequest{ + HostID: "host-id", + Role: types.RoleNode, + PublicTLSKey: tlsPublicKey, + PublicSSHKey: sshPublicKey, + } + } + + caPool := x509.NewCertPool() + require.True(t, caPool.AppendCertsFromPEM([]byte(apifixtures.TLSCACertPEM))) + + allowRulesNotMatched := require.ErrorAssertionFunc(func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "validated tpm attributes did not match any allow rules") + require.True(t, trace.IsAccessDenied(err)) + }) + tests := []struct { + name string + setOSS bool + + tokenSpec types.ProvisionTokenSpecV2 + + validateReturnTPM *tpm.ValidatedTPM + validateReturnErr error + + initReq *proto.RegisterUsingTPMMethodInitialRequest + wantParams *tpm.ValidateParams + + assertError require.ErrorAssertionFunc + }{ + { + name: "success, ekpub", + assertError: require.NoError, + + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: joinRequest(), + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ + EkKey: []byte("mock-ekpub"), + }, + AttestationParams: attParams, + }, + wantParams: &tpm.ValidateParams{ + EKKey: []byte("mock-ekpub"), + AttestParams: tpm.AttestationParametersFromProto(attParams), + }, + + tokenSpec: tokenSpec(nil), + validateReturnTPM: &tpm.ValidatedTPM{ + EKPubHash: goodEKPubHash, + }, + }, + { + name: "success, ekcert", + assertError: require.NoError, + + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: joinRequest(), + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{ + EkCert: []byte("mock-ekcert"), + }, + AttestationParams: attParams, + }, + wantParams: &tpm.ValidateParams{ + EKCert: []byte("mock-ekcert"), + AttestParams: tpm.AttestationParametersFromProto(attParams), + AllowedCAs: caPool, + }, + + tokenSpec: tokenSpec(func(v2 *types.ProvisionTokenSpecV2) { + v2.TPM.EKCertAllowedCAs = []string{apifixtures.TLSCACertPEM} + }), + validateReturnTPM: &tpm.ValidatedTPM{ + EKCertSerial: goodEKCertSerial, + EKCertVerified: true, + }, + }, + { + name: "success, both ek cert serial and ek pub hash match", + assertError: require.NoError, + + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: joinRequest(), + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{ + EkCert: []byte("mock-ekcert"), + }, + AttestationParams: attParams, + }, + wantParams: &tpm.ValidateParams{ + EKCert: []byte("mock-ekcert"), + AttestParams: tpm.AttestationParametersFromProto(attParams), + }, + + tokenSpec: tokenSpec(nil), + validateReturnTPM: &tpm.ValidatedTPM{ + EKCertSerial: goodEKCertSerialAlt, + EKPubHash: goodEKPubHashAlt, + EKCertVerified: true, + }, + }, + { + name: "failure, mismatched ekpub", + assertError: allowRulesNotMatched, + + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: joinRequest(), + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ + EkKey: []byte("mock-ekpub"), + }, + AttestationParams: attParams, + }, + wantParams: &tpm.ValidateParams{ + EKKey: []byte("mock-ekpub"), + AttestParams: tpm.AttestationParametersFromProto(attParams), + }, + + tokenSpec: tokenSpec(nil), + validateReturnTPM: &tpm.ValidatedTPM{ + EKPubHash: "mock-ekpub-hashed-mismatched!", + }, + }, + { + name: "failure, mismatched ekcert", + assertError: allowRulesNotMatched, + + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: joinRequest(), + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{ + EkCert: []byte("mock-ekcert"), + }, + AttestationParams: attParams, + }, + wantParams: &tpm.ValidateParams{ + EKCert: []byte("mock-ekcert"), + AttestParams: tpm.AttestationParametersFromProto(attParams), + }, + + tokenSpec: tokenSpec(nil), + validateReturnTPM: &tpm.ValidatedTPM{ + EKCertSerial: "mock-ekcert-serial-mismatched!", + }, + }, + { + name: "failure, verification", + assertError: func(t require.TestingT, err error, i ...interface{}) { + assert.ErrorContains(t, err, "capacitor overcharged") + }, + + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: joinRequest(), + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{ + EkCert: []byte("mock-ekcert"), + }, + AttestationParams: attParams, + }, + wantParams: &tpm.ValidateParams{ + EKCert: []byte("mock-ekcert"), + AttestParams: tpm.AttestationParametersFromProto(attParams), + }, + + tokenSpec: tokenSpec(nil), + validateReturnTPM: &tpm.ValidatedTPM{ + EKCertSerial: goodEKCertSerial, + }, + validateReturnErr: errors.New("capacitor overcharged"), + }, + { + name: "failure, no enterprise", + setOSS: true, + assertError: func(t require.TestingT, err error, i ...interface{}) { + assert.ErrorIs(t, err, ErrRequiresEnterprise) + }, + + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: joinRequest(), + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkCert{ + EkCert: []byte("mock-ekcert"), + }, + AttestationParams: attParams, + }, + + tokenSpec: tokenSpec(nil), + }, + } + + solver := func(t *testing.T) func(ec *proto.TPMEncryptedCredential) ( + *proto.RegisterUsingTPMMethodChallengeResponse, error, + ) { + return func(ec *proto.TPMEncryptedCredential) ( + *proto.RegisterUsingTPMMethodChallengeResponse, error, + ) { + assert.Equal(t, []byte("mock-secret"), ec.Secret) + assert.Equal(t, []byte("mock-credential"), ec.CredentialBlob) + return &proto.RegisterUsingTPMMethodChallengeResponse{ + Solution: []byte("mock-solution"), + }, nil + } + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockValidator.setup(tt.validateReturnTPM, tt.validateReturnErr) + if !tt.setOSS { + modules.SetTestModules( + t, + &modules.TestModules{TestBuildType: modules.BuildEnterprise}, + ) + } + + token, err := types.NewProvisionTokenFromSpec( + tt.name, time.Now().Add(time.Minute), tt.tokenSpec, + ) + require.NoError(t, err) + require.NoError(t, auth.CreateToken(ctx, token)) + tt.initReq.JoinRequest.Token = tt.name + + _, err = auth.registerUsingTPMMethod( + ctx, + tt.initReq, + solver(t)) + tt.assertError(t, err) + + assert.Empty(t, + gocmp.Diff( + tt.wantParams, + mockValidator.lastCalledParams, + cmpopts.IgnoreFields(tpm.ValidateParams{}, "Solve"), + ), + ) + }) + } +} diff --git a/lib/auth/machineid/machineidv1/bot_service.go b/lib/auth/machineid/machineidv1/bot_service.go index 893c2e291449c..ee898d6614cf0 100644 --- a/lib/auth/machineid/machineidv1/bot_service.go +++ b/lib/auth/machineid/machineidv1/bot_service.go @@ -53,6 +53,7 @@ var SupportedJoinMethods = []types.JoinMethod{ types.JoinMethodKubernetes, types.JoinMethodSpacelift, types.JoinMethodToken, + types.JoinMethodTPM, } // BotResourceName returns the default name for resources associated with the diff --git a/lib/joinserver/joinserver.go b/lib/joinserver/joinserver.go index 09fcb270e9522..82aebfb65968a 100644 --- a/lib/joinserver/joinserver.go +++ b/lib/joinserver/joinserver.go @@ -22,6 +22,7 @@ package joinserver import ( "context" + "log/slog" "net" "slices" "time" @@ -41,15 +42,25 @@ import ( const ( iamJoinRequestTimeout = time.Minute azureJoinRequestTimeout = time.Minute + tpmJoinRequestTimeout = time.Minute ) type joinServiceClient interface { RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) RegisterUsingAzureMethod(ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc) (*proto.Certs, error) + RegisterUsingTPMMethod( + ctx context.Context, + initReq *proto.RegisterUsingTPMMethodInitialRequest, + solveChallenge client.RegisterTPMChallengeResponseFunc, + ) (*proto.Certs, error) } // JoinServiceGRPCServer implements proto.JoinServiceServer and is designed // to run on both the Teleport Proxy and Auth servers. +// +// On the Proxy, this uses a gRPC client to forward the request to the Auth +// server. On the Auth Server, this is passed to auth.ServerWithRoles and +// through to auth.Server to be handled. type JoinServiceGRPCServer struct { *proto.UnimplementedJoinServiceServer @@ -221,3 +232,111 @@ func (s *JoinServiceGRPCServer) registerUsingAzureMethod(srv proto.JoinService_R Certs: certs, })) } + +// RegisterUsingTPMMethod allows nodes and bots to join the cluster using the +// TPM join method. +// +// When running on the Auth server, this method will call the +// auth.ServerWithRoles's RegisterUsingTPMMethod method. When running on the +// Proxy, this method will forward the request to the Auth server. +func (s *JoinServiceGRPCServer) RegisterUsingTPMMethod(srv proto.JoinService_RegisterUsingTPMMethodServer) error { + ctx := srv.Context() + + // Enforce a timeout on the entire RPC so that misbehaving clients cannot + // hold connections open indefinitely. + timeout := s.clock.After(tpmJoinRequestTimeout) + + // The only way to cancel a blocked Send or Recv on the server side without + // adding an interceptor to the entire gRPC service is to return from the + // handler https://github.com/grpc/grpc-go/issues/465#issuecomment-179414474 + errCh := make(chan error, 1) + go func() { + errCh <- s.registerUsingTPMMethod(ctx, srv) + }() + select { + case err := <-errCh: + // Completed before the deadline, return the error (may be nil). + return trace.Wrap(err) + case <-timeout: + nodeAddr := "" + if peerInfo, ok := peer.FromContext(ctx); ok { + nodeAddr = peerInfo.Addr.String() + } + slog.WarnContext( + srv.Context(), + "TPM join attempt timed out, node is misbehaving or did not close the connection after encountering an error", + "node_addr", nodeAddr, + ) + // Returning here should cancel any blocked Send or Recv operations. + return trace.LimitExceeded( + "RegisterUsingTPMMethod timed out after %s, terminating the stream on the server", + tpmJoinRequestTimeout, + ) + case <-ctx.Done(): + return trace.Wrap(ctx.Err()) + } +} + +func (s *JoinServiceGRPCServer) registerUsingTPMMethod( + ctx context.Context, srv proto.JoinService_RegisterUsingTPMMethodServer, +) error { + // Get initial payload from the client + req, err := srv.Recv() + if err != nil { + return trace.Wrap(err, "receiving initial payload") + } + initReq := req.GetInit() + if initReq == nil { + return trace.BadParameter("expected non-nil Init payload") + } + if initReq.JoinRequest == nil { + return trace.BadParameter( + "expected JoinRequest in RegisterUsingTPMMethodRequest_Init, got nil", + ) + } + if err := setClientRemoteAddr(ctx, initReq.JoinRequest); err != nil { + return trace.Wrap(err, "setting client address") + } + + certs, err := s.joinServiceClient.RegisterUsingTPMMethod( + ctx, + initReq, + func(challenge *proto.TPMEncryptedCredential, + ) (*proto.RegisterUsingTPMMethodChallengeResponse, error) { + // First, forward the challenge from Auth to the client. + err := srv.Send(&proto.RegisterUsingTPMMethodResponse{ + Payload: &proto.RegisterUsingTPMMethodResponse_ChallengeRequest{ + ChallengeRequest: challenge, + }, + }) + if err != nil { + return nil, trace.Wrap( + err, "forwarding challenge to client", + ) + } + // Get response from Client + req, err := srv.Recv() + if err != nil { + return nil, trace.Wrap( + err, "receiving challenge solution from client", + ) + } + challengeResponse := req.GetChallengeResponse() + if challengeResponse == nil { + return nil, trace.BadParameter( + "expected non-nil ChallengeResponse payload", + ) + } + return challengeResponse, nil + }) + if err != nil { + return trace.Wrap(err) + } + + // finally, send the certs on the response stream + return trace.Wrap(srv.Send(&proto.RegisterUsingTPMMethodResponse{ + Payload: &proto.RegisterUsingTPMMethodResponse_Certs{ + Certs: certs, + }, + })) +} diff --git a/lib/joinserver/joinserver_test.go b/lib/joinserver/joinserver_test.go index 82028b1b6fc67..0189c5db0945a 100644 --- a/lib/joinserver/joinserver_test.go +++ b/lib/joinserver/joinserver_test.go @@ -20,6 +20,7 @@ package joinserver import ( "context" + "fmt" "net" "sync" "sync/atomic" @@ -28,6 +29,7 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -45,6 +47,8 @@ type mockJoinServiceClient struct { returnError error gotIAMChallengeResponse *proto.RegisterUsingIAMMethodRequest gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest + gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse + gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest } func (c *mockJoinServiceClient) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) { @@ -65,6 +69,22 @@ func (c *mockJoinServiceClient) RegisterUsingAzureMethod(ctx context.Context, ch return c.returnCerts, c.returnError } +func (c *mockJoinServiceClient) RegisterUsingTPMMethod( + ctx context.Context, + initReq *proto.RegisterUsingTPMMethodInitialRequest, + challengeResponse client.RegisterTPMChallengeResponseFunc, +) (*proto.Certs, error) { + c.gotTPMInitReq = initReq + resp, err := challengeResponse(&proto.TPMEncryptedCredential{ + Secret: []byte(c.sendChallenge), + }) + if err != nil { + return nil, trace.Wrap(err) + } + c.gotTPMChallengeResponse = resp + return c.returnCerts, c.returnError +} + func ConnectionCountingStreamInterceptor(count *atomic.Int32) grpc.StreamServerInterceptor { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { count.Add(1) @@ -303,6 +323,124 @@ func TestJoinServiceGRPCServer_RegisterUsingAzureMethod(t *testing.T) { } } +func TestJoinServiceGRPCServer_RegisterUsingTPMMethod(t *testing.T) { + t.Parallel() + testPack := newTestPack(t) + + testCases := []struct { + desc string + challenge string + initReq *proto.RegisterUsingTPMMethodInitialRequest + challengeResponse *proto.RegisterUsingTPMMethodChallengeResponse + challengeResponseErr error + authErr string + certs *proto.Certs + }{ + { + desc: "pass case", + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ + EkKey: []byte("llama"), + }, + JoinRequest: &types.RegisterUsingTokenRequest{ + Token: "xyzzy", + }, + }, + challenge: "foo", + challengeResponse: &proto.RegisterUsingTPMMethodChallengeResponse{ + Solution: []byte("bar"), + }, + certs: &proto.Certs{SSH: []byte("qux")}, + }, + { + desc: "auth error", + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ + EkKey: []byte("llama"), + }, + JoinRequest: &types.RegisterUsingTokenRequest{ + Token: "xyzzy", + }, + }, + challenge: "foo", + challengeResponse: &proto.RegisterUsingTPMMethodChallengeResponse{ + Solution: []byte("bar"), + }, + authErr: "test auth error", + }, + { + desc: "challenge response error", + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ + EkKey: []byte("llama"), + }, + JoinRequest: &types.RegisterUsingTokenRequest{ + Token: "xyzzy", + }, + }, + challenge: "foo", + challengeResponseErr: trace.BadParameter("test challenge error"), + }, + { + desc: "missing join request", + initReq: &proto.RegisterUsingTPMMethodInitialRequest{ + Ek: &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ + EkKey: []byte("llama"), + }, + JoinRequest: nil, + }, + challenge: "foo", + authErr: "expected JoinRequest in RegisterUsingTPMMethodRequest_Init, got nil", + }, + } + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + testPack.mockAuthServer.sendChallenge = tc.challenge + testPack.mockAuthServer.returnCerts = tc.certs + if tc.authErr != "" { + testPack.mockAuthServer.returnError = fmt.Errorf(tc.authErr) + } + challengeResponder := func( + challenge *proto.TPMEncryptedCredential, + ) (*proto.RegisterUsingTPMMethodChallengeResponse, error) { + assert.Equal(t, &proto.TPMEncryptedCredential{ + Secret: []byte(tc.challenge), + }, challenge) + return tc.challengeResponse, tc.challengeResponseErr + } + + for suffix, clt := range map[string]*client.JoinServiceClient{ + "_auth": testPack.authClient, + "_proxy": testPack.proxyClient, + } { + t.Run(tc.desc+suffix, func(t *testing.T) { + certs, err := clt.RegisterUsingTPMMethod( + context.Background(), tc.initReq, challengeResponder, + ) + if tc.challengeResponseErr != nil { + require.ErrorIs(t, err, tc.challengeResponseErr) + return + } + if tc.authErr != "" { + require.ErrorContains(t, err, tc.authErr, "authErr mismatch") + return + } + if assert.NoError(t, err) { + assert.Equal(t, tc.certs, certs) + } + expectedInitReq := tc.initReq + expectedInitReq.JoinRequest.RemoteAddr = "bufconn" + assert.Equal( + t, + expectedInitReq, + testPack.mockAuthServer.gotTPMInitReq, + ) + }) + } + }) + } +} + func TestTimeout(t *testing.T) { t.Parallel() diff --git a/lib/tpm/validate.go b/lib/tpm/validate.go index d0c9b74736342..268857d35e4ff 100644 --- a/lib/tpm/validate.go +++ b/lib/tpm/validate.go @@ -184,6 +184,8 @@ func verifyEKCert( return trace.BadParameter("tpm did not provide an EKCert to validate against allowed CAs") } + StripSANExtensionOIDs(ekCert) + // Validate EKCert against CA pool _, err := ekCert.Verify(x509.VerifyOptions{ Roots: allowedCAs, @@ -199,3 +201,28 @@ func verifyEKCert( } return nil } + +var sanExtensionOID = []int{2, 5, 29, 17} + +// StripSANExtensionOIDs removes the SAN Extension OID from the specified +// cert. This method may re-assign the remaining extensions out of order. +// +// This is necessary because the EKCert may contain additional data +// bundled within the SAN extension. This ext is also sometimes marked +// critical. This causes the Verify() to reject the cert because not all data +// within a critical extension has been handled. We mark this as OK here by +// stripping the SAN Extension OID out of UnhandledCriticalExtensions. +func StripSANExtensionOIDs(cert *x509.Certificate) { + for i := 0; i < len(cert.UnhandledCriticalExtensions); i++ { + ext := cert.UnhandledCriticalExtensions[i] + if !ext.Equal(sanExtensionOID) { + continue + } + // Swap ext with the last index and remove it. + last := len(cert.UnhandledCriticalExtensions) - 1 + cert.UnhandledCriticalExtensions[i] = cert.UnhandledCriticalExtensions[last] + cert.UnhandledCriticalExtensions[last] = nil // "Release" extension + cert.UnhandledCriticalExtensions = cert.UnhandledCriticalExtensions[:last] + i-- + } +} diff --git a/lib/tpm/validate_test.go b/lib/tpm/validate_test.go new file mode 100644 index 0000000000000..3821574fefb9d --- /dev/null +++ b/lib/tpm/validate_test.go @@ -0,0 +1,68 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package tpm + +import ( + "crypto/x509" + "encoding/pem" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// tpmEKCertPEM is the real RSA 2048 EK certificate. This was captured from +// Noah's Infineon SLB9665 TPM. +const tpmEKCertPEM = `-----BEGIN CERTIFICATE----- +MIIElTCCA32gAwIBAgIEXs1fjjANBgkqhkiG9w0BAQsFADCBgzELMAkGA1UEBhMC +REUxITAfBgNVBAoMGEluZmluZW9uIFRlY2hub2xvZ2llcyBBRzEaMBgGA1UECwwR +T1BUSUdBKFRNKSBUUE0yLjAxNTAzBgNVBAMMLEluZmluZW9uIE9QVElHQShUTSkg +UlNBIE1hbnVmYWN0dXJpbmcgQ0EgMDM2MB4XDTIzMDExNzA4MDY0MVoXDTM3MDgy +MDIzNTk1OVowADCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAJ0qMMu+ +SRyCrhKcEQvKXk+Md2rdZC317Nqmhjf7rXJ527DX051XMTKfy+SXfalkqdT8IQkd +aUPYC/m8XWz7/J/9781dVt7rOw1CJsEk9DFoaInQmL2E5dUDgsA8Em942o2r1x7K +NdigHrLRQetn/CJkODYeBnHmmQUpU9syZ86Dhxl5tK1Sq2ddCm5Z/RCy+LIRBrpl +qstrTsY3Wyj0aqt/Opikq3geSkW+viG9ipk/D5J3i/qbdHQHSWZqD6ImixTmqIZf +I3u9QCVftoeWuJxrXgUPHdyyO6lMXDUgW918912Ihr6ZRBY6jUZT2Y8II+T1IRGT +/ymr26W7Mf7qfYkCAwEAAaOCAZEwggGNMFsGCCsGAQUFBwEBBE8wTTBLBggrBgEF +BQcwAoY/aHR0cDovL3BraS5pbmZpbmVvbi5jb20vT3B0aWdhUnNhTWZyQ0EwMzYv +T3B0aWdhUnNhTWZyQ0EwMzYuY3J0MA4GA1UdDwEB/wQEAwIAIDBRBgNVHREBAf8E +RzBFpEMwQTEWMBQGBWeBBQIBDAtpZDo0OTQ2NTgwMDETMBEGBWeBBQICDAhTTEIg +OTY2NTESMBAGBWeBBQIDDAdpZDowNTNmMAwGA1UdEwEB/wQCMAAwUAYDVR0fBEkw +RzBFoEOgQYY/aHR0cDovL3BraS5pbmZpbmVvbi5jb20vT3B0aWdhUnNhTWZyQ0Ew +MzYvT3B0aWdhUnNhTWZyQ0EwMzYuY3JsMBUGA1UdIAQOMAwwCgYIKoIUAEQBFAEw +HwYDVR0jBBgwFoAUfLS3jmiGFL5EIcWFjxW5bV6rUe4wEAYDVR0lBAkwBwYFZ4EF +CAEwIQYDVR0JBBowGDAWBgVngQUCEDENMAsMAzIuMAIBAAIBdDANBgkqhkiG9w0B +AQsFAAOCAQEACSSM+6o4INqV7mJ+aD5kPH6BkbEPhJBsYRA6vka+911Th7JfGZA7 +4C1ig4EjD1qUaRvkwNoDbGr3MRiNPHan3PLJkBy+WSERWglnBlXooJnkncWsNGwm +lzCTAYPOKZSTLiZiijvzW1XO+VqaTCMTkTpegO3MnE6xXZhXSyXQs8ro7qY6cTBd +whcEtufTT4khxMhjRTUBocqlLlN8PifG6xL2GD6xAW/PplL0uQFLUgnY0U4WQ/oP +nK/NX7N02p23JzlhDgcOdRrF3hYi8huKuoe3YfVJCTJLfSwxHytgXqJdpiRSHvut +Upw0EsfcY7cSbnlMWx2n4c7ptVvqiLXTkg== +-----END CERTIFICATE-----` + +func TestStripSANExtensionsOIDs(t *testing.T) { + b, _ := pem.Decode([]byte(tpmEKCertPEM)) + require.NotNil(t, b) + c, err := x509.ParseCertificate(b.Bytes) + require.NoError(t, err) + + assert.Len(t, c.UnhandledCriticalExtensions, 1) + StripSANExtensionOIDs(c) + assert.Empty(t, c.UnhandledCriticalExtensions) +}