diff --git a/api/client/client.go b/api/client/client.go index f69903807b00d..43a47458e9cd1 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -1687,8 +1687,9 @@ func (c *Client) SignDatabaseCSR(ctx context.Context, req *proto.DatabaseCSRRequ return resp, nil } -// GenerateDatabaseCert generates client certificate used by a database -// service to authenticate with the database instance. +// GenerateDatabaseCert generates a client certificate used by a database +// service to authenticate with the database instance, or a server certificate +// for configuring a self-hosted database, depending on the requester_name. func (c *Client) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { resp, err := c.grpc.GenerateDatabaseCert(ctx, req) if err != nil { diff --git a/api/types/constants.go b/api/types/constants.go index 1cc88d5f04c54..24d1935c00c67 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -394,7 +394,8 @@ const ( // KindConnectionDiagnostic is a resource that tracks the result of testing a connection KindConnectionDiagnostic = "connection_diagnostic" - // KindDatabaseCertificate is a resource to control Database Certificates generation + // KindDatabaseCertificate is a resource to control db CA cert + // generation. KindDatabaseCertificate = "database_certificate" // KindInstaller is a resource that holds a node installer script diff --git a/api/types/trust.go b/api/types/trust.go index 0f5b06ee44640..485b29910c361 100644 --- a/api/types/trust.go +++ b/api/types/trust.go @@ -38,8 +38,12 @@ const ( HostCA CertAuthType = "host" // UserCA identifies the key as a user certificate authority UserCA CertAuthType = "user" - // DatabaseCA is a certificate authority used in database access. + // DatabaseCA is a certificate authority used as a server CA in database + // access. DatabaseCA CertAuthType = "db" + // DatabaseClientCA is a certificate authority used as a client CA in + // database access. + DatabaseClientCA CertAuthType = "db_client" // OpenSSHCA is a certificate authority used when connecting to agentless nodes. OpenSSHCA CertAuthType = "openssh" // JWTSigner identifies type of certificate authority as JWT signer. In this @@ -57,7 +61,7 @@ const ( ) // CertAuthTypes lists all certificate authority types. -var CertAuthTypes = []CertAuthType{HostCA, UserCA, DatabaseCA, OpenSSHCA, JWTSigner, SAMLIDPCA, OIDCIdPCA} +var CertAuthTypes = []CertAuthType{HostCA, UserCA, DatabaseCA, DatabaseClientCA, OpenSSHCA, JWTSigner, SAMLIDPCA, OIDCIdPCA} // NewlyAdded should return true for CA types that were added in the current // major version, so that we can avoid erroring out when a potentially older @@ -73,12 +77,23 @@ func (c CertAuthType) addedInMajorVer() int64 { return 9 case OpenSSHCA, SAMLIDPCA, OIDCIdPCA: return 12 + case DatabaseClientCA: + return 15 default: // We don't care about other CAs added before v4.0.0 return 4 } } +// IsUnsupportedAuthorityErr returns whether an error is due to an unsupported +// CertAuthType. +func IsUnsupportedAuthorityErr(err error) bool { + return err != nil && trace.IsBadParameter(err) && + strings.Contains(err.Error(), authTypeNotSupported) +} + +const authTypeNotSupported string = "authority type is not supported" + // Check checks if certificate authority type value is correct func (c CertAuthType) Check() error { for _, caType := range CertAuthTypes { @@ -87,7 +102,7 @@ func (c CertAuthType) Check() error { } } - return trace.BadParameter("%q authority type is not supported", c) + return trace.BadParameter("%q %s", c, authTypeNotSupported) } // CertAuthID - id of certificate authority (it's type and domain name) diff --git a/integration/helpers/instance.go b/integration/helpers/instance.go index 4b5db25997872..4089f54ec92c1 100644 --- a/integration/helpers/instance.go +++ b/integration/helpers/instance.go @@ -192,6 +192,21 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) { return nil, trace.Wrap(err) } + dbClientCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseClientCA, + ClusterName: s.SiteName, + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{ + Key: s.PrivKey, + KeyType: types.PrivateKeyType_RAW, + Cert: s.TLSCACert, + }}, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + osshCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ Type: types.OpenSSHCA, ClusterName: s.SiteName, @@ -207,7 +222,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) { return nil, trace.Wrap(err) } - return []types.CertAuthority{hostCA, userCA, dbCA, osshCA}, nil + return []types.CertAuthority{hostCA, userCA, dbCA, dbClientCA, osshCA}, nil } func (s *InstanceSecrets) AllowedLogins() []string { diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 8f3f4ecda568c..a66967ba27739 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -6035,7 +6035,7 @@ func newKeySet(ctx context.Context, keyStore *keystore.Manager, caID types.CertA } keySet.SSH = append(keySet.SSH, sshKeyPair) keySet.TLS = append(keySet.TLS, tlsKeyPair) - case types.DatabaseCA: + case types.DatabaseCA, types.DatabaseClientCA: // Database CA only contains TLS cert. tlsKeyPair, err := keyStore.NewTLSKeyPair(ctx, caID.DomainName) if err != nil { diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 8a5b388261273..2325ca75f142c 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -4773,26 +4773,40 @@ func (a *ServerWithRoles) SignDatabaseCSR(ctx context.Context, req *proto.Databa return a.authServer.SignDatabaseCSR(ctx, req) } -// GenerateDatabaseCert generates a certificate used by a database service -// to authenticate with the database instance. +// GenerateDatabaseCert generates a client certificate used by a database +// service to authenticate with the database instance, or a server certificate +// for configuring a self-hosted database, depending on the requester_name. // // This certificate can be requested by: // // - Cluster administrator using "tctl auth sign --format=db" command locally // on the auth server to produce a certificate for configuring a self-hosted // database. -// - Remote user using "tctl auth sign --format=db" command with a remote -// proxy (e.g. Teleport Cloud), as long as they can impersonate system -// role Db. +// - Remote user using "tctl auth sign --format=db" command or +// /webapi/sites/:site/sign/db with a remote proxy (e.g. Teleport Cloud), +// as long as they can impersonate system role Db. // - Database service when initiating connection to a database instance to // produce a client certificate. -// - Proxy service when generating mTLS files to a database func (a *ServerWithRoles) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { - // Check if the User can `create` DatabaseCertificates - err := a.action(apidefaults.Namespace, types.KindDatabaseCertificate, types.VerbCreate) + err := a.checkAccessToGenerateDatabaseCert(types.KindDatabaseCertificate) + if err != nil { + return nil, trace.Wrap(err) + } + return a.authServer.GenerateDatabaseCert(ctx, req) +} + +// checkAccessToGenerateDatabaseCert is a helper for checking db cert gen authz. +// Requester must have at least one of: +// - create: database_certificate or database_client_certificate. +// - built-in Admin or DB role. +// - allowed to impersonate the built-in DB role. +func (a *ServerWithRoles) checkAccessToGenerateDatabaseCert(resourceKind string) error { + const verb = types.VerbCreate + // Check if the User can `create` Database Certificates + err := a.action(apidefaults.Namespace, resourceKind, verb) if err != nil { if !trace.IsAccessDenied(err) { - return nil, trace.Wrap(err) + return trace.Wrap(err) } // Err is access denied, trying the old way @@ -4802,12 +4816,13 @@ func (a *ServerWithRoles) GenerateDatabaseCert(ctx context.Context, req *proto.D if !a.hasBuiltinRole(types.RoleDatabase, types.RoleAdmin) { if err := a.canImpersonateBuiltinRole(types.RoleDatabase); err != nil { log.WithError(err).Warnf("User %v tried to generate database certificate but does not have '%s' permission for '%s' kind, nor is allowed to impersonate %q system role", - a.context.User.GetName(), types.VerbCreate, types.KindDatabaseCertificate, types.RoleDatabase) - return nil, trace.AccessDenied(fmt.Sprintf("access denied. User must have '%s' permission for '%s' kind to generate the certificate ", types.VerbCreate, types.KindDatabaseCertificate)) + a.context.User.GetName(), verb, resourceKind, types.RoleDatabase) + return trace.AccessDenied("access denied. User must have '%s' permission for '%s' kind to generate the certificate ", + verb, resourceKind) } } } - return a.authServer.GenerateDatabaseCert(ctx, req) + return nil } // GenerateSnowflakeJWT generates JWT in the Snowflake required format. @@ -6233,10 +6248,10 @@ func (a *ServerWithRoles) GetAccountRecoveryCodes(ctx context.Context, req *prot // // - Windows desktop service when updating the certificate authority contents // on LDAP. -// - Cluster administrator using "tctl auth crl --type=db" command locally +// - Cluster administrator using "tctl auth crl --type=db_client" command locally // on the auth server to produce revocation list used to be configured on // external services such as Windows certificate store. -// - Remote user using "tctl auth crl --type=db" command with a remote +// - Remote user using "tctl auth crl --type=db_client" command with a remote // proxy (e.g. Teleport Cloud), as long as they have permission to read // certificate authorities. func (a *ServerWithRoles) GenerateCertAuthorityCRL(ctx context.Context, caType types.CertAuthType) ([]byte, error) { diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 6079c5d352f81..c386feeaa2da7 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -1088,7 +1088,7 @@ func TestRoleRequestDenyReimpersonation(t *testing.T) { } // TestGenerateDatabaseCert makes sure users and services with appropriate -// permissions can generate certificates for self-hosted databases. +// permissions can generate database certificates. func TestGenerateDatabaseCert(t *testing.T) { t.Parallel() ctx := context.Background() @@ -1109,22 +1109,26 @@ func TestGenerateDatabaseCert(t *testing.T) { require.NoError(t, err) tests := []struct { - desc string - identity TestIdentity - err string + desc string + identity TestIdentity + requester proto.DatabaseCertRequest_Requester + err string }{ { - desc: "user can't sign database certs", - identity: TestUser(userWithoutAccess.GetName()), - err: "access denied", + desc: "user can't sign database certs", + identity: TestUser(userWithoutAccess.GetName()), + requester: proto.DatabaseCertRequest_TCTL, + err: "access denied", }, { - desc: "user can impersonate Db and sign database certs", - identity: TestUser(userImpersonateDb.GetName()), + desc: "user can impersonate Db and sign database certs", + identity: TestUser(userImpersonateDb.GetName()), + requester: proto.DatabaseCertRequest_TCTL, }, { - desc: "built-in admin can sign database certs", - identity: TestAdmin(), + desc: "built-in admin can sign database certs", + identity: TestAdmin(), + requester: proto.DatabaseCertRequest_TCTL, }, { desc: "database service can sign database certs", @@ -1144,7 +1148,7 @@ func TestGenerateDatabaseCert(t *testing.T) { client, err := srv.NewClient(test.identity) require.NoError(t, err) - _, err = client.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{CSR: csr}) + _, err = client.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{CSR: csr, RequesterName: test.requester}) if test.err != "" { require.ErrorContains(t, err, test.err) } else { diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 97d567ef4f66d..ad54fbeddd3f9 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -862,8 +862,9 @@ type ClientI interface { // sessions created by the SAML identity provider. CreateSAMLIdPSession(context.Context, types.CreateSAMLIdPSessionRequest) (types.WebSession, error) - // GenerateDatabaseCert generates client certificate used by a database - // service to authenticate with the database instance. + // GenerateDatabaseCert generates a client certificate used by a database + // service to authenticate with the database instance, or a server certificate + // for configuring a self-hosted database, depending on the requester_name. GenerateDatabaseCert(context.Context, *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) // GetWebSession queries the existing web session described with req. diff --git a/lib/auth/db.go b/lib/auth/db.go index cbeea66836f68..1716f5a2d91c5 100644 --- a/lib/auth/db.go +++ b/lib/auth/db.go @@ -48,32 +48,89 @@ import ( // GenerateDatabaseCert generates client certificate used by a database // service to authenticate with the database instance. func (a *Server) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { - csr, err := tlsca.ParseCertificateRequestPEM(req.CSR) + if req.RequesterName == proto.DatabaseCertRequest_TCTL { + // tctl/web cert request needs to generate a db server cert and trust + // the db client CA. + return a.generateDatabaseServerCert(ctx, req) + } + // db service needs to generate a db client cert and trust the db server CA. + return a.generateDatabaseClientCert(ctx, req) +} + +// generateDatabaseServerCert generates database server certificate used by a +// database to authenticate itself to a database service. +func (a *Server) generateDatabaseServerCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + clusterName, err := a.GetClusterName() if err != nil { return nil, trace.Wrap(err) } - clusterName, err := a.GetClusterName() + // databases should be configured to trust the DatabaseClientCA when + // clients connect so return DatabaseClientCA in the response. + dbClientCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseClientCA, + DomainName: clusterName.GetClusterName(), + }, false) if err != nil { return nil, trace.Wrap(err) } - databaseCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ + dbServerCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ Type: types.DatabaseCA, DomainName: clusterName.GetClusterName(), }, true) if err != nil { - if trace.IsNotFound(err) { - // Database CA doesn't exist. Fallback to Host CA. - // https://github.com/gravitational/teleport/issues/5029 - databaseCA, err = a.GetCertAuthority(ctx, types.CertAuthID{ - Type: types.HostCA, - DomainName: clusterName.GetClusterName(), - }, true) - } - if err != nil { - return nil, trace.Wrap(err) - } + return nil, trace.Wrap(err) + } + + cert, err := a.generateDatabaseCert(ctx, req, dbServerCA) + if err != nil { + return nil, trace.Wrap(err) + } + return &proto.DatabaseCertResponse{ + Cert: cert, + CACerts: services.GetTLSCerts(dbClientCA), + }, nil +} + +// generateDatabaseClientCert generates client certificate used by a database +// service to authenticate with the database instance. +func (a *Server) generateDatabaseClientCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + clusterName, err := a.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + dbClientCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseClientCA, + DomainName: clusterName.GetClusterName(), + }, true) + if err != nil { + return nil, trace.Wrap(err) + } + + cert, err := a.generateDatabaseCert(ctx, req, dbClientCA) + if err != nil { + return nil, trace.Wrap(err) + } + // db clients should trust the Database Server CA when establishing + // connection to a database, so return that CA's certs in the response. + dbServerCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseCA, + DomainName: clusterName.GetClusterName(), + }, false) + if err != nil { + return nil, trace.Wrap(err) + } + return &proto.DatabaseCertResponse{ + Cert: cert, + CACerts: services.GetTLSCerts(dbServerCA), + }, nil +} + +func (a *Server) generateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest, ca types.CertAuthority) ([]byte, error) { + csr, err := tlsca.ParseCertificateRequestPEM(req.CSR) + if err != nil { + return nil, trace.Wrap(err) } - caCert, signer, err := getCAandSigner(ctx, a.GetKeyStore(), databaseCA, req) + caCert, signer, err := getCAandSigner(ctx, a.GetKeyStore(), ca, req) if err != nil { return nil, trace.Wrap(err) } @@ -100,15 +157,21 @@ func (a *Server) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCe // has been deprecated since Go 1.15: // https://golang.org/doc/go1.15#commonname certReq.DNSNames = getServerNames(req) + + // The windows smartcard cert req already does the same in + // lib/auth/windows/windows.go, along with another ExtKeyUsage for + // smartcard logon that we don't want to override above. + switch ca.GetType() { + case types.DatabaseCA: + // Override ExtKeyUsage to ExtKeyUsageServerAuth. + certReq.ExtraExtensions = append(certReq.ExtraExtensions, extKeyUsageServerAuthExtension) + case types.DatabaseClientCA: + // Override ExtKeyUsage to ExtKeyUsageClientAuth. + certReq.ExtraExtensions = append(certReq.ExtraExtensions, extKeyUsageClientAuthExtension) + } } cert, err := tlsCA.GenerateCertificate(certReq) - if err != nil { - return nil, trace.Wrap(err) - } - return &proto.DatabaseCertResponse{ - Cert: cert, - CACerts: services.GetTLSCerts(databaseCA), - }, nil + return cert, trace.Wrap(err) } // getCAandSigner returns correct signer and CA that should be used when generating database certificate. @@ -118,6 +181,7 @@ func (a *Server) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCe func getCAandSigner(ctx context.Context, keyStore *keystore.Manager, databaseCA types.CertAuthority, req *proto.DatabaseCertRequest, ) ([]byte, crypto.Signer, error) { if req.RequesterName == proto.DatabaseCertRequest_TCTL && + databaseCA.GetType() == types.DatabaseCA && databaseCA.GetRotation().Phase == types.RotationPhaseInit { return keyStore.GetAdditionalTrustedTLSCertAndSigner(ctx, databaseCA) } @@ -236,11 +300,21 @@ func (a *Server) GenerateSnowflakeJWT(ctx context.Context, req *proto.SnowflakeJ return nil, trace.Wrap(err) } ca, err := a.GetCertAuthority(ctx, types.CertAuthID{ - Type: types.DatabaseCA, + Type: types.DatabaseClientCA, DomainName: clusterName.GetClusterName(), }, true) if err != nil { - return nil, trace.Wrap(err) + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + // DatabaseClientCA doesn't exist, fallback to DatabaseCA. + ca, err = a.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseCA, + DomainName: clusterName.GetClusterName(), + }, true) + if err != nil { + return nil, trace.Wrap(err) + } } if len(ca.GetActiveKeys().TLS) == 0 { @@ -324,7 +398,31 @@ func filterExtensions(extensions []pkix.Extension, oids ...asn1.ObjectIdentifier return filtered } +// TODO(gavin): move OIDs from here and in lib/auth/windows to tlsca package. var ( oidExtKeyUsage = asn1.ObjectIdentifier{2, 5, 29, 37} oidSubjectAltName = asn1.ObjectIdentifier{2, 5, 29, 17} + + oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1} + oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2} + extKeyUsageServerAuthExtension = pkix.Extension{ + Id: oidExtKeyUsage, + Value: func() []byte { + val, err := asn1.Marshal([]asn1.ObjectIdentifier{oidExtKeyUsageServerAuth}) + if err != nil { + panic(err) + } + return val + }(), + } + extKeyUsageClientAuthExtension = pkix.Extension{ + Id: oidExtKeyUsage, + Value: func() []byte { + val, err := asn1.Marshal([]asn1.ObjectIdentifier{oidExtKeyUsageClientAuth}) + if err != nil { + panic(err) + } + return val + }(), + } ) diff --git a/lib/auth/db_test.go b/lib/auth/db_test.go index b362d0ec79b63..4127b9a2ed570 100644 --- a/lib/auth/db_test.go +++ b/lib/auth/db_test.go @@ -19,9 +19,19 @@ package auth import ( + "context" + "crypto/x509" + "crypto/x509/pkix" "testing" + "time" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/tlsca" ) func Test_getSnowflakeJWTParams(t *testing.T) { @@ -87,3 +97,113 @@ func Test_getSnowflakeJWTParams(t *testing.T) { }) } } + +func TestDBCertSigning(t *testing.T) { + t.Parallel() + authServer, err := NewTestAuthServer(TestAuthServerConfig{ + Clock: clockwork.NewFakeClockAt(time.Now()), + ClusterName: "local.me", + Dir: t.TempDir(), + }) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, authServer.Close()) }) + + ctx := context.Background() + + privateKey, err := testauthority.New().GeneratePrivateKey() + require.NoError(t, err) + + csr, err := tlsca.GenerateCertificateRequestPEM(pkix.Name{ + CommonName: "localhost", + }, privateKey) + require.NoError(t, err) + + // Set rotation to init phase. New CA will be generated. + // DB service should use active key to sign certificates. + // tctl should use new key to sign certificates. + err = authServer.AuthServer.RotateCertAuthority(ctx, types.RotateRequest{ + Type: types.DatabaseCA, + TargetPhase: types.RotationPhaseInit, + Mode: types.RotationModeManual, + }) + require.NoError(t, err) + err = authServer.AuthServer.RotateCertAuthority(ctx, types.RotateRequest{ + Type: types.DatabaseClientCA, + TargetPhase: types.RotationPhaseInit, + Mode: types.RotationModeManual, + }) + require.NoError(t, err) + + dbCAs, err := authServer.AuthServer.GetCertAuthorities(ctx, types.DatabaseCA, false) + require.NoError(t, err) + require.Len(t, dbCAs, 1) + require.Len(t, dbCAs[0].GetActiveKeys().TLS, 1) + require.Len(t, dbCAs[0].GetAdditionalTrustedKeys().TLS, 1) + activeDBCACert := dbCAs[0].GetActiveKeys().TLS[0].Cert + newDBCACert := dbCAs[0].GetAdditionalTrustedKeys().TLS[0].Cert + + dbClientCAs, err := authServer.AuthServer.GetCertAuthorities(ctx, types.DatabaseClientCA, false) + require.NoError(t, err) + require.Len(t, dbClientCAs, 1) + require.Len(t, dbClientCAs[0].GetActiveKeys().TLS, 1) + require.Len(t, dbClientCAs[0].GetAdditionalTrustedKeys().TLS, 1) + activeDBClientCACert := dbClientCAs[0].GetActiveKeys().TLS[0].Cert + newDBClientCACert := dbClientCAs[0].GetAdditionalTrustedKeys().TLS[0].Cert + + tests := []struct { + name string + requester proto.DatabaseCertRequest_Requester + wantCertSigner []byte + wantCACerts [][]byte + wantKeyUsage []x509.ExtKeyUsage + }{ + { + name: "DB service request is signed by active db client CA and trusts db CAs", + wantCertSigner: activeDBClientCACert, + wantCACerts: [][]byte{activeDBCACert, newDBCACert}, + wantKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }, + { + name: "tctl request is signed by new db CA and trusts db client CAs", + requester: proto.DatabaseCertRequest_TCTL, + wantCertSigner: newDBCACert, + wantCACerts: [][]byte{activeDBClientCACert, newDBClientCACert}, + wantKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + certResp, err := authServer.AuthServer.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{ + CSR: csr, + ServerName: "localhost", + TTL: proto.Duration(time.Hour), + RequesterName: tt.requester, + }) + require.NoError(t, err) + require.Equal(t, tt.wantCACerts, certResp.CACerts) + + // verify that the response cert is a DB CA cert. + mustVerifyCert(t, tt.wantCertSigner, certResp.Cert, tt.wantKeyUsage...) + }) + } +} + +// mustVerifyCert is a helper func that verifies leaf cert with root cert. +func mustVerifyCert(t *testing.T, rootPEM, leafPEM []byte, keyUsages ...x509.ExtKeyUsage) { + t.Helper() + leafCert, err := tlsca.ParseCertificatePEM(leafPEM) + require.NoError(t, err) + + certPool := x509.NewCertPool() + ok := certPool.AppendCertsFromPEM(rootPEM) + require.True(t, ok) + opts := x509.VerifyOptions{ + Roots: certPool, + KeyUsages: keyUsages, + } + // Verify if the generated certificate can be verified with the correct CA. + _, err = leafCert.Verify(opts) + require.NoError(t, err) +} diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 5173a43520c63..ca6fdbce1c5fc 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -436,6 +436,7 @@ func WatchEvents(watch *authpb.Watch, stream WatchEvent, componentName string, a AllowPartialSuccess: watch.AllowPartialSuccess, } + opInitHandler := maybeFilterCertAuthorityWatches(stream.Context(), &servicesWatch) events, err := auth.NewStream(stream.Context(), servicesWatch) if err != nil { return trace.Wrap(err) @@ -450,6 +451,9 @@ func WatchEvents(watch *authpb.Watch, stream WatchEvent, componentName string, a for events.Next() { event := events.Item() + if event.Type == types.OpInit && opInitHandler != nil { + opInitHandler(&event) + } if role, ok := event.Resource.(*types.RoleV6); ok { downgraded, err := maybeDowngradeRole(stream.Context(), role) if err != nil { @@ -475,6 +479,123 @@ func WatchEvents(watch *authpb.Watch, stream WatchEvent, componentName string, a return nil } +// dbClientCAVersionCutoff is the version starting from which we stop +// injecting a filter that drops DatabaseClientCA events. +var dbClientCACutoffVersion = semver.Version{Major: 14, Minor: 3, Patch: 1} + +// maybeFilterCertAuthorityWatches will inject a CA filter and return a +// function that removes the filter from the OpInit event if the client version +// does not support DatabaseClientCA type and if the client's CA WatchKind is +// trivial, i.e. it's not already filtering. Otherwise we assume that the client +// knows what it's doing and this function does nothing. +// This is a hack to avoid client cache re-init during CA rotation in older +// services that don't use a CA watch filter, i.e. every service except Node +// since v9. +// The returned function, if non-nil, must be called on the OpInit event, to +// remove the injected filter from the OpInit event's WatchStatus. This is to +// maintain the illusion to the client that CA events have not been filtered. +// If we did not remove the injected filter, then validateWatchRequest would +// fail on the client side because the confirmed kind filter is not as narrow or +// narrower than a trivial WatchKind. +// +// TODO(gavin): DELETE IN 16.0.0 - no supported clients will require this at +// that point. +func maybeFilterCertAuthorityWatches(ctx context.Context, watch *types.Watch) func(*types.Event) { + // check client version to see if it knows the DatabaseClientCA type. + clientVersion, err := getClientVersion(ctx) + if err != nil { + log.Debugf("Unable to determine client version: %v", err) + return nil + } + if versionHandlesDatabaseClientCAEvents(*clientVersion) { + // don't need to inject a CA filter if the client support DB Client CA. + return nil + } + + // search for trivial CA WatchKinds to inject a filter into - all of them + // must be trivial so we can remove the filter(s) from OpInit later. + var targets []*types.WatchKind + for i, k := range watch.Kinds { + if k.Kind != types.KindCertAuthority { + continue + } + + if !k.IsTrivial() { + // We need to remove the injected filter(s) from the OpInit event + // later. + // As a precaution, do nothing when any of the CA WatchKind(s) are + // non-trivial. + log.Debugf("Cannot inject filter into non-trivial CertAuthority watcher with client version %s.", clientVersion) + return nil + } + targets = append(targets, &watch.Kinds[i]) + } + if len(targets) == 0 { + return nil + } + + // create a CA filter that excludes DatabaseClientCA. + caFilter := make(types.CertAuthorityFilter, len(types.CertAuthTypes)-1) + for _, caType := range types.CertAuthTypes { + // exclude db client CA. + if caType == types.DatabaseClientCA { + continue + } + caFilter[caType] = types.Wildcard + } + + log.Debugf("Injecting filter for CertAuthority watcher with client version %s.", clientVersion) + for _, t := range targets { + t.Filter = caFilter.IntoMap() + } + + // return a func that removes the injected filter from the OpInit event. + // otherwise, client watchers may get confused by the upstream confirmed + // kinds. + return removeOpInitWatchStatusCAFilters +} + +func getClientVersion(ctx context.Context) (*semver.Version, error) { + clientVersionString, ok := metadata.ClientVersionFromContext(ctx) + if !ok { + return nil, trace.NotFound("no client version found in grpc context") + } + clientVersion, err := semver.NewVersion(clientVersionString) + return clientVersion, trace.Wrap(err) +} + +// versionHandlesDatabaseClientCAEvents returns true if the client version can +// handle the DatabaseClientCA, either because the client knows of the +// DatabaseClientCA or it uses a CA filter. This CA was introduced in backports. +// Client version in the intervals [v12.x, v13.0), [v13.y, v14.0), [v14.z, inf) +// can handle the DatabaseClientCA type, where x, y, z are the minor release +// versions that the DatabaseClientCA is backported to. +func versionHandlesDatabaseClientCAEvents(v semver.Version) bool { + v.PreRelease = "" // ignore pre-release tags + return !v.LessThan(dbClientCACutoffVersion) +} + +func removeOpInitWatchStatusCAFilters(e *types.Event) { + // this is paranoid, but make sure we don't panic or modify events that + // aren't OpInit. + if e == nil || e.Resource == nil || e.Type != types.OpInit { + return + } + status, ok := e.Resource.(types.WatchStatus) + if !ok || status == nil { + return + } + + kinds := status.GetKinds() + for i, k := range kinds { + if k.Kind != types.KindCertAuthority { + continue + } + kinds[i].Filter = nil + } + status.SetKinds(kinds) +} + // resourceLabel returns the label for the provided types.Event func resourceLabel(event types.Event) string { if event.Resource == nil { @@ -1404,8 +1525,9 @@ func (g *GRPCServer) SignDatabaseCSR(ctx context.Context, req *authpb.DatabaseCS return response, nil } -// GenerateDatabaseCert generates client certificate used by a database -// service to authenticate with the database instance. +// GenerateDatabaseCert generates a client certificate used by a database +// service to authenticate with the database instance, or a server certificate +// for configuring a self-hosted database, depending on the requester_name. func (g *GRPCServer) GenerateDatabaseCert(ctx context.Context, req *authpb.DatabaseCertRequest) (*authpb.DatabaseCertResponse, error) { auth, err := g.authenticate(ctx) if err != nil { diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 335d0c9d3809d..8f6b66138c6e0 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -21,6 +21,7 @@ package auth import ( "context" "crypto/x509" + "crypto/x509/pkix" "encoding/base32" "encoding/pem" "fmt" @@ -2168,6 +2169,29 @@ func TestGenerateHostCerts(t *testing.T) { require.NotNil(t, certs) } +func TestGenerateDatabaseCerts(t *testing.T) { + t.Parallel() + ctx := context.Background() + srv := newTestTLSServer(t) + + clt, err := srv.NewClient(TestAdmin()) + require.NoError(t, err) + + // Generate CSR once for speed sake. + priv, err := testauthority.New().GeneratePrivateKey() + require.NoError(t, err) + csr, err := tlsca.GenerateCertificateRequestPEM(pkix.Name{CommonName: "test"}, priv) + require.NoError(t, err) + + certs, err := clt.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{CSR: csr}) + require.NoError(t, err) + require.NotNil(t, certs) + + certs, err = clt.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{CSR: csr, RequesterName: proto.DatabaseCertRequest_TCTL}) + require.NoError(t, err) + require.NotNil(t, certs) +} + // TestInstanceCertAndControlStream uses an instance cert to send an // inventory ping via the control stream. func TestInstanceCertAndControlStream(t *testing.T) { @@ -4235,3 +4259,105 @@ func TestUpsertApplicationServerOrigin(t *testing.T) { _, err = client.UpsertApplicationServer(ctx, appServer) require.NoError(t, err) } + +func TestDropDBClientCAEvents(t *testing.T) { + server := newTestTLSServer(t) + client, err := server.NewClient(TestIdentity{ + I: authz.BuiltinRole{ + Role: types.RoleDatabase, + AdditionalSystemRoles: []types.SystemRole{ + types.RoleNode, + }, + Username: server.ClusterName(), + }, + }) + require.NoError(t, err) + + dbClientCAs, err := client.GetCertAuthorities(context.Background(), types.DatabaseClientCA, false) + require.NoError(t, err) + require.Len(t, dbClientCAs, 1) + + dbCAs, err := client.GetCertAuthorities(context.Background(), types.DatabaseCA, false) + require.NoError(t, err) + require.Len(t, dbCAs, 1) + + tests := []struct { + desc string + clientVersion string + filter map[string]string + expectDrop bool + }{ + { + desc: "send db client CA events to supported client", + clientVersion: dbClientCACutoffVersion.String(), + }, + { + desc: "drop db client CA events to unsupported client", + clientVersion: "14.0.0", + expectDrop: true, + }, + } + + for i, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + clientCtx := metadata.AddMetadataToContext(ctx, + map[string]string{ + metadata.VersionKey: test.clientVersion, + }) + + requestedKind := types.WatchKind{Kind: types.KindCertAuthority, LoadSecrets: false} + watcher, err := client.NewWatcher(clientCtx, types.Watch{ + Name: "cas", + Kinds: []types.WatchKind{requestedKind}, + }) + require.NoError(t, err) + defer watcher.Close() + + // Swallow the init event + e := <-watcher.Events() + require.Equal(t, types.OpInit, e.Type) + status, ok := e.Resource.(types.WatchStatus) + require.True(t, ok) + require.NotNil(t, status) + kinds := status.GetKinds() + for _, k := range kinds { + if k.Kind == types.KindCertAuthority { + require.Equal(t, requestedKind, k) + } + } + + // update the db client ca so the watcher gets an OpPut event + dbClientCAs[0].SetName(fmt.Sprintf("stub_%v", i)) + err = server.Auth().UpsertCertAuthority(ctx, dbClientCAs[0]) + require.NoError(t, err) + + // update the db ca so the watcher gets an OpPut event + dbCAs[0].SetName(fmt.Sprintf("stub_%v", i)) + err = server.Auth().UpsertCertAuthority(ctx, dbCAs[0]) + require.NoError(t, err) + + gotCA, err := func() (types.CertAuthority, error) { + for { + select { + case <-watcher.Done(): + return nil, watcher.Error() + case e := <-watcher.Events(): + if ca, ok := e.Resource.(types.CertAuthority); ok { + return ca, nil + } + } + } + }() + require.NoError(t, err) + if test.expectDrop { + // the watcher should only see the second ca event. + require.Equal(t, types.DatabaseCA, gotCA.GetType(), "db client CA event was supposed to be dropped") + return + } + // watcher should see the first event if it wasn't dropped. + require.Equal(t, types.DatabaseClientCA, gotCA.GetType(), "db client CA event was not supposed to be dropped") + }) + } +} diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index d3cf98cb1564f..370f2dee3815f 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -632,6 +632,17 @@ func (a *TestAuthServer) Trust(ctx context.Context, remote *TestAuthServer, role if err != nil { return trace.Wrap(err) } + remoteCA, err = remote.AuthServer.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseClientCA, + DomainName: remote.ClusterName, + }, false) + if err != nil { + return trace.Wrap(err) + } + err = a.AuthServer.UpsertCertAuthority(ctx, remoteCA) + if err != nil { + return trace.Wrap(err) + } remoteCA, err = remote.AuthServer.GetCertAuthority(ctx, types.CertAuthID{ Type: types.OpenSSHCA, DomainName: remote.ClusterName, diff --git a/lib/auth/init.go b/lib/auth/init.go index a40f9712f64eb..94b738a99ea4e 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -476,6 +476,12 @@ func initCluster(ctx context.Context, cfg InitConfig, asrv *Server) error { if err := migration.Apply(ctx, cfg.Backend); err != nil { return trace.Wrap(err, "applying migrations") } + span.AddEvent("migrating db_client_authority") + err = migrateDBClientAuthority(ctx, asrv.Trust, cfg.ClusterName.GetClusterName()) + if err != nil { + return trace.Wrap(err) + } + span.AddEvent("completed migration db_client_authority") // generate certificate authorities if they don't exist var ( @@ -906,7 +912,7 @@ func checkResourceConsistency(ctx context.Context, keyStore *keystore.Manager, c switch r.GetType() { case types.HostCA, types.UserCA, types.OpenSSHCA: _, signerErr = keyStore.GetSSHSigner(ctx, r) - case types.DatabaseCA, types.SAMLIDPCA: + case types.DatabaseCA, types.DatabaseClientCA, types.SAMLIDPCA: _, _, signerErr = keyStore.GetTLSCertAndSigner(ctx, r) case types.JWTSigner, types.OIDCIdPCA: _, signerErr = keyStore.GetJWTSigner(ctx, r) @@ -1432,3 +1438,14 @@ func applyResources(ctx context.Context, service *Services, resources []types.Re } return nil } + +// migrateDBClientAuthority copies Database CA as Database Client CA. +// Does nothing if the Database Client CA already exists. +// +// TODO(gavin): DELETE IN 16.0.0 +func migrateDBClientAuthority(ctx context.Context, trustSvc services.Trust, cluster string) error { + migrationStart(ctx, "db_client_authority") + defer migrationEnd(ctx, "db_client_authority") + err := migration.MigrateDBClientAuthority(ctx, trustSvc, cluster) + return trace.Wrap(err) +} diff --git a/lib/auth/init_test.go b/lib/auth/init_test.go index 8094b2548108b..a0929158a2df0 100644 --- a/lib/auth/init_test.go +++ b/lib/auth/init_test.go @@ -49,6 +49,7 @@ import ( "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/suite" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/proxy" @@ -1066,6 +1067,21 @@ spec: key: LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBdmZQYlVPV3RMYU9QYndzTDVSRENWemFXMnliRUVmbzZNd3RrYVVFditnQjhocVBwCjhyU1VPa2EyYzBJbU11WGFIa05Qa1lUaTNPQmsvQmdUN3RKQlh0TS9mSDJjUy85S0l2NENwRlBQdHhEWWRaKzIKUHArYlJENkpiNVh0R0VaM2w4T1k0SkdBblpramRTM3JvMnFNL1NhMnJrOStEaDVkN1VvWU13RVJMbnoxRTlqKwpjbVFMQWJxSUhOQnNEY3RTZklya013cUNqTXlqZDMyN1YxMG9ua3M1NXBUOG5Oc3N4NnBrM1lEK3JndVdwRk00CkRxdlhXbkVTYWxHMEZNbHZSRlR1TlhUL3BmdGNMT29FWVVwL3FFYkFTdkprbEppT3UvRmRlcUxhamtlbmZjbGIKeVZPa1o5ZjlXOTJ6UXlxK3BGeUpkQmtGN041bys3UUFOMDdBUlFJREFRQUJBb0lCQVFDWkFiTHBxUGdrU1JtaQpncTFrS0duQ3dwQWxtMFpZak0wUWpONm5BZ0ZaU2NjRTFVZi9Yb0gvcHpJVUNYYW5qUXB6VWhqbnlMak0zbHU1CnpOTlJqajlsMkpmTStZbEtsaXJyb053VDdnYmxHVWFqQ0xGT0pGWjNWRUIwaDduaDBmRkhhQ0RlMDVWY1hSeDQKcVRLa0FaSHI0S0ZLSzNJSWdXRjdZREc1OCtRWkl0N01BdDQyc1NrUlhaenNyaUVzVWxZNU9xT1dlVHlkYy93dQpDK292V1FlOUU2bnRSc0EvY25nS0M1bnp6T0pCano3SlVTMGQ2UnBGK1Zodk9SRlZ6YTZKMnpRRXJsWWtFdVVUCkJYdThsVmhhcDVOU09nd3ByZXVmb2tNZVY4eGRNdUxuVEthdHVjajIwT3JFRUlHZVFrb08yTzhTeDJtNitOS0kKbzBTT2kzU0pBb0dCQU5LVkk4R0JHQUpSUlV2NVBnVkNDQ1VHSkFBRmtxY0FMTFlaZ0NMZCtLYStBVW41L2VONwpWemR4b2VqNG1UcnFOb0RLTnk5b0N2Z1RBWTlWZ3RZZllpdlNQbW5ZZk9FbTJUR2pLRTJSYzhBd0R5UWdHWFJVCm5WOUVueHRUUHF5dU9tQTFqa0ladXBLeUU0b0drdTE4MDRVaTB2d2pJUWI3cDdML0ZGYksvaGlmQW9HQkFPYnIKclBrOHdicnBlVk9QWmVudzVzbjVhTUpjdDM3emt4dHZzejVSVWlsZnZtVkdja1lyMU41SGpmZjVVVGcwNmNWdgp1N0dGY0NLQmhncFdqSXM3dHZ4cDhUVFM1VEZaWHZqY2VhRC91ZkFiN3ZUVnR0L1hRMFUwQ28rcDdzMENKWlZKCjY2MUtUU1RCYlFtR0xYYmNqNmFYdVM3bzFJM1FRU1l4NW5LV1F5aWJBb0dCQUtBTWZocUtGVWRkb1g5MnRhNmwKV3k5WWxXLzJ6RmxsQnBaNGx5em83QjArK0JmVGl5V2tEc3V5NzgzemMvS1ZKRXVLWlpzQVJxWDVQQXhHZjZSaQpRZWp3YUVObUtMT3ZKUkJXNDBEaE5jcHlQRy9HZmRJdXBWVk5BR2h5UW9aWC9VSTJNaU1IRHdpRGs5b3AyTzNyCkc1QnF3VlNsRm1zS1JaRUQwZCtOZE1ZZEFvR0FkVDNQRXJQd1FHL3RzNmtvdTBBZVRRbWVVS0EyWWZSVkNpY0sKUUdlVmFZQTg4THAxcG43Mmt1eU5mZ3ROVzFZeUlwWDZHOFYrQzJicm9UQVVKMVRvTVB1eEJYclY5dHBEUitMWQp0ZzlnWGpJd2ZvcExVUmJBQnRESFUrMlpXdWp1SC8vcDhvKzQzeUo5czhvMkp4VVFzaXB5VVFqUmNqYjcvT0owCitGU21RR1VDZ1lBdU5XcFhrbVkzbGtRSEFpU3RlblhTT2Zqc0xsbm5XWFJoTDVBYTZqRVZRT3FnUldVQnZ2YWkKK1RQRTNUTHQ5MGwveTZOdjU0dDdrT1QvSlEvREU4WmFiMERlTjdBRzRwRjMwNkxpYlpZNmswc3M1UDNXME8vbAozekJzQ0lEY3BHOWw4bDZSNkUwdHN0Z0I1c25hOE4rRzA2V3Q1R0M0UitRSVd1YTVpUmtVTXc9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo= type: user sub_kind: user +version: v2` + databaseClientCAYAML = ` +kind: cert_authority +metadata: + id: 1696989861240620000 + name: me.localhost +spec: + active_keys: + tls: + - cert: LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURqVENDQW5XZ0F3SUJBZ0lSQU5XcUtsOWR3WGYrVTBWUmhxNGEyaTB3RFFZSktvWklodmNOQVFFTEJRQXcKWURFVk1CTUdBMVVFQ2hNTWJXVXViRzlqWVd4b2IzTjBNUlV3RXdZRFZRUURFd3h0WlM1c2IyTmhiR2h2YzNReApNREF1QmdOVkJBVVRKekk0TkRBd09URXhNams0TlRBek1qYzRPVEV3TURJNE5qUXlPREV6TmpRMk9UQTVNamt3Ck9UQWVGdzB5TXpFd01URXdNakEwTWpGYUZ3MHpNekV3TURnd01qQTBNakZhTUdBeEZUQVRCZ05WQkFvVERHMWwKTG14dlkyRnNhRzl6ZERFVk1CTUdBMVVFQXhNTWJXVXViRzlqWVd4b2IzTjBNVEF3TGdZRFZRUUZFeWN5T0RRdwpNRGt4TVRJNU9EVXdNekkzT0RreE1EQXlPRFkwTWpneE16WTBOamt3T1RJNU1Ea3dnZ0VpTUEwR0NTcUdTSWIzCkRRRUJBUVVBQTRJQkR3QXdnZ0VLQW9JQkFRRGlDNU5GVDhmUy9hSzdSSVVRVnFjVmFxaFBzMFNuU0prTmd3azEKUHZkeFV1OWZZMlNwek5NaUUzSGZlb0Y4S1h2YUU0aHJzMEFGOVRmYlpJTnM1RjNHNTNzOUg3Q2JXWHpOWVRtZApCN0gyWEVxVGp3N0xGL2pzYzkwcTN4ZnZqMkk0Z29tOUdYK3dGMXdaRldjZXVJRkJTdXdCRkV6a1Yzc1o5NEVqClBsWUIxK2lnNlJoWGhvUjdhRlJUNDFvZmtMUUovMDdBVmR4blUranp5VkVFSVk3SjUwUWU3bFc2Nk9wL3BncmwKR1FBSnkwbnowUVpVYVJjVmZrODVHK3NwMnhjcUJ6clJHbXNybmw1TmhMdGJqcUJIUkZ2cU5XS1pLa1V2M0NjUApiTytWT1krV1FmV0UzRThhekxUQ2ppcnJXYWVXeTNLR0RTZGF5YktOK0FKbVpqU3hBZ01CQUFHalFqQkFNQTRHCkExVWREd0VCL3dRRUF3SUJwakFQQmdOVkhSTUJBZjhFQlRBREFRSC9NQjBHQTFVZERnUVdCQlFhVmMzdXlYWnoKdDBWLzFnNzE3MzMrRjFhaHNqQU5CZ2txaGtpRzl3MEJBUXNGQUFPQ0FRRUF6NTVkVnVFdmVLdnJtYThzL0dWSQo2Q0t5akNNYXNjWmhwV1JIT3QxRjQ1T0pjcXg1RDBQeVhSenZXS2NTYzlZTkN0M1BzSi8yNGp3VDlLaElqK2NiClQ5Z0h5WXNkb3pWY2NzMXNZTkFjK3VFSmRSOEsydHJqa1JJN0Q5VmZvTEJJVFlHUkJGTWpSOEE1bENlUzVnTkgKRG42V09rSlpRUi9UQS9IbFFlUmttMW5teUp3VVVQOVA0aUVWVlVSS0lMRVVNTS9EdERXdTZuNnM2K0pVVXNDNwp5QmI2T3JQeVRGbkV4TFljN2RhYUM1bm5UVDZHY2xUSm4wYkJ2UmtXdUFVa1FtWXJyYkpBMnhEVjFBL0JOcmp3Ci9aU2ErU1ZlVWJxSW05ZEVESE4zQUhXcmJzbWwyVjI3YUtrMHVUK0JmeUJBZ3NSdGpMN0U2YUdJanlNcStlOW4KYmc9PQotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg== + key: LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFb3dJQkFBS0NBUUVBNGd1VFJVL0gwdjJpdTBTRkVGYW5GV3FvVDdORXAwaVpEWU1KTlQ3M2NWTHZYMk5rCnFjelRJaE54MzNxQmZDbDcyaE9JYTdOQUJmVTMyMlNEYk9SZHh1ZDdQUit3bTFsOHpXRTVuUWV4OWx4S2s0OE8KeXhmNDdIUGRLdDhYNzQ5aU9JS0p2Umwvc0JkY0dSVm5IcmlCUVVyc0FSUk01RmQ3R2ZlQkl6NVdBZGZvb09rWQpWNGFFZTJoVVUrTmFINUMwQ2Y5T3dGWGNaMVBvODhsUkJDR095ZWRFSHU1VnV1anFmNllLNVJrQUNjdEo4OUVHClZHa1hGWDVQT1J2cktkc1hLZ2M2MFJwcks1NWVUWVM3VzQ2Z1IwUmI2alZpbVNwRkw5d25EMnp2bFRtUGxrSDEKaE54UEdzeTB3bzRxNjFtbmxzdHloZzBuV3NteWpmZ0NabVkwc1FJREFRQUJBb0lCQUQ5R01EWkJxOVRDek0rUQowWktPUHZ6K3V4aDhQT1o2cXVVZVhmQjZyTGNiR1FoaGdTY0t2N3NWS0ZYL0s4bStydjJQWkN1SnBJMUdaQmxVCm5IbFp2MnBURjZzM2VLOHpzSHlwRDRDR1MrbURVaGpWL2JVYUE4TGtkKzl0UFgwQWJPVVduVW5Dbm55RFBYT0UKQ3phTlBSa3l5TGRRb0dsMmwyM2dXMVNyT1ZZUVBEUjZncWVJZVFYa3pHYUFUQ2twZWYrOVk4US9pTkZUR05oZAppamtXSUZOdEYzQjdIODUwdnR0VFFRckE3QXQ3ZnN2bmo2YVVDUkQ3MmFYZGJmeHIwK0VQbUR6WGNhejM3U20yClg3ZkJrakRFa0pCa0gwVnBnZHdvMDh3cmtzbnBieUNpbE95alp4Z3lhUWw2NGFKcGVUN1FHbEMxcm1kUEFmTU4KSEdweFBwVUNnWUVBK21vbllBVW1oSVpTR0R6VjI0NFc0QVRnVFdGb2gxb24wKzlZZ2xxR1RUZFFBM2dXYy9ReQowSmJ6QXpOVFZnODNibWhvU3NtbUMwZ1BoSytBb3BCZXc5ZlVxMmhIOVptUzVjZE1CaTJ6cFdvZHQrWXMvNys1Ckk3d3d2bGgvY3llelQvU0ZyazVCVVB0azZFR1pLZk1KdGZlNDcvb05uUzRmc3lmYlAyWDVUWHNDZ1lFQTV4WkcKSmhiYkwwWFljL0plc0JucXBuRFZHNXhkbkh6aGx0aVB3QzdGbHRDRUpuNFdhTUNpSUtsL0o3VGUzbndqMHk0YQpSVzFTWGN6anc5dHZxY0V3aE9CQUZBUHlsd0FWWjUyOTRzdWZzMnk4SHBFcFhhMjlIRXNReW50TE9JRFZyYkVsClJCV1pEb0xhbllVRTNtWml5WnZ0ZU9TT0hTajVhUnIzRTg1UmtNTUNnWUFweUVLUG8reGNXbWtpUUN4U3VPK2EKSzFZZHN5NFV2M2M3eG9qWEh6R2ZlcVl3SGY1cEZJclNBUTNGTC9Bc3dOYzM1ZFhZL0xKbTJYdzFZRzh2TUxXUApLZGtEVEtBTkc3WEYveTN4TGZqMmxiRWx1Uk16RFJOZ0lndGtCeklremEvK25FY2Q0VkxHcDF1YjRTNGtNTGdqCkU1VlkvVGorUys3Z0hydFhaYlZtTndLQmdET3A2eTBBMXlnT2VZSVNvZERGT296VGxSR0ROL3FRZ083MG84N1gKcGgwOXFRM2lDcWlJeUxaOHJvejJCdzIrdTFPdmJ2Z3VwTWVMMHpBcWt5QmtyTEJJWW9zWEJ0bHpqMVdIRXJqdAp4VnFiNk1MOHVUN1VaUDg2V1JxcnpmbG45RjNNeVFRYndBaGFnUDNPaTNRZGQrQ1RGOWg3WUxwc09yYWc3TFJrCjRCOTVBb0dCQU9maHNVSzVSZm1RU1ZzN08wMmMrdWFVelB2dnEwaXNqNW45dWlOaFQxdjFDUGY4YStZdkkyTisKcWV5bHkwRjN3L2sxbElaUzFjWlMwRDVWMUd6bmVHTUgreWYwVFAvNmlUcElHTC94N1pTNGJEZjE3dEc5dklDdgoya1JBTno5WHpzVzFETm9CRkJmZXg5NDFmT0RzdHlvdThvYmF3dDdJWThTdU1GMHV3aDlBCi0tLS0tRU5EIFJTQSBQUklWQVRFIEtFWS0tLS0tCg== + additional_trusted_keys: {} + cluster_name: me.localhost + type: db_client +sub_kind: db_client version: v2` databaseCAYAML = `kind: cert_authority metadata: @@ -1141,7 +1157,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA := resourceFromYAML(t, hostCAYAML).(types.CertAuthority) userCA := resourceFromYAML(t, userCAYAML).(types.CertAuthority) jwtCA := resourceFromYAML(t, jwtCAYAML).(types.CertAuthority) - dbCA := resourceFromYAML(t, databaseCAYAML).(types.CertAuthority) + dbServerCA := resourceFromYAML(t, databaseCAYAML).(types.CertAuthority) + dbClientCA := resourceFromYAML(t, databaseClientCAYAML).(types.CertAuthority) osshCA := resourceFromYAML(t, openSSHCAYAML).(types.CertAuthority) samlCA := resourceFromYAML(t, samlCAYAML).(types.CertAuthority) @@ -1151,8 +1168,10 @@ func TestInit_bootstrap(t *testing.T) { invalidUserCA.(*types.CertAuthorityV2).Spec.ActiveKeys.SSH = nil invalidJWTCA := resourceFromYAML(t, jwtCAYAML).(types.CertAuthority) invalidJWTCA.(*types.CertAuthorityV2).Spec.ActiveKeys.JWT = nil - invalidDBCA := resourceFromYAML(t, databaseCAYAML).(types.CertAuthority) - invalidDBCA.(*types.CertAuthorityV2).Spec.ActiveKeys.TLS = nil + invalidDBServerCA := resourceFromYAML(t, databaseCAYAML).(types.CertAuthority) + invalidDBServerCA.(*types.CertAuthorityV2).Spec.ActiveKeys.TLS = nil + invalidDBClientCA := resourceFromYAML(t, databaseClientCAYAML).(types.CertAuthority) + invalidDBClientCA.(*types.CertAuthorityV2).Spec.ActiveKeys.TLS = nil invalidOSSHCA := resourceFromYAML(t, openSSHCAYAML).(types.CertAuthority) invalidOSSHCA.(*types.CertAuthorityV2).Spec.ActiveKeys.SSH = nil invalidSAMLCA := resourceFromYAML(t, samlCAYAML).(types.CertAuthority) @@ -1172,7 +1191,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -1187,7 +1207,8 @@ func TestInit_bootstrap(t *testing.T) { invalidHostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -1202,7 +1223,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), invalidUserCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -1217,7 +1239,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), invalidJWTCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -1232,7 +1255,23 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - invalidDBCA.Clone(), + invalidDBServerCA.Clone(), + osshCA.Clone(), + samlCA.Clone(), + ) + }, + assertError: require.Error, + }, + { + name: "NOK bootstrap Database Client CA missing keys", + modifyConfig: func(cfg *InitConfig) { + cfg.BootstrapResources = append( + cfg.BootstrapResources, + hostCA.Clone(), + userCA.Clone(), + jwtCA.Clone(), + dbServerCA.Clone(), + invalidDBClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -1247,7 +1286,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), invalidOSSHCA.Clone(), samlCA.Clone(), ) @@ -1262,7 +1302,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), invalidSAMLCA.Clone(), ) @@ -1705,3 +1746,26 @@ func TestInitCreatesCertsIfMissing(t *testing.T) { require.Len(t, cert, 1) } } + +func TestMigrateDatabaseClientCA(t *testing.T) { + ctx := context.Background() + conf := setupConfig(t) + + hostCA := suite.NewTestCA(types.HostCA, "me.localhost") + userCA := suite.NewTestCA(types.UserCA, "me.localhost") + dbServerCA := suite.NewTestCA(types.DatabaseCA, "me.localhost") + + conf.Authorities = []types.CertAuthority{hostCA, userCA, dbServerCA} + auth, err := Init(ctx, conf) + require.NoError(t, err) + t.Cleanup(func() { + err = auth.Close() + require.NoError(t, err) + }) + + dbClientCAs, err := auth.GetCertAuthorities(ctx, types.DatabaseClientCA, true) + require.NoError(t, err) + require.Len(t, dbClientCAs, 1) + require.Equal(t, dbServerCA.Spec.ActiveKeys.TLS[0].Cert, dbClientCAs[0].GetActiveKeys().TLS[0].Cert) + require.Equal(t, dbServerCA.Spec.ActiveKeys.TLS[0].Key, dbClientCAs[0].GetActiveKeys().TLS[0].Key) +} diff --git a/lib/auth/migration/0001_db_ca.go b/lib/auth/migration/0001_db_ca.go index 0d7366c77ce60..4b3b18399cdd9 100644 --- a/lib/auth/migration/0001_db_ca.go +++ b/lib/auth/migration/0001_db_ca.go @@ -22,9 +22,7 @@ import ( "context" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" - "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/services" @@ -48,9 +46,9 @@ func (d createDBAuthority) Name() string { return "create_db_cas" } -// Up creates a Database CA for all known clusters. If a Database CA -// already exist for a cluster, it is skipped. If no Host or Database CA -// exist for a cluster, it is also skipped. +// Up creates a new CA for all known clusters as a copy of the old CA. +// If the new CA already exists for a cluster, it is skipped. +// If neither the old nor the new CA exist for a cluster, it is also skipped. func (d createDBAuthority) Up(ctx context.Context, b backend.Backend) error { ctx, span := tracer.Start(ctx, "createDBAuthority/Up") defer span.End() @@ -81,10 +79,6 @@ func (d createDBAuthority) Up(ctx context.Context, b backend.Backend) error { } presenceSvc := d.presenceServiceFn(b) - return trace.Wrap(d.up(ctx, configSvc, trustSvc, presenceSvc)) -} - -func (d createDBAuthority) up(ctx context.Context, configSvc services.ClusterConfiguration, trustSvc services.Trust, presenceSvc services.Presence) error { localClusterName, err := configSvc.GetClusterName() if err != nil { return trace.Wrap(err) @@ -102,65 +96,17 @@ func (d createDBAuthority) up(ctx context.Context, configSvc services.ClusterCon } for _, cluster := range allClusters { - _, err := trustSvc.GetCertAuthority(ctx, types.CertAuthID{Type: types.DatabaseCA, DomainName: cluster}, false) - // The migration for this cluster can be skipped since - // a Database CA already exists. - if err == nil { - continue - } - - if err != nil && !trace.IsNotFound(err) { - return trace.Wrap(err) - } - - // The Database CA does not exists, so we must check to - // see if the Host CA exists before proceeding with the migration. - // If both the Database and Host CA do not exist, then this cluster - // is brand new and the migration can be avoided because they will - // both automatically be created. If the Host CA does exist, then - // a new Database CA should be constructed from it. - hostCA, err := trustSvc.GetCertAuthority(ctx, types.CertAuthID{Type: types.HostCA, DomainName: cluster}, false) - if trace.IsNotFound(err) { - continue - } - if err != nil { - return trace.Wrap(err) - } - - logrus.Infof("Migrating Database CA cluster: %s", cluster) - - ca, ok := hostCA.(*types.CertAuthorityV2) - if !ok { - return trace.BadParameter("expected host CA to be *types.CertAuthorityV2, got %T", hostCA) - } - - dbCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ - Type: types.DatabaseCA, - ClusterName: cluster, - ActiveKeys: types.CAKeySet{ - TLS: ca.Spec.ActiveKeys.TLS, - }, - }) + err := migrateDBAuthority(ctx, trustSvc, cluster, types.HostCA, types.DatabaseCA) if err != nil { return trace.Wrap(err) } - - err = trustSvc.CreateCertAuthority(ctx, dbCA) - if trace.IsAlreadyExists(err) { - logrus.Warn("Database CA has already been created by a different Auth instance") - continue - } else if err != nil { - return trace.Wrap(err) - } } - return nil } -// Down deletes existing Database CAs for all clusters. +// Down deletes any existing CAs of the new CA type for all clusters. func (d createDBAuthority) Down(ctx context.Context, b backend.Backend) error { - tracer := tracing.NewTracer("migrations") - _, span := tracer.Start(ctx, "migrations/CreateDBAuthorityDown") + _, span := tracer.Start(ctx, "CreateDBAuthorityDown") defer span.End() if d.trustServiceFn == nil { @@ -172,3 +118,72 @@ func (d createDBAuthority) Down(ctx context.Context, b backend.Backend) error { trustSvc := d.trustServiceFn(b) return trace.Wrap(trustSvc.DeleteAllCertAuthorities(types.DatabaseCA)) } + +// MigrateDBClientAuthority performs a migration which creates the db_client CA +// as a copy of the existing db CA for backwards compatibility. +func MigrateDBClientAuthority(ctx context.Context, trustSvc services.Trust, cluster string) error { + err := migrateDBAuthority(ctx, trustSvc, cluster, types.DatabaseCA, types.DatabaseClientCA) + return trace.Wrap(err) +} + +// migrateDBAuthority performs a migration which creates a new CA from an +// existing CA. +// The new CA is created as a copy of the existing CA for backwards +// compatibility. +// This func is generalized for copying db/db_client CAs, although it may appear +// to be usable for other CA types - that's why it is unexported. +func migrateDBAuthority(ctx context.Context, trustSvc services.Trust, cluster string, fromType, toType types.CertAuthType) error { + _, err := trustSvc.GetCertAuthority(ctx, types.CertAuthID{ + Type: toType, + DomainName: cluster, + }, false) + // The migration for this cluster can be skipped since + // the new CA already exists. + if err == nil { + log.Debugf("Migrations: cert authority %q already exists.", toType) + return nil + } + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + + // The new CA type does not exist, so we must check to + // see if the existing CA exists before proceeding with the split migration. + // If both the existing and new CA do not exist, then this cluster + // is brand new and the migration can be avoided because they will + // both automatically be created. If the existing CA does exist, then + // a new CA should be constructed from it as a copy. + existingCA, err := trustSvc.GetCertAuthority(ctx, types.CertAuthID{ + Type: fromType, + DomainName: cluster, + }, true) + if trace.IsNotFound(err) { + return nil + } + if err != nil { + return trace.Wrap(err) + } + + log.Infof("Migrating %s CA for cluster: %s", toType, cluster) + + existingCAV2, ok := existingCA.(*types.CertAuthorityV2) + if !ok { + return trace.BadParameter("expected %s CA to be *types.CertAuthorityV2, got %T", fromType, existingCA) + } + + newCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: toType, + ClusterName: cluster, + ActiveKeys: existingCAV2.Spec.ActiveKeys, + }) + if err != nil { + return trace.Wrap(err) + } + + err = trustSvc.CreateCertAuthority(ctx, newCA) + if trace.IsAlreadyExists(err) { + log.Warnf("%s CA has already been created by a different Auth instance", toType) + return nil + } + return trace.Wrap(err) +} diff --git a/lib/auth/migration/migration.go b/lib/auth/migration/migration.go index 45973174a09c2..d33dfe41993b8 100644 --- a/lib/auth/migration/migration.go +++ b/lib/auth/migration/migration.go @@ -30,6 +30,7 @@ import ( "go.opentelemetry.io/otel/codes" oteltrace "go.opentelemetry.io/otel/trace" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/lib/backend" ) @@ -49,6 +50,10 @@ func withMigrations(m []migration) func(c *applyConfig) { } } +var log = logrus.WithFields(logrus.Fields{ + trace.Component: teleport.ComponentAuth, +}) + var tracer = tracing.NewTracer("migrations") // migration is an interface responsible for applying data migrations to the backend. @@ -120,7 +125,7 @@ func Apply(ctx context.Context, b backend.Backend, opts ...func(c *applyConfig)) continue } - logrus.Infof("Starting migration %d %s", version, m.Name()) + log.Infof("Starting migration %d %s", version, m.Name()) span.AddEvent("Starting migration", oteltrace.WithAttributes(attribute.Int("migration", version))) started := time.Now().UTC() @@ -140,7 +145,7 @@ func Apply(ctx context.Context, b backend.Backend, opts ...func(c *applyConfig)) return trace.Wrap(err) } - logrus.Infof("Completed migration %d %s", version, m.Name()) + log.Infof("Completed migration %d %s", version, m.Name()) span.AddEvent("Completed migration", oteltrace.WithAttributes(attribute.Int("migration", version))) } diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 7ddc8692c2499..365135251931f 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -1982,6 +1982,11 @@ func TestGetCertAuthority(t *testing.T) { Type: types.DatabaseCA, }, true) require.True(t, trace.IsAccessDenied(err)) + _, err = proxyClt.GetCertAuthority(ctx, types.CertAuthID{ + DomainName: testSrv.ClusterName(), + Type: types.DatabaseClientCA, + }, true) + require.True(t, trace.IsAccessDenied(err)) _, err = proxyClt.GetCertAuthority(ctx, types.CertAuthID{ DomainName: testSrv.ClusterName(), diff --git a/lib/auth/windows/ldap.go b/lib/auth/windows/ldap.go index 9b50e3115ada1..4970500c34628 100644 --- a/lib/auth/windows/ldap.go +++ b/lib/auth/windows/ldap.go @@ -314,7 +314,7 @@ func crlDN(clusterName string, config LDAPConfig, caType types.CertAuthType) str // Note: UserCA must use "Teleport" to keep backwards compatibility. func crlKeyName(caType types.CertAuthType) string { switch caType { - case types.DatabaseCA: + case types.DatabaseClientCA, types.DatabaseCA: return "TeleportDB" default: return "Teleport" diff --git a/lib/auth/windows/windows_test.go b/lib/auth/windows/windows_test.go index f40c546861999..1331af1796ed6 100644 --- a/lib/auth/windows/windows_test.go +++ b/lib/auth/windows/windows_test.go @@ -161,7 +161,7 @@ func TestCRLDN(t *testing.T) { { name: "database CA", clusterName: "cluster.goteleport.com", - caType: types.DatabaseCA, + caType: types.DatabaseClientCA, crlDN: "CN=cluster.goteleport.com,CN=TeleportDB,CN=CDP,CN=Public Key Services,CN=Services,CN=Configuration,DC=test,DC=goteleport,DC=com", }, { diff --git a/lib/cache/collections.go b/lib/cache/collections.go index aadf1353b5b89..01de74c710864 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -1010,8 +1010,11 @@ func (e certAuthorityExecutor) getAll(ctx context.Context, cache *Cache, loadSec cas, err := cache.Trust.GetCertAuthorities(ctx, caType, loadSecrets) // if caType was added in this major version we might get a BadParameter // error if we're connecting to an older upstream that doesn't know about it - if err != nil && !(caType.NewlyAdded() && trace.IsBadParameter(err)) { - return nil, trace.Wrap(err) + if err != nil { + if !(types.IsUnsupportedAuthorityErr(err) && caType.NewlyAdded()) { + return nil, trace.Wrap(err) + } + continue } // this can be removed once we get the ability to fetch CAs with a filter, diff --git a/lib/client/api_test.go b/lib/client/api_test.go index 33a5bad286135..47c6669ff1a38 100644 --- a/lib/client/api_test.go +++ b/lib/client/api_test.go @@ -792,6 +792,15 @@ func TestVirtualPathNames(t *testing.T) { "TSH_VIRTUAL_PATH_CA", }, }, + { + name: "database client ca", + kind: VirtualPathCA, + params: VirtualPathCAParams(types.DatabaseClientCA), + expected: []string{ + "TSH_VIRTUAL_PATH_CA_DB_CLIENT", + "TSH_VIRTUAL_PATH_CA", + }, + }, { name: "host ca", kind: VirtualPathCA, diff --git a/lib/client/ca_export.go b/lib/client/ca_export.go index 531eed120b6b7..c27a6f2306c21 100644 --- a/lib/client/ca_export.go +++ b/lib/client/ca_export.go @@ -119,6 +119,20 @@ func exportAuth(ctx context.Context, client auth.ClientI, req ExportAuthoritiesR ExportPrivateKeys: exportSecrets, } return exportTLSAuthority(ctx, client, req) + case "db-client": + req := exportTLSAuthorityRequest{ + AuthType: types.DatabaseClientCA, + UnpackPEM: false, + ExportPrivateKeys: exportSecrets, + } + return exportTLSAuthority(ctx, client, req) + case "db-client-der": + req := exportTLSAuthorityRequest{ + AuthType: types.DatabaseClientCA, + UnpackPEM: true, + ExportPrivateKeys: exportSecrets, + } + return exportTLSAuthority(ctx, client, req) case "tls-user-der", "windows": req := exportTLSAuthorityRequest{ AuthType: types.UserCA, diff --git a/lib/client/ca_export_test.go b/lib/client/ca_export_test.go index d37d16bb07353..ddb26629d29fe 100644 --- a/lib/client/ca_export_test.go +++ b/lib/client/ca_export_test.go @@ -25,6 +25,7 @@ import ( "fmt" "testing" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -34,6 +35,8 @@ import ( type mockAuthClient struct { auth.ClientI server *auth.Server + + unsupportedCATypes []types.CertAuthType } func (m *mockAuthClient) GetDomainName(ctx context.Context) (string, error) { @@ -41,10 +44,20 @@ func (m *mockAuthClient) GetDomainName(ctx context.Context) (string, error) { } func (m *mockAuthClient) GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error) { + for _, unsupported := range m.unsupportedCATypes { + if unsupported == caType { + return nil, trace.BadParameter("%q authority type is not supported", unsupported) + } + } return m.server.GetCertAuthorities(ctx, caType, loadKeys) } func (m *mockAuthClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) { + for _, unsupported := range m.unsupportedCATypes { + if unsupported == id.Type { + return nil, trace.BadParameter("%q authority type is not supported", unsupported) + } + } return m.server.GetCertAuthority(ctx, id, loadKeys) } @@ -202,6 +215,15 @@ func TestExportAuthorities(t *testing.T) { }, assertSecrets: validatePrivateKeyPEMFunc, }, + { + name: "db", + req: ExportAuthoritiesRequest{ + AuthType: "db", + }, + errorCheck: require.NoError, + assertNoSecrets: validateTLSCertificatePEMFunc, + assertSecrets: validatePrivateKeyPEMFunc, + }, { name: "db-der", req: ExportAuthoritiesRequest{ @@ -211,6 +233,24 @@ func TestExportAuthorities(t *testing.T) { assertNoSecrets: validateTLSCertificateDERFunc, assertSecrets: validatePrivateKeyDERFunc, }, + { + name: "db-client", + req: ExportAuthoritiesRequest{ + AuthType: "db-client", + }, + errorCheck: require.NoError, + assertNoSecrets: validateTLSCertificatePEMFunc, + assertSecrets: validatePrivateKeyPEMFunc, + }, + { + name: "db-client-der", + req: ExportAuthoritiesRequest{ + AuthType: "db-client-der", + }, + errorCheck: require.NoError, + assertNoSecrets: validateTLSCertificateDERFunc, + assertSecrets: validatePrivateKeyDERFunc, + }, } { t.Run(fmt.Sprintf("%s_exportSecrets_%v", tt.name, exportSecrets), func(t *testing.T) { mockedClient := &mockAuthClient{ diff --git a/lib/client/db/database_certificates.go b/lib/client/db/database_certificates.go index 8c14009b18084..b562d5d958c6f 100644 --- a/lib/client/db/database_certificates.go +++ b/lib/client/db/database_certificates.go @@ -26,6 +26,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/identityfile" @@ -47,8 +48,8 @@ type GenerateDatabaseCertificatesRequest struct { Password string } -// GenerateDatabaseCertificates to be used by databases to set up mTLS authentication -func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCertificatesRequest) ([]string, error) { +// GenerateDatabaseServerCertificates to be used by databases to set up mTLS authentication +func GenerateDatabaseServerCertificates(ctx context.Context, req GenerateDatabaseCertificatesRequest) ([]string, error) { if len(req.Principals) == 0 || (len(req.Principals) == 1 && req.Principals[0] == "" && req.OutputFormat != identityfile.FormatSnowflake) { @@ -65,6 +66,11 @@ func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCerti subject := pkix.Name{CommonName: req.Principals[0]} + clusterNameType, err := req.ClusterAPI.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + clusterName := clusterNameType.GetClusterName() if req.OutputFormat == identityfile.FormatMongo { // Include Organization attribute in MongoDB certificates as well. // @@ -77,12 +83,7 @@ func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCerti // MongoDB cluster members so set it to the Teleport cluster name // to avoid hardcoding anything. - clusterNameType, err := req.ClusterAPI.GetClusterName() - if err != nil { - return nil, trace.Wrap(err) - } - - subject.Organization = []string{clusterNameType.GetClusterName()} + subject.Organization = []string{clusterName} } if req.Key == nil { @@ -114,6 +115,25 @@ func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCerti return nil, trace.Wrap(err) } + // For CockroachDB we provide node.crt, node.key, ca.crt, and ca-client.crt, + // and the user must use their own CA to issue client.node.crt, + // client.node.key, and add their own CA's cert to ca-client.crt. + // The response CA certs are for Teleport DB Client CA, so fetch the + // Teleport Database CA certs as well. + var additionalCACerts [][]byte + if req.OutputFormat == identityfile.FormatCockroach { + dbServerCA, err := req.ClusterAPI.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseCA, + DomainName: clusterName, + }, false) + if err != nil { + return nil, trace.Wrap(err) + } + for _, keyPair := range dbServerCA.GetTrustedTLSKeyPairs() { + additionalCACerts = append(additionalCACerts, keyPair.Cert) + } + } + req.Key.TLSCert = resp.Cert req.Key.TrustedCerts = []auth.TrustedCerts{{ ClusterName: req.Key.ClusterName, @@ -126,6 +146,7 @@ func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCerti OverwriteDestination: req.OutputCanOverwrite, Writer: req.IdentityFileWriter, Password: req.Password, + AdditionalCACerts: additionalCACerts, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/identityfile/identity.go b/lib/client/identityfile/identity.go index b93df2bd58de9..85cbc48d802e5 100644 --- a/lib/client/identityfile/identity.go +++ b/lib/client/identityfile/identity.go @@ -203,6 +203,9 @@ type WriteConfig struct { Writer ConfigWriter // Password is the password for the JKS keystore used by Cassandra format and Oracle wallet. Password string + // AdditionalCACerts contains additional CA certs, used by Cockroach format + // to distinguish DB Server CA certs from DB Client CA certs. + AdditionalCACerts [][]byte } // Write writes user credentials to disk in a specified format. @@ -293,18 +296,54 @@ func Write(ctx context.Context, cfg WriteConfig) (filesWritten []string, err err } } - case FormatTLS, FormatDatabase, FormatCockroach, FormatRedis, FormatElasticsearch, FormatScylla: + case FormatCockroach: + // CockroachDB expects files to be named node.crt, node.key, ca.crt, + // ca-client.crt + certPath := filepath.Join(cfg.OutputPath, "node.crt") + keyPath := filepath.Join(cfg.OutputPath, "node.key") + casPath := filepath.Join(cfg.OutputPath, "ca.crt") + clientCAsPath := filepath.Join(cfg.OutputPath, "ca-client.crt") + + filesWritten = append(filesWritten, certPath, keyPath, casPath, clientCAsPath) + if err := checkOverwrite(ctx, writer, cfg.OverwriteDestination, filesWritten...); err != nil { + return nil, trace.Wrap(err) + } + + err = writer.WriteFile(certPath, cfg.Key.TLSCert, identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + + err = writer.WriteFile(keyPath, cfg.Key.PrivateKeyPEM(), identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + + var serverCACerts []byte + for _, cert := range cfg.AdditionalCACerts { + serverCACerts = append(serverCACerts, cert...) + } + err = writer.WriteFile(casPath, serverCACerts, identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + + var clientCACerts []byte + for _, ca := range cfg.Key.TrustedCerts { + for _, cert := range ca.TLSCertificates { + clientCACerts = append(clientCACerts, cert...) + } + } + err = writer.WriteFile(clientCAsPath, clientCACerts, identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + + case FormatTLS, FormatDatabase, FormatRedis, FormatElasticsearch, FormatScylla: keyPath := cfg.OutputPath + ".key" certPath := cfg.OutputPath + ".crt" casPath := cfg.OutputPath + ".cas" - // CockroachDB expects files to be named ca.crt, node.crt and node.key. - if cfg.Format == FormatCockroach { - keyPath = filepath.Join(cfg.OutputPath, "node.key") - certPath = filepath.Join(cfg.OutputPath, "node.crt") - casPath = filepath.Join(cfg.OutputPath, "ca.crt") - } - filesWritten = append(filesWritten, keyPath, certPath, casPath) if err := checkOverwrite(ctx, writer, cfg.OverwriteDestination, filesWritten...); err != nil { return nil, trace.Wrap(err) @@ -482,31 +521,39 @@ func writeOracleFormat(cfg WriteConfig, writer ConfigWriter) ([]string, error) { if err != nil { return nil, trace.Wrap(err) } - var caCerts []*x509.Certificate - for _, ca := range cfg.Key.TrustedCerts { - for _, cert := range ca.TLSCertificates { - c, err := tlsca.ParseCertificatePEM(cert) - if err != nil { - return nil, trace.Wrap(err) - } - caCerts = append(caCerts, c) - } - } - pf, err := pkcs12.LegacyRC2.WithRand(rand.Reader).Encode(keyK, certBlock, caCerts, cfg.Password) + // encode the private key and cert. + // orapki import_pkcs12 refuses to add trusted certs unless they are an + // issuer for an oracle wallet user_cert, and the server cert we create + // is not signed by the DB Client CA, so don't pass trusted certs + // (DB Client CA) here. + pf, err := pkcs12.LegacyRC2.WithRand(rand.Reader).Encode(keyK, certBlock, nil, cfg.Password) if err != nil { return nil, trace.Wrap(err) } p12Path := cfg.OutputPath + ".p12" - certPath := cfg.OutputPath + ".crt" - if err := writer.WriteFile(p12Path, pf, identityfile.FilePermissions); err != nil { return nil, trace.Wrap(err) } - err = writer.WriteFile(certPath, cfg.Key.TLSCert, identityfile.FilePermissions) - if err != nil { - return nil, trace.Wrap(err) + + clientCAs := cfg.Key.TLSCAs() + var caPaths []string + for i, caPEM := range clientCAs { + var caPath string + if len(clientCAs) > 1 { + // orapki wallet add can only add one trusted cert at a time, so we + // output up to two files - one for each CA key to trust during a + // rotation. + caPath = fmt.Sprintf("%s.ca-client-%d.crt", cfg.OutputPath, i) + } else { + caPath = cfg.OutputPath + ".ca-client.crt" + } + err = writer.WriteFile(caPath, caPEM, identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + caPaths = append(caPaths, caPath) } // Is ORAPKI binary is available is user env run command ang generate autologin Oracle wallet. @@ -514,15 +561,17 @@ func writeOracleFormat(cfg WriteConfig, writer ConfigWriter) ([]string, error) { // Is Orapki is available in the user env create the Oracle wallet directly. // otherwise Orapki tool needs to be executed on the server site to import keypair to // Oracle wallet. - if err := createOracleWallet(cfg.OutputPath, p12Path, certPath, cfg.Password); err != nil { + if err := createOracleWallet(caPaths, cfg.OutputPath, p12Path, cfg.Password); err != nil { return nil, trace.Wrap(err) } // If Oracle Wallet was created the raw p12 keypair and trusted cert are no longer needed. if err := os.Remove(p12Path); err != nil { return nil, trace.Wrap(err) } - if err := os.Remove(certPath); err != nil { - return nil, trace.Wrap(err) + for _, caPath := range caPaths { + if err := os.Remove(caPath); err != nil { + return nil, trace.Wrap(err) + } } // Return the path to the Oracle wallet. return []string{cfg.OutputPath}, nil @@ -530,7 +579,7 @@ func writeOracleFormat(cfg WriteConfig, writer ConfigWriter) ([]string, error) { // Otherwise return destinations to p12 keypair and trusted CA allowing a user to run the convert flow on the // Oracle server instance in order to create Oracle wallet file. - return []string{p12Path, certPath}, nil + return append([]string{p12Path}, caPaths...), nil } const ( @@ -542,7 +591,7 @@ func isOrapkiAvailable() bool { return err == nil } -func createOracleWallet(walletPath, p12Path, certPath, password string) error { +func createOracleWallet(caCertPaths []string, walletPath, p12Path, password string) error { errDetailsFormat := "\n\nOrapki command:\n%s \n\nCompleted with following error: \n%s" // Create Raw Oracle wallet with auto_login_only flag - no password required. args := []string{ @@ -566,16 +615,18 @@ func createOracleWallet(walletPath, p12Path, certPath, password string) error { return trace.Wrap(err, fmt.Sprintf(errDetailsFormat, cmd.String(), output)) } - // Add import teleport CA to the oracle wallet. - args = []string{ - "wallet", "add", "-wallet", walletPath, - "-trusted_cert", - "-auto_login_only", - "-cert", certPath, - } - cmd = exec.Command(orapkiBinary, args...) - if output, err := exec.Command(orapkiBinary, args...).CombinedOutput(); err != nil { - return trace.Wrap(err, fmt.Sprintf(errDetailsFormat, cmd.String(), output)) + // Add import teleport CA(s) to the oracle wallet. + for _, certPath := range caCertPaths { + args = []string{ + "wallet", "add", "-wallet", walletPath, + "-trusted_cert", + "-auto_login_only", + "-cert", certPath, + } + cmd = exec.Command(orapkiBinary, args...) + if output, err := exec.Command(orapkiBinary, args...).CombinedOutput(); err != nil { + return trace.Wrap(err, fmt.Sprintf(errDetailsFormat, cmd.String(), output)) + } } return nil } diff --git a/lib/fixtures/keys.go b/lib/fixtures/keys.go index b9d3ebaa93af1..fa6c06111aab9 100644 --- a/lib/fixtures/keys.go +++ b/lib/fixtures/keys.go @@ -61,6 +61,34 @@ MHeDg2Bs7/XZsIrn6vo7kXmQSoQKA8O2E7rYSigUayBKa/+5thbnjKlEP+slBzmp 7JPquig/B6L2pNoxPa41VDGawQjJY5m4l3ap/oJj61HBB+Auf29BWXqg7V7B7XMB NFJgTFxC2o3mVBkQ/s6FeDl62hpMheCuO6jRYbZjsM2tUeAKORws -----END RSA PRIVATE KEY-----`), + "rsa-db-client": []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA3scV81B4bpa0qkpBsJDUf3UOIs6A4+WZnf5eXJcJg7zDi5/J +2vtBuk8CTvp8eQhu4Pq2G0RhHZzrYiBMWLf/ORzBSKrN+bQi9pRKNzN/hJ7SOq+T +tkyv5tGNn+PIGzX712Ao9Iw5TIJTy2QpiWKQ7MFiVAs399B0Ow7aHNRmdrB3jcJN +V/HizGkno8SF7EgMGwlIG+z8pE9PmlKV5ZkX0fz+W1IMMkiveGgy0tTNazTsnwOE +meBDTyB8YbLGPqiQJxoTEHWaQUTLNfWXdvD2x15UBgWiKA/Ng05fPpDqR4cYgM+n +Nv8lhwKprWcJmF34HVNO6xhFkMU78EIW4lFeYQIDAQABAoIBACIHsU+wnCTwenqE +y1IIXZ12qQkiGEg3u2aKA6oLHFX2ULyUVQZRWTH3fbfIxZjLc/yD76tsn5UhckdT +/bWTrbXwsYnDJaGeJbUa49dY04LTq/Nw/JRdVIViv0qMRfX6IhU9SCRLAzmvstMf +4sRsvQydYcLKz+rX+dlHpIPA4kIA27wqEGbaCb4WatOYf4kUvIyBT/A/QrdRQfPK +YTsQoQk8TNMdeTCGyyHqBSnbiI9r7EmWrH2r7FdNSFNsoH1FyX35sD4dLC/vjXNT +dDT7cnwwIHHKpQIDdBKG8ivVM2nuvSNNVp+LUs9rWKZ3R2xyln3NOTLAHCMDUoAi +TCY/YAECgYEA/VGBIoH4kozoY7SJAEjMPNzHFJDmUVo4U/bN/hGAcGM+YXjSnuO0 +ljgSn7h7U2Z09KfIO8GjWtHq6K2gCw23of1D/CRPDIRvrONCVrhWcqebAcFD9Zw0 +4EbPYDnTq4M9gC2RLsCIFVVL4Gsw9+4iSAKhCsHdW/dORUCDyBRbOwECgYEA4SLQ +0Myni4U7v8QVnwREJifjAvE0xrALpdF82CFgyimkCsgzxxBwdc2qhbtiIoNEYx/X +OEPpFU+SuCAQe6xYCsjP3rh1kCN8ETu9NlnpD5o2BUYVgPR7xYYQ3aci/UYWHfZ1 +BGus1PNkL+87d62bvMcpDC9VfAvjGAvOqqUPA2ECgYEAlB4EE9lLLuWVPDdjo/bs +9OlivnO7N/Y42V+GMvio0Q42e2faP22FOhCvUxTbh3hxClzQh6BBk+kKIeLjoZLz +vJQKHHRehEMryTtYnrxKT+AQkoYe5o3fnQPKXclyKuciHsCGE4AgEdk99Iq4pz9m +bBSddVzFwfBoo7WFWIgOkAECgYEAzS1cnx4Uh5vJ4y/CAKTzss5hHlpTHcxtIRa1 +L4fj3PpsLQNd5Mp/o2znPm+StR9qoOfwza9eafSWI0Xdn8hmiJWQlEsJoW4lcNM/ +0pvIQlbpao7/pAGsF0zibA8ZXTeVioMFDB1RatXSdbkSOjS3HSloqFkvEBkJQu3n +0C8TaqECgYAic82i4ZMIJNeFtI2eGw2a89ofO3gpJvGmaJwY8RgZYWtU7+YiI1ts +RBeQG4aFIeOs/3nf0n8pp/xsgreLVZJBXjoWyvw7pDi60N4C07d2gA+hqK3rAvQC +0Be4kdn0Jxx/OSYuGKl1PI0DB1RaCkWZHNay73amkkP+HD/BqcIeLA== +-----END RSA PRIVATE KEY----- +`), } // LocalhostCert is a PEM-encoded TLS cert with SAN IPs diff --git a/lib/services/authority.go b/lib/services/authority.go index 554fd6a76dbfb..380c5204b1aeb 100644 --- a/lib/services/authority.go +++ b/lib/services/authority.go @@ -70,7 +70,7 @@ func ValidateCertAuthority(ca types.CertAuthority) (err error) { switch ca.GetType() { case types.UserCA, types.HostCA: err = checkUserOrHostCA(ca) - case types.DatabaseCA: + case types.DatabaseCA, types.DatabaseClientCA: err = checkDatabaseCA(ca) case types.OpenSSHCA: err = checkOpenSSHCA(ca) @@ -118,7 +118,7 @@ func checkDatabaseCA(cai types.CertAuthority) error { } if len(ca.Spec.ActiveKeys.TLS) == 0 { - return trace.BadParameter("DB certificate authority missing TLS key pairs") + return trace.BadParameter("%s certificate authority missing TLS key pairs", ca.GetType()) } for _, pair := range ca.GetTrustedTLSKeyPairs() { diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 5897be48fd205..f7dd3f876adc9 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -78,7 +78,15 @@ type TestCAConfig struct { func NewTestCAWithConfig(config TestCAConfig) *types.CertAuthorityV2 { // privateKeys is to specify another RSA private key if len(config.PrivateKeys) == 0 { - config.PrivateKeys = [][]byte{fixtures.PEMBytes["rsa"]} + // db client CA gets its own private key to distinguish its pub key + // from the other CAs. Snowflake uses public key to verify JWT signer, + // so if we don't do this then tests verifying that the correct + // signer was used are pointless. + if config.Type == types.DatabaseClientCA { + config.PrivateKeys = [][]byte{fixtures.PEMBytes["rsa-db-client"]} + } else { + config.PrivateKeys = [][]byte{fixtures.PEMBytes["rsa"]} + } } keyBytes := config.PrivateKeys[0] rsaKey, err := ssh.ParseRawPrivateKey(keyBytes) @@ -120,7 +128,7 @@ func NewTestCAWithConfig(config TestCAConfig) *types.CertAuthorityV2 { // Match the key set to lib/auth/auth.go:newKeySet(). switch config.Type { - case types.DatabaseCA: + case types.DatabaseCA, types.DatabaseClientCA, types.SAMLIDPCA: ca.Spec.ActiveKeys.TLS = []*types.TLSKeyPair{{Cert: cert, Key: keyBytes}} case types.KindJWT, types.OIDCIdPCA: // Generating keys is CPU intensive operation. Generate JWT keys only @@ -148,8 +156,6 @@ func NewTestCAWithConfig(config TestCAConfig) *types.CertAuthorityV2 { PrivateKey: keyBytes, }}, } - case types.SAMLIDPCA: - ca.Spec.ActiveKeys.TLS = []*types.TLSKeyPair{{Cert: cert, Key: keyBytes}} default: panic("unknown CA type") } diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index fdb87140a7360..fa6fe72275866 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -823,8 +823,8 @@ func TestCertAuthorityWatcher(t *testing.T) { waitForEvent(t, sub, types.UserCA, "unknown", types.OpPut) // Should NOT receive any HostCA events from another cluster. - // Should NOT receive any DatabaseCA events. require.NoError(t, caService.UpsertCertAuthority(ctx, newCertAuthority(t, "unknown", types.HostCA))) + // Should NOT receive any DatabaseCA events. require.NoError(t, caService.UpsertCertAuthority(ctx, newCertAuthority(t, "test", types.DatabaseCA))) ensureNoEvents(t, sub) }) diff --git a/lib/srv/db/auth_test.go b/lib/srv/db/auth_test.go index 365977b3959ea..109fb3898c06f 100644 --- a/lib/srv/db/auth_test.go +++ b/lib/srv/db/auth_test.go @@ -20,27 +20,19 @@ package db import ( "context" - "crypto/x509" - "crypto/x509/pkix" "testing" - "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/db/common" - "github.com/gravitational/teleport/lib/tlsca" ) // TestAuthTokens verifies that proper IAM auth tokens are used when connecting @@ -334,94 +326,6 @@ func (a *testAuth) GetAWSIAMCreds(ctx context.Context, sessionCtx *common.Sessio return atlasAuthUser, atlasAuthToken, atlasAuthSessionToken, nil } -func TestDBCertSigning(t *testing.T) { - authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ - Clock: clockwork.NewFakeClockAt(time.Now()), - ClusterName: "local.me", - Dir: t.TempDir(), - }) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, authServer.Close()) }) - - ctx := context.Background() - - privateKey, err := testauthority.New().GeneratePrivateKey() - require.NoError(t, err) - - csr, err := tlsca.GenerateCertificateRequestPEM(pkix.Name{ - CommonName: "localhost", - }, privateKey) - require.NoError(t, err) - - // Set rotation to init phase. New CA will be generated. - // DB service should still use old key to sign certificates. - // tctl should use new key to sign certificates. - err = authServer.AuthServer.RotateCertAuthority(ctx, types.RotateRequest{ - Type: types.DatabaseCA, - TargetPhase: types.RotationPhaseInit, - Mode: types.RotationModeManual, - }) - require.NoError(t, err) - - dbCAs, err := authServer.AuthServer.GetCertAuthorities(ctx, types.DatabaseCA, false) - require.NoError(t, err) - require.Len(t, dbCAs, 1) - require.NotNil(t, dbCAs[0].GetActiveKeys().TLS) - require.NotNil(t, dbCAs[0].GetAdditionalTrustedKeys().TLS) - - tests := []struct { - name string - requester proto.DatabaseCertRequest_Requester - getCertFn func(dbCAs []types.CertAuthority) []byte - }{ - { - name: "sign from DB service", - requester: proto.DatabaseCertRequest_UNSPECIFIED, // default behavior - getCertFn: func(dbCAs []types.CertAuthority) []byte { - return dbCAs[0].GetActiveKeys().TLS[0].Cert - }, - }, - { - name: "sign from tctl", - requester: proto.DatabaseCertRequest_TCTL, - getCertFn: func(dbCAs []types.CertAuthority) []byte { - return dbCAs[0].GetAdditionalTrustedKeys().TLS[0].Cert - }, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - certResp, err := authServer.AuthServer.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{ - CSR: csr, - ServerName: "localhost", - TTL: proto.Duration(time.Hour), - RequesterName: tt.requester, - }) - require.NoError(t, err) - require.NotNil(t, certResp.Cert) - require.Len(t, certResp.CACerts, 2) - - dbCert, err := tlsca.ParseCertificatePEM(certResp.Cert) - require.NoError(t, err) - - certPool := x509.NewCertPool() - ok := certPool.AppendCertsFromPEM(tt.getCertFn(dbCAs)) - require.True(t, ok) - - opts := x509.VerifyOptions{ - Roots: certPool, - } - - // Verify if the generated certificate can be verified with the correct CA. - _, err = dbCert.Verify(opts) - require.NoError(t, err) - }) - } -} - func TestMongoDBAtlas(t *testing.T) { t.Parallel() diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index e8015d8aa64c6..815da368cf959 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -903,6 +903,9 @@ type authClientMock struct { // GenerateDatabaseCert generates a cert using fixtures TLS CA. func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + if req.GetRequesterName() != proto.DatabaseCertRequest_UNSPECIFIED { + return nil, trace.BadParameter("db agent should not specify requester name") + } csr, err := tlsca.ParseCertificateRequestPEM(req.CSR) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/db/common/test.go b/lib/srv/db/common/test.go index 1f8e14c555251..1dbee19083c1e 100644 --- a/lib/srv/db/common/test.go +++ b/lib/srv/db/common/test.go @@ -133,9 +133,10 @@ func MakeTestServerTLSConfig(config TestServerConfig) (*tls.Config, error) { } resp, err := config.AuthClient.GenerateDatabaseCert(context.Background(), &proto.DatabaseCertRequest{ - CSR: csr, - ServerName: cn, - TTL: proto.Duration(time.Hour), + CSR: csr, + ServerName: cn, + TTL: proto.Duration(time.Hour), + RequesterName: proto.DatabaseCertRequest_TCTL, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/db/snowflake/test.go b/lib/srv/db/snowflake/test.go index e97f575581ce5..3cc0dba24364d 100644 --- a/lib/srv/db/snowflake/test.go +++ b/lib/srv/db/snowflake/test.go @@ -187,14 +187,14 @@ func (s *TestServer) handleToken(w http.ResponseWriter, r *http.Request) { s.authorizationToken = "sessionToken-123" } -// verifyJWT verifies the provided JWT token. It checks if the token was signed with the Database CA +// verifyJWT verifies the provided JWT token. It checks if the token was signed with the Database Client CA // and asserts token claims. func (s *TestServer) verifyJWT(ctx context.Context, accName, loginName, token string) error { clusterName := "root.example.com" caCert, err := s.cfg.AuthClient.GetCertAuthority(ctx, types.CertAuthID{ DomainName: clusterName, - Type: types.DatabaseCA, + Type: types.DatabaseClientCA, }, false) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/sqlserver/connect_test.go b/lib/srv/db/sqlserver/connect_test.go index 93835e0d7a847..5453c382f04b1 100644 --- a/lib/srv/db/sqlserver/connect_test.go +++ b/lib/srv/db/sqlserver/connect_test.go @@ -29,6 +29,7 @@ import ( "testing" "time" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/client/proto" @@ -224,7 +225,10 @@ func (m *mockAuth) GetClusterName(opts ...services.MarshalOption) (types.Cluster }) } -func (m *mockAuth) GenerateDatabaseCert(context.Context, *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { +func (m *mockAuth) GenerateDatabaseCert(_ context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + if req.GetRequesterName() != proto.DatabaseCertRequest_UNSPECIFIED { + return nil, trace.BadParameter("db agent should not specify requester name") + } return &proto.DatabaseCertResponse{Cert: []byte(mockCA), CACerts: [][]byte{[]byte(mockCA)}}, nil } diff --git a/lib/srv/db/sqlserver/kinit/kinit.go b/lib/srv/db/sqlserver/kinit/kinit.go index d2bcf67113d01..fc9ec2328d5ad 100644 --- a/lib/srv/db/sqlserver/kinit/kinit.go +++ b/lib/srv/db/sqlserver/kinit/kinit.go @@ -213,7 +213,7 @@ func (d *DBCertGetter) GetCertificateBytes(ctx context.Context) (*WindowsCAAndKe } certPEM, keyPEM, caCerts, err := windows.CertKeyPEM(ctx, &windows.GenerateCredentialsRequest{ - CAType: types.DatabaseCA, + CAType: types.DatabaseClientCA, Username: d.UserName, Domain: d.RealmName, TTL: certTTL, diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index a0f42385f13c1..7b3f49a94eb01 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -3803,6 +3803,24 @@ func TestAuthExport(t *testing.T) { expectedStatus: http.StatusOK, assertBody: validateTLSCertificatePEMFunc, }, + { + name: "db-der", + authType: "db-der", + expectedStatus: http.StatusOK, + assertBody: validateTLSCertificateDERFunc, + }, + { + name: "db-client", + authType: "db-client", + expectedStatus: http.StatusOK, + assertBody: validateTLSCertificatePEMFunc, + }, + { + name: "db-client-der", + authType: "db-client-der", + expectedStatus: http.StatusOK, + assertBody: validateTLSCertificateDERFunc, + }, { name: "tls", authType: "tls", diff --git a/lib/web/databases.go b/lib/web/databases.go index 627c448526bea..7e05abc161626 100644 --- a/lib/web/databases.go +++ b/lib/web/databases.go @@ -35,6 +35,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/tlsutils" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/reversetunnelclient" @@ -322,21 +323,12 @@ func (h *Handler) sqlServerConfigureADScriptHandle(w http.ResponseWriter, r *htt return "", trace.NotFound("no proxy servers found") } - clusterName, err := h.GetProxyClient().GetDomainName(r.Context()) + certAuthority, err := getCAForSQLServerConfigureADScript(r.Context(), h.GetProxyClient()) if err != nil { return nil, trace.Wrap(err) } - certAuthority, err := h.GetProxyClient().GetCertAuthority( - r.Context(), - types.CertAuthID{Type: types.DatabaseCA, DomainName: clusterName}, - false, - ) - if err != nil { - return nil, trace.Wrap(err) - } - - caCRL, err := h.GetProxyClient().GenerateCertAuthorityCRL(r.Context(), types.DatabaseCA) + caCRL, err := h.GetProxyClient().GenerateCertAuthorityCRL(r.Context(), types.DatabaseClientCA) if err != nil { return nil, trace.Wrap(err) } @@ -440,3 +432,31 @@ func encodeCRLPEM(contents []byte) []byte { Bytes: contents, }) } + +// getCAForSQLServerConfigureADScript is a helper for sql server configuration +// that fetches the DatabaseClientCA if the auth service supports it or falls back +// to the DatabaseCA if auth service does not support it. +// TODO(gavin): DELETE IN 16.0.0 +func getCAForSQLServerConfigureADScript(ctx context.Context, clusterAPI auth.ClientI) (types.CertAuthority, error) { + domainName, err := clusterAPI.GetDomainName(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + dbClientCA, err := clusterAPI.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseClientCA, + DomainName: domainName, + }, false) + if err == nil { + return dbClientCA, nil + } + if !types.IsUnsupportedAuthorityErr(err) { + return nil, trace.Wrap(err) + } + + // fallback to DatabaseCA if DatabaseClientCA isn't supported by backend. + dbServerCA, err := clusterAPI.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseCA, + DomainName: domainName, + }, false) + return dbServerCA, trace.Wrap(err) +} diff --git a/lib/web/sign.go b/lib/web/sign.go index b401d61ded05f..cada8ec93c7db 100644 --- a/lib/web/sign.go +++ b/lib/web/sign.go @@ -78,7 +78,7 @@ func (h *Handler) signDatabaseCertificate(w http.ResponseWriter, r *http.Request IdentityFileWriter: virtualFS, TTL: req.TTL, } - filesWritten, err := db.GenerateDatabaseCertificates(r.Context(), dbCertReq) + filesWritten, err := db.GenerateDatabaseServerCertificates(r.Context(), dbCertReq) if err != nil { return nil, trace.Wrap(err) } diff --git a/tool/tctl/common/auth_command.go b/tool/tctl/common/auth_command.go index c5abda2cab896..3265d40bd8b67 100644 --- a/tool/tctl/common/auth_command.go +++ b/tool/tctl/common/auth_command.go @@ -24,7 +24,6 @@ import ( "io" "net/url" "os" - "path/filepath" "strings" "text/template" "time" @@ -195,6 +194,8 @@ var allowedCertificateTypes = []string{ "windows", "db", "db-der", + "db-client", + "db-client-der", "openssh", "saml-idp", } @@ -204,6 +205,7 @@ var allowedCertificateTypes = []string{ var allowedCRLCertificateTypes = []string{ string(types.HostCA), string(types.DatabaseCA), + string(types.DatabaseClientCA), string(types.UserCA), } @@ -357,20 +359,11 @@ func (a *AuthCommand) generateSnowflakeKey(ctx context.Context, clusterAPI auth. return trace.Wrap(err) } - cn, err := clusterAPI.GetClusterName() - if err != nil { - return trace.Wrap(err) - } - certAuthID := types.CertAuthID{ - Type: types.DatabaseCA, - DomainName: cn.GetClusterName(), - } - databaseCA, err := clusterAPI.GetCertAuthority(ctx, certAuthID, false) + dbClientCA, err := getDatabaseClientCA(ctx, clusterAPI) if err != nil { return trace.Wrap(err) } - - key.TrustedCerts = []auth.TrustedCerts{{TLSCertificates: services.GetTLSCerts(databaseCA)}} + key.TrustedCerts = []auth.TrustedCerts{{TLSCertificates: services.GetTLSCerts(dbClientCA)}} filesWritten, err := identityfile.Write(ctx, identityfile.WriteConfig{ OutputPath: a.output, @@ -532,7 +525,7 @@ func (a *AuthCommand) generateDatabaseKeysForKey(ctx context.Context, clusterAPI Password: a.password, IdentityFileWriter: a.identityWriter, } - filesWritten, err := db.GenerateDatabaseCertificates(ctx, dbCertReq) + filesWritten, err := db.GenerateDatabaseServerCertificates(ctx, dbCertReq) if err != nil { return trace.Wrap(err) } @@ -569,9 +562,21 @@ func writeHelperMessageDBmTLS(writer io.Writer, filesWritten []string, output st "output": output, "tarOutput": tarOutput, } - if outputFormat == defaults.ProtocolOracle { + switch outputFormat { + case defaults.ProtocolCockroachDB: + tplVars["clientCAPath"] = "/path/to/client-ca.key" + case defaults.ProtocolOracle: tplVars["manualOrapkiFlow"] = len(filesWritten) != 1 - tplVars["walletDir"] = filepath.Dir(output) + // use a generic example path since they will have to copy the files + // to the oracle server. + tplVars["walletDir"] = "/path/to/oracleWalletDir" + var caCertPaths []string + for _, f := range filesWritten { + if strings.HasSuffix(f, ".crt") { + caCertPaths = append(caCertPaths, f) + } + } + tplVars["caCertPaths"] = caCertPaths } return trace.Wrap(tpl.Execute(writer, tplVars)) @@ -624,12 +629,31 @@ net: `)) cockroachAuthSignTpl = template.Must(template.New("").Parse(`Database credentials have been written to {{.files}}. -To enable mutual TLS on your CockroachDB server, point it to the certs -directory using --certs-dir flag: +To enable mutual TLS on your CockroachDB server, generate a client CA and client +certs for your node: + +# --overwrite flag prepends the client CA cert to {{.output}}/ca-client.crt +cockroach cert create-client-ca \ + --certs-dir={{.output}} \ + --ca-key={{.clientCAPath}} \ + --overwrite + +cockroach cert create-client node \ + --certs-dir={{.output}} + --ca-key={{.clientCAPath}} + +Then point cockroach to the certs directory using the --certs-dir flag: cockroach start \ --certs-dir={{.output}} \ # other flags... + +For more information about creating a client CA and issuing certs, see: +https://www.cockroachlabs.com/docs/stable/cockroach-cert + +Teleport uses a split CA architecture for database access. +For more information about using a split CA with CockroachDB, see: +https://www.cockroachlabs.com/docs/stable/authentication#using-split-ca-certificates `)) redisAuthSignTpl = template.Must(template.New("").Parse( @@ -647,6 +671,9 @@ tls-ca-cert-file /path/to/{{.output}}.cas tls-cert-file /path/to/{{.output}}.crt tls-key-file /path/to/{{.output}}.key tls-protocols "TLSv1.2 TLSv1.3" + +For information on enabling Redis Cluster bus communication TLS, see: +https://goteleport.com/docs/database-access/guides/redis-cluster `)) snowflakeAuthSignTpl = template.Must(template.New("").Parse(`Database credentials have been written to {{.files}}. @@ -717,9 +744,12 @@ $ tctl auth sign ${FLAGS} | tar -x {{- if .manualOrapkiFlow}} Orapki binary was not found. Please create oracle wallet file manually by running the following commands on the Oracle server: -orapki wallet create -wallet {{.walletDir}} -auto_login_only -orapki wallet import_pkcs12 -wallet {{.walletDir}} -auto_login_only -pkcs12file {{.output}}.p12 -pkcs12pwd {{.password}} -orapki wallet add -wallet {{.walletDir}} -trusted_cert -auto_login_only -cert {{.output}}.crt +WALLET_DIR="{{.walletDir}}" +orapki wallet create -wallet "$WALLET_DIR" -auto_login_only +orapki wallet import_pkcs12 -wallet "$WALLET_DIR" -auto_login_only -pkcs12file {{.output}}.p12 -pkcs12pwd {{.password}} +{{- range $certPath := .caCertPaths }} +orapki wallet add -wallet "$WALLET_DIR" -trusted_cert -auto_login_only -cert {{ $certPath }} +{{- end}} If copying these files to your Oracle server, ensure the cert file permissions are readable by the "oracle" user. {{else}} @@ -728,7 +758,7 @@ Oracle wallet stored in {{.output}} directory created with Oracle Orapki. {{end}} To enable mutual TLS on your Oracle server, add the following settings to Oracle sqlnet.ora configuration file: -WALLET_LOCATION = (SOURCE = (METHOD = FILE)(METHOD_DATA = (DIRECTORY = /path/to/oracleWalletDir))) +WALLET_LOCATION = (SOURCE = (METHOD = FILE)(METHOD_DATA = (DIRECTORY = {{.walletDir}}))) SSL_CLIENT_AUTHENTICATION = TRUE SQLNET.AUTHENTICATION_SERVICES = (TCPS) @@ -742,7 +772,7 @@ LISTENER = ) ) -WALLET_LOCATION = (SOURCE = (METHOD = FILE)(METHOD_DATA = (DIRECTORY = /path/to/oracleWalletDir))) +WALLET_LOCATION = (SOURCE = (METHOD = FILE)(METHOD_DATA = (DIRECTORY = {{.walletDir}}))) SSL_CLIENT_AUTHENTICATION = TRUE `)) @@ -1147,3 +1177,28 @@ func (a *AuthCommand) helperMsgDst() io.Writer { } return os.Stdout } + +// TODO(gavin): DELETE IN 16.0.0 +func getDatabaseClientCA(ctx context.Context, clusterAPI auth.ClientI) (types.CertAuthority, error) { + cn, err := clusterAPI.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + dbClientCA, err := clusterAPI.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseClientCA, + DomainName: cn.GetClusterName(), + }, false) + if err == nil { + return dbClientCA, nil + } + if !types.IsUnsupportedAuthorityErr(err) { + return nil, trace.Wrap(err) + } + + // fallback to DatabaseCA if DatabaseClientCA isn't supported by backend. + dbServerCA, err := clusterAPI.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseCA, + DomainName: cn.GetClusterName(), + }, false) + return dbServerCA, trace.Wrap(err) +} diff --git a/tool/tctl/common/auth_command_test.go b/tool/tctl/common/auth_command_test.go index ce9a18f44b90a..bb4b42858ac2c 100644 --- a/tool/tctl/common/auth_command_test.go +++ b/tool/tctl/common/auth_command_test.go @@ -396,6 +396,8 @@ type mockClient struct { appSession types.WebSession networkConfig types.ClusterNetworkingConfig crl []byte + + unsupportedCATypes []types.CertAuthType } func (c *mockClient) GetClusterName(...services.MarshalOption) (types.ClusterName, error) { @@ -415,15 +417,25 @@ func (c *mockClient) GenerateUserCerts(ctx context.Context, userCertsReq proto.U } func (c *mockClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadSigningKeys bool) (types.CertAuthority, error) { + for _, unsupported := range c.unsupportedCATypes { + if unsupported == id.Type { + return nil, trace.BadParameter("%q authority type is not supported", unsupported) + } + } for _, v := range c.cas { if v.GetType() == id.Type && v.GetClusterName() == id.DomainName { return v, nil } } - return nil, trace.NotFound("not found") + return nil, trace.NotFound("%q CA not found", id) } -func (c *mockClient) GetCertAuthorities(context.Context, types.CertAuthType, bool) ([]types.CertAuthority, error) { +func (c *mockClient) GetCertAuthorities(_ context.Context, caType types.CertAuthType, _ bool) ([]types.CertAuthority, error) { + for _, unsupported := range c.unsupportedCATypes { + if unsupported == caType { + return nil, trace.BadParameter("%q authority type is not supported", unsupported) + } + } return c.cas, nil } @@ -440,6 +452,9 @@ func (c *mockClient) GetKubernetesServers(context.Context) ([]types.KubeServer, } func (c *mockClient) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + if req.GetRequesterName() != proto.DatabaseCertRequest_TCTL { + return nil, trace.BadParameter("need tctl requester name in tctl database cert request") + } c.dbCertsReq = req return c.dbCerts, nil } @@ -572,14 +587,24 @@ func TestGenerateDatabaseKeys(t *testing.T) { require.NoError(t, err) certBytes := []byte("TLS cert") - caBytes := []byte("CA cert") + dbClientCABytes := []byte("DB Client CA cert") + dbServerCABytes := []byte("DB Server CA cert") + dbCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseCA, + ClusterName: "example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: dbServerCABytes}}, + }, + }) + require.NoError(t, err) authClient := &mockClient{ clusterName: clusterName, dbCerts: &proto.DatabaseCertResponse{ Cert: certBytes, - CACerts: [][]byte{caBytes}, + CACerts: [][]byte{dbClientCABytes}, }, + cas: []types.CertAuthority{dbCA}, } key, err := client.GenerateRSAKey() @@ -593,13 +618,9 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile string outSubject pkix.Name outServerNames []string - outKeyFile string - outCertFile string - outCAFile string - outKey []byte - outCert []byte - outCA []byte - genKeyErrMsg string + // maps filename -> file contents + wantFiles map[string][]byte + genKeyErrMsg string }{ { name: "database certificate", @@ -609,12 +630,11 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile: "db", outSubject: pkix.Name{CommonName: "postgres.example.com"}, outServerNames: []string{"postgres.example.com"}, - outKeyFile: "db.key", - outCertFile: "db.crt", - outCAFile: "db.cas", - outKey: key.PrivateKeyPEM(), - outCert: certBytes, - outCA: caBytes, + wantFiles: map[string][]byte{ + "db.key": key.PrivateKeyPEM(), + "db.crt": certBytes, + "db.cas": dbClientCABytes, + }, }, { name: "database certificate multiple SANs", @@ -624,12 +644,11 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile: "db", outSubject: pkix.Name{CommonName: "mysql.external.net"}, outServerNames: []string{"mysql.external.net", "mysql.internal.net", "192.168.1.1"}, - outKeyFile: "db.key", - outCertFile: "db.crt", - outCAFile: "db.cas", - outKey: key.PrivateKeyPEM(), - outCert: certBytes, - outCA: caBytes, + wantFiles: map[string][]byte{ + "db.key": key.PrivateKeyPEM(), + "db.crt": certBytes, + "db.cas": dbClientCABytes, + }, }, { name: "mongodb certificate", @@ -639,10 +658,10 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile: "mongo", outSubject: pkix.Name{CommonName: "mongo.example.com", Organization: []string{"example.com"}}, outServerNames: []string{"mongo.example.com"}, - outCertFile: "mongo.crt", - outCAFile: "mongo.cas", - outCert: append(certBytes, key.PrivateKeyPEM()...), - outCA: caBytes, + wantFiles: map[string][]byte{ + "mongo.crt": append(certBytes, key.PrivateKeyPEM()...), + "mongo.cas": dbClientCABytes, + }, }, { name: "cockroachdb certificate", @@ -651,12 +670,12 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutDir: t.TempDir(), outSubject: pkix.Name{CommonName: "node"}, outServerNames: []string{"node", "localhost", "roach1"}, // "node" principal should always be added - outKeyFile: "node.key", - outCertFile: "node.crt", - outCAFile: "ca.crt", - outKey: key.PrivateKeyPEM(), - outCert: certBytes, - outCA: caBytes, + wantFiles: map[string][]byte{ + "node.key": key.PrivateKeyPEM(), + "node.crt": certBytes, + "ca.crt": dbServerCABytes, + "ca-client.crt": dbClientCABytes, + }, }, { name: "redis certificate", @@ -666,12 +685,11 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile: "db", outSubject: pkix.Name{CommonName: "localhost"}, outServerNames: []string{"localhost", "redis1", "172.0.0.1"}, - outKeyFile: "db.key", - outCertFile: "db.crt", - outCAFile: "db.cas", - outKey: key.PrivateKeyPEM(), - outCert: certBytes, - outCA: caBytes, + wantFiles: map[string][]byte{ + "db.key": key.PrivateKeyPEM(), + "db.crt": certBytes, + "db.cas": dbClientCABytes, + }, }, { name: "missing host", @@ -709,22 +727,10 @@ func TestGenerateDatabaseKeys(t *testing.T) { require.Equal(t, test.outServerNames, authClient.dbCertsReq.ServerNames) require.Equal(t, test.outServerNames[0], authClient.dbCertsReq.ServerName) - if len(test.outKey) > 0 { - keyBytes, err := os.ReadFile(filepath.Join(test.inOutDir, test.outKeyFile)) + for wantFilename, wantContents := range test.wantFiles { + contents, err := os.ReadFile(filepath.Join(test.inOutDir, wantFilename)) require.NoError(t, err) - require.Equal(t, test.outKey, keyBytes, "keys match") - } - - if len(test.outCert) > 0 { - certBytes, err := os.ReadFile(filepath.Join(test.inOutDir, test.outCertFile)) - require.NoError(t, err) - require.Equal(t, test.outCert, certBytes, "certificates match") - } - - if len(test.outCA) > 0 { - caBytes, err := os.ReadFile(filepath.Join(test.inOutDir, test.outCAFile)) - require.NoError(t, err) - require.Equal(t, test.outCA, caBytes, "CA certificates match") + require.Equal(t, wantContents, contents, "contents of %s match", wantFilename) } }) } @@ -972,49 +978,83 @@ func TestGenerateAndSignKeys(t *testing.T) { _, cert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: "example.com"}, nil, time.Minute) require.NoError(t, err) - firstCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + dbCARoot, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ Type: types.DatabaseCA, ClusterName: "example.com", ActiveKeys: types.CAKeySet{ - SSH: []*types.SSHKeyPair{{PublicKey: []byte("SSH CA cert")}}, TLS: []*types.TLSKeyPair{{Cert: cert}}, }, }) require.NoError(t, err) - secondCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + dbCALeaf, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ Type: types.DatabaseCA, ClusterName: "leaf.example.com", ActiveKeys: types.CAKeySet{ - SSH: []*types.SSHKeyPair{{PublicKey: []byte("SSH CA cert")}}, TLS: []*types.TLSKeyPair{{Cert: cert}}, }, }) require.NoError(t, err) - certBytes := []byte("TLS cert") - caBytes := []byte("CA cert") - authClient := &mockClient{ - clusterName: clusterName, - dbCerts: &proto.DatabaseCertResponse{ - Cert: certBytes, - CACerts: [][]byte{caBytes}, + dbClientCARoot, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseClientCA, + ClusterName: "example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: cert}}, }, - cas: []types.CertAuthority{firstCA, secondCA}, - } + }) + require.NoError(t, err) + + dbClientCALeaf, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseClientCA, + ClusterName: "leaf.example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: cert}}, + }, + }) + require.NoError(t, err) + + allCAs := []types.CertAuthority{dbCARoot, dbCALeaf, dbClientCARoot, dbClientCALeaf} + + certBytes := []byte("TLS cert") + caBytes := []byte("CA cert") tests := []struct { - name string - inFormat identityfile.Format - inHost string - inOutDir string - inOutFile string + name string + inFormat identityfile.Format + inHost string + inOutDir string + inOutFile string + authClient *mockClient }{ { name: "snowflake format", inFormat: identityfile.FormatSnowflake, inOutDir: t.TempDir(), - inOutFile: "server", + inOutFile: "ca", + authClient: &mockClient{ + clusterName: clusterName, + dbCerts: &proto.DatabaseCertResponse{ + Cert: certBytes, + CACerts: [][]byte{caBytes}, + }, + cas: allCAs, + }, + }, + { + name: "snowflake format db client ca not supported upstream", + inFormat: identityfile.FormatSnowflake, + inOutDir: t.TempDir(), + inOutFile: "ca", + authClient: &mockClient{ + clusterName: clusterName, + dbCerts: &proto.DatabaseCertResponse{ + Cert: certBytes, + CACerts: [][]byte{caBytes}, + }, + cas: []types.CertAuthority{dbCARoot, dbCALeaf}, + unsupportedCATypes: []types.CertAuthType{types.DatabaseClientCA}, + }, }, { name: "db format", @@ -1022,6 +1062,14 @@ func TestGenerateAndSignKeys(t *testing.T) { inOutDir: t.TempDir(), inOutFile: "server", inHost: "localhost", + authClient: &mockClient{ + clusterName: clusterName, + dbCerts: &proto.DatabaseCertResponse{ + Cert: certBytes, + CACerts: [][]byte{caBytes}, + }, + cas: allCAs, + }, }, } @@ -1035,7 +1083,7 @@ func TestGenerateAndSignKeys(t *testing.T) { genTTL: time.Hour, } - err = ac.GenerateAndSignKeys(context.Background(), authClient) + err = ac.GenerateAndSignKeys(context.Background(), test.authClient) require.NoError(t, err) }) } @@ -1058,3 +1106,62 @@ func TestGenerateCRLForCA(t *testing.T) { require.Error(t, ac.GenerateCRLForCA(ctx, authClient)) }) } + +func TestGetDatabaseClientCA(t *testing.T) { + _, cert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: "example.com"}, nil, time.Minute) + require.NoError(t, err) + + dbClientCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseClientCA, + ClusterName: "example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: cert}}, + }, + }) + require.NoError(t, err) + + dbServerCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseCA, + ClusterName: "example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: cert}}, + }, + }) + require.NoError(t, err) + + clusterName, err := services.NewClusterNameWithRandomID( + types.ClusterNameSpecV2{ + ClusterName: "example.com", + }) + require.NoError(t, err) + tests := []struct { + desc string + authClient *mockClient + wantCA types.CertAuthority + }{ + { + desc: "db client ca exists", + authClient: &mockClient{ + clusterName: clusterName, + cas: []types.CertAuthority{dbClientCA, dbServerCA}, + }, + wantCA: dbClientCA, + }, + { + desc: "db client ca not supported", + authClient: &mockClient{ + clusterName: clusterName, + unsupportedCATypes: []types.CertAuthType{types.DatabaseClientCA}, + cas: []types.CertAuthority{dbServerCA}, + }, + wantCA: dbServerCA, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ca, err := getDatabaseClientCA(context.Background(), test.authClient) + require.NoError(t, err) + require.Equal(t, test.wantCA, ca) + }) + } +}