diff --git a/api/client/client.go b/api/client/client.go index a938d14c42810..b48d821b07095 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -1597,8 +1597,9 @@ func (c *Client) SignDatabaseCSR(ctx context.Context, req *proto.DatabaseCSRRequ return resp, nil } -// GenerateDatabaseCert generates client certificate used by a database -// service to authenticate with the database instance. +// GenerateDatabaseCert generates a client certificate used by a database +// service to authenticate with the database instance, or a server certificate +// for configuring a self-hosted database, depending on the requester_name. func (c *Client) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { resp, err := c.grpc.GenerateDatabaseCert(ctx, req) if err != nil { diff --git a/api/types/constants.go b/api/types/constants.go index fd100e15e683b..f3cab58eb6b73 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -287,7 +287,8 @@ const ( // KindConnectionDiagnostic is a resource that tracks the result of testing a connection KindConnectionDiagnostic = "connection_diagnostic" - // KindDatabaseCertificate is a resource to control Database Certificates generation + // KindDatabaseCertificate is a resource to control db CA cert + // generation. KindDatabaseCertificate = "database_certificate" // KindInstaller is a resource that holds a node installer script diff --git a/api/types/trust.go b/api/types/trust.go index 7b971354f5f66..51ae216a485cf 100644 --- a/api/types/trust.go +++ b/api/types/trust.go @@ -31,8 +31,12 @@ const ( HostCA CertAuthType = "host" // UserCA identifies the key as a user certificate authority UserCA CertAuthType = "user" - // DatabaseCA is a certificate authority used in database access. + // DatabaseCA is a certificate authority used as a server CA in database + // access. DatabaseCA CertAuthType = "db" + // DatabaseClientCA is a certificate authority used as a client CA in + // database access. + DatabaseClientCA CertAuthType = "db_client" // OpenSSHCA is a certificate authority used when connecting to agentless nodes. OpenSSHCA CertAuthType = "openssh" // JWTSigner identifies type of certificate authority as JWT signer. In this @@ -54,7 +58,16 @@ const ( ) // CertAuthTypes lists all certificate authority types. -var CertAuthTypes = []CertAuthType{HostCA, UserCA, DatabaseCA, OpenSSHCA, JWTSigner, SAMLIDPCA, OIDCIdPCA} +var CertAuthTypes = []CertAuthType{HostCA, UserCA, DatabaseCA, DatabaseClientCA, OpenSSHCA, JWTSigner, SAMLIDPCA, OIDCIdPCA} + +// IsUnsupportedAuthorityErr returns whether an error is due to an unsupported +// CertAuthType. +func IsUnsupportedAuthorityErr(err error) bool { + return err != nil && trace.IsBadParameter(err) && + strings.Contains(err.Error(), authTypeNotSupported) +} + +const authTypeNotSupported string = "authority type is not supported" // Check checks if certificate authority type value is correct func (c CertAuthType) Check() error { @@ -64,7 +77,7 @@ func (c CertAuthType) Check() error { } } - return trace.BadParameter("%q authority type is not supported", c) + return trace.BadParameter("%q %s", c, authTypeNotSupported) } // CertAuthID - id of certificate authority (it's type and domain name) diff --git a/integration/helpers/instance.go b/integration/helpers/instance.go index b71def47c0d8a..4fcbcc9ef6d2d 100644 --- a/integration/helpers/instance.go +++ b/integration/helpers/instance.go @@ -188,6 +188,21 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) { return nil, trace.Wrap(err) } + dbClientCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseClientCA, + ClusterName: s.SiteName, + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{ + Key: s.PrivKey, + KeyType: types.PrivateKeyType_RAW, + Cert: s.TLSCACert, + }}, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + osshCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ Type: types.OpenSSHCA, ClusterName: s.SiteName, @@ -203,7 +218,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) { return nil, trace.Wrap(err) } - return []types.CertAuthority{hostCA, userCA, dbCA, osshCA}, nil + return []types.CertAuthority{hostCA, userCA, dbCA, dbClientCA, osshCA}, nil } func (s *InstanceSecrets) AllowedLogins() []string { diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 95fc16c5d7d21..9fa46f2705c9d 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -4876,7 +4876,7 @@ func newKeySet(ctx context.Context, keyStore *keystore.Manager, caID types.CertA } keySet.SSH = append(keySet.SSH, sshKeyPair) keySet.TLS = append(keySet.TLS, tlsKeyPair) - case types.DatabaseCA: + case types.DatabaseCA, types.DatabaseClientCA: // Database CA only contains TLS cert. tlsKeyPair, err := keyStore.NewTLSKeyPair(ctx, caID.DomainName) if err != nil { diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 41736eee676e4..db4410fed4b70 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -4318,26 +4318,40 @@ func (a *ServerWithRoles) SignDatabaseCSR(ctx context.Context, req *proto.Databa return a.authServer.SignDatabaseCSR(ctx, req) } -// GenerateDatabaseCert generates a certificate used by a database service -// to authenticate with the database instance. +// GenerateDatabaseCert generates a client certificate used by a database +// service to authenticate with the database instance, or a server certificate +// for configuring a self-hosted database, depending on the requester_name. // // This certificate can be requested by: // // - Cluster administrator using "tctl auth sign --format=db" command locally // on the auth server to produce a certificate for configuring a self-hosted // database. -// - Remote user using "tctl auth sign --format=db" command with a remote -// proxy (e.g. Teleport Cloud), as long as they can impersonate system -// role Db. +// - Remote user using "tctl auth sign --format=db" command or +// /webapi/sites/:site/sign/db with a remote proxy (e.g. Teleport Cloud), +// as long as they can impersonate system role Db. // - Database service when initiating connection to a database instance to // produce a client certificate. -// - Proxy service when generating mTLS files to a database func (a *ServerWithRoles) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { - // Check if the User can `create` DatabaseCertificates - err := a.action(apidefaults.Namespace, types.KindDatabaseCertificate, types.VerbCreate) + err := a.checkAccessToGenerateDatabaseCert(types.KindDatabaseCertificate) + if err != nil { + return nil, trace.Wrap(err) + } + return a.authServer.GenerateDatabaseCert(ctx, req) +} + +// checkAccessToGenerateDatabaseCert is a helper for checking db cert gen authz. +// Requester must have at least one of: +// - create: database_certificate or database_client_certificate. +// - built-in Admin or DB role. +// - allowed to impersonate the built-in DB role. +func (a *ServerWithRoles) checkAccessToGenerateDatabaseCert(resourceKind string) error { + const verb = types.VerbCreate + // Check if the User can `create` Database Certificates + err := a.action(apidefaults.Namespace, resourceKind, verb) if err != nil { if !trace.IsAccessDenied(err) { - return nil, trace.Wrap(err) + return trace.Wrap(err) } // Err is access denied, trying the old way @@ -4347,12 +4361,13 @@ func (a *ServerWithRoles) GenerateDatabaseCert(ctx context.Context, req *proto.D if !a.hasBuiltinRole(types.RoleDatabase, types.RoleAdmin) { if err := a.canImpersonateBuiltinRole(types.RoleDatabase); err != nil { log.WithError(err).Warnf("User %v tried to generate database certificate but does not have '%s' permission for '%s' kind, nor is allowed to impersonate %q system role", - a.context.User.GetName(), types.VerbCreate, types.KindDatabaseCertificate, types.RoleDatabase) - return nil, trace.AccessDenied(fmt.Sprintf("access denied. User must have '%s' permission for '%s' kind to generate the certificate ", types.VerbCreate, types.KindDatabaseCertificate)) + a.context.User.GetName(), verb, resourceKind, types.RoleDatabase) + return trace.AccessDenied("access denied. User must have '%s' permission for '%s' kind to generate the certificate ", + verb, resourceKind) } } } - return a.authServer.GenerateDatabaseCert(ctx, req) + return nil } // GenerateSnowflakeJWT generates JWT in the Snowflake required format. diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 93e67441d51c6..f1dbc20cdc49f 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -937,7 +937,7 @@ func TestRoleRequestDenyReimpersonation(t *testing.T) { } // TestGenerateDatabaseCert makes sure users and services with appropriate -// permissions can generate certificates for self-hosted databases. +// permissions can generate database certificates. func TestGenerateDatabaseCert(t *testing.T) { t.Parallel() ctx := context.Background() @@ -957,22 +957,26 @@ func TestGenerateDatabaseCert(t *testing.T) { require.NoError(t, srv.Auth().UpsertRole(ctx, roleDb)) tests := []struct { - desc string - identity TestIdentity - err string + desc string + identity TestIdentity + requester proto.DatabaseCertRequest_Requester + err string }{ { - desc: "user can't sign database certs", - identity: TestUser(userWithoutAccess.GetName()), - err: "access denied", + desc: "user can't sign database certs", + identity: TestUser(userWithoutAccess.GetName()), + requester: proto.DatabaseCertRequest_TCTL, + err: "access denied", }, { - desc: "user can impersonate Db and sign database certs", - identity: TestUser(userImpersonateDb.GetName()), + desc: "user can impersonate Db and sign database certs", + identity: TestUser(userImpersonateDb.GetName()), + requester: proto.DatabaseCertRequest_TCTL, }, { - desc: "built-in admin can sign database certs", - identity: TestAdmin(), + desc: "built-in admin can sign database certs", + identity: TestAdmin(), + requester: proto.DatabaseCertRequest_TCTL, }, { desc: "database service can sign database certs", @@ -992,7 +996,7 @@ func TestGenerateDatabaseCert(t *testing.T) { client, err := srv.NewClient(test.identity) require.NoError(t, err) - _, err = client.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{CSR: csr}) + _, err = client.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{CSR: csr, RequesterName: test.requester}) if test.err != "" { require.ErrorContains(t, err, test.err) } else { diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 5392b89eedc98..84617d4f34832 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -717,8 +717,9 @@ type ClientI interface { // sessions created by the SAML identity provider. CreateSAMLIdPSession(context.Context, types.CreateSAMLIdPSessionRequest) (types.WebSession, error) - // GenerateDatabaseCert generates client certificate used by a database - // service to authenticate with the database instance. + // GenerateDatabaseCert generates a client certificate used by a database + // service to authenticate with the database instance, or a server certificate + // for configuring a self-hosted database, depending on the requester_name. GenerateDatabaseCert(context.Context, *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) // GetWebSession queries the existing web session described with req. diff --git a/lib/auth/db.go b/lib/auth/db.go index 4fc1cc935877a..6a5b650e7a9ea 100644 --- a/lib/auth/db.go +++ b/lib/auth/db.go @@ -46,32 +46,89 @@ import ( // GenerateDatabaseCert generates client certificate used by a database // service to authenticate with the database instance. func (a *Server) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { - csr, err := tlsca.ParseCertificateRequestPEM(req.CSR) + if req.RequesterName == proto.DatabaseCertRequest_TCTL { + // tctl/web cert request needs to generate a db server cert and trust + // the db client CA. + return a.generateDatabaseServerCert(ctx, req) + } + // db service needs to generate a db client cert and trust the db server CA. + return a.generateDatabaseClientCert(ctx, req) +} + +// generateDatabaseServerCert generates database server certificate used by a +// database to authenticate itself to a database service. +func (a *Server) generateDatabaseServerCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + clusterName, err := a.GetClusterName() if err != nil { return nil, trace.Wrap(err) } - clusterName, err := a.GetClusterName() + // databases should be configured to trust the DatabaseClientCA when + // clients connect so return DatabaseClientCA in the response. + dbClientCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseClientCA, + DomainName: clusterName.GetClusterName(), + }, false) if err != nil { return nil, trace.Wrap(err) } - databaseCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ + dbServerCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ Type: types.DatabaseCA, DomainName: clusterName.GetClusterName(), }, true) if err != nil { - if trace.IsNotFound(err) { - // Database CA doesn't exist. Fallback to Host CA. - // https://github.com/gravitational/teleport/issues/5029 - databaseCA, err = a.GetCertAuthority(ctx, types.CertAuthID{ - Type: types.HostCA, - DomainName: clusterName.GetClusterName(), - }, true) - } - if err != nil { - return nil, trace.Wrap(err) - } + return nil, trace.Wrap(err) + } + + cert, err := a.generateDatabaseCert(ctx, req, dbServerCA) + if err != nil { + return nil, trace.Wrap(err) + } + return &proto.DatabaseCertResponse{ + Cert: cert, + CACerts: services.GetTLSCerts(dbClientCA), + }, nil +} + +// generateDatabaseClientCert generates client certificate used by a database +// service to authenticate with the database instance. +func (a *Server) generateDatabaseClientCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + clusterName, err := a.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + dbClientCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseClientCA, + DomainName: clusterName.GetClusterName(), + }, true) + if err != nil { + return nil, trace.Wrap(err) + } + + cert, err := a.generateDatabaseCert(ctx, req, dbClientCA) + if err != nil { + return nil, trace.Wrap(err) + } + // db clients should trust the Database Server CA when establishing + // connection to a database, so return that CA's certs in the response. + dbServerCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseCA, + DomainName: clusterName.GetClusterName(), + }, false) + if err != nil { + return nil, trace.Wrap(err) + } + return &proto.DatabaseCertResponse{ + Cert: cert, + CACerts: services.GetTLSCerts(dbServerCA), + }, nil +} + +func (a *Server) generateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest, ca types.CertAuthority) ([]byte, error) { + csr, err := tlsca.ParseCertificateRequestPEM(req.CSR) + if err != nil { + return nil, trace.Wrap(err) } - caCert, signer, err := getCAandSigner(ctx, a.GetKeyStore(), databaseCA, req) + caCert, signer, err := getCAandSigner(ctx, a.GetKeyStore(), ca, req) if err != nil { return nil, trace.Wrap(err) } @@ -98,15 +155,21 @@ func (a *Server) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCe // has been deprecated since Go 1.15: // https://golang.org/doc/go1.15#commonname certReq.DNSNames = getServerNames(req) + + // The windows smartcard cert req already does the same in + // lib/auth/windows/windows.go, along with another ExtKeyUsage for + // smartcard logon that we don't want to override above. + switch ca.GetType() { + case types.DatabaseCA: + // Override ExtKeyUsage to ExtKeyUsageServerAuth. + certReq.ExtraExtensions = append(certReq.ExtraExtensions, extKeyUsageServerAuthExtension) + case types.DatabaseClientCA: + // Override ExtKeyUsage to ExtKeyUsageClientAuth. + certReq.ExtraExtensions = append(certReq.ExtraExtensions, extKeyUsageClientAuthExtension) + } } cert, err := tlsCA.GenerateCertificate(certReq) - if err != nil { - return nil, trace.Wrap(err) - } - return &proto.DatabaseCertResponse{ - Cert: cert, - CACerts: services.GetTLSCerts(databaseCA), - }, nil + return cert, trace.Wrap(err) } // getCAandSigner returns correct signer and CA that should be used when generating database certificate. @@ -116,6 +179,7 @@ func (a *Server) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCe func getCAandSigner(ctx context.Context, keyStore *keystore.Manager, databaseCA types.CertAuthority, req *proto.DatabaseCertRequest, ) ([]byte, crypto.Signer, error) { if req.RequesterName == proto.DatabaseCertRequest_TCTL && + databaseCA.GetType() == types.DatabaseCA && databaseCA.GetRotation().Phase == types.RotationPhaseInit { return keyStore.GetAdditionalTrustedTLSCertAndSigner(ctx, databaseCA) } @@ -242,11 +306,21 @@ func (a *Server) GenerateSnowflakeJWT(ctx context.Context, req *proto.SnowflakeJ return nil, trace.Wrap(err) } ca, err := a.GetCertAuthority(ctx, types.CertAuthID{ - Type: types.DatabaseCA, + Type: types.DatabaseClientCA, DomainName: clusterName.GetClusterName(), }, true) if err != nil { - return nil, trace.Wrap(err) + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + // DatabaseClientCA doesn't exist, fallback to DatabaseCA. + ca, err = a.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseCA, + DomainName: clusterName.GetClusterName(), + }, true) + if err != nil { + return nil, trace.Wrap(err) + } } if len(ca.GetActiveKeys().TLS) == 0 { @@ -330,7 +404,31 @@ func filterExtensions(extensions []pkix.Extension, oids ...asn1.ObjectIdentifier return filtered } +// TODO(gavin): move OIDs from here and in lib/auth/windows to tlsca package. var ( oidExtKeyUsage = asn1.ObjectIdentifier{2, 5, 29, 37} oidSubjectAltName = asn1.ObjectIdentifier{2, 5, 29, 17} + + oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1} + oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2} + extKeyUsageServerAuthExtension = pkix.Extension{ + Id: oidExtKeyUsage, + Value: func() []byte { + val, err := asn1.Marshal([]asn1.ObjectIdentifier{oidExtKeyUsageServerAuth}) + if err != nil { + panic(err) + } + return val + }(), + } + extKeyUsageClientAuthExtension = pkix.Extension{ + Id: oidExtKeyUsage, + Value: func() []byte { + val, err := asn1.Marshal([]asn1.ObjectIdentifier{oidExtKeyUsageClientAuth}) + if err != nil { + panic(err) + } + return val + }(), + } ) diff --git a/lib/auth/db_test.go b/lib/auth/db_test.go index ba1ac1158716a..31d24cff8501a 100644 --- a/lib/auth/db_test.go +++ b/lib/auth/db_test.go @@ -20,9 +20,19 @@ package auth import ( + "context" + "crypto/x509" + "crypto/x509/pkix" "testing" + "time" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/tlsca" ) func Test_getSnowflakeJWTParams(t *testing.T) { @@ -88,3 +98,113 @@ func Test_getSnowflakeJWTParams(t *testing.T) { }) } } + +func TestDBCertSigning(t *testing.T) { + t.Parallel() + authServer, err := NewTestAuthServer(TestAuthServerConfig{ + Clock: clockwork.NewFakeClockAt(time.Now()), + ClusterName: "local.me", + Dir: t.TempDir(), + }) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, authServer.Close()) }) + + ctx := context.Background() + + privateKey, err := testauthority.New().GeneratePrivateKey() + require.NoError(t, err) + + csr, err := tlsca.GenerateCertificateRequestPEM(pkix.Name{ + CommonName: "localhost", + }, privateKey) + require.NoError(t, err) + + // Set rotation to init phase. New CA will be generated. + // DB service should use active key to sign certificates. + // tctl should use new key to sign certificates. + err = authServer.AuthServer.RotateCertAuthority(ctx, RotateRequest{ + Type: types.DatabaseCA, + TargetPhase: types.RotationPhaseInit, + Mode: types.RotationModeManual, + }) + require.NoError(t, err) + err = authServer.AuthServer.RotateCertAuthority(ctx, RotateRequest{ + Type: types.DatabaseClientCA, + TargetPhase: types.RotationPhaseInit, + Mode: types.RotationModeManual, + }) + require.NoError(t, err) + + dbCAs, err := authServer.AuthServer.GetCertAuthorities(ctx, types.DatabaseCA, false) + require.NoError(t, err) + require.Len(t, dbCAs, 1) + require.Len(t, dbCAs[0].GetActiveKeys().TLS, 1) + require.Len(t, dbCAs[0].GetAdditionalTrustedKeys().TLS, 1) + activeDBCACert := dbCAs[0].GetActiveKeys().TLS[0].Cert + newDBCACert := dbCAs[0].GetAdditionalTrustedKeys().TLS[0].Cert + + dbClientCAs, err := authServer.AuthServer.GetCertAuthorities(ctx, types.DatabaseClientCA, false) + require.NoError(t, err) + require.Len(t, dbClientCAs, 1) + require.Len(t, dbClientCAs[0].GetActiveKeys().TLS, 1) + require.Len(t, dbClientCAs[0].GetAdditionalTrustedKeys().TLS, 1) + activeDBClientCACert := dbClientCAs[0].GetActiveKeys().TLS[0].Cert + newDBClientCACert := dbClientCAs[0].GetAdditionalTrustedKeys().TLS[0].Cert + + tests := []struct { + name string + requester proto.DatabaseCertRequest_Requester + wantCertSigner []byte + wantCACerts [][]byte + wantKeyUsage []x509.ExtKeyUsage + }{ + { + name: "DB service request is signed by active db client CA and trusts db CAs", + wantCertSigner: activeDBClientCACert, + wantCACerts: [][]byte{activeDBCACert, newDBCACert}, + wantKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }, + { + name: "tctl request is signed by new db CA and trusts db client CAs", + requester: proto.DatabaseCertRequest_TCTL, + wantCertSigner: newDBCACert, + wantCACerts: [][]byte{activeDBClientCACert, newDBClientCACert}, + wantKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + certResp, err := authServer.AuthServer.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{ + CSR: csr, + ServerName: "localhost", + TTL: proto.Duration(time.Hour), + RequesterName: tt.requester, + }) + require.NoError(t, err) + require.Equal(t, tt.wantCACerts, certResp.CACerts) + + // verify that the response cert is a DB CA cert. + mustVerifyCert(t, tt.wantCertSigner, certResp.Cert, tt.wantKeyUsage...) + }) + } +} + +// mustVerifyCert is a helper func that verifies leaf cert with root cert. +func mustVerifyCert(t *testing.T, rootPEM, leafPEM []byte, keyUsages ...x509.ExtKeyUsage) { + t.Helper() + leafCert, err := tlsca.ParseCertificatePEM(leafPEM) + require.NoError(t, err) + + certPool := x509.NewCertPool() + ok := certPool.AppendCertsFromPEM(rootPEM) + require.True(t, ok) + opts := x509.VerifyOptions{ + Roots: certPool, + KeyUsages: keyUsages, + } + // Verify if the generated certificate can be verified with the correct CA. + _, err = leafCert.Verify(opts) + require.NoError(t, err) +} diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index d725cff3456f2..7c3d2bada82a4 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -404,7 +404,18 @@ func (g *GRPCServer) WatchEvents(watch *proto.Watch, stream proto.AuthService_Wa } } -// maybeFilterCertAuthorityWatches will add filters to the CertAuthority +// maybeFilterCertAuthorityWatches is a helper function that will first try to +// add a CertAuthority filter for clients that are unaware of the DatabaseCA. +// If that filter is not added, then it will try to add a filter for clients +// that are unaware of the DatabaseClientCA. +func maybeFilterCertAuthorityWatches(ctx context.Context, clusterName string, roleNames []string, watch *types.Watch) bool { + if ok := maybeFilterCertAuthorityWatchesForDatabaseCA(ctx, clusterName, roleNames, watch); ok { + return ok + } + return maybeFilterCertAuthorityWatchesForDBClientCA(ctx, watch) +} + +// maybeFilterCertAuthorityWatchesForDatabaseCA will add filters to the CertAuthority // WatchKinds in the watch if the client is authenticated as just a `Node` with // no other roles and if the client is older than the cutoff version, and if the // WatchKind for KindCertAuthority is trivial, i.e. it's a WatchKind{Kind: @@ -413,30 +424,31 @@ func (g *GRPCServer) WatchEvents(watch *proto.Watch, stream proto.AuthService_Wa // everything. // // DELETE IN 10.0, no supported clients should require this at that point -func maybeFilterCertAuthorityWatches(ctx context.Context, clusterName string, roleNames []string, watch *types.Watch) { +func maybeFilterCertAuthorityWatchesForDatabaseCA(ctx context.Context, clusterName string, roleNames []string, watch *types.Watch) bool { if len(roleNames) != 1 || roleNames[0] != string(types.RoleNode) { - return + return false } clientVersionString, ok := metadata.ClientVersionFromContext(ctx) if !ok { log.Debug("no client version found in grpc context") - return + return false } clientVersion, err := semver.NewVersion(clientVersionString) if err != nil { log.WithError(err).Debugf("couldn't parse client version %q", clientVersionString) - return + return false } // we treat the entire previous major version as "old" for this version // check, even if there might have been backports; compliant clients will // supply their own filter anyway if !clientVersion.LessThan(certAuthorityFilterVersionCutoff) { - return + return false } + var filtered bool for i, k := range watch.Kinds { if k.Kind != types.KindCertAuthority || !k.IsTrivial() { continue @@ -447,13 +459,113 @@ func maybeFilterCertAuthorityWatches(ctx context.Context, clusterName string, ro types.HostCA: clusterName, types.UserCA: types.Wildcard, }.IntoMap() + filtered = true } + return filtered } // certAuthorityFilterVersionCutoff is the version starting from which we stop // injecting filters for CertAuthority watches in maybeFilterCertAuthorityWatches. var certAuthorityFilterVersionCutoff = *semver.New("9.0.0") +// dbClientCAVersionCutoff is the version starting from which we stop +// injecting a filter that drops DatabaseClientCA events. +// +// TODO(gavin): adjust for release! +var dbClientCACutoffVersion = semver.Version{Major: 14, Minor: 2, Patch: 0} + +// maybeFilterCertAuthorityWatchesForDBClientCA will inject a CA filter and return a +// function that removes the filter from the OpInit event if the client version +// does not support DatabaseClientCA type and if the client's CA WatchKind is +// trivial, i.e. it's not already filtering. Otherwise we assume that the client +// knows what it's doing and this function does nothing. +// This is a hack to avoid client cache re-init during CA rotation in older +// services that don't use a CA watch filter, i.e. every service except Node +// since v9. +// The returned function, if non-nil, must be called on the OpInit event, to +// remove the injected filter from the OpInit event's WatchStatus. This is to +// maintain the illusion to the client that CA events have not been filtered. +// If we did not remove the injected filter, then validateWatchRequest would +// fail on the client side because the confirmed kind filter is not as narrow or +// narrower than a trivial WatchKind. +// +// TODO(gavin): DELETE IN 16.0.0 - no supported clients will require this at +// that point. +func maybeFilterCertAuthorityWatchesForDBClientCA(ctx context.Context, watch *types.Watch) bool { + // check client version to see if it knows the DatabaseClientCA type. + clientVersion, err := getClientVersion(ctx) + if err != nil { + log.Debugf("Unable to determine client version: %v", err) + return false + } + if versionSupportsDatabaseClientCA(*clientVersion) { + // don't need to inject a CA filter if the client support DB Client CA. + return false + } + + // search for trivial CA WatchKinds to inject a filter into - all of them + // must be trivial so we can remove the filter(s) from OpInit later. + var targets []*types.WatchKind + for i, k := range watch.Kinds { + if k.Kind != types.KindCertAuthority { + continue + } + + if !k.IsTrivial() { + // We need to remove the injected filter(s) from the OpInit event + // later. + // As a precaution, do nothing when any of the CA WatchKind(s) are + // non-trivial. + log.Warnf("Cannot inject filter into non-trivial CertAuthority watcher with client version %s.", clientVersion) + return false + } + targets = append(targets, &watch.Kinds[i]) + } + if len(targets) == 0 { + return false + } + + // create a CA filter that excludes DatabaseClientCA. + caFilter := make(types.CertAuthorityFilter, len(types.CertAuthTypes)-1) + for _, caType := range types.CertAuthTypes { + // exclude db client CA. + if caType == types.DatabaseClientCA { + continue + } + caFilter[caType] = types.Wildcard + } + + log.Debugf("Injecting filter for CertAuthority watcher with client version %s.", clientVersion) + for _, t := range targets { + t.Filter = caFilter.IntoMap() + } + return true +} + +func getClientVersion(ctx context.Context) (*semver.Version, error) { + clientVersionString, ok := metadata.ClientVersionFromContext(ctx) + if !ok { + return nil, trace.NotFound("no client version found in grpc context") + } + clientVersion, err := semver.NewVersion(clientVersionString) + return clientVersion, trace.Wrap(err) +} + +// versionSupportsDatabaseClientCA returns true if the client version supports +// the DatabaseClientCA. This CA was introduced in backports. +// Client version in the intervals [v12.x, v13.0), [v13.y, v14.0), [v14.z, inf) +// supports the DatabaseClientCA type, where x, y, z are the minor release +// versions that the DatabaseClientCA is backported to. +// Since this function needs to be aware of multiple minor release versions, +// we should first backport to v12 with a known minor version, then +// v13, then v14, and finally merge into v15. +// That way each minor release will be aware of the supported version +// intervals. +func versionSupportsDatabaseClientCA(v semver.Version) bool { + v.PreRelease = "" // ignore pre-release tags + return !v.LessThan(dbClientCACutoffVersion) +} + // resourceLabel returns the label for the provided types.Event func resourceLabel(event types.Event) string { if event.Resource == nil { @@ -1325,8 +1437,9 @@ func (g *GRPCServer) SignDatabaseCSR(ctx context.Context, req *proto.DatabaseCSR return response, nil } -// GenerateDatabaseCert generates client certificate used by a database -// service to authenticate with the database instance. +// GenerateDatabaseCert generates a client certificate used by a database +// service to authenticate with the database instance, or a server certificate +// for configuring a self-hosted database, depending on the requester_name. func (g *GRPCServer) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { auth, err := g.authenticate(ctx) if err != nil { diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 9821a3f153072..b37975e006324 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -20,6 +20,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" "encoding/base32" "encoding/pem" "fmt" @@ -2352,8 +2353,31 @@ func TestGenerateHostCerts(t *testing.T) { require.NotNil(t, certs) } -// TestInstanceCertAndControlStream attempts to generate an instance cert via the -// assertion API and use it to handle an inventory ping via the control stream. +func TestGenerateDatabaseCerts(t *testing.T) { + t.Parallel() + ctx := context.Background() + srv := newTestTLSServer(t) + + clt, err := srv.NewClient(TestAdmin()) + require.NoError(t, err) + + // Generate CSR once for speed sake. + priv, err := testauthority.New().GeneratePrivateKey() + require.NoError(t, err) + csr, err := tlsca.GenerateCertificateRequestPEM(pkix.Name{CommonName: "test"}, priv) + require.NoError(t, err) + + certs, err := clt.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{CSR: csr}) + require.NoError(t, err) + require.NotNil(t, certs) + + certs, err = clt.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{CSR: csr, RequesterName: proto.DatabaseCertRequest_TCTL}) + require.NoError(t, err) + require.NotNil(t, certs) +} + +// TestInstanceCertAndControlStream uses an instance cert to send an +// inventory ping via the control stream. func TestInstanceCertAndControlStream(t *testing.T) { const assertionID = "test-assertion" const serverID = "test-server" @@ -3989,3 +4013,97 @@ func TestGRPCServer_GetInstallers(t *testing.T) { } } + +func TestDropDBClientCAEvents(t *testing.T) { + server := newTestTLSServer(t) + client, err := server.NewClient(TestIdentity{ + I: authz.BuiltinRole{ + Role: types.RoleDatabase, + AdditionalSystemRoles: []types.SystemRole{ + types.RoleNode, + }, + Username: server.ClusterName(), + }, + }) + require.NoError(t, err) + + dbClientCAs, err := client.GetCertAuthorities(context.Background(), types.DatabaseClientCA, false) + require.NoError(t, err) + require.Len(t, dbClientCAs, 1) + + dbCAs, err := client.GetCertAuthorities(context.Background(), types.DatabaseCA, false) + require.NoError(t, err) + require.Len(t, dbCAs, 1) + + tests := []struct { + desc string + clientVersion string + filter map[string]string + expectDrop bool + }{ + { + desc: "send db client CA events to supported client", + clientVersion: dbClientCACutoffVersion.String(), + }, + { + desc: "drop db client CA events to unsupported client", + clientVersion: "14.0.0", + expectDrop: true, + }, + } + + for i, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + clientCtx := metadata.AddMetadataToContext(ctx, + map[string]string{ + metadata.VersionKey: test.clientVersion, + }) + + requestedKind := types.WatchKind{Kind: types.KindCertAuthority, LoadSecrets: false} + watcher, err := client.NewWatcher(clientCtx, types.Watch{ + Name: "cas", + Kinds: []types.WatchKind{requestedKind}, + }) + require.NoError(t, err) + defer watcher.Close() + + // Swallow the init event + e := <-watcher.Events() + require.Equal(t, types.OpInit, e.Type) + require.Nil(t, e.Resource) + + // update the db client ca so the watcher gets an OpPut event + dbClientCAs[0].SetName(fmt.Sprintf("stub_%v", i)) + err = server.Auth().UpsertCertAuthority(dbClientCAs[0]) + require.NoError(t, err) + + // update the db ca so the watcher gets an OpPut event + dbCAs[0].SetName(fmt.Sprintf("stub_%v", i)) + err = server.Auth().UpsertCertAuthority(dbCAs[0]) + require.NoError(t, err) + + gotCA, err := func() (types.CertAuthority, error) { + for { + select { + case <-watcher.Done(): + return nil, watcher.Error() + case e := <-watcher.Events(): + if ca, ok := e.Resource.(types.CertAuthority); ok { + return ca, nil + } + } + } + }() + require.NoError(t, err) + if test.expectDrop { + // the watcher should only see the second ca event. + require.Equal(t, types.DatabaseCA, gotCA.GetType(), "db client CA event was supposed to be dropped") + return + } + // watcher should see the first event if it wasn't dropped. + require.Equal(t, types.DatabaseClientCA, gotCA.GetType(), "db client CA event was not supposed to be dropped") + }) + } +} diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 9a098d8b8ab26..999cf7f6166f3 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -553,6 +553,17 @@ func (a *TestAuthServer) Trust(ctx context.Context, remote *TestAuthServer, role if err != nil { return trace.Wrap(err) } + remoteCA, err = remote.AuthServer.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseClientCA, + DomainName: remote.ClusterName, + }, false) + if err != nil { + return trace.Wrap(err) + } + err = a.AuthServer.UpsertCertAuthority(remoteCA) + if err != nil { + return trace.Wrap(err) + } remoteCA, err = remote.AuthServer.GetCertAuthority(ctx, types.CertAuthID{ Type: types.OpenSSHCA, DomainName: remote.ClusterName, diff --git a/lib/auth/init.go b/lib/auth/init.go index 96a135a28813e..f6d935cca29db 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -43,6 +43,7 @@ import ( apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth/keystore" + "github.com/gravitational/teleport/lib/auth/migration" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/events" @@ -426,6 +427,12 @@ func initCluster(ctx context.Context, cfg InitConfig, asrv *Server) error { if err := migrateDBAuthority(ctx, asrv); err != nil { return trace.Wrap(err, "failed to migrate database CA") } + span.AddEvent("migrating db_client_authority") + err = migrateDBClientAuthority(ctx, asrv.Trust, cfg.ClusterName.GetClusterName()) + if err != nil { + return trace.Wrap(err) + } + span.AddEvent("completed migration db_client_authority") // generate certificate authorities if they don't exist var ( @@ -777,7 +784,7 @@ func checkResourceConsistency(ctx context.Context, keyStore *keystore.Manager, c switch r.GetType() { case types.HostCA, types.UserCA, types.OpenSSHCA: _, signerErr = keyStore.GetSSHSigner(ctx, r) - case types.DatabaseCA, types.SAMLIDPCA: + case types.DatabaseCA, types.DatabaseClientCA, types.SAMLIDPCA: _, _, signerErr = keyStore.GetTLSCertAndSigner(ctx, r) case types.JWTSigner, types.OIDCIdPCA: _, signerErr = keyStore.GetJWTSigner(ctx, r) @@ -1334,3 +1341,14 @@ func applyResources(ctx context.Context, service *Services, resources []types.Re } return nil } + +// migrateDBClientAuthority copies Database CA as Database Client CA. +// Does nothing if the Database Client CA already exists. +// +// TODO(gavin): DELETE IN 16.0.0 +func migrateDBClientAuthority(ctx context.Context, trustSvc services.Trust, cluster string) error { + migrationStart(ctx, "db_client_authority") + defer migrationEnd(ctx, "db_client_authority") + err := migration.MigrateDBClientAuthority(ctx, trustSvc, cluster) + return trace.Wrap(err) +} diff --git a/lib/auth/init_test.go b/lib/auth/init_test.go index 7450f94c9c707..51d10920c9040 100644 --- a/lib/auth/init_test.go +++ b/lib/auth/init_test.go @@ -817,6 +817,21 @@ spec: key: LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBdmZQYlVPV3RMYU9QYndzTDVSRENWemFXMnliRUVmbzZNd3RrYVVFditnQjhocVBwCjhyU1VPa2EyYzBJbU11WGFIa05Qa1lUaTNPQmsvQmdUN3RKQlh0TS9mSDJjUy85S0l2NENwRlBQdHhEWWRaKzIKUHArYlJENkpiNVh0R0VaM2w4T1k0SkdBblpramRTM3JvMnFNL1NhMnJrOStEaDVkN1VvWU13RVJMbnoxRTlqKwpjbVFMQWJxSUhOQnNEY3RTZklya013cUNqTXlqZDMyN1YxMG9ua3M1NXBUOG5Oc3N4NnBrM1lEK3JndVdwRk00CkRxdlhXbkVTYWxHMEZNbHZSRlR1TlhUL3BmdGNMT29FWVVwL3FFYkFTdkprbEppT3UvRmRlcUxhamtlbmZjbGIKeVZPa1o5ZjlXOTJ6UXlxK3BGeUpkQmtGN041bys3UUFOMDdBUlFJREFRQUJBb0lCQVFDWkFiTHBxUGdrU1JtaQpncTFrS0duQ3dwQWxtMFpZak0wUWpONm5BZ0ZaU2NjRTFVZi9Yb0gvcHpJVUNYYW5qUXB6VWhqbnlMak0zbHU1CnpOTlJqajlsMkpmTStZbEtsaXJyb053VDdnYmxHVWFqQ0xGT0pGWjNWRUIwaDduaDBmRkhhQ0RlMDVWY1hSeDQKcVRLa0FaSHI0S0ZLSzNJSWdXRjdZREc1OCtRWkl0N01BdDQyc1NrUlhaenNyaUVzVWxZNU9xT1dlVHlkYy93dQpDK292V1FlOUU2bnRSc0EvY25nS0M1bnp6T0pCano3SlVTMGQ2UnBGK1Zodk9SRlZ6YTZKMnpRRXJsWWtFdVVUCkJYdThsVmhhcDVOU09nd3ByZXVmb2tNZVY4eGRNdUxuVEthdHVjajIwT3JFRUlHZVFrb08yTzhTeDJtNitOS0kKbzBTT2kzU0pBb0dCQU5LVkk4R0JHQUpSUlV2NVBnVkNDQ1VHSkFBRmtxY0FMTFlaZ0NMZCtLYStBVW41L2VONwpWemR4b2VqNG1UcnFOb0RLTnk5b0N2Z1RBWTlWZ3RZZllpdlNQbW5ZZk9FbTJUR2pLRTJSYzhBd0R5UWdHWFJVCm5WOUVueHRUUHF5dU9tQTFqa0ladXBLeUU0b0drdTE4MDRVaTB2d2pJUWI3cDdML0ZGYksvaGlmQW9HQkFPYnIKclBrOHdicnBlVk9QWmVudzVzbjVhTUpjdDM3emt4dHZzejVSVWlsZnZtVkdja1lyMU41SGpmZjVVVGcwNmNWdgp1N0dGY0NLQmhncFdqSXM3dHZ4cDhUVFM1VEZaWHZqY2VhRC91ZkFiN3ZUVnR0L1hRMFUwQ28rcDdzMENKWlZKCjY2MUtUU1RCYlFtR0xYYmNqNmFYdVM3bzFJM1FRU1l4NW5LV1F5aWJBb0dCQUtBTWZocUtGVWRkb1g5MnRhNmwKV3k5WWxXLzJ6RmxsQnBaNGx5em83QjArK0JmVGl5V2tEc3V5NzgzemMvS1ZKRXVLWlpzQVJxWDVQQXhHZjZSaQpRZWp3YUVObUtMT3ZKUkJXNDBEaE5jcHlQRy9HZmRJdXBWVk5BR2h5UW9aWC9VSTJNaU1IRHdpRGs5b3AyTzNyCkc1QnF3VlNsRm1zS1JaRUQwZCtOZE1ZZEFvR0FkVDNQRXJQd1FHL3RzNmtvdTBBZVRRbWVVS0EyWWZSVkNpY0sKUUdlVmFZQTg4THAxcG43Mmt1eU5mZ3ROVzFZeUlwWDZHOFYrQzJicm9UQVVKMVRvTVB1eEJYclY5dHBEUitMWQp0ZzlnWGpJd2ZvcExVUmJBQnRESFUrMlpXdWp1SC8vcDhvKzQzeUo5czhvMkp4VVFzaXB5VVFqUmNqYjcvT0owCitGU21RR1VDZ1lBdU5XcFhrbVkzbGtRSEFpU3RlblhTT2Zqc0xsbm5XWFJoTDVBYTZqRVZRT3FnUldVQnZ2YWkKK1RQRTNUTHQ5MGwveTZOdjU0dDdrT1QvSlEvREU4WmFiMERlTjdBRzRwRjMwNkxpYlpZNmswc3M1UDNXME8vbAozekJzQ0lEY3BHOWw4bDZSNkUwdHN0Z0I1c25hOE4rRzA2V3Q1R0M0UitRSVd1YTVpUmtVTXc9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo= type: user sub_kind: user +version: v2` + databaseClientCAYAML = ` +kind: cert_authority +metadata: + id: 1696989861240620000 + name: me.localhost +spec: + active_keys: + tls: + - cert: LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURqVENDQW5XZ0F3SUJBZ0lSQU5XcUtsOWR3WGYrVTBWUmhxNGEyaTB3RFFZSktvWklodmNOQVFFTEJRQXcKWURFVk1CTUdBMVVFQ2hNTWJXVXViRzlqWVd4b2IzTjBNUlV3RXdZRFZRUURFd3h0WlM1c2IyTmhiR2h2YzNReApNREF1QmdOVkJBVVRKekk0TkRBd09URXhNams0TlRBek1qYzRPVEV3TURJNE5qUXlPREV6TmpRMk9UQTVNamt3Ck9UQWVGdzB5TXpFd01URXdNakEwTWpGYUZ3MHpNekV3TURnd01qQTBNakZhTUdBeEZUQVRCZ05WQkFvVERHMWwKTG14dlkyRnNhRzl6ZERFVk1CTUdBMVVFQXhNTWJXVXViRzlqWVd4b2IzTjBNVEF3TGdZRFZRUUZFeWN5T0RRdwpNRGt4TVRJNU9EVXdNekkzT0RreE1EQXlPRFkwTWpneE16WTBOamt3T1RJNU1Ea3dnZ0VpTUEwR0NTcUdTSWIzCkRRRUJBUVVBQTRJQkR3QXdnZ0VLQW9JQkFRRGlDNU5GVDhmUy9hSzdSSVVRVnFjVmFxaFBzMFNuU0prTmd3azEKUHZkeFV1OWZZMlNwek5NaUUzSGZlb0Y4S1h2YUU0aHJzMEFGOVRmYlpJTnM1RjNHNTNzOUg3Q2JXWHpOWVRtZApCN0gyWEVxVGp3N0xGL2pzYzkwcTN4ZnZqMkk0Z29tOUdYK3dGMXdaRldjZXVJRkJTdXdCRkV6a1Yzc1o5NEVqClBsWUIxK2lnNlJoWGhvUjdhRlJUNDFvZmtMUUovMDdBVmR4blUranp5VkVFSVk3SjUwUWU3bFc2Nk9wL3BncmwKR1FBSnkwbnowUVpVYVJjVmZrODVHK3NwMnhjcUJ6clJHbXNybmw1TmhMdGJqcUJIUkZ2cU5XS1pLa1V2M0NjUApiTytWT1krV1FmV0UzRThhekxUQ2ppcnJXYWVXeTNLR0RTZGF5YktOK0FKbVpqU3hBZ01CQUFHalFqQkFNQTRHCkExVWREd0VCL3dRRUF3SUJwakFQQmdOVkhSTUJBZjhFQlRBREFRSC9NQjBHQTFVZERnUVdCQlFhVmMzdXlYWnoKdDBWLzFnNzE3MzMrRjFhaHNqQU5CZ2txaGtpRzl3MEJBUXNGQUFPQ0FRRUF6NTVkVnVFdmVLdnJtYThzL0dWSQo2Q0t5akNNYXNjWmhwV1JIT3QxRjQ1T0pjcXg1RDBQeVhSenZXS2NTYzlZTkN0M1BzSi8yNGp3VDlLaElqK2NiClQ5Z0h5WXNkb3pWY2NzMXNZTkFjK3VFSmRSOEsydHJqa1JJN0Q5VmZvTEJJVFlHUkJGTWpSOEE1bENlUzVnTkgKRG42V09rSlpRUi9UQS9IbFFlUmttMW5teUp3VVVQOVA0aUVWVlVSS0lMRVVNTS9EdERXdTZuNnM2K0pVVXNDNwp5QmI2T3JQeVRGbkV4TFljN2RhYUM1bm5UVDZHY2xUSm4wYkJ2UmtXdUFVa1FtWXJyYkpBMnhEVjFBL0JOcmp3Ci9aU2ErU1ZlVWJxSW05ZEVESE4zQUhXcmJzbWwyVjI3YUtrMHVUK0JmeUJBZ3NSdGpMN0U2YUdJanlNcStlOW4KYmc9PQotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg== + key: LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFb3dJQkFBS0NBUUVBNGd1VFJVL0gwdjJpdTBTRkVGYW5GV3FvVDdORXAwaVpEWU1KTlQ3M2NWTHZYMk5rCnFjelRJaE54MzNxQmZDbDcyaE9JYTdOQUJmVTMyMlNEYk9SZHh1ZDdQUit3bTFsOHpXRTVuUWV4OWx4S2s0OE8KeXhmNDdIUGRLdDhYNzQ5aU9JS0p2Umwvc0JkY0dSVm5IcmlCUVVyc0FSUk01RmQ3R2ZlQkl6NVdBZGZvb09rWQpWNGFFZTJoVVUrTmFINUMwQ2Y5T3dGWGNaMVBvODhsUkJDR095ZWRFSHU1VnV1anFmNllLNVJrQUNjdEo4OUVHClZHa1hGWDVQT1J2cktkc1hLZ2M2MFJwcks1NWVUWVM3VzQ2Z1IwUmI2alZpbVNwRkw5d25EMnp2bFRtUGxrSDEKaE54UEdzeTB3bzRxNjFtbmxzdHloZzBuV3NteWpmZ0NabVkwc1FJREFRQUJBb0lCQUQ5R01EWkJxOVRDek0rUQowWktPUHZ6K3V4aDhQT1o2cXVVZVhmQjZyTGNiR1FoaGdTY0t2N3NWS0ZYL0s4bStydjJQWkN1SnBJMUdaQmxVCm5IbFp2MnBURjZzM2VLOHpzSHlwRDRDR1MrbURVaGpWL2JVYUE4TGtkKzl0UFgwQWJPVVduVW5Dbm55RFBYT0UKQ3phTlBSa3l5TGRRb0dsMmwyM2dXMVNyT1ZZUVBEUjZncWVJZVFYa3pHYUFUQ2twZWYrOVk4US9pTkZUR05oZAppamtXSUZOdEYzQjdIODUwdnR0VFFRckE3QXQ3ZnN2bmo2YVVDUkQ3MmFYZGJmeHIwK0VQbUR6WGNhejM3U20yClg3ZkJrakRFa0pCa0gwVnBnZHdvMDh3cmtzbnBieUNpbE95alp4Z3lhUWw2NGFKcGVUN1FHbEMxcm1kUEFmTU4KSEdweFBwVUNnWUVBK21vbllBVW1oSVpTR0R6VjI0NFc0QVRnVFdGb2gxb24wKzlZZ2xxR1RUZFFBM2dXYy9ReQowSmJ6QXpOVFZnODNibWhvU3NtbUMwZ1BoSytBb3BCZXc5ZlVxMmhIOVptUzVjZE1CaTJ6cFdvZHQrWXMvNys1Ckk3d3d2bGgvY3llelQvU0ZyazVCVVB0azZFR1pLZk1KdGZlNDcvb05uUzRmc3lmYlAyWDVUWHNDZ1lFQTV4WkcKSmhiYkwwWFljL0plc0JucXBuRFZHNXhkbkh6aGx0aVB3QzdGbHRDRUpuNFdhTUNpSUtsL0o3VGUzbndqMHk0YQpSVzFTWGN6anc5dHZxY0V3aE9CQUZBUHlsd0FWWjUyOTRzdWZzMnk4SHBFcFhhMjlIRXNReW50TE9JRFZyYkVsClJCV1pEb0xhbllVRTNtWml5WnZ0ZU9TT0hTajVhUnIzRTg1UmtNTUNnWUFweUVLUG8reGNXbWtpUUN4U3VPK2EKSzFZZHN5NFV2M2M3eG9qWEh6R2ZlcVl3SGY1cEZJclNBUTNGTC9Bc3dOYzM1ZFhZL0xKbTJYdzFZRzh2TUxXUApLZGtEVEtBTkc3WEYveTN4TGZqMmxiRWx1Uk16RFJOZ0lndGtCeklremEvK25FY2Q0VkxHcDF1YjRTNGtNTGdqCkU1VlkvVGorUys3Z0hydFhaYlZtTndLQmdET3A2eTBBMXlnT2VZSVNvZERGT296VGxSR0ROL3FRZ083MG84N1gKcGgwOXFRM2lDcWlJeUxaOHJvejJCdzIrdTFPdmJ2Z3VwTWVMMHpBcWt5QmtyTEJJWW9zWEJ0bHpqMVdIRXJqdAp4VnFiNk1MOHVUN1VaUDg2V1JxcnpmbG45RjNNeVFRYndBaGFnUDNPaTNRZGQrQ1RGOWg3WUxwc09yYWc3TFJrCjRCOTVBb0dCQU9maHNVSzVSZm1RU1ZzN08wMmMrdWFVelB2dnEwaXNqNW45dWlOaFQxdjFDUGY4YStZdkkyTisKcWV5bHkwRjN3L2sxbElaUzFjWlMwRDVWMUd6bmVHTUgreWYwVFAvNmlUcElHTC94N1pTNGJEZjE3dEc5dklDdgoya1JBTno5WHpzVzFETm9CRkJmZXg5NDFmT0RzdHlvdThvYmF3dDdJWThTdU1GMHV3aDlBCi0tLS0tRU5EIFJTQSBQUklWQVRFIEtFWS0tLS0tCg== + additional_trusted_keys: {} + cluster_name: me.localhost + type: db_client +sub_kind: db_client version: v2` databaseCAYAML = `kind: cert_authority metadata: @@ -892,7 +907,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA := resourceFromYAML(t, hostCAYAML).(types.CertAuthority) userCA := resourceFromYAML(t, userCAYAML).(types.CertAuthority) jwtCA := resourceFromYAML(t, jwtCAYAML).(types.CertAuthority) - dbCA := resourceFromYAML(t, databaseCAYAML).(types.CertAuthority) + dbServerCA := resourceFromYAML(t, databaseCAYAML).(types.CertAuthority) + dbClientCA := resourceFromYAML(t, databaseClientCAYAML).(types.CertAuthority) osshCA := resourceFromYAML(t, openSSHCAYAML).(types.CertAuthority) samlCA := resourceFromYAML(t, samlCAYAML).(types.CertAuthority) @@ -902,8 +918,10 @@ func TestInit_bootstrap(t *testing.T) { invalidUserCA.(*types.CertAuthorityV2).Spec.ActiveKeys.SSH = nil invalidJWTCA := resourceFromYAML(t, jwtCAYAML).(types.CertAuthority) invalidJWTCA.(*types.CertAuthorityV2).Spec.ActiveKeys.JWT = nil - invalidDBCA := resourceFromYAML(t, databaseCAYAML).(types.CertAuthority) - invalidDBCA.(*types.CertAuthorityV2).Spec.ActiveKeys.TLS = nil + invalidDBServerCA := resourceFromYAML(t, databaseCAYAML).(types.CertAuthority) + invalidDBServerCA.(*types.CertAuthorityV2).Spec.ActiveKeys.TLS = nil + invalidDBClientCA := resourceFromYAML(t, databaseClientCAYAML).(types.CertAuthority) + invalidDBClientCA.(*types.CertAuthorityV2).Spec.ActiveKeys.TLS = nil invalidOSSHCA := resourceFromYAML(t, openSSHCAYAML).(types.CertAuthority) invalidOSSHCA.(*types.CertAuthorityV2).Spec.ActiveKeys.SSH = nil invalidSAMLCA := resourceFromYAML(t, samlCAYAML).(types.CertAuthority) @@ -923,7 +941,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -938,7 +957,8 @@ func TestInit_bootstrap(t *testing.T) { invalidHostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -953,7 +973,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), invalidUserCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -968,7 +989,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), invalidJWTCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -983,7 +1005,23 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - invalidDBCA.Clone(), + invalidDBServerCA.Clone(), + osshCA.Clone(), + samlCA.Clone(), + ) + }, + assertError: require.Error, + }, + { + name: "NOK bootstrap Database Client CA missing keys", + modifyConfig: func(cfg *InitConfig) { + cfg.BootstrapResources = append( + cfg.BootstrapResources, + hostCA.Clone(), + userCA.Clone(), + jwtCA.Clone(), + dbServerCA.Clone(), + invalidDBClientCA.Clone(), osshCA.Clone(), samlCA.Clone(), ) @@ -998,7 +1036,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), invalidOSSHCA.Clone(), samlCA.Clone(), ) @@ -1013,7 +1052,8 @@ func TestInit_bootstrap(t *testing.T) { hostCA.Clone(), userCA.Clone(), jwtCA.Clone(), - dbCA.Clone(), + dbServerCA.Clone(), + dbClientCA.Clone(), osshCA.Clone(), invalidSAMLCA.Clone(), ) @@ -1318,3 +1358,26 @@ func TestRotateDuplicatedCerts(t *testing.T) { require.NotEqual(t, newHostCA.GetActiveKeys().TLS, newDatabaseCA.GetActiveKeys().TLS) require.NotEqual(t, newHostCA.GetActiveKeys().SSH, newDatabaseCA.GetActiveKeys().SSH) } + +func TestMigrateDatabaseClientCA(t *testing.T) { + ctx := context.Background() + conf := setupConfig(t) + + hostCA := suite.NewTestCA(types.HostCA, "me.localhost") + userCA := suite.NewTestCA(types.UserCA, "me.localhost") + dbServerCA := suite.NewTestCA(types.DatabaseCA, "me.localhost") + + conf.Authorities = []types.CertAuthority{hostCA, userCA, dbServerCA} + auth, err := Init(ctx, conf) + require.NoError(t, err) + t.Cleanup(func() { + err = auth.Close() + require.NoError(t, err) + }) + + dbClientCAs, err := auth.GetCertAuthorities(ctx, types.DatabaseClientCA, true) + require.NoError(t, err) + require.Len(t, dbClientCAs, 1) + require.Equal(t, dbServerCA.Spec.ActiveKeys.TLS[0].Cert, dbClientCAs[0].GetActiveKeys().TLS[0].Cert) + require.Equal(t, dbServerCA.Spec.ActiveKeys.TLS[0].Key, dbClientCAs[0].GetActiveKeys().TLS[0].Key) +} diff --git a/lib/auth/migration/0001_db_ca.go b/lib/auth/migration/0001_db_ca.go new file mode 100644 index 0000000000000..0109e96e67562 --- /dev/null +++ b/lib/auth/migration/0001_db_ca.go @@ -0,0 +1,97 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package migration + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" +) + +// MigrateDBClientAuthority performs a migration which creates the db_client CA +// as a copy of the existing db CA for backwards compatibility. +func MigrateDBClientAuthority(ctx context.Context, trustSvc services.Trust, cluster string) error { + err := migrateDBAuthority(ctx, trustSvc, cluster, types.DatabaseCA, types.DatabaseClientCA) + return trace.Wrap(err) +} + +// migrateDBAuthority performs a migration which creates a new CA from an +// existing CA. +// The new CA is created as a copy of the existing CA for backwards +// compatibility. +// This func is generalized for copying db/db_client CAs, although it may appear +// to be usable for other CA types - that's why it is unexported. +func migrateDBAuthority(ctx context.Context, trustSvc services.Trust, cluster string, fromType, toType types.CertAuthType) error { + _, err := trustSvc.GetCertAuthority(ctx, types.CertAuthID{ + Type: toType, + DomainName: cluster, + }, false) + // The migration for this cluster can be skipped since + // the new CA already exists. + if err == nil { + log.Debugf("Migrations: cert authority %q already exists.", toType) + return nil + } + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + + // The new CA type does not exist, so we must check to + // see if the existing CA exists before proceeding with the split migration. + // If both the existing and new CA do not exist, then this cluster + // is brand new and the migration can be avoided because they will + // both automatically be created. If the existing CA does exist, then + // a new CA should be constructed from it as a copy. + existingCA, err := trustSvc.GetCertAuthority(ctx, types.CertAuthID{ + Type: fromType, + DomainName: cluster, + }, true) + if trace.IsNotFound(err) { + return nil + } + if err != nil { + return trace.Wrap(err) + } + + log.Infof("Migrating %s CA for cluster: %s", toType, cluster) + + existingCAV2, ok := existingCA.(*types.CertAuthorityV2) + if !ok { + return trace.BadParameter("expected %s CA to be *types.CertAuthorityV2, got %T", fromType, existingCA) + } + + newCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: toType, + ClusterName: cluster, + ActiveKeys: existingCAV2.Spec.ActiveKeys, + }) + if err != nil { + return trace.Wrap(err) + } + + err = trustSvc.CreateCertAuthority(newCA) + if trace.IsAlreadyExists(err) { + log.Warnf("%s CA has already been created by a different Auth instance", toType) + return nil + } + return trace.Wrap(err) +} diff --git a/lib/auth/migration/migration.go b/lib/auth/migration/migration.go new file mode 100644 index 0000000000000..9a246202f074e --- /dev/null +++ b/lib/auth/migration/migration.go @@ -0,0 +1,30 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package migration + +import ( + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport" +) + +var log = logrus.WithFields(logrus.Fields{ + trace.Component: teleport.ComponentAuth, +}) diff --git a/lib/auth/rotate.go b/lib/auth/rotate.go index 76e591f8079d8..f13f82c58f02b 100644 --- a/lib/auth/rotate.go +++ b/lib/auth/rotate.go @@ -62,25 +62,10 @@ type RotateRequest struct { // Types returns cert authority types requested to be rotated. func (r *RotateRequest) Types() []types.CertAuthType { - switch r.Type { - case types.CertAuthTypeAll: + if r.Type == types.CertAuthTypeAll { return types.CertAuthTypes[:] - case types.HostCA: - return []types.CertAuthType{types.HostCA} - case types.DatabaseCA: - return []types.CertAuthType{types.DatabaseCA} - case types.UserCA: - return []types.CertAuthType{types.UserCA} - case types.OpenSSHCA: - return []types.CertAuthType{types.OpenSSHCA} - case types.JWTSigner: - return []types.CertAuthType{types.JWTSigner} - case types.SAMLIDPCA: - return []types.CertAuthType{types.SAMLIDPCA} - case types.OIDCIdPCA: - return []types.CertAuthType{types.OIDCIdPCA} } - return nil + return []types.CertAuthType{r.Type} } // CheckAndSetDefaults checks and sets default values. diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index f4231cb8b2479..8bc5db94d3321 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -1906,7 +1906,12 @@ func TestGetCertAuthority(t *testing.T) { DomainName: tt.server.ClusterName(), Type: types.DatabaseCA, }, true) - require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + _, err = proxyClt.GetCertAuthority(ctx, types.CertAuthID{ + DomainName: tt.server.ClusterName(), + Type: types.DatabaseClientCA, + }, true) + require.True(t, trace.IsAccessDenied(err)) _, err = proxyClt.GetCertAuthority(ctx, types.CertAuthID{ DomainName: tt.server.ClusterName(), diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 32b1ff6f14911..3f55bf0818b51 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -879,6 +879,13 @@ func (c *certAuthority) fetch(ctx context.Context) (apply func(ctx context.Conte } else if err != nil { return nil, trace.Wrap(err) } + missingDatabaseClientCA := false + applyDatabaseClientCAs, err := c.fetchCertAuthorities(ctx, types.DatabaseClientCA) + if trace.IsBadParameter(err) { + missingDatabaseClientCA = true + } else if err != nil { + return nil, trace.Wrap(err) + } // DELETE IN 13.0. // missingOpenSSHCA is needed only when leaf cluster v11 is connected @@ -932,6 +939,17 @@ func (c *certAuthority) fetch(ctx context.Context) (apply func(ctx context.Conte } } } + if !missingDatabaseClientCA { + if err := applyDatabaseClientCAs(ctx); err != nil { + return trace.Wrap(err) + } + } else { + if err := c.trustCache.DeleteAllCertAuthorities(types.DatabaseClientCA); err != nil { + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + } + } if !missingOpenSSHCA { if err := applyOpenSSHCAs(ctx); err != nil { return trace.Wrap(err) diff --git a/lib/client/api_test.go b/lib/client/api_test.go index 1911a01514df3..2ad52a06e481f 100644 --- a/lib/client/api_test.go +++ b/lib/client/api_test.go @@ -745,6 +745,15 @@ func TestVirtualPathNames(t *testing.T) { "TSH_VIRTUAL_PATH_CA", }, }, + { + name: "database client ca", + kind: VirtualPathCA, + params: VirtualPathCAParams(types.DatabaseClientCA), + expected: []string{ + "TSH_VIRTUAL_PATH_CA_DB_CLIENT", + "TSH_VIRTUAL_PATH_CA", + }, + }, { name: "host ca", kind: VirtualPathCA, diff --git a/lib/client/ca_export.go b/lib/client/ca_export.go index 59b50756c0e7b..4a1152dfd2621 100644 --- a/lib/client/ca_export.go +++ b/lib/client/ca_export.go @@ -117,6 +117,20 @@ func exportAuth(ctx context.Context, client auth.ClientI, req ExportAuthoritiesR ExportPrivateKeys: exportSecrets, } return exportTLSAuthority(ctx, client, req) + case "db-client": + req := exportTLSAuthorityRequest{ + AuthType: types.DatabaseClientCA, + UnpackPEM: false, + ExportPrivateKeys: exportSecrets, + } + return exportTLSAuthority(ctx, client, req) + case "db-client-der": + req := exportTLSAuthorityRequest{ + AuthType: types.DatabaseClientCA, + UnpackPEM: true, + ExportPrivateKeys: exportSecrets, + } + return exportTLSAuthority(ctx, client, req) case "tls-user-der", "windows": req := exportTLSAuthorityRequest{ AuthType: types.UserCA, diff --git a/lib/client/ca_export_test.go b/lib/client/ca_export_test.go index 727e51c33099b..1bda9aa3400ac 100644 --- a/lib/client/ca_export_test.go +++ b/lib/client/ca_export_test.go @@ -23,6 +23,7 @@ import ( "fmt" "testing" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -33,6 +34,8 @@ import ( type mockAuthClient struct { auth.ClientI server *auth.Server + + unsupportedCATypes []types.CertAuthType } func (m *mockAuthClient) GetDomainName(ctx context.Context) (string, error) { @@ -40,11 +43,21 @@ func (m *mockAuthClient) GetDomainName(ctx context.Context) (string, error) { } func (m *mockAuthClient) GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool, opts ...services.MarshalOption) ([]types.CertAuthority, error) { - return m.server.GetCertAuthorities(ctx, caType, loadKeys, opts...) + for _, unsupported := range m.unsupportedCATypes { + if unsupported == caType { + return nil, trace.BadParameter("%q authority type is not supported", unsupported) + } + } + return m.server.GetCertAuthorities(ctx, caType, loadKeys) } func (m *mockAuthClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) { - return m.server.GetCertAuthority(ctx, id, loadKeys, opts...) + for _, unsupported := range m.unsupportedCATypes { + if unsupported == id.Type { + return nil, trace.BadParameter("%q authority type is not supported", unsupported) + } + } + return m.server.GetCertAuthority(ctx, id, loadKeys) } func TestExportAuthorities(t *testing.T) { @@ -201,6 +214,15 @@ func TestExportAuthorities(t *testing.T) { }, assertSecrets: validatePrivateKeyPEMFunc, }, + { + name: "db", + req: ExportAuthoritiesRequest{ + AuthType: "db", + }, + errorCheck: require.NoError, + assertNoSecrets: validateTLSCertificatePEMFunc, + assertSecrets: validatePrivateKeyPEMFunc, + }, { name: "db-der", req: ExportAuthoritiesRequest{ @@ -210,6 +232,24 @@ func TestExportAuthorities(t *testing.T) { assertNoSecrets: validateTLSCertificateDERFunc, assertSecrets: validatePrivateKeyDERFunc, }, + { + name: "db-client", + req: ExportAuthoritiesRequest{ + AuthType: "db-client", + }, + errorCheck: require.NoError, + assertNoSecrets: validateTLSCertificatePEMFunc, + assertSecrets: validatePrivateKeyPEMFunc, + }, + { + name: "db-client-der", + req: ExportAuthoritiesRequest{ + AuthType: "db-client-der", + }, + errorCheck: require.NoError, + assertNoSecrets: validateTLSCertificateDERFunc, + assertSecrets: validatePrivateKeyDERFunc, + }, } { t.Run(fmt.Sprintf("%s_exportSecrets_%v", tt.name, exportSecrets), func(t *testing.T) { mockedClient := &mockAuthClient{ diff --git a/lib/client/db/database_certificates.go b/lib/client/db/database_certificates.go index 1209b40465432..462fb2cffd092 100644 --- a/lib/client/db/database_certificates.go +++ b/lib/client/db/database_certificates.go @@ -24,6 +24,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/identityfile" @@ -45,8 +46,8 @@ type GenerateDatabaseCertificatesRequest struct { Password string } -// GenerateDatabaseCertificates to be used by databases to set up mTLS authentication -func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCertificatesRequest) ([]string, error) { +// GenerateDatabaseServerCertificates to be used by databases to set up mTLS authentication +func GenerateDatabaseServerCertificates(ctx context.Context, req GenerateDatabaseCertificatesRequest) ([]string, error) { if len(req.Principals) == 0 || (len(req.Principals) == 1 && req.Principals[0] == "" && req.OutputFormat != identityfile.FormatSnowflake) { @@ -63,6 +64,11 @@ func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCerti subject := pkix.Name{CommonName: req.Principals[0]} + clusterNameType, err := req.ClusterAPI.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + clusterName := clusterNameType.GetClusterName() if req.OutputFormat == identityfile.FormatMongo { // Include Organization attribute in MongoDB certificates as well. // @@ -75,12 +81,7 @@ func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCerti // MongoDB cluster members so set it to the Teleport cluster name // to avoid hardcoding anything. - clusterNameType, err := req.ClusterAPI.GetClusterName() - if err != nil { - return nil, trace.Wrap(err) - } - - subject.Organization = []string{clusterNameType.GetClusterName()} + subject.Organization = []string{clusterName} } if req.Key == nil { @@ -112,6 +113,25 @@ func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCerti return nil, trace.Wrap(err) } + // For CockroachDB we provide node.crt, node.key, ca.crt, and ca-client.crt, + // and the user must use their own CA to issue client.node.crt, + // client.node.key, and add their own CA's cert to ca-client.crt. + // The response CA certs are for Teleport DB Client CA, so fetch the + // Teleport Database CA certs as well. + var additionalCACerts [][]byte + if req.OutputFormat == identityfile.FormatCockroach { + dbServerCA, err := req.ClusterAPI.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseCA, + DomainName: clusterName, + }, false) + if err != nil { + return nil, trace.Wrap(err) + } + for _, keyPair := range dbServerCA.GetTrustedTLSKeyPairs() { + additionalCACerts = append(additionalCACerts, keyPair.Cert) + } + } + req.Key.TLSCert = resp.Cert req.Key.TrustedCerts = []auth.TrustedCerts{{ ClusterName: req.Key.ClusterName, @@ -124,6 +144,7 @@ func GenerateDatabaseCertificates(ctx context.Context, req GenerateDatabaseCerti OverwriteDestination: req.OutputCanOverwrite, Writer: req.IdentityFileWriter, Password: req.Password, + AdditionalCACerts: additionalCACerts, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/identityfile/identity.go b/lib/client/identityfile/identity.go index d77329f172f3a..f067bac4d286e 100644 --- a/lib/client/identityfile/identity.go +++ b/lib/client/identityfile/identity.go @@ -193,6 +193,9 @@ type WriteConfig struct { Writer ConfigWriter // Password is the password for the JKS keystore used by Cassandra format and Oracle wallet. Password string + // AdditionalCACerts contains additional CA certs, used by Cockroach format + // to distinguish DB Server CA certs from DB Client CA certs. + AdditionalCACerts [][]byte } // Write writes user credentials to disk in a specified format. @@ -283,18 +286,54 @@ func Write(ctx context.Context, cfg WriteConfig) (filesWritten []string, err err } } - case FormatTLS, FormatDatabase, FormatCockroach, FormatRedis, FormatElasticsearch, FormatScylla: + case FormatCockroach: + // CockroachDB expects files to be named node.crt, node.key, ca.crt, + // ca-client.crt + certPath := filepath.Join(cfg.OutputPath, "node.crt") + keyPath := filepath.Join(cfg.OutputPath, "node.key") + casPath := filepath.Join(cfg.OutputPath, "ca.crt") + clientCAsPath := filepath.Join(cfg.OutputPath, "ca-client.crt") + + filesWritten = append(filesWritten, certPath, keyPath, casPath, clientCAsPath) + if err := checkOverwrite(ctx, writer, cfg.OverwriteDestination, filesWritten...); err != nil { + return nil, trace.Wrap(err) + } + + err = writer.WriteFile(certPath, cfg.Key.TLSCert, identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + + err = writer.WriteFile(keyPath, cfg.Key.PrivateKeyPEM(), identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + + var serverCACerts []byte + for _, cert := range cfg.AdditionalCACerts { + serverCACerts = append(serverCACerts, cert...) + } + err = writer.WriteFile(casPath, serverCACerts, identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + + var clientCACerts []byte + for _, ca := range cfg.Key.TrustedCerts { + for _, cert := range ca.TLSCertificates { + clientCACerts = append(clientCACerts, cert...) + } + } + err = writer.WriteFile(clientCAsPath, clientCACerts, identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + + case FormatTLS, FormatDatabase, FormatRedis, FormatElasticsearch, FormatScylla: keyPath := cfg.OutputPath + ".key" certPath := cfg.OutputPath + ".crt" casPath := cfg.OutputPath + ".cas" - // CockroachDB expects files to be named ca.crt, node.crt and node.key. - if cfg.Format == FormatCockroach { - keyPath = filepath.Join(cfg.OutputPath, "node.key") - certPath = filepath.Join(cfg.OutputPath, "node.crt") - casPath = filepath.Join(cfg.OutputPath, "ca.crt") - } - filesWritten = append(filesWritten, keyPath, certPath, casPath) if err := checkOverwrite(ctx, writer, cfg.OverwriteDestination, filesWritten...); err != nil { return nil, trace.Wrap(err) @@ -472,31 +511,39 @@ func writeOracleFormat(cfg WriteConfig, writer ConfigWriter) ([]string, error) { if err != nil { return nil, trace.Wrap(err) } - var caCerts []*x509.Certificate - for _, ca := range cfg.Key.TrustedCerts { - for _, cert := range ca.TLSCertificates { - c, err := tlsca.ParseCertificatePEM(cert) - if err != nil { - return nil, trace.Wrap(err) - } - caCerts = append(caCerts, c) - } - } - pf, err := pkcs12.Encode(rand.Reader, keyK, certBlock, caCerts, cfg.Password) + // encode the private key and cert. + // orapki import_pkcs12 refuses to add trusted certs unless they are an + // issuer for an oracle wallet user_cert, and the server cert we create + // is not signed by the DB Client CA, so don't pass trusted certs + // (DB Client CA) here. + pf, err := pkcs12.Encode(rand.Reader, keyK, certBlock, nil, cfg.Password) if err != nil { return nil, trace.Wrap(err) } p12Path := cfg.OutputPath + ".p12" - certPath := cfg.OutputPath + ".crt" - if err := writer.WriteFile(p12Path, pf, identityfile.FilePermissions); err != nil { return nil, trace.Wrap(err) } - err = writer.WriteFile(certPath, cfg.Key.TLSCert, identityfile.FilePermissions) - if err != nil { - return nil, trace.Wrap(err) + + clientCAs := cfg.Key.TLSCAs() + var caPaths []string + for i, caPEM := range clientCAs { + var caPath string + if len(clientCAs) > 1 { + // orapki wallet add can only add one trusted cert at a time, so we + // output up to two files - one for each CA key to trust during a + // rotation. + caPath = fmt.Sprintf("%s.ca-client-%d.crt", cfg.OutputPath, i) + } else { + caPath = cfg.OutputPath + ".ca-client.crt" + } + err = writer.WriteFile(caPath, caPEM, identityfile.FilePermissions) + if err != nil { + return nil, trace.Wrap(err) + } + caPaths = append(caPaths, caPath) } // Is ORAPKI binary is available is user env run command ang generate autologin Oracle wallet. @@ -504,15 +551,17 @@ func writeOracleFormat(cfg WriteConfig, writer ConfigWriter) ([]string, error) { // Is Orapki is available in the user env create the Oracle wallet directly. // otherwise Orapki tool needs to be executed on the server site to import keypair to // Oracle wallet. - if err := createOracleWallet(cfg.OutputPath, p12Path, certPath, cfg.Password); err != nil { + if err := createOracleWallet(caPaths, cfg.OutputPath, p12Path, cfg.Password); err != nil { return nil, trace.Wrap(err) } // If Oracle Wallet was created the raw p12 keypair and trusted cert are no longer needed. if err := os.Remove(p12Path); err != nil { return nil, trace.Wrap(err) } - if err := os.Remove(certPath); err != nil { - return nil, trace.Wrap(err) + for _, caPath := range caPaths { + if err := os.Remove(caPath); err != nil { + return nil, trace.Wrap(err) + } } // Return the path to the Oracle wallet. return []string{cfg.OutputPath}, nil @@ -520,7 +569,7 @@ func writeOracleFormat(cfg WriteConfig, writer ConfigWriter) ([]string, error) { // Otherwise return destinations to p12 keypair and trusted CA allowing a user to run the convert flow on the // Oracle server instance in order to create Oracle wallet file. - return []string{p12Path, certPath}, nil + return append([]string{p12Path}, caPaths...), nil } const ( @@ -532,7 +581,7 @@ func isOrapkiAvailable() bool { return err == nil } -func createOracleWallet(walletPath, p12Path, certPath, password string) error { +func createOracleWallet(caCertPaths []string, walletPath, p12Path, password string) error { errDetailsFormat := "\n\nOrapki command:\n%s \n\nCompleted with following error: \n%s" // Create Raw Oracle wallet with auto_login_only flag - no password required. args := []string{ @@ -556,16 +605,18 @@ func createOracleWallet(walletPath, p12Path, certPath, password string) error { return trace.Wrap(err, fmt.Sprintf(errDetailsFormat, cmd.String(), output)) } - // Add import teleport CA to the oracle wallet. - args = []string{ - "wallet", "add", "-wallet", walletPath, - "-trusted_cert", - "-auto_login_only", - "-cert", certPath, - } - cmd = exec.Command(orapkiBinary, args...) - if output, err := exec.Command(orapkiBinary, args...).CombinedOutput(); err != nil { - return trace.Wrap(err, fmt.Sprintf(errDetailsFormat, cmd.String(), output)) + // Add import teleport CA(s) to the oracle wallet. + for _, certPath := range caCertPaths { + args = []string{ + "wallet", "add", "-wallet", walletPath, + "-trusted_cert", + "-auto_login_only", + "-cert", certPath, + } + cmd = exec.Command(orapkiBinary, args...) + if output, err := exec.Command(orapkiBinary, args...).CombinedOutput(); err != nil { + return trace.Wrap(err, fmt.Sprintf(errDetailsFormat, cmd.String(), output)) + } } return nil } diff --git a/lib/fixtures/keys.go b/lib/fixtures/keys.go index b8b6a1a987db2..55f9d72867de9 100644 --- a/lib/fixtures/keys.go +++ b/lib/fixtures/keys.go @@ -57,6 +57,34 @@ MHeDg2Bs7/XZsIrn6vo7kXmQSoQKA8O2E7rYSigUayBKa/+5thbnjKlEP+slBzmp 7JPquig/B6L2pNoxPa41VDGawQjJY5m4l3ap/oJj61HBB+Auf29BWXqg7V7B7XMB NFJgTFxC2o3mVBkQ/s6FeDl62hpMheCuO6jRYbZjsM2tUeAKORws -----END RSA PRIVATE KEY-----`), + "rsa-db-client": []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA3scV81B4bpa0qkpBsJDUf3UOIs6A4+WZnf5eXJcJg7zDi5/J +2vtBuk8CTvp8eQhu4Pq2G0RhHZzrYiBMWLf/ORzBSKrN+bQi9pRKNzN/hJ7SOq+T +tkyv5tGNn+PIGzX712Ao9Iw5TIJTy2QpiWKQ7MFiVAs399B0Ow7aHNRmdrB3jcJN +V/HizGkno8SF7EgMGwlIG+z8pE9PmlKV5ZkX0fz+W1IMMkiveGgy0tTNazTsnwOE +meBDTyB8YbLGPqiQJxoTEHWaQUTLNfWXdvD2x15UBgWiKA/Ng05fPpDqR4cYgM+n +Nv8lhwKprWcJmF34HVNO6xhFkMU78EIW4lFeYQIDAQABAoIBACIHsU+wnCTwenqE +y1IIXZ12qQkiGEg3u2aKA6oLHFX2ULyUVQZRWTH3fbfIxZjLc/yD76tsn5UhckdT +/bWTrbXwsYnDJaGeJbUa49dY04LTq/Nw/JRdVIViv0qMRfX6IhU9SCRLAzmvstMf +4sRsvQydYcLKz+rX+dlHpIPA4kIA27wqEGbaCb4WatOYf4kUvIyBT/A/QrdRQfPK +YTsQoQk8TNMdeTCGyyHqBSnbiI9r7EmWrH2r7FdNSFNsoH1FyX35sD4dLC/vjXNT +dDT7cnwwIHHKpQIDdBKG8ivVM2nuvSNNVp+LUs9rWKZ3R2xyln3NOTLAHCMDUoAi +TCY/YAECgYEA/VGBIoH4kozoY7SJAEjMPNzHFJDmUVo4U/bN/hGAcGM+YXjSnuO0 +ljgSn7h7U2Z09KfIO8GjWtHq6K2gCw23of1D/CRPDIRvrONCVrhWcqebAcFD9Zw0 +4EbPYDnTq4M9gC2RLsCIFVVL4Gsw9+4iSAKhCsHdW/dORUCDyBRbOwECgYEA4SLQ +0Myni4U7v8QVnwREJifjAvE0xrALpdF82CFgyimkCsgzxxBwdc2qhbtiIoNEYx/X +OEPpFU+SuCAQe6xYCsjP3rh1kCN8ETu9NlnpD5o2BUYVgPR7xYYQ3aci/UYWHfZ1 +BGus1PNkL+87d62bvMcpDC9VfAvjGAvOqqUPA2ECgYEAlB4EE9lLLuWVPDdjo/bs +9OlivnO7N/Y42V+GMvio0Q42e2faP22FOhCvUxTbh3hxClzQh6BBk+kKIeLjoZLz +vJQKHHRehEMryTtYnrxKT+AQkoYe5o3fnQPKXclyKuciHsCGE4AgEdk99Iq4pz9m +bBSddVzFwfBoo7WFWIgOkAECgYEAzS1cnx4Uh5vJ4y/CAKTzss5hHlpTHcxtIRa1 +L4fj3PpsLQNd5Mp/o2znPm+StR9qoOfwza9eafSWI0Xdn8hmiJWQlEsJoW4lcNM/ +0pvIQlbpao7/pAGsF0zibA8ZXTeVioMFDB1RatXSdbkSOjS3HSloqFkvEBkJQu3n +0C8TaqECgYAic82i4ZMIJNeFtI2eGw2a89ofO3gpJvGmaJwY8RgZYWtU7+YiI1ts +RBeQG4aFIeOs/3nf0n8pp/xsgreLVZJBXjoWyvw7pDi60N4C07d2gA+hqK3rAvQC +0Be4kdn0Jxx/OSYuGKl1PI0DB1RaCkWZHNay73amkkP+HD/BqcIeLA== +-----END RSA PRIVATE KEY----- +`), } // LocalhostCert is a PEM-encoded TLS cert with SAN IPs diff --git a/lib/services/authority.go b/lib/services/authority.go index 0167895dbecb5..5328b040f2299 100644 --- a/lib/services/authority.go +++ b/lib/services/authority.go @@ -67,7 +67,7 @@ func ValidateCertAuthority(ca types.CertAuthority) (err error) { switch ca.GetType() { case types.UserCA, types.HostCA: err = checkUserOrHostCA(ca) - case types.DatabaseCA: + case types.DatabaseCA, types.DatabaseClientCA: err = checkDatabaseCA(ca) case types.OpenSSHCA: err = checkOpenSSHCA(ca) @@ -115,7 +115,7 @@ func checkDatabaseCA(cai types.CertAuthority) error { } if len(ca.Spec.ActiveKeys.TLS) == 0 { - return trace.BadParameter("DB certificate authority missing TLS key pairs") + return trace.BadParameter("%s certificate authority missing TLS key pairs", ca.GetType()) } for _, pair := range ca.GetTrustedTLSKeyPairs() { diff --git a/lib/services/local/events.go b/lib/services/local/events.go index 6bef5ea19a72c..9ed77a61dd910 100644 --- a/lib/services/local/events.go +++ b/lib/services/local/events.go @@ -59,7 +59,7 @@ func (e *EventsService) NewWatcher(ctx context.Context, watch types.Watch) (type var parser resourceParser switch kind.Kind { case types.KindCertAuthority: - parser = newCertAuthorityParser(kind.LoadSecrets) + parser = newCertAuthorityParser(kind.LoadSecrets, kind.Filter) case types.KindToken: parser = newProvisionTokenParser() case types.KindStaticTokens: @@ -298,16 +298,20 @@ func (p baseParser) match(key []byte) bool { return false } -func newCertAuthorityParser(loadSecrets bool) *certAuthorityParser { +func newCertAuthorityParser(loadSecrets bool, filter map[string]string) *certAuthorityParser { + var caFilter types.CertAuthorityFilter + caFilter.FromMap(filter) return &certAuthorityParser{ loadSecrets: loadSecrets, baseParser: newBaseParser(backend.Key(authoritiesPrefix)), + filter: caFilter, } } type certAuthorityParser struct { baseParser loadSecrets bool + filter types.CertAuthorityFilter } func (p *certAuthorityParser) parse(event backend.Event) (types.Resource, error) { @@ -332,6 +336,9 @@ func (p *certAuthorityParser) parse(event backend.Event) (types.Resource, error) if err != nil { return nil, trace.Wrap(err) } + if !p.filter.Match(ca) { + return nil, nil + } // never send private signing keys over event stream? // this might not be true setSigningKeys(ca, p.loadSecrets) diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 28e2d49d3b159..da68542d74b7c 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -73,7 +73,15 @@ type TestCAConfig struct { func NewTestCAWithConfig(config TestCAConfig) *types.CertAuthorityV2 { // privateKeys is to specify another RSA private key if len(config.PrivateKeys) == 0 { - config.PrivateKeys = [][]byte{fixtures.PEMBytes["rsa"]} + // db client CA gets its own private key to distinguish its pub key + // from the other CAs. Snowflake uses public key to verify JWT signer, + // so if we don't do this then tests verifying that the correct + // signer was used are pointless. + if config.Type == types.DatabaseClientCA { + config.PrivateKeys = [][]byte{fixtures.PEMBytes["rsa-db-client"]} + } else { + config.PrivateKeys = [][]byte{fixtures.PEMBytes["rsa"]} + } } keyBytes := config.PrivateKeys[0] rsaKey, err := ssh.ParseRawPrivateKey(keyBytes) @@ -115,7 +123,7 @@ func NewTestCAWithConfig(config TestCAConfig) *types.CertAuthorityV2 { // Match the key set to lib/auth/auth.go:newKeySet(). switch config.Type { - case types.DatabaseCA: + case types.DatabaseCA, types.DatabaseClientCA, types.SAMLIDPCA: ca.Spec.ActiveKeys.TLS = []*types.TLSKeyPair{{Cert: cert, Key: keyBytes}} case types.KindJWT, types.OIDCIdPCA: // Generating keys is CPU intensive operation. Generate JWT keys only @@ -143,8 +151,6 @@ func NewTestCAWithConfig(config TestCAConfig) *types.CertAuthorityV2 { PrivateKey: keyBytes, }}, } - case types.SAMLIDPCA: - ca.Spec.ActiveKeys.TLS = []*types.TLSKeyPair{{Cert: cert, Key: keyBytes}} default: panic("unknown CA type") } diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index a782c9fcbc971..bf5c72f308baf 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -820,8 +820,8 @@ func TestCertAuthorityWatcher(t *testing.T) { waitForEvent(t, sub, types.UserCA, "unknown", types.OpPut) // Should NOT receive any HostCA events from another cluster. - // Should NOT receive any DatabaseCA events. require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "unknown", types.HostCA))) + // Should NOT receive any DatabaseCA events. require.NoError(t, caService.UpsertCertAuthority(newCertAuthority(t, "test", types.DatabaseCA))) ensureNoEvents(t, sub) }) diff --git a/lib/srv/db/auth_test.go b/lib/srv/db/auth_test.go index 647b209016dcd..4c2ece8104dc3 100644 --- a/lib/srv/db/auth_test.go +++ b/lib/srv/db/auth_test.go @@ -18,23 +18,15 @@ package db import ( "context" - "crypto/x509" - "crypto/x509/pkix" "testing" - "time" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/db/common" - "github.com/gravitational/teleport/lib/tlsca" ) // TestAuthTokens verifies that proper IAM auth tokens are used when connecting @@ -248,91 +240,3 @@ func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *c a.Infof("Generating Azure Redis token for %v.", sessionCtx) return azureRedisToken, nil } - -func TestDBCertSigning(t *testing.T) { - authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ - Clock: clockwork.NewFakeClockAt(time.Now()), - ClusterName: "local.me", - Dir: t.TempDir(), - }) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, authServer.Close()) }) - - ctx := context.Background() - - privateKey, err := testauthority.New().GeneratePrivateKey() - require.NoError(t, err) - - csr, err := tlsca.GenerateCertificateRequestPEM(pkix.Name{ - CommonName: "localhost", - }, privateKey) - require.NoError(t, err) - - // Set rotation to init phase. New CA will be generated. - // DB service should still use old key to sign certificates. - // tctl should use new key to sign certificates. - err = authServer.AuthServer.RotateCertAuthority(ctx, auth.RotateRequest{ - Type: types.DatabaseCA, - TargetPhase: types.RotationPhaseInit, - Mode: types.RotationModeManual, - }) - require.NoError(t, err) - - dbCAs, err := authServer.AuthServer.GetCertAuthorities(ctx, types.DatabaseCA, false) - require.NoError(t, err) - require.Len(t, dbCAs, 1) - require.NotNil(t, dbCAs[0].GetActiveKeys().TLS) - require.NotNil(t, dbCAs[0].GetAdditionalTrustedKeys().TLS) - - tests := []struct { - name string - requester proto.DatabaseCertRequest_Requester - getCertFn func(dbCAs []types.CertAuthority) []byte - }{ - { - name: "sign from DB service", - requester: proto.DatabaseCertRequest_UNSPECIFIED, // default behavior - getCertFn: func(dbCAs []types.CertAuthority) []byte { - return dbCAs[0].GetActiveKeys().TLS[0].Cert - }, - }, - { - name: "sign from tctl", - requester: proto.DatabaseCertRequest_TCTL, - getCertFn: func(dbCAs []types.CertAuthority) []byte { - return dbCAs[0].GetAdditionalTrustedKeys().TLS[0].Cert - }, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - certResp, err := authServer.AuthServer.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{ - CSR: csr, - ServerName: "localhost", - TTL: proto.Duration(time.Hour), - RequesterName: tt.requester, - }) - require.NoError(t, err) - require.NotNil(t, certResp.Cert) - require.Len(t, certResp.CACerts, 2) - - dbCert, err := tlsca.ParseCertificatePEM(certResp.Cert) - require.NoError(t, err) - - certPool := x509.NewCertPool() - ok := certPool.AppendCertsFromPEM(tt.getCertFn(dbCAs)) - require.True(t, ok) - - opts := x509.VerifyOptions{ - Roots: certPool, - } - - // Verify if the generated certificate can be verified with the correct CA. - _, err = dbCert.Verify(opts) - require.NoError(t, err) - }) - } -} diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index 48b00f05940b7..847ba1c8af804 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -604,6 +604,9 @@ type authClientMock struct { // GenerateDatabaseCert generates a cert using fixtures TLS CA. func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + if req.GetRequesterName() != proto.DatabaseCertRequest_UNSPECIFIED { + return nil, trace.BadParameter("db agent should not specify requester name") + } csr, err := tlsca.ParseCertificateRequestPEM(req.CSR) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/db/common/test.go b/lib/srv/db/common/test.go index a61e712a802b3..3a5cc0b894e2f 100644 --- a/lib/srv/db/common/test.go +++ b/lib/srv/db/common/test.go @@ -122,9 +122,10 @@ func MakeTestServerTLSConfig(config TestServerConfig) (*tls.Config, error) { } resp, err := config.AuthClient.GenerateDatabaseCert(context.Background(), &proto.DatabaseCertRequest{ - CSR: csr, - ServerName: cn, - TTL: proto.Duration(time.Hour), + CSR: csr, + ServerName: cn, + TTL: proto.Duration(time.Hour), + RequesterName: proto.DatabaseCertRequest_TCTL, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/db/snowflake/test.go b/lib/srv/db/snowflake/test.go index d8efb609db077..287cdc7575248 100644 --- a/lib/srv/db/snowflake/test.go +++ b/lib/srv/db/snowflake/test.go @@ -187,14 +187,14 @@ func (s *TestServer) handleToken(w http.ResponseWriter, r *http.Request) { s.authorizationToken = "sessionToken-123" } -// verifyJWT verifies the provided JWT token. It checks if the token was signed with the Database CA +// verifyJWT verifies the provided JWT token. It checks if the token was signed with the Database Client CA // and asserts token claims. func (s *TestServer) verifyJWT(ctx context.Context, accName, loginName, token string) error { clusterName := "root.example.com" caCert, err := s.cfg.AuthClient.GetCertAuthority(ctx, types.CertAuthID{ DomainName: clusterName, - Type: types.DatabaseCA, + Type: types.DatabaseClientCA, }, false) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/sqlserver/connect_test.go b/lib/srv/db/sqlserver/connect_test.go index eeb10171f8e89..ebb7f5a2757d1 100644 --- a/lib/srv/db/sqlserver/connect_test.go +++ b/lib/srv/db/sqlserver/connect_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/client/proto" @@ -202,8 +203,11 @@ func (m *mockAuth) GetClusterName(opts ...services.MarshalOption) (types.Cluster }) } -func (m *mockAuth) GenerateDatabaseCert(context.Context, *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { - return &proto.DatabaseCertResponse{Cert: []byte(mockCA)}, nil +func (m *mockAuth) GenerateDatabaseCert(_ context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + if req.GetRequesterName() != proto.DatabaseCertRequest_UNSPECIFIED { + return nil, trace.BadParameter("db agent should not specify requester name") + } + return &proto.DatabaseCertResponse{Cert: []byte(mockCA), CACerts: [][]byte{[]byte(mockCA)}}, nil } func TestConnectorKInitClient(t *testing.T) { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 790cdeb21ed98..d1ef84e4baa5d 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -3312,6 +3312,24 @@ func TestAuthExport(t *testing.T) { expectedStatus: http.StatusOK, assertBody: validateTLSCertificatePEMFunc, }, + { + name: "db-der", + authType: "db-der", + expectedStatus: http.StatusOK, + assertBody: validateTLSCertificateDERFunc, + }, + { + name: "db-client", + authType: "db-client", + expectedStatus: http.StatusOK, + assertBody: validateTLSCertificatePEMFunc, + }, + { + name: "db-client-der", + authType: "db-client-der", + expectedStatus: http.StatusOK, + assertBody: validateTLSCertificateDERFunc, + }, { name: "tls", authType: "tls", diff --git a/lib/web/sign.go b/lib/web/sign.go index 9a0a122ab4574..84bd9d341af31 100644 --- a/lib/web/sign.go +++ b/lib/web/sign.go @@ -76,7 +76,7 @@ func (h *Handler) signDatabaseCertificate(w http.ResponseWriter, r *http.Request IdentityFileWriter: virtualFS, TTL: req.TTL, } - filesWritten, err := db.GenerateDatabaseCertificates(r.Context(), dbCertReq) + filesWritten, err := db.GenerateDatabaseServerCertificates(r.Context(), dbCertReq) if err != nil { return nil, trace.Wrap(err) } diff --git a/tool/tctl/common/auth_command.go b/tool/tctl/common/auth_command.go index 3365d59806875..2d8bb302338a3 100644 --- a/tool/tctl/common/auth_command.go +++ b/tool/tctl/common/auth_command.go @@ -20,7 +20,6 @@ import ( "io" "net/url" "os" - "path/filepath" "strings" "text/template" "time" @@ -186,6 +185,8 @@ var allowedCertificateTypes = []string{ "windows", "db", "db-der", + "db-client", + "db-client-der", "openssh", "saml-idp", } @@ -195,6 +196,7 @@ var allowedCertificateTypes = []string{ var allowedCRLCertificateTypes = []string{ string(types.HostCA), string(types.DatabaseCA), + string(types.DatabaseClientCA), string(types.UserCA), } @@ -340,20 +342,11 @@ func (a *AuthCommand) generateSnowflakeKey(ctx context.Context, clusterAPI auth. return trace.Wrap(err) } - cn, err := clusterAPI.GetClusterName() - if err != nil { - return trace.Wrap(err) - } - certAuthID := types.CertAuthID{ - Type: types.DatabaseCA, - DomainName: cn.GetClusterName(), - } - databaseCA, err := clusterAPI.GetCertAuthority(ctx, certAuthID, false) + dbClientCA, err := getDatabaseClientCA(ctx, clusterAPI) if err != nil { return trace.Wrap(err) } - - key.TrustedCerts = []auth.TrustedCerts{{TLSCertificates: services.GetTLSCerts(databaseCA)}} + key.TrustedCerts = []auth.TrustedCerts{{TLSCertificates: services.GetTLSCerts(dbClientCA)}} filesWritten, err := identityfile.Write(ctx, identityfile.WriteConfig{ OutputPath: a.output, @@ -518,7 +511,7 @@ func (a *AuthCommand) generateDatabaseKeysForKey(ctx context.Context, clusterAPI Key: key, Password: a.password, } - filesWritten, err := db.GenerateDatabaseCertificates(ctx, dbCertReq) + filesWritten, err := db.GenerateDatabaseServerCertificates(ctx, dbCertReq) if err != nil { return trace.Wrap(err) } @@ -554,9 +547,21 @@ func writeHelperMessageDBmTLS(writer io.Writer, filesWritten []string, output st "password": password, "output": output, } - if outputFormat == defaults.ProtocolOracle { + switch outputFormat { + case defaults.ProtocolCockroachDB: + tplVars["clientCAPath"] = "/path/to/client-ca.key" + case defaults.ProtocolOracle: tplVars["manualOrapkiFlow"] = len(filesWritten) != 1 - tplVars["walletDir"] = filepath.Dir(output) + // use a generic example path since they will have to copy the files + // to the oracle server. + tplVars["walletDir"] = "/path/to/oracleWalletDir" + var caCertPaths []string + for _, f := range filesWritten { + if strings.HasSuffix(f, ".crt") { + caCertPaths = append(caCertPaths, f) + } + } + tplVars["caCertPaths"] = caCertPaths } return trace.Wrap(tpl.Execute(writer, tplVars)) @@ -595,12 +600,31 @@ net: `)) cockroachAuthSignTpl = template.Must(template.New("").Parse(`Database credentials have been written to {{.files}}. -To enable mutual TLS on your CockroachDB server, point it to the certs -directory using --certs-dir flag: +To enable mutual TLS on your CockroachDB server, generate a client CA and client +certs for your node: + +# --overwrite flag prepends the client CA cert to {{.output}}/ca-client.crt +cockroach cert create-client-ca \ + --certs-dir={{.output}} \ + --ca-key={{.clientCAPath}} \ + --overwrite + +cockroach cert create-client node \ + --certs-dir={{.output}} + --ca-key={{.clientCAPath}} + +Then point cockroach to the certs directory using the --certs-dir flag: cockroach start \ --certs-dir={{.output}} \ # other flags... + +For more information about creating a client CA and issuing certs, see: +https://www.cockroachlabs.com/docs/stable/cockroach-cert + +Teleport uses a split CA architecture for database access. +For more information about using a split CA with CockroachDB, see: +https://www.cockroachlabs.com/docs/stable/authentication#using-split-ca-certificates `)) redisAuthSignTpl = template.Must(template.New("").Parse(`Database credentials have been written to {{.files}}. @@ -611,6 +635,9 @@ tls-ca-cert-file /path/to/{{.output}}.cas tls-cert-file /path/to/{{.output}}.crt tls-key-file /path/to/{{.output}}.key tls-protocols "TLSv1.2 TLSv1.3" + +For information on enabling Redis Cluster bus communication TLS, see: +https://goteleport.com/docs/database-access/guides/redis-cluster `)) snowflakeAuthSignTpl = template.Must(template.New("").Parse(`Database credentials have been written to {{.files}}. @@ -661,16 +688,21 @@ client_encryption_options: {{if .manualOrapkiFlow}} Orapki binary was not found. Please create oracle wallet file manually by running the following commands on the Oracle server: -orapki wallet create -wallet {{.walletDir}} -auto_login_only -orapki wallet import_pkcs12 -wallet {{.walletDir}} -auto_login_only -pkcs12file {{.output}}.p12 -pkcs12pwd {{.password}} -orapki wallet add -wallet {{.walletDir}} -trusted_cert -auto_login_only -cert {{.output}}.crt +WALLET_DIR="{{.walletDir}}" +orapki wallet create -wallet "$WALLET_DIR" -auto_login_only +orapki wallet import_pkcs12 -wallet "$WALLET_DIR" -auto_login_only -pkcs12file {{.output}}.p12 -pkcs12pwd {{.password}} +{{- range $certPath := .caCertPaths }} +orapki wallet add -wallet "$WALLET_DIR" -trusted_cert -auto_login_only -cert {{ $certPath }} +{{- end}} + +If copying these files to your Oracle server, ensure the cert file permissions are readable by the "oracle" user. {{else}} Oracle wallet stored in {{.output}} directory created with Oracle Orapki. {{end}} To enable mutual TLS on your Oracle server, add the following settings to Oracle sqlnet.ora configuration file: -WALLET_LOCATION = (SOURCE = (METHOD = FILE)(METHOD_DATA = (DIRECTORY = /path/to/oracleWalletDir))) +WALLET_LOCATION = (SOURCE = (METHOD = FILE)(METHOD_DATA = (DIRECTORY = {{.walletDir}}))) SSL_CLIENT_AUTHENTICATION = TRUE SQLNET.AUTHENTICATION_SERVICES = (TCPS) @@ -684,7 +716,7 @@ LISTENER = ) ) -WALLET_LOCATION = (SOURCE = (METHOD = FILE)(METHOD_DATA = (DIRECTORY = /path/to/oracleWalletDir))) +WALLET_LOCATION = (SOURCE = (METHOD = FILE)(METHOD_DATA = (DIRECTORY = {{.walletDir}}))) SSL_CLIENT_AUTHENTICATION = TRUE `)) @@ -1081,3 +1113,28 @@ func getCertAuthTypes() []string { t = append(t, string(types.CertAuthTypeAll)) return t } + +// TODO(gavin): DELETE IN 16.0.0 +func getDatabaseClientCA(ctx context.Context, clusterAPI auth.ClientI) (types.CertAuthority, error) { + cn, err := clusterAPI.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + dbClientCA, err := clusterAPI.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseClientCA, + DomainName: cn.GetClusterName(), + }, false) + if err == nil { + return dbClientCA, nil + } + if !types.IsUnsupportedAuthorityErr(err) { + return nil, trace.Wrap(err) + } + + // fallback to DatabaseCA if DatabaseClientCA isn't supported by backend. + dbServerCA, err := clusterAPI.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.DatabaseCA, + DomainName: cn.GetClusterName(), + }, false) + return dbServerCA, trace.Wrap(err) +} diff --git a/tool/tctl/common/auth_command_test.go b/tool/tctl/common/auth_command_test.go index 5df8ef3d417ef..a29b1e4e266e4 100644 --- a/tool/tctl/common/auth_command_test.go +++ b/tool/tctl/common/auth_command_test.go @@ -412,6 +412,8 @@ type mockClient struct { appSession types.WebSession networkConfig types.ClusterNetworkingConfig crl []byte + + unsupportedCATypes []types.CertAuthType } func (c *mockClient) GetClusterName(...services.MarshalOption) (types.ClusterName, error) { @@ -431,15 +433,25 @@ func (c *mockClient) GenerateUserCerts(ctx context.Context, userCertsReq proto.U } func (c *mockClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadSigningKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) { + for _, unsupported := range c.unsupportedCATypes { + if unsupported == id.Type { + return nil, trace.BadParameter("%q authority type is not supported", unsupported) + } + } for _, v := range c.cas { if v.GetType() == id.Type && v.GetClusterName() == id.DomainName { return v, nil } } - return nil, trace.NotFound("not found") + return nil, trace.NotFound("%q CA not found", id) } -func (c *mockClient) GetCertAuthorities(context.Context, types.CertAuthType, bool, ...services.MarshalOption) ([]types.CertAuthority, error) { +func (c *mockClient) GetCertAuthorities(_ context.Context, caType types.CertAuthType, _ bool, opts ...services.MarshalOption) ([]types.CertAuthority, error) { + for _, unsupported := range c.unsupportedCATypes { + if unsupported == caType { + return nil, trace.BadParameter("%q authority type is not supported", unsupported) + } + } return c.cas, nil } @@ -456,6 +468,9 @@ func (c *mockClient) GetKubernetesServers(context.Context) ([]types.KubeServer, } func (c *mockClient) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) { + if req.GetRequesterName() != proto.DatabaseCertRequest_TCTL { + return nil, trace.BadParameter("need tctl requester name in tctl database cert request") + } c.dbCertsReq = req return c.dbCerts, nil } @@ -588,14 +603,24 @@ func TestGenerateDatabaseKeys(t *testing.T) { require.NoError(t, err) certBytes := []byte("TLS cert") - caBytes := []byte("CA cert") + dbClientCABytes := []byte("DB Client CA cert") + dbServerCABytes := []byte("DB Server CA cert") + dbCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseCA, + ClusterName: "example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: dbServerCABytes}}, + }, + }) + require.NoError(t, err) authClient := &mockClient{ clusterName: clusterName, dbCerts: &proto.DatabaseCertResponse{ Cert: certBytes, - CACerts: [][]byte{caBytes}, + CACerts: [][]byte{dbClientCABytes}, }, + cas: []types.CertAuthority{dbCA}, } key, err := client.GenerateRSAKey() @@ -609,13 +634,9 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile string outSubject pkix.Name outServerNames []string - outKeyFile string - outCertFile string - outCAFile string - outKey []byte - outCert []byte - outCA []byte - genKeyErrMsg string + // maps filename -> file contents + wantFiles map[string][]byte + genKeyErrMsg string }{ { name: "database certificate", @@ -625,12 +646,11 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile: "db", outSubject: pkix.Name{CommonName: "postgres.example.com"}, outServerNames: []string{"postgres.example.com"}, - outKeyFile: "db.key", - outCertFile: "db.crt", - outCAFile: "db.cas", - outKey: key.PrivateKeyPEM(), - outCert: certBytes, - outCA: caBytes, + wantFiles: map[string][]byte{ + "db.key": key.PrivateKeyPEM(), + "db.crt": certBytes, + "db.cas": dbClientCABytes, + }, }, { name: "database certificate multiple SANs", @@ -640,12 +660,11 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile: "db", outSubject: pkix.Name{CommonName: "mysql.external.net"}, outServerNames: []string{"mysql.external.net", "mysql.internal.net", "192.168.1.1"}, - outKeyFile: "db.key", - outCertFile: "db.crt", - outCAFile: "db.cas", - outKey: key.PrivateKeyPEM(), - outCert: certBytes, - outCA: caBytes, + wantFiles: map[string][]byte{ + "db.key": key.PrivateKeyPEM(), + "db.crt": certBytes, + "db.cas": dbClientCABytes, + }, }, { name: "mongodb certificate", @@ -655,10 +674,10 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile: "mongo", outSubject: pkix.Name{CommonName: "mongo.example.com", Organization: []string{"example.com"}}, outServerNames: []string{"mongo.example.com"}, - outCertFile: "mongo.crt", - outCAFile: "mongo.cas", - outCert: append(certBytes, key.PrivateKeyPEM()...), - outCA: caBytes, + wantFiles: map[string][]byte{ + "mongo.crt": append(certBytes, key.PrivateKeyPEM()...), + "mongo.cas": dbClientCABytes, + }, }, { name: "cockroachdb certificate", @@ -667,12 +686,12 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutDir: t.TempDir(), outSubject: pkix.Name{CommonName: "node"}, outServerNames: []string{"node", "localhost", "roach1"}, // "node" principal should always be added - outKeyFile: "node.key", - outCertFile: "node.crt", - outCAFile: "ca.crt", - outKey: key.PrivateKeyPEM(), - outCert: certBytes, - outCA: caBytes, + wantFiles: map[string][]byte{ + "node.key": key.PrivateKeyPEM(), + "node.crt": certBytes, + "ca.crt": dbServerCABytes, + "ca-client.crt": dbClientCABytes, + }, }, { name: "redis certificate", @@ -682,12 +701,11 @@ func TestGenerateDatabaseKeys(t *testing.T) { inOutFile: "db", outSubject: pkix.Name{CommonName: "localhost"}, outServerNames: []string{"localhost", "redis1", "172.0.0.1"}, - outKeyFile: "db.key", - outCertFile: "db.crt", - outCAFile: "db.cas", - outKey: key.PrivateKeyPEM(), - outCert: certBytes, - outCA: caBytes, + wantFiles: map[string][]byte{ + "db.key": key.PrivateKeyPEM(), + "db.crt": certBytes, + "db.cas": dbClientCABytes, + }, }, { name: "missing host", @@ -725,22 +743,10 @@ func TestGenerateDatabaseKeys(t *testing.T) { require.Equal(t, test.outServerNames, authClient.dbCertsReq.ServerNames) require.Equal(t, test.outServerNames[0], authClient.dbCertsReq.ServerName) - if len(test.outKey) > 0 { - keyBytes, err := os.ReadFile(filepath.Join(test.inOutDir, test.outKeyFile)) + for wantFilename, wantContents := range test.wantFiles { + contents, err := os.ReadFile(filepath.Join(test.inOutDir, wantFilename)) require.NoError(t, err) - require.Equal(t, test.outKey, keyBytes, "keys match") - } - - if len(test.outCert) > 0 { - certBytes, err := os.ReadFile(filepath.Join(test.inOutDir, test.outCertFile)) - require.NoError(t, err) - require.Equal(t, test.outCert, certBytes, "certificates match") - } - - if len(test.outCA) > 0 { - caBytes, err := os.ReadFile(filepath.Join(test.inOutDir, test.outCAFile)) - require.NoError(t, err) - require.Equal(t, test.outCA, caBytes, "CA certificates match") + require.Equal(t, wantContents, contents, "contents of %s match", wantFilename) } }) } @@ -988,49 +994,83 @@ func TestGenerateAndSignKeys(t *testing.T) { _, cert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: "example.com"}, nil, time.Minute) require.NoError(t, err) - firstCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + dbCARoot, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ Type: types.DatabaseCA, ClusterName: "example.com", ActiveKeys: types.CAKeySet{ - SSH: []*types.SSHKeyPair{{PublicKey: []byte("SSH CA cert")}}, TLS: []*types.TLSKeyPair{{Cert: cert}}, }, }) require.NoError(t, err) - secondCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + dbCALeaf, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ Type: types.DatabaseCA, ClusterName: "leaf.example.com", ActiveKeys: types.CAKeySet{ - SSH: []*types.SSHKeyPair{{PublicKey: []byte("SSH CA cert")}}, TLS: []*types.TLSKeyPair{{Cert: cert}}, }, }) require.NoError(t, err) - certBytes := []byte("TLS cert") - caBytes := []byte("CA cert") - authClient := &mockClient{ - clusterName: clusterName, - dbCerts: &proto.DatabaseCertResponse{ - Cert: certBytes, - CACerts: [][]byte{caBytes}, + dbClientCARoot, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseClientCA, + ClusterName: "example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: cert}}, }, - cas: []types.CertAuthority{firstCA, secondCA}, - } + }) + require.NoError(t, err) + + dbClientCALeaf, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseClientCA, + ClusterName: "leaf.example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: cert}}, + }, + }) + require.NoError(t, err) + + allCAs := []types.CertAuthority{dbCARoot, dbCALeaf, dbClientCARoot, dbClientCALeaf} + + certBytes := []byte("TLS cert") + caBytes := []byte("CA cert") tests := []struct { - name string - inFormat identityfile.Format - inHost string - inOutDir string - inOutFile string + name string + inFormat identityfile.Format + inHost string + inOutDir string + inOutFile string + authClient *mockClient }{ { name: "snowflake format", inFormat: identityfile.FormatSnowflake, inOutDir: t.TempDir(), - inOutFile: "server", + inOutFile: "ca", + authClient: &mockClient{ + clusterName: clusterName, + dbCerts: &proto.DatabaseCertResponse{ + Cert: certBytes, + CACerts: [][]byte{caBytes}, + }, + cas: allCAs, + }, + }, + { + name: "snowflake format db client ca not supported upstream", + inFormat: identityfile.FormatSnowflake, + inOutDir: t.TempDir(), + inOutFile: "ca", + authClient: &mockClient{ + clusterName: clusterName, + dbCerts: &proto.DatabaseCertResponse{ + Cert: certBytes, + CACerts: [][]byte{caBytes}, + }, + cas: []types.CertAuthority{dbCARoot, dbCALeaf}, + unsupportedCATypes: []types.CertAuthType{types.DatabaseClientCA}, + }, }, { name: "db format", @@ -1038,6 +1078,14 @@ func TestGenerateAndSignKeys(t *testing.T) { inOutDir: t.TempDir(), inOutFile: "server", inHost: "localhost", + authClient: &mockClient{ + clusterName: clusterName, + dbCerts: &proto.DatabaseCertResponse{ + Cert: certBytes, + CACerts: [][]byte{caBytes}, + }, + cas: allCAs, + }, }, } @@ -1051,7 +1099,7 @@ func TestGenerateAndSignKeys(t *testing.T) { genTTL: time.Hour, } - err = ac.GenerateAndSignKeys(context.Background(), authClient) + err = ac.GenerateAndSignKeys(context.Background(), test.authClient) require.NoError(t, err) }) } @@ -1074,3 +1122,62 @@ func TestGenerateCRLForCA(t *testing.T) { require.Error(t, ac.GenerateCRLForCA(ctx, authClient)) }) } + +func TestGetDatabaseClientCA(t *testing.T) { + _, cert, err := tlsca.GenerateSelfSignedCA(pkix.Name{CommonName: "example.com"}, nil, time.Minute) + require.NoError(t, err) + + dbClientCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseClientCA, + ClusterName: "example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: cert}}, + }, + }) + require.NoError(t, err) + + dbServerCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.DatabaseCA, + ClusterName: "example.com", + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{Cert: cert}}, + }, + }) + require.NoError(t, err) + + clusterName, err := services.NewClusterNameWithRandomID( + types.ClusterNameSpecV2{ + ClusterName: "example.com", + }) + require.NoError(t, err) + tests := []struct { + desc string + authClient *mockClient + wantCA types.CertAuthority + }{ + { + desc: "db client ca exists", + authClient: &mockClient{ + clusterName: clusterName, + cas: []types.CertAuthority{dbClientCA, dbServerCA}, + }, + wantCA: dbClientCA, + }, + { + desc: "db client ca not supported", + authClient: &mockClient{ + clusterName: clusterName, + unsupportedCATypes: []types.CertAuthType{types.DatabaseClientCA}, + cas: []types.CertAuthority{dbServerCA}, + }, + wantCA: dbServerCA, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ca, err := getDatabaseClientCA(context.Background(), test.authClient) + require.NoError(t, err) + require.Equal(t, test.wantCA, ca) + }) + } +}