From b14b31064ce50bc5768b389986acd639166e5125 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Thu, 18 Apr 2024 10:20:42 +0100 Subject: [PATCH 01/11] Add clientside elements of TPM joining --- api/client/joinservice.go | 74 ++++++++++ api/types/provisioning.go | 58 ++++++++ api/types/provisioning_test.go | 118 +++++++++++++++ lib/auth/register.go | 136 ++++++++++++++---- lib/service/service.go | 3 +- lib/tbot/config/config.go | 1 + lib/tpm/testdata/TestPrintQuery/ekcert.golden | 6 +- .../TestPrintQuery/ekcert_debug.golden | 10 +- lib/tpm/testdata/TestPrintQuery/ekpub.golden | 4 +- .../TestPrintQuery/ekpub_debug.golden | 6 +- lib/tpm/tpm.go | 10 +- 11 files changed, 380 insertions(+), 46 deletions(-) diff --git a/api/client/joinservice.go b/api/client/joinservice.go index 0116fbb93b134..1c877a58ecf05 100644 --- a/api/client/joinservice.go +++ b/api/client/joinservice.go @@ -48,6 +48,12 @@ type RegisterIAMChallengeResponseFunc func(challenge string) (*proto.RegisterUsi // *proto.RegisterUsingAzureMethodRequest for a given challenge, or an error. type RegisterAzureChallengeResponseFunc func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) +// RegisterTPMChallengeResponseFunc is a function type meant to be passed to +// RegisterUsingTPMMethod. It must return a +// *proto.RegisterUsingTPMMethodChallengeResponse for a given challenge, or an +// error. +type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error) + // RegisterUsingIAMMethod registers the caller using the IAM join method and // returns signed certs to join the cluster. // @@ -125,3 +131,71 @@ func (c *JoinServiceClient) RegisterUsingAzureMethod(ctx context.Context, challe } return certsResp.Certs, nil } + +// RegisterUsingTPMMethod registers the caller using the TPM join method and +// returns signed certs to join the cluster. The caller must provide a +// ChallengeResponseFunc which returns a *proto.RegisterUsingTPMMethodRequest +// for a given challenge, or an error. +func (c *JoinServiceClient) RegisterUsingTPMMethod( + ctx context.Context, + initReq *proto.RegisterUsingTPMMethodInitialRequest, + solveChallenge RegisterTPMChallengeResponseFunc, +) (*proto.Certs, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + tpmJoinClient, err := c.grpcClient.RegisterUsingTPMMethod(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + err = tpmJoinClient.Send(&proto.RegisterUsingTPMMethodRequest{ + Payload: &proto.RegisterUsingTPMMethodRequest_Init{ + Init: initReq, + }, + }) + if err != nil { + return nil, trace.Wrap(err, "sending initial request") + } + + challengeResp, err := tpmJoinClient.Recv() + if err != nil { + return nil, trace.Wrap(err, "receiving challenge") + } + + challenge, ok := challengeResp.Payload.(*proto.RegisterUsingTPMMethodResponse_ChallengeRequest) + if !ok { + return nil, trace.BadParameter( + "unexpected payload type %T, expected *RegisterUsingTPMMethodResponse_ChallengeRequest", + challengeResp.Payload, + ) + } + + solution, err := solveChallenge(challenge.ChallengeRequest) + if err != nil { + return nil, trace.Wrap(err, "calling solveChallenge") + } + + err = tpmJoinClient.Send(&proto.RegisterUsingTPMMethodRequest{ + Payload: &proto.RegisterUsingTPMMethodRequest_ChallengeResponse{ + ChallengeResponse: solution, + }, + }) + if err != nil { + return nil, trace.Wrap(err, "sending solution") + } + + certsResp, err := tpmJoinClient.Recv() + if err != nil { + return nil, trace.Wrap(err, "receiving certs") + } + certs, ok := certsResp.Payload.(*proto.RegisterUsingTPMMethodResponse_Certs) + if !ok { + return nil, trace.BadParameter( + "unexpected payload type %T, expected *RegisterUsingTPMMethodResponse_Certs", + certsResp.Payload, + ) + } + + return certs.Certs, nil +} diff --git a/api/types/provisioning.go b/api/types/provisioning.go index c421bf0a31f8c..030f307ff4e7e 100644 --- a/api/types/provisioning.go +++ b/api/types/provisioning.go @@ -17,6 +17,8 @@ limitations under the License. package types import ( + "crypto/x509" + "encoding/pem" "fmt" "slices" "strings" @@ -66,6 +68,9 @@ const ( // method. Documentation regarding implementation of this can be found in // lib/spacelift. JoinMethodSpacelift JoinMethod = "spacelift" + // JoinMethodTPM indicates that the node will join with the TPM join method. + // The core implementation of this join method can be found in lib/tpm. + JoinMethodTPM JoinMethod = "tpm" ) var JoinMethods = []JoinMethod{ @@ -79,6 +84,7 @@ var JoinMethods = []JoinMethod{ JoinMethodKubernetes, JoinMethodSpacelift, JoinMethodToken, + JoinMethodTPM, } func ValidateJoinMethod(method JoinMethod) error { @@ -328,6 +334,17 @@ func (p *ProvisionTokenV2) CheckAndSetDefaults() error { if err := providerCfg.checkAndSetDefaults(); err != nil { return trace.Wrap(err, "spec.spacelift: failed validation") } + case JoinMethodTPM: + providerCfg := p.Spec.TPM + if providerCfg == nil { + return trace.BadParameter( + `spec.tpm: must be configured for the join method %q`, + JoinMethodTPM, + ) + } + if err := providerCfg.checkAndSetDefaults(); err != nil { + return trace.Wrap(err, "spec.tpm: failed validation") + } default: return trace.BadParameter("unknown join method %q", p.Spec.JoinMethod) } @@ -754,3 +771,44 @@ func (a *ProvisionTokenSpecV2Spacelift) checkAndSetDefaults() error { } return nil } + +func (a *ProvisionTokenSpecV2TPM) checkAndSetDefaults() error { + for i, caData := range a.EKCertAllowedCAs { + p, _ := pem.Decode([]byte(caData)) + if p == nil { + return trace.BadParameter( + "ekcert_allowed_cas[%d]: no pem block found", + i, + ) + } + if p.Type != "CERTIFICATE" { + return trace.BadParameter( + "ekcert_allowed_cas[%d]: pem block is not 'CERTIFICATE' type", + i, + ) + } + if _, err := x509.ParseCertificate(p.Bytes); err != nil { + return trace.Wrap( + err, + "ekcert_allowed_cas[%d]: parsing certificate", + i, + ) + + } + } + + if len(a.Allow) == 0 { + return trace.BadParameter( + "allow: at least one rule must be set", + ) + } + for i, allowRule := range a.Allow { + if len(allowRule.EKPublicHash) == 0 && len(allowRule.EKCertificateSerial) == 0 { + return trace.BadParameter( + "allow[%d]: at least one of ['ek_public_hash', 'ek_certificate_serial'] must be set", + i, + ) + } + } + return nil +} diff --git a/api/types/provisioning_test.go b/api/types/provisioning_test.go index f201457447ba9..08dceb3d3b0d0 100644 --- a/api/types/provisioning_test.go +++ b/api/types/provisioning_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/fixtures" ) func TestProvisionTokenV2_CheckAndSetDefaults(t *testing.T) { @@ -901,6 +902,123 @@ func TestProvisionTokenV2_CheckAndSetDefaults(t *testing.T) { }, wantErr: true, }, + { + desc: "tpm success with CA", + token: &ProvisionTokenV2{ + Metadata: Metadata{ + Name: "test", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodTPM, + TPM: &ProvisionTokenSpecV2TPM{ + EKCertAllowedCAs: []string{fixtures.TLSCACertPEM}, + Allow: []*ProvisionTokenSpecV2TPM_Rule{ + { + Description: "my description", + EKPublicHash: "d4b45864d9d6fabfc568d74f26c35ababde2105337d7af9a6605e1c56c891aa6", + }, + { + EKCertificateSerial: "73:df:dc:bd:af:ef:8a:d8:15:2e:96:71:7a:3e:7f:a4", + }, + { + EKPublicHash: "d4b45864d9d6fabfc568d74f26c35ababde2105337d7af9a6605e1c56c891aa6", + EKCertificateSerial: "73:df:dc:bd:af:ef:8a:d8:15:2e:96:71:7a:3e:7f:a4", + }, + }, + }, + }, + }, + wantErr: false, + }, + { + desc: "tpm success without CA", + token: &ProvisionTokenV2{ + Metadata: Metadata{ + Name: "test", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodTPM, + TPM: &ProvisionTokenSpecV2TPM{ + Allow: []*ProvisionTokenSpecV2TPM_Rule{ + { + Description: "my description", + EKPublicHash: "d4b45864d9d6fabfc568d74f26c35ababde2105337d7af9a6605e1c56c891aa6", + }, + { + EKCertificateSerial: "73:df:dc:bd:af:ef:8a:d8:15:2e:96:71:7a:3e:7f:a4", + }, + { + EKPublicHash: "d4b45864d9d6fabfc568d74f26c35ababde2105337d7af9a6605e1c56c891aa6", + EKCertificateSerial: "73:df:dc:bd:af:ef:8a:d8:15:2e:96:71:7a:3e:7f:a4", + }, + }, + }, + }, + }, + wantErr: false, + }, + { + desc: "tpm corrupt CA", + token: &ProvisionTokenV2{ + Metadata: Metadata{ + Name: "test", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodTPM, + TPM: &ProvisionTokenSpecV2TPM{ + EKCertAllowedCAs: []string{"corrupt"}, + Allow: []*ProvisionTokenSpecV2TPM_Rule{ + { + Description: "my description", + EKPublicHash: "d4b45864d9d6fabfc568d74f26c35ababde2105337d7af9a6605e1c56c891aa6", + }, + }, + }, + }, + }, + wantErr: true, + }, + { + desc: "tpm missing rules", + token: &ProvisionTokenV2{ + Metadata: Metadata{ + Name: "test", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodTPM, + TPM: &ProvisionTokenSpecV2TPM{ + EKCertAllowedCAs: []string{}, + Allow: []*ProvisionTokenSpecV2TPM_Rule{}, + }, + }, + }, + wantErr: true, + }, + { + desc: "tpm rule without ekpubhash or ekcertserial", + token: &ProvisionTokenV2{ + Metadata: Metadata{ + Name: "test", + }, + Spec: ProvisionTokenSpecV2{ + Roles: []SystemRole{RoleNode}, + JoinMethod: JoinMethodTPM, + TPM: &ProvisionTokenSpecV2TPM{ + EKCertAllowedCAs: []string{}, + Allow: []*ProvisionTokenSpecV2TPM_Rule{ + { + Description: "my description", + }, + }, + }, + }, + }, + wantErr: true, + }, } for _, tc := range testcases { diff --git a/lib/auth/register.go b/lib/auth/register.go index 893323ee2f9e1..3e8b5d5baf88f 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -22,6 +22,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "log/slog" "os" "slices" "time" @@ -53,6 +54,7 @@ import ( "github.com/gravitational/teleport/lib/spacelift" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/tpm" "github.com/gravitational/teleport/lib/utils" ) @@ -328,7 +330,7 @@ func registerThroughProxy( defer func() { tracing.EndSpan(span, err) }() switch params.JoinMethod { - case types.JoinMethodIAM, types.JoinMethodAzure: + case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM: // IAM and Azure join methods require gRPC client conn, err := proxyJoinServiceConn(ctx, params, params.Insecure) if err != nil { @@ -337,10 +339,13 @@ func registerThroughProxy( defer conn.Close() joinServiceClient := client.NewJoinServiceClient(proto.NewJoinServiceClient(conn)) - if params.JoinMethod == types.JoinMethodIAM { + switch params.JoinMethod { + case types.JoinMethodIAM: certs, err = registerUsingIAMMethod(ctx, joinServiceClient, token, params) - } else { + case types.JoinMethodAzure: certs, err = registerUsingAzureMethod(ctx, joinServiceClient, token, params) + case types.JoinMethodTPM: + certs, err = registerUsingTPMMethod(ctx, joinServiceClient, token, params) } if err != nil { @@ -418,6 +423,8 @@ func registerThroughAuth( certs, err = registerUsingIAMMethod(ctx, client, token, params) case types.JoinMethodAzure: certs, err = registerUsingAzureMethod(ctx, client, token, params) + case types.JoinMethodTPM: + certs, err = registerUsingTPMMethod(ctx, client, token, params) default: // non-IAM join methods use HTTP endpoint // Get the SSH and X509 certificates for a node. @@ -670,6 +677,25 @@ func caPathRegisterClient(params RegisterParams) (*Client, error) { type joinServiceClient interface { RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) RegisterUsingAzureMethod(ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc) (*proto.Certs, error) + RegisterUsingTPMMethod( + ctx context.Context, + initReq *proto.RegisterUsingTPMMethodInitialRequest, + solveChallenge client.RegisterTPMChallengeResponseFunc, + ) (*proto.Certs, error) +} + +func registerUsingTokenRequestForParams(token string, params RegisterParams) *types.RegisterUsingTokenRequest { + return &types.RegisterUsingTokenRequest{ + Token: token, + HostID: params.ID.HostUUID, + NodeName: params.ID.NodeName, + Role: params.ID.Role, + AdditionalPrincipals: params.AdditionalPrincipals, + DNSNames: params.DNSNames, + PublicTLSKey: params.PublicTLSKey, + PublicSSHKey: params.PublicSSHKey, + Expires: params.Expires, + } } // registerUsingIAMMethod is used to register using the IAM join method. It is @@ -691,18 +717,8 @@ func registerUsingIAMMethod( // send the register request including the challenge response return &proto.RegisterUsingIAMMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: token, - HostID: params.ID.HostUUID, - NodeName: params.ID.NodeName, - Role: params.ID.Role, - AdditionalPrincipals: params.AdditionalPrincipals, - DNSNames: params.DNSNames, - PublicTLSKey: params.PublicTLSKey, - PublicSSHKey: params.PublicSSHKey, - Expires: params.Expires, - }, - StsIdentityRequest: signedRequest, + RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params), + StsIdentityRequest: signedRequest, }, nil }) if err != nil { @@ -734,23 +750,89 @@ func registerUsingAzureMethod( } return &proto.RegisterUsingAzureMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: token, - HostID: params.ID.HostUUID, - NodeName: params.ID.NodeName, - Role: params.ID.Role, - AdditionalPrincipals: params.AdditionalPrincipals, - DNSNames: params.DNSNames, - PublicTLSKey: params.PublicTLSKey, - PublicSSHKey: params.PublicSSHKey, - }, - AttestedData: ad, - AccessToken: accessToken, + RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params), + AttestedData: ad, + AccessToken: accessToken, }, nil }) return certs, trace.Wrap(err) } +// registerUsingTPMMethod is used to register using the TPM join method. It +// is able to register through a proxy or through the auth server directly. +func registerUsingTPMMethod( + ctx context.Context, + client joinServiceClient, + token string, + params RegisterParams, +) (*proto.Certs, error) { + log := slog.Default() + + initReq := &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: registerUsingTokenRequestForParams(token, params), + } + + attestation, close, err := tpm.Attest(ctx, log) + if err != nil { + return nil, trace.Wrap(err) + } + defer func() { + if err := close(); err != nil { + log.WarnContext(ctx, "Failed to close TPM", "error", err) + } + }() + + initReq.AttestationParams = tpm.AttestationParametersToProto( + attestation.AttestParams, + ) + // Get the EKKey or EKCert. We want to prefer the EKCert if it is available + // as this is signed by the manufacturer. + switch { + case attestation.Data.EKCert != nil: + log.DebugContext( + ctx, + "Using EKCert for TPM registration", + slog.String("ekcert_serial", attestation.Data.EKCert.SerialNumber), + ) + initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkCert{ + EkCert: attestation.Data.EKCert.Raw, + } + case attestation.Data.EKPub != nil: + log.DebugContext( + ctx, + "Using EKKey for TPM registration", + slog.String("ekpub_hash", attestation.Data.EKPubHash), + ) + initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ + EkKey: attestation.Data.EKPub, + } + default: + return nil, trace.BadParameter("tpm has neither ekkey or ekcert") + } + + // Submit initial request to the Auth Server. + certs, err := client.RegisterUsingTPMMethod( + ctx, + initReq, + func( + challenge *proto.TPMEncryptedCredential, + ) (*proto.RegisterUsingTPMMethodChallengeResponse, error) { + // Solve the encrypted credential with our AK to prove possession + // and obtain the solution we need to complete the ceremony. + solution, err := attestation.Solve(tpm.EncryptedCredentialFromProto( + challenge, + )) + if err != nil { + return nil, trace.Wrap(err, "activating credential") + } + return &proto.RegisterUsingTPMMethodChallengeResponse{ + Solution: solution, + }, nil + }, + ) + return certs, trace.Wrap(err) +} + // ReRegisterParams specifies parameters for re-registering // in the cluster (rotating certificates for existing members) type ReRegisterParams struct { diff --git a/lib/service/service.go b/lib/service/service.go index 0d33fc47bb79b..e3a036beeeebd 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -5900,7 +5900,8 @@ func readOrGenerateHostID(ctx context.Context, cfg *servicecfg.Config, kubeBacke types.JoinMethodGitHub, types.JoinMethodGitLab, types.JoinMethodAzure, - types.JoinMethodGCP: + types.JoinMethodGCP, + types.JoinMethodTPM: // Checking error instead of the usual uuid.New() in case uuid generation // fails due to not enough randomness. It's been known to happen happen when // Teleport starts very early in the node initialization cycle and /dev/urandom diff --git a/lib/tbot/config/config.go b/lib/tbot/config/config.go index bc032d42e9f8d..6dddbb848570e 100644 --- a/lib/tbot/config/config.go +++ b/lib/tbot/config/config.go @@ -58,6 +58,7 @@ var SupportedJoinMethods = []string{ string(types.JoinMethodKubernetes), string(types.JoinMethodSpacelift), string(types.JoinMethodToken), + string(types.JoinMethodTPM), } var log = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentTBot) diff --git a/lib/tpm/testdata/TestPrintQuery/ekcert.golden b/lib/tpm/testdata/TestPrintQuery/ekcert.golden index 99fa7d0c361c7..6612f1f53b0e4 100644 --- a/lib/tpm/testdata/TestPrintQuery/ekcert.golden +++ b/lib/tpm/testdata/TestPrintQuery/ekcert.golden @@ -1,4 +1,4 @@ TPM Information -EKPub Hash: aabbaabbcc -EKCert Detected: true -EKCert Serial: aa:bb:cc +EK Public Hash: aabbaabbcc +EK Certificate Detected: true +EK Certificate Serial: aa:bb:cc diff --git a/lib/tpm/testdata/TestPrintQuery/ekcert_debug.golden b/lib/tpm/testdata/TestPrintQuery/ekcert_debug.golden index fed739fac3690..2bb61022f4f0a 100644 --- a/lib/tpm/testdata/TestPrintQuery/ekcert_debug.golden +++ b/lib/tpm/testdata/TestPrintQuery/ekcert_debug.golden @@ -1,12 +1,12 @@ TPM Information -EKPub Hash: aabbaabbcc -EKCert Detected: true -EKCert Serial: aa:bb:cc -EKPub: +EK Public Hash: aabbaabbcc +EK Certificate Detected: true +EK Certificate Serial: aa:bb:cc +EK Public: -----BEGIN PUBLIC KEY----- ZWtwdWI= -----END PUBLIC KEY----- -EKCert: +EK Certificate: -----BEGIN CERTIFICATE----- ZWtjZXJ0 -----END CERTIFICATE----- diff --git a/lib/tpm/testdata/TestPrintQuery/ekpub.golden b/lib/tpm/testdata/TestPrintQuery/ekpub.golden index a9bb8c429df9f..381a300e56779 100644 --- a/lib/tpm/testdata/TestPrintQuery/ekpub.golden +++ b/lib/tpm/testdata/TestPrintQuery/ekpub.golden @@ -1,3 +1,3 @@ TPM Information -EKPub Hash: aabbaabbcc -EKCert Detected: false +EK Public Hash: aabbaabbcc +EK Certificate Detected: false diff --git a/lib/tpm/testdata/TestPrintQuery/ekpub_debug.golden b/lib/tpm/testdata/TestPrintQuery/ekpub_debug.golden index 1151728fc3f28..931fe7fd202de 100644 --- a/lib/tpm/testdata/TestPrintQuery/ekpub_debug.golden +++ b/lib/tpm/testdata/TestPrintQuery/ekpub_debug.golden @@ -1,7 +1,7 @@ TPM Information -EKPub Hash: aabbaabbcc -EKCert Detected: false -EKPub: +EK Public Hash: aabbaabbcc +EK Certificate Detected: false +EK Public: -----BEGIN PUBLIC KEY----- ZWtwdWI= -----END PUBLIC KEY----- diff --git a/lib/tpm/tpm.go b/lib/tpm/tpm.go index 4ab87298226be..b720df596a822 100644 --- a/lib/tpm/tpm.go +++ b/lib/tpm/tpm.go @@ -236,18 +236,18 @@ func AttestWithTPM(ctx context.Context, log *slog.Logger, tpm *attest.TPM) ( // specified io.Writer. func PrintQuery(data *QueryRes, debug bool, w io.Writer) { _, _ = fmt.Fprintf(w, "TPM Information\n") - _, _ = fmt.Fprintf(w, "EKPub Hash: %s\n", data.EKPubHash) - _, _ = fmt.Fprintf(w, "EKCert Detected: %t\n", data.EKCert != nil) + _, _ = fmt.Fprintf(w, "EK Public Hash: %s\n", data.EKPubHash) + _, _ = fmt.Fprintf(w, "EK Certificate Detected: %t\n", data.EKCert != nil) if data.EKCert != nil { - _, _ = fmt.Fprintf(w, "EKCert Serial: %s\n", data.EKCert.SerialNumber) + _, _ = fmt.Fprintf(w, "EK Certificate Serial: %s\n", data.EKCert.SerialNumber) } if debug { - _, _ = fmt.Fprintf(w, "EKPub:\n%s", pem.EncodeToMemory(&pem.Block{ + _, _ = fmt.Fprintf(w, "EK Public:\n%s", pem.EncodeToMemory(&pem.Block{ Type: "PUBLIC KEY", Bytes: data.EKPub, })) if data.EKCert != nil { - _, _ = fmt.Fprintf(w, "EKCert:\n%s", pem.EncodeToMemory(&pem.Block{ + _, _ = fmt.Fprintf(w, "EK Certificate:\n%s", pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: data.EKCert.Raw, })) From 3a6e0abcea373cceb5a538368b7937cd323ad3cb Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Thu, 18 Apr 2024 17:41:52 +0100 Subject: [PATCH 02/11] Update lib/auth/register.go Co-authored-by: Alan Parra --- lib/auth/register.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/auth/register.go b/lib/auth/register.go index 3e8b5d5baf88f..9858b84606338 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -792,7 +792,7 @@ func registerUsingTPMMethod( log.DebugContext( ctx, "Using EKCert for TPM registration", - slog.String("ekcert_serial", attestation.Data.EKCert.SerialNumber), + "ekcert_serial", attestation.Data.EKCert.SerialNumber, ) initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkCert{ EkCert: attestation.Data.EKCert.Raw, From 7f68da41c88fc075f687dd03a9c356749d767534 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Thu, 18 Apr 2024 17:42:03 +0100 Subject: [PATCH 03/11] Update api/client/joinservice.go Co-authored-by: Alan Parra --- api/client/joinservice.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/client/joinservice.go b/api/client/joinservice.go index 1c877a58ecf05..4c900f41d6ef0 100644 --- a/api/client/joinservice.go +++ b/api/client/joinservice.go @@ -173,7 +173,7 @@ func (c *JoinServiceClient) RegisterUsingTPMMethod( solution, err := solveChallenge(challenge.ChallengeRequest) if err != nil { - return nil, trace.Wrap(err, "calling solveChallenge") + return nil, trace.Wrap(err, "solving challenge") } err = tpmJoinClient.Send(&proto.RegisterUsingTPMMethodRequest{ From 06cd3a82fe4fad22d71f0f51d07929bdba2e3117 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Thu, 18 Apr 2024 17:42:17 +0100 Subject: [PATCH 04/11] Update lib/auth/register.go Co-authored-by: Alan Parra --- lib/auth/register.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/auth/register.go b/lib/auth/register.go index 9858b84606338..0c6012c44fc3b 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -801,7 +801,7 @@ func registerUsingTPMMethod( log.DebugContext( ctx, "Using EKKey for TPM registration", - slog.String("ekpub_hash", attestation.Data.EKPubHash), + "ekpub_hash", attestation.Data.EKPubHash, ) initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ EkKey: attestation.Data.EKPub, From 969b9288a302a91813cb15db0ba2f0df136c647b Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Thu, 18 Apr 2024 18:55:30 +0100 Subject: [PATCH 05/11] Tidy up RegisterUsingTPMMethod method --- api/client/joinservice.go | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/api/client/joinservice.go b/api/client/joinservice.go index 4c900f41d6ef0..71db117b4acf5 100644 --- a/api/client/joinservice.go +++ b/api/client/joinservice.go @@ -144,12 +144,13 @@ func (c *JoinServiceClient) RegisterUsingTPMMethod( ctx, cancel := context.WithCancel(ctx) defer cancel() - tpmJoinClient, err := c.grpcClient.RegisterUsingTPMMethod(ctx) + stream, err := c.grpcClient.RegisterUsingTPMMethod(ctx) if err != nil { return nil, trace.Wrap(err) } + defer stream.CloseSend() - err = tpmJoinClient.Send(&proto.RegisterUsingTPMMethodRequest{ + err = stream.Send(&proto.RegisterUsingTPMMethodRequest{ Payload: &proto.RegisterUsingTPMMethodRequest_Init{ Init: initReq, }, @@ -158,25 +159,25 @@ func (c *JoinServiceClient) RegisterUsingTPMMethod( return nil, trace.Wrap(err, "sending initial request") } - challengeResp, err := tpmJoinClient.Recv() + res, err := stream.Recv() if err != nil { return nil, trace.Wrap(err, "receiving challenge") } - challenge, ok := challengeResp.Payload.(*proto.RegisterUsingTPMMethodResponse_ChallengeRequest) - if !ok { + challenge := res.GetChallengeRequest() + if challenge == nil { return nil, trace.BadParameter( - "unexpected payload type %T, expected *RegisterUsingTPMMethodResponse_ChallengeRequest", - challengeResp.Payload, + "expected ChallengeRequest payload, got %T", + res.Payload, ) } - solution, err := solveChallenge(challenge.ChallengeRequest) + solution, err := solveChallenge(challenge) if err != nil { return nil, trace.Wrap(err, "solving challenge") } - err = tpmJoinClient.Send(&proto.RegisterUsingTPMMethodRequest{ + err = stream.Send(&proto.RegisterUsingTPMMethodRequest{ Payload: &proto.RegisterUsingTPMMethodRequest_ChallengeResponse{ ChallengeResponse: solution, }, @@ -185,17 +186,17 @@ func (c *JoinServiceClient) RegisterUsingTPMMethod( return nil, trace.Wrap(err, "sending solution") } - certsResp, err := tpmJoinClient.Recv() + res, err = stream.Recv() if err != nil { return nil, trace.Wrap(err, "receiving certs") } - certs, ok := certsResp.Payload.(*proto.RegisterUsingTPMMethodResponse_Certs) - if !ok { + certs := res.GetCerts() + if certs == nil { return nil, trace.BadParameter( - "unexpected payload type %T, expected *RegisterUsingTPMMethodResponse_Certs", - certsResp.Payload, + "expected Certs payload, got %T", + res.Payload, ) } - return certs.Certs, nil + return certs, nil } From 4590316b99a2b7d70d639ce68aef8592b94bac89 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Thu, 18 Apr 2024 18:58:36 +0100 Subject: [PATCH 06/11] Add default case --- lib/auth/register.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/auth/register.go b/lib/auth/register.go index 0c6012c44fc3b..fe29a827539d8 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -346,6 +346,8 @@ func registerThroughProxy( certs, err = registerUsingAzureMethod(ctx, joinServiceClient, token, params) case types.JoinMethodTPM: certs, err = registerUsingTPMMethod(ctx, joinServiceClient, token, params) + default: + return nil, trace.BadParameter("unhandled join method %q", params.JoinMethod) } if err != nil { From 17d41824b11a15d0e86bf87fe82e8433c5966133 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Thu, 18 Apr 2024 18:59:18 +0100 Subject: [PATCH 07/11] Rename CheckAndSetDefaults to validate --- api/types/provisioning.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/types/provisioning.go b/api/types/provisioning.go index 030f307ff4e7e..ea3aeeb4d551d 100644 --- a/api/types/provisioning.go +++ b/api/types/provisioning.go @@ -342,7 +342,7 @@ func (p *ProvisionTokenV2) CheckAndSetDefaults() error { JoinMethodTPM, ) } - if err := providerCfg.checkAndSetDefaults(); err != nil { + if err := providerCfg.validate(); err != nil { return trace.Wrap(err, "spec.tpm: failed validation") } default: @@ -772,7 +772,7 @@ func (a *ProvisionTokenSpecV2Spacelift) checkAndSetDefaults() error { return nil } -func (a *ProvisionTokenSpecV2TPM) checkAndSetDefaults() error { +func (a *ProvisionTokenSpecV2TPM) validate() error { for i, caData := range a.EKCertAllowedCAs { p, _ := pem.Decode([]byte(caData)) if p == nil { From ddff12ebbb10d92a402b56d5ec08697bf0c1ee7c Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 19 Apr 2024 12:07:32 +0100 Subject: [PATCH 08/11] Add basic success test for JoinServiceClient_RegisterUsingTPMMethod --- api/client/joinservice_test.go | 110 +++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 api/client/joinservice_test.go diff --git a/api/client/joinservice_test.go b/api/client/joinservice_test.go new file mode 100644 index 0000000000000..b873d9a80dfc2 --- /dev/null +++ b/api/client/joinservice_test.go @@ -0,0 +1,110 @@ +package client + +import ( + "context" + "net" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" +) + +type mockJoinServiceServer struct { + *proto.UnimplementedJoinServiceServer + registerUsingTPMMethod func(srv proto.JoinService_RegisterUsingTPMMethodServer) error +} + +func (m *mockJoinServiceServer) RegisterUsingTPMMethod(srv proto.JoinService_RegisterUsingTPMMethodServer) error { + return m.registerUsingTPMMethod(srv) +} + +func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + lis := bufconn.Listen(100) + t.Cleanup(func() { + assert.NoError(t, lis.Close()) + }) + + mockInitReq := &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: &types.RegisterUsingTokenRequest{ + Token: "token", + }, + } + mockChallenge := &proto.TPMEncryptedCredential{ + CredentialBlob: []byte("cred-blob"), + Secret: []byte("secret"), + } + mockChallengeResp := &proto.RegisterUsingTPMMethodChallengeResponse{ + Solution: []byte("solution"), + } + mockCerts := &proto.Certs{ + TLS: []byte("cert"), + } + mockService := &mockJoinServiceServer{ + registerUsingTPMMethod: func(srv proto.JoinService_RegisterUsingTPMMethodServer) error { + req, err := srv.Recv() + require.NoError(t, err) + assert.Empty(t, cmp.Diff(req.GetInit(), mockInitReq)) + + err = srv.Send(&proto.RegisterUsingTPMMethodResponse{ + Payload: &proto.RegisterUsingTPMMethodResponse_ChallengeRequest{ + ChallengeRequest: mockChallenge, + }, + }) + require.NoError(t, err) + + req, err = srv.Recv() + require.NoError(t, err) + assert.Empty(t, cmp.Diff(req.GetChallengeResponse(), mockChallengeResp)) + + err = srv.Send(&proto.RegisterUsingTPMMethodResponse{ + Payload: &proto.RegisterUsingTPMMethodResponse_Certs{ + Certs: mockCerts, + }, + }) + require.NoError(t, err) + return nil + }, + } + srv := grpc.NewServer() + t.Cleanup(func() { + srv.Stop() + }) + proto.RegisterJoinServiceServer(srv, mockService) + + go func() { + err := srv.Serve(lis) + assert.NotErrorIs(t, err, grpc.ErrServerStopped) + cancel() + }() + + c, err := grpc.NewClient("example.com", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + }), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + + joinClient := NewJoinServiceClient(proto.NewJoinServiceClient(c)) + + certs, err := joinClient.RegisterUsingTPMMethod( + ctx, + mockInitReq, + func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error) { + assert.Empty(t, cmp.Diff(mockChallenge, challenge)) + return mockChallengeResp, nil + }, + ) + if assert.NoError(t, err) { + assert.Empty(t, cmp.Diff(mockCerts, certs)) + } +} From 7b6879a95ac1fa4fff847062a042935d45535b0f Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Mon, 22 Apr 2024 11:18:13 +0100 Subject: [PATCH 09/11] Add final touches to client joinservice test --- api/client/joinservice_test.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/api/client/joinservice_test.go b/api/client/joinservice_test.go index b873d9a80dfc2..e7e763ef92e60 100644 --- a/api/client/joinservice_test.go +++ b/api/client/joinservice_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "net" "testing" @@ -54,7 +55,9 @@ func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) { mockService := &mockJoinServiceServer{ registerUsingTPMMethod: func(srv proto.JoinService_RegisterUsingTPMMethodServer) error { req, err := srv.Recv() - require.NoError(t, err) + if !assert.NoError(t, err) { + return err + } assert.Empty(t, cmp.Diff(req.GetInit(), mockInitReq)) err = srv.Send(&proto.RegisterUsingTPMMethodResponse{ @@ -62,10 +65,14 @@ func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) { ChallengeRequest: mockChallenge, }, }) - require.NoError(t, err) + if !assert.NoError(t, err) { + return err + } req, err = srv.Recv() - require.NoError(t, err) + if !assert.NoError(t, err) { + return err + } assert.Empty(t, cmp.Diff(req.GetChallengeResponse(), mockChallengeResp)) err = srv.Send(&proto.RegisterUsingTPMMethodResponse{ @@ -73,7 +80,9 @@ func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) { Certs: mockCerts, }, }) - require.NoError(t, err) + if !assert.NoError(t, err) { + return err + } return nil }, } @@ -85,11 +94,13 @@ func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) { go func() { err := srv.Serve(lis) - assert.NotErrorIs(t, err, grpc.ErrServerStopped) + if err != nil && !errors.Is(err, grpc.ErrServerStopped) { + assert.NoError(t, err) + } cancel() }() - c, err := grpc.NewClient("example.com", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + c, err := grpc.NewClient("unused.com", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { return lis.DialContext(ctx) }), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) From 3c60121408bf149a5717d75996589d5d0e75499a Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Tue, 23 Apr 2024 09:25:55 +0100 Subject: [PATCH 10/11] Add license header to joinservice_test.go --- api/client/joinservice_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/api/client/joinservice_test.go b/api/client/joinservice_test.go index e7e763ef92e60..dec8138ea57d0 100644 --- a/api/client/joinservice_test.go +++ b/api/client/joinservice_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package client import ( From 6e92dc108cd9961390bf2c4461188ec8e5af482f Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Tue, 23 Apr 2024 15:22:20 +0100 Subject: [PATCH 11/11] Switch to grpc.Dial from grpc.NewClient --- api/client/joinservice_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/client/joinservice_test.go b/api/client/joinservice_test.go index dec8138ea57d0..04966142d79c3 100644 --- a/api/client/joinservice_test.go +++ b/api/client/joinservice_test.go @@ -116,7 +116,7 @@ func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) { cancel() }() - c, err := grpc.NewClient("unused.com", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + c, err := grpc.Dial("unused.com", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { return lis.DialContext(ctx) }), grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err)