diff --git a/integration/proxy/teleterm_test.go b/integration/proxy/teleterm_test.go index dd735bfa82110..93bc1046ead68 100644 --- a/integration/proxy/teleterm_test.go +++ b/integration/proxy/teleterm_test.go @@ -35,6 +35,7 @@ import ( "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/daemon" + "github.com/gravitational/teleport/lib/tlsca" ) // testTeletermGatewaysCertRenewal is run from within TestALPNSNIProxyDatabaseAccess to amortize the @@ -120,17 +121,23 @@ func testGatewayCertRenewal(t *testing.T, pack *dbhelpers.DatabasePack, albAddr }) // Here the test setup ends and actual test code starts. - gateway, err := daemonService.CreateGateway(context.Background(), daemon.CreateGatewayParams{ TargetURI: databaseURI.String(), TargetUser: "root", }) require.NoError(t, err, trace.DebugReport(err)) + routeToDatabase := tlsca.RouteToDatabase{ + ServiceName: databaseURI.GetDbName(), + Protocol: "mysql", + Username: "root", + } + // Open a new connection. client, err := mysql.MakeTestClientWithoutTLS( net.JoinHostPort(gateway.LocalAddress(), gateway.LocalPort()), - gateway.RouteToDatabase()) + routeToDatabase) + require.NoError(t, err) // Execute a query. @@ -160,7 +167,7 @@ func testGatewayCertRenewal(t *testing.T, pack *dbhelpers.DatabasePack, albAddr // will let the connection through. client, err = mysql.MakeTestClientWithoutTLS( net.JoinHostPort(gateway.LocalAddress(), gateway.LocalPort()), - gateway.RouteToDatabase()) + routeToDatabase) require.NoError(t, err) // Execute a query. diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 813ae59e7c0a8..facfd0cc5eb61 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -369,26 +369,7 @@ func (l *LocalProxy) CheckDBCerts(dbRoute tlsca.RouteToDatabase) error { return trace.Wrap(err) } - return trace.Wrap(CheckCertSubject(cert, dbRoute)) -} - -// CheckCertSubject checks if the route to the database from the cert matches the provided route in -// terms of username and database (if present). -func CheckCertSubject(cert *x509.Certificate, dbRoute tlsca.RouteToDatabase) error { - identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) - if err != nil { - return trace.Wrap(err) - } - if dbRoute.Username != "" && dbRoute.Username != identity.RouteToDatabase.Username { - return trace.Errorf("certificate subject is for user %s, but need %s", - identity.RouteToDatabase.Username, dbRoute.Username) - } - if dbRoute.Database != "" && dbRoute.Database != identity.RouteToDatabase.Database { - return trace.Errorf("certificate subject is for database name %s, but need %s", - identity.RouteToDatabase.Database, dbRoute.Database) - } - - return nil + return trace.Wrap(dbRoute.CheckCertSubject(cert)) } // SetCerts sets the local proxy's configured TLS certificates. diff --git a/lib/teleterm/clusters/cluster_databases.go b/lib/teleterm/clusters/cluster_databases.go index 7283812a7b8e3..f2448fc8f97ec 100644 --- a/lib/teleterm/clusters/cluster_databases.go +++ b/lib/teleterm/clusters/cluster_databases.go @@ -18,6 +18,7 @@ package clusters import ( "context" + "crypto/tls" "github.com/gravitational/trace" @@ -25,6 +26,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keys" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" @@ -33,6 +35,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" ) // Database describes database @@ -150,8 +153,8 @@ func (c *Cluster) GetDatabases(ctx context.Context, r *api.GetDatabasesRequest) return response, nil } -// ReissueDBCerts issues new certificates for specific DB access and saves them to disk. -func (c *Cluster) ReissueDBCerts(ctx context.Context, routeToDatabase tlsca.RouteToDatabase) error { +// reissueDBCerts issues new certificates for specific DB access and saves them to disk. +func (c *Cluster) reissueDBCerts(ctx context.Context, routeToDatabase tlsca.RouteToDatabase) error { // When generating certificate for MongoDB access, database username must // be encoded into it. This is required to be able to tell which database // user to authenticate the connection as. @@ -197,6 +200,34 @@ func (c *Cluster) ReissueDBCerts(ctx context.Context, routeToDatabase tlsca.Rout return nil } +func (c *Cluster) loadDBCert(routeToDatabase tlsca.RouteToDatabase) (tls.Certificate, error) { + tlsCert, err := keys.LoadX509KeyPair( + c.status.DatabaseCertPathForCluster(c.clusterClient.SiteName, routeToDatabase.ServiceName), + c.status.KeyPath(), + ) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + cert, err := utils.TLSCertLeaf(tlsCert) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + if err := routeToDatabase.CheckCertSubject(cert); err != nil { + return tls.Certificate{}, trace.Wrap(err, "database certificate check failed, try restarting the database connection") + } + + return tlsCert, nil +} + +func (c *Cluster) reissueAndLoadDBCert(ctx context.Context, routeToDatabase tlsca.RouteToDatabase) (tls.Certificate, error) { + if err := c.reissueDBCerts(ctx, routeToDatabase); err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + tlsCert, err := c.loadDBCert(routeToDatabase) + return tlsCert, trace.Wrap(err) +} // GetAllowedDatabaseUsers returns allowed users for the given database based on the role set. func (c *Cluster) GetAllowedDatabaseUsers(ctx context.Context, dbURI string) ([]string, error) { diff --git a/lib/teleterm/clusters/cluster_gateways.go b/lib/teleterm/clusters/cluster_gateways.go index d901b779cc49a..0be1ba25962e0 100644 --- a/lib/teleterm/clusters/cluster_gateways.go +++ b/lib/teleterm/clusters/cluster_gateways.go @@ -18,13 +18,25 @@ package clusters import ( "context" + "crypto/tls" "github.com/gravitational/trace" + "github.com/gravitational/teleport/lib/client/db/dbcmd" "github.com/gravitational/teleport/lib/teleterm/gateway" "github.com/gravitational/teleport/lib/tlsca" ) +// ReissueCertFunc is a callback function for Cluster to actually do the issue +// of user certificates with TeleportClient. +type ReissueCertFunc func(context.Context) error + +// GatewayCertReissuer defines an interface of a helper that manages the +// process of reissuing certificates. +type GatewayCertReissuer interface { + ReissueCert(ctx context.Context, gateway *gateway.Gateway, doReissueCert ReissueCertFunc) error +} + type CreateGatewayParams struct { // TargetURI is the cluster resource URI TargetURI string @@ -36,12 +48,15 @@ type CreateGatewayParams struct { // LocalPort is the gateway local port LocalPort string CLICommandProvider gateway.CLICommandProvider - TCPPortAllocator gateway.TCPPortAllocator - OnExpiredCert gateway.OnExpiredCertFunc + CertReissuer GatewayCertReissuer } // CreateGateway creates a gateway func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) (*gateway.Gateway, error) { + if params.CLICommandProvider == nil { + params.CLICommandProvider = NewDbcmdCLICommandProvider(c, dbcmd.SystemExecer{}) + } + db, err := c.GetDatabase(ctx, params.TargetURI) if err != nil { return nil, trace.Wrap(err) @@ -53,7 +68,8 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) Username: params.TargetUser, } - if err := c.ReissueDBCerts(ctx, routeToDatabase); err != nil { + tlsCert, err := c.reissueAndLoadDBCert(ctx, routeToDatabase) + if err != nil { return nil, trace.Wrap(err) } @@ -63,15 +79,14 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) TargetUser: params.TargetUser, TargetName: db.GetName(), TargetSubresourceName: params.TargetSubresourceName, + Cert: tlsCert, Protocol: db.GetProtocol(), - KeyPath: c.status.KeyPath(), - CertPath: c.status.DatabaseCertPathForCluster(c.clusterClient.SiteName, db.GetName()), Insecure: c.clusterClient.InsecureSkipVerify, WebProxyAddr: c.clusterClient.WebProxyAddr, Log: c.Log, CLICommandProvider: params.CLICommandProvider, - TCPPortAllocator: params.TCPPortAllocator, - OnExpiredCert: params.OnExpiredCert, + TCPPortAllocator: gateway.NetTCPPortAllocator{}, + ReissueCert: c.makeGatewayReissueDBCertFunc(params.CertReissuer, routeToDatabase), Clock: c.clock, TLSRoutingConnUpgradeRequired: c.clusterClient.TLSRoutingConnUpgradeRequired, RootClusterCACertPoolFunc: c.clusterClient.RootClusterCACertPool, @@ -82,3 +97,20 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) return gw, nil } + +// makeGatewayReissueDBCertFunc creates a gateway.ReissueCertFunc that reissues +// the database certificate using provided GatewayCertReissuer, then loads the +// certificate. +func (c *Cluster) makeGatewayReissueDBCertFunc(certReissuer GatewayCertReissuer, routeToDatabase tlsca.RouteToDatabase) gateway.ReissueCertFunc { + return func(ctx context.Context, gateway *gateway.Gateway) (tls.Certificate, error) { + err := certReissuer.ReissueCert(ctx, gateway, func(ctx context.Context) error { + return trace.Wrap(c.reissueDBCerts(ctx, routeToDatabase)) + }) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + + tlsCert, err := c.loadDBCert(routeToDatabase) + return tlsCert, trace.Wrap(err) + } +} diff --git a/lib/teleterm/clusters/dbcmd_cli_command_provider.go b/lib/teleterm/clusters/dbcmd_cli_command_provider.go index c2a4a92b6f1fb..90843b4dea410 100644 --- a/lib/teleterm/clusters/dbcmd_cli_command_provider.go +++ b/lib/teleterm/clusters/dbcmd_cli_command_provider.go @@ -27,27 +27,18 @@ import ( // DbcmdCLICommandProvider provides CLI commands for database gateways. It needs Storage to read // fresh profile state from the disk. type DbcmdCLICommandProvider struct { - storage StorageByResourceURI + cluster *Cluster execer dbcmd.Execer } -type StorageByResourceURI interface { - GetByResourceURI(string) (*Cluster, error) -} - -func NewDbcmdCLICommandProvider(storage StorageByResourceURI, execer dbcmd.Execer) DbcmdCLICommandProvider { +func NewDbcmdCLICommandProvider(cluster *Cluster, execer dbcmd.Execer) DbcmdCLICommandProvider { return DbcmdCLICommandProvider{ - storage: storage, + cluster: cluster, execer: execer, } } func (d DbcmdCLICommandProvider) GetCommand(gateway *gateway.Gateway) (*exec.Cmd, error) { - cluster, err := d.storage.GetByResourceURI(gateway.TargetURI()) - if err != nil { - return nil, trace.Wrap(err) - } - routeToDb := tlsca.RouteToDatabase{ ServiceName: gateway.TargetName(), Protocol: gateway.Protocol(), @@ -55,14 +46,14 @@ func (d DbcmdCLICommandProvider) GetCommand(gateway *gateway.Gateway) (*exec.Cmd Database: gateway.TargetSubresourceName(), } - cmd, err := dbcmd.NewCmdBuilder(cluster.clusterClient, &cluster.status, routeToDb, + cmd, err := dbcmd.NewCmdBuilder(d.cluster.clusterClient, &d.cluster.status, routeToDb, // TODO(ravicious): Pass the root cluster name here. cluster.Name returns leaf name for leaf // clusters. // // At this point it doesn't matter though because this argument is used only for // generating correct CA paths. We use dbcmd.WithNoTLS here which means that the CA paths aren't // included in the returned CLI command. - cluster.Name, + d.cluster.Name, dbcmd.WithLogger(gateway.Log()), dbcmd.WithLocalProxy(gateway.LocalAddress(), gateway.LocalPortInt(), ""), dbcmd.WithNoTLS(), diff --git a/lib/teleterm/clusters/dbcmd_cli_command_provider_test.go b/lib/teleterm/clusters/dbcmd_cli_command_provider_test.go index 3c9b6926e18a0..3e806eb82a489 100644 --- a/lib/teleterm/clusters/dbcmd_cli_command_provider_test.go +++ b/lib/teleterm/clusters/dbcmd_cli_command_provider_test.go @@ -17,10 +17,8 @@ package clusters import ( "os/exec" "path/filepath" - "strings" "testing" - "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/client" @@ -47,20 +45,6 @@ func (f fakeExec) Command(name string, arg ...string) *exec.Cmd { return cmd } -type fakeStorage struct { - clusters []*Cluster -} - -func (f fakeStorage) GetByResourceURI(resourceURI string) (*Cluster, error) { - for _, cluster := range f.clusters { - if strings.HasPrefix(resourceURI, cluster.URI.String()) { - return cluster, nil - } - } - - return nil, trace.NotFound("not found") -} - func TestDbcmdCLICommandProviderGetCommand(t *testing.T) { testCases := []struct { name string @@ -87,12 +71,9 @@ func TestDbcmdCLICommandProviderGetCommand(t *testing.T) { }, }, } - fakeStorage := fakeStorage{ - clusters: []*Cluster{&cluster}, - } - dbcmdCLICommandProvider := NewDbcmdCLICommandProvider(fakeStorage, fakeExec{}) + dbcmdCLICommandProvider := NewDbcmdCLICommandProvider(&cluster, fakeExec{}) - keyPairPaths := gatewaytest.MustGenAndSaveCert(t, tlsca.Identity{ + tlsCert := gatewaytest.MustGenDBCert(t, tlsca.Identity{ Username: "alice", Groups: []string{"test-group"}, RouteToDatabase: tlsca.RouteToDatabase{ @@ -112,8 +93,7 @@ func TestDbcmdCLICommandProviderGetCommand(t *testing.T) { LocalAddress: "localhost", WebProxyAddr: "localhost:1337", Insecure: true, - CertPath: keyPairPaths.CertPath, - KeyPath: keyPairPaths.KeyPath, + Cert: tlsCert, CLICommandProvider: dbcmdCLICommandProvider, TCPPortAllocator: gateway.NetTCPPortAllocator{}, }, @@ -130,43 +110,3 @@ func TestDbcmdCLICommandProviderGetCommand(t *testing.T) { }) } } - -func TestDbcmdCLICommandProviderGetCommand_ReturnsErrorIfClusterIsNotFound(t *testing.T) { - fakeStorage := fakeStorage{ - clusters: []*Cluster{}, - } - dbcmdCLICommandProvider := NewDbcmdCLICommandProvider(fakeStorage, fakeExec{}) - - keyPairPaths := gatewaytest.MustGenAndSaveCert(t, tlsca.Identity{ - Username: "alice", - Groups: []string{"test-group"}, - RouteToDatabase: tlsca.RouteToDatabase{ - ServiceName: "foo", - Protocol: defaults.ProtocolPostgres, - Username: "alice", - }, - }) - - gateway, err := gateway.New( - gateway.Config{ - TargetURI: uri.NewClusterURI("quux").AppendDB("foo").String(), - TargetName: "foo", - TargetUser: "alice", - TargetSubresourceName: "", - Protocol: defaults.ProtocolPostgres, - LocalAddress: "localhost", - WebProxyAddr: "localhost:1337", - Insecure: true, - CertPath: keyPairPaths.CertPath, - KeyPath: keyPairPaths.KeyPath, - CLICommandProvider: dbcmdCLICommandProvider, - TCPPortAllocator: gateway.NetTCPPortAllocator{}, - }, - ) - require.NoError(t, err) - t.Cleanup(func() { gateway.Close() }) - - _, err = dbcmdCLICommandProvider.GetCommand(gateway) - require.Error(t, err) - require.True(t, trace.IsNotFound(err), "err is not trace.NotFound") -} diff --git a/lib/teleterm/daemon/config.go b/lib/teleterm/daemon/config.go index c66c960cc38b0..7ae67d0924b84 100644 --- a/lib/teleterm/daemon/config.go +++ b/lib/teleterm/daemon/config.go @@ -22,7 +22,6 @@ import ( "google.golang.org/grpc" "github.com/gravitational/teleport/lib/teleterm/clusters" - "github.com/gravitational/teleport/lib/teleterm/gateway" ) // Config is the cluster service config @@ -30,9 +29,8 @@ type Config struct { // Storage is a storage service that reads/writes to tsh profiles Storage *clusters.Storage // Log is a component logger - Log *logrus.Entry - GatewayCreator GatewayCreator - TCPPortAllocator gateway.TCPPortAllocator + Log *logrus.Entry + GatewayCreator GatewayCreator // CreateTshdEventsClientCredsFunc lazily creates creds for the tshd events server ran by the // Electron app. This is to ensure that the server public key is written to the disk under the // expected location by the time we get around to creating the client. @@ -54,10 +52,6 @@ func (c *Config) CheckAndSetDefaults() error { c.GatewayCreator = clusters.NewGatewayCreator(c.Storage) } - if c.TCPPortAllocator == nil { - c.TCPPortAllocator = gateway.NetTCPPortAllocator{} - } - if c.Log == nil { c.Log = logrus.NewEntry(logrus.StandardLogger()).WithField(trace.Component, "daemon") } diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index ffdfe2b532ee1..52f5025770ebc 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -24,7 +24,6 @@ import ( "github.com/gravitational/teleport/api/types" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" - "github.com/gravitational/teleport/lib/client/db/dbcmd" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/gateway" usagereporter "github.com/gravitational/teleport/lib/usagereporter/daemon" @@ -185,15 +184,12 @@ type GatewayCreator interface { // createGateway assumes that mu is already held by a public method. func (s *Service) createGateway(ctx context.Context, params CreateGatewayParams) (*gateway.Gateway, error) { - cliCommandProvider := clusters.NewDbcmdCLICommandProvider(s.cfg.Storage, dbcmd.SystemExecer{}) clusterCreateGatewayParams := clusters.CreateGatewayParams{ TargetURI: params.TargetURI, TargetUser: params.TargetUser, TargetSubresourceName: params.TargetSubresourceName, LocalPort: params.LocalPort, - CLICommandProvider: cliCommandProvider, - TCPPortAllocator: s.cfg.TCPPortAllocator, - OnExpiredCert: s.onExpiredGatewayCert, + CertReissuer: s.cfg.GatewayCertReissuer, } gateway, err := s.cfg.GatewayCreator.CreateGateway(ctx, clusterCreateGatewayParams) @@ -212,15 +208,6 @@ func (s *Service) createGateway(ctx context.Context, params CreateGatewayParams) return gateway, nil } -func (s *Service) onExpiredGatewayCert(ctx context.Context, gateway *gateway.Gateway) error { - cluster, err := s.ResolveCluster(gateway.TargetURI()) - if err != nil { - return trace.Wrap(err) - } - - return trace.Wrap(s.cfg.GatewayCertReissuer.ReissueCert(ctx, gateway, cluster)) -} - // RemoveGateway removes cluster gateway func (s *Service) RemoveGateway(gatewayURI string) error { s.mu.Lock() diff --git a/lib/teleterm/daemon/daemon_test.go b/lib/teleterm/daemon/daemon_test.go index b92c3dd8fc187..696c1cce02f3a 100644 --- a/lib/teleterm/daemon/daemon_test.go +++ b/lib/teleterm/daemon/daemon_test.go @@ -35,8 +35,9 @@ import ( ) type mockGatewayCreator struct { - t *testing.T - callCount int + t *testing.T + tcpPortAllocator gateway.TCPPortAllocator + callCount int } func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters.CreateGatewayParams) (*gateway.Gateway, error) { @@ -49,7 +50,7 @@ func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters. resourceURI := uri.New(params.TargetURI) - keyPairPaths := gatewaytest.MustGenAndSaveCert(m.t, tlsca.Identity{ + tlsCert := gatewaytest.MustGenDBCert(m.t, tlsca.Identity{ Username: params.TargetUser, Groups: []string{"test-group"}, RouteToDatabase: tlsca.RouteToDatabase{ @@ -66,12 +67,11 @@ func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters. TargetName: params.TargetURI, TargetSubresourceName: params.TargetSubresourceName, Protocol: defaults.ProtocolPostgres, - CertPath: keyPairPaths.CertPath, - KeyPath: keyPairPaths.KeyPath, + Cert: tlsCert, Insecure: true, WebProxyAddr: hs.Listener.Addr().String(), - CLICommandProvider: params.CLICommandProvider, - TCPPortAllocator: params.TCPPortAllocator, + CLICommandProvider: mockCLICommandProvider{}, + TCPPortAllocator: m.tcpPortAllocator, }) if err != nil { return nil, trace.Wrap(err) @@ -214,7 +214,10 @@ func TestGatewayCRUD(t *testing.T) { } homeDir := t.TempDir() - mockGatewayCreator := &mockGatewayCreator{t: t} + mockGatewayCreator := &mockGatewayCreator{ + t: t, + tcpPortAllocator: tt.tcpPortAllocator, + } storage, err := clusters.NewStorage(clusters.Config{ Dir: homeDir, @@ -223,9 +226,8 @@ func TestGatewayCRUD(t *testing.T) { require.NoError(t, err) daemon, err := New(Config{ - Storage: storage, - GatewayCreator: mockGatewayCreator, - TCPPortAllocator: tt.tcpPortAllocator, + Storage: storage, + GatewayCreator: mockGatewayCreator, }) require.NoError(t, err) diff --git a/lib/teleterm/daemon/gateway_cert_reissuer.go b/lib/teleterm/daemon/gateway_cert_reissuer.go index f0b7330119d9d..9ad8b2e991d6e 100644 --- a/lib/teleterm/daemon/gateway_cert_reissuer.go +++ b/lib/teleterm/daemon/gateway_cert_reissuer.go @@ -28,8 +28,8 @@ import ( api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/gateway" - "github.com/gravitational/teleport/lib/tlsca" ) // GatewayCertReissuer is responsible for managing the process of reissuing a db cert for a gateway @@ -43,14 +43,6 @@ type GatewayCertReissuer struct { Log *logrus.Entry } -// DBCertReissuer lets us pass a mock in tests and clusters.Cluster (which makes calls to the -// cluster) in production code. -type DBCertReissuer interface { - // ReissueDBCerts reaches out to the cluster to get a cert for the specific tlsca.RouteToDatabase - // and saves it to disk. - ReissueDBCerts(context.Context, tlsca.RouteToDatabase) error -} - // TSHDEventsClient takes only those methods from api.TshdEventsServiceClient that // GatewayCertReissuer actually needs. It makes mocking the client in tests easier and future-proof. // @@ -78,8 +70,8 @@ type TSHDEventsClient interface { // Any error ReissueCert returns is also forwarded to the Electron app so that it can show an error // notification. GatewayCertReissuer is typically called from within a goroutine that handles the // gateway, so without forwarding the error to the app, it would be visible only in the logs. -func (r *GatewayCertReissuer) ReissueCert(ctx context.Context, gateway *gateway.Gateway, dbCertReissuer DBCertReissuer) error { - if err := r.reissueCert(ctx, gateway, dbCertReissuer); err != nil { +func (r *GatewayCertReissuer) ReissueCert(ctx context.Context, gateway *gateway.Gateway, doReissueCert clusters.ReissueCertFunc) error { + if err := r.reissueCert(ctx, gateway, doReissueCert); err != nil { r.notifyAppAboutError(ctx, err, gateway) // Return the error to the alpn.LocalProxy's middleware. @@ -89,7 +81,7 @@ func (r *GatewayCertReissuer) ReissueCert(ctx context.Context, gateway *gateway. return nil } -func (r *GatewayCertReissuer) reissueCert(ctx context.Context, gateway *gateway.Gateway, dbCertReissuer DBCertReissuer) error { +func (r *GatewayCertReissuer) reissueCert(ctx context.Context, gateway *gateway.Gateway, doReissueCert clusters.ReissueCertFunc) error { // Make the first attempt at reissuing the db cert. // // It might happen that the db cert has expired but the user cert is still active, allowing us to @@ -98,7 +90,7 @@ func (r *GatewayCertReissuer) reissueCert(ctx context.Context, gateway *gateway. // This can happen if the user cert was refreshed by anything other than the gateway itself. For // example, if you execute `tsh ssh` within Connect after your user cert expires or there are two // gateways that subsequently go through this flow. - err := r.reissueAndReloadGatewayCert(ctx, gateway, dbCertReissuer) + err := doReissueCert(ctx) if err == nil { return nil @@ -129,7 +121,7 @@ func (r *GatewayCertReissuer) reissueCert(ctx context.Context, gateway *gateway. return trace.Wrap(err) } - err = r.reissueAndReloadGatewayCert(ctx, gateway, dbCertReissuer) + err = doReissueCert(ctx) if err != nil { return trace.Wrap(err) } @@ -137,15 +129,6 @@ func (r *GatewayCertReissuer) reissueCert(ctx context.Context, gateway *gateway. return nil } -func (r *GatewayCertReissuer) reissueAndReloadGatewayCert(ctx context.Context, gateway *gateway.Gateway, dbCertReissuer DBCertReissuer) error { - err := dbCertReissuer.ReissueDBCerts(ctx, gateway.RouteToDatabase()) - if err != nil { - return trace.Wrap(err) - } - - return trace.Wrap(gateway.ReloadCert()) -} - func (r *GatewayCertReissuer) requestReloginFromElectronApp(ctx context.Context, req *api.ReloginRequest) error { const reloginUserTimeout = time.Minute timeoutCtx, cancelTshdEventsCtx := context.WithTimeout(ctx, reloginUserTimeout) diff --git a/lib/teleterm/daemon/gateway_cert_reissuer_test.go b/lib/teleterm/daemon/gateway_cert_reissuer_test.go index 88bc3ab94e47d..8387a1c8e97ec 100644 --- a/lib/teleterm/daemon/gateway_cert_reissuer_test.go +++ b/lib/teleterm/daemon/gateway_cert_reissuer_test.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/gateway" - "github.com/gravitational/teleport/lib/tlsca" ) var log = logrus.WithField(trace.Component, "reissuer") @@ -138,7 +137,7 @@ func TestReissueCert(t *testing.T) { if tt.reissuerOpt != nil { tt.reissuerOpt(t, reissuer) } - err := reissuer.ReissueCert(ctx, gateway, dbCertReissuer) + err := reissuer.ReissueCert(ctx, gateway, dbCertReissuer.ReissueDBCerts) if tt.wantErr != nil { require.ErrorIs(t, err, tt.wantErr) require.ErrorContains(t, err, tt.wantAddedMessage) @@ -173,7 +172,7 @@ type mockDBCertReissuer struct { returnValuesForSubsequentCalls []error } -func (r *mockDBCertReissuer) ReissueDBCerts(context.Context, tlsca.RouteToDatabase) error { +func (r *mockDBCertReissuer) ReissueDBCerts(context.Context) error { var err error if r.callCount < len(r.returnValuesForSubsequentCalls) { err = r.returnValuesForSubsequentCalls[r.callCount] diff --git a/lib/teleterm/gateway/config.go b/lib/teleterm/gateway/config.go index 448850d33fc48..b22ee8cfa852c 100644 --- a/lib/teleterm/gateway/config.go +++ b/lib/teleterm/gateway/config.go @@ -18,6 +18,7 @@ package gateway import ( "context" + "crypto/tls" "runtime" "github.com/google/uuid" @@ -52,10 +53,8 @@ type Config struct { LocalAddress string // Protocol is the gateway protocol Protocol string - // CertPath - CertPath string - // KeyPath - KeyPath string + // Cert is the initial client certificate used to connect to the remote Teleport Proxy. + Cert tls.Certificate // Insecure Insecure bool // WebProxyAddr @@ -69,11 +68,11 @@ type Config struct { TCPPortAllocator TCPPortAllocator // Clock is used by Gateway.localProxy to check cert expiration. Clock clockwork.Clock - // OnExpiredCert is called when a new downstream connection is accepted by the + // ReissueCert is called when a new downstream connection is accepted by the // gateway but cannot be proxied because the cert used by the gateway has expired. // - // Handling of the connection is blocked until OnExpiredCert returns. - OnExpiredCert OnExpiredCertFunc + // Handling of the connection is blocked until ReissueCert returns. + ReissueCert ReissueCertFunc // TLSRoutingConnUpgradeRequired indicates that ALPN connection upgrades // are required for making TLS routing requests. TLSRoutingConnUpgradeRequired bool @@ -82,11 +81,11 @@ type Config struct { RootClusterCACertPoolFunc alpnproxy.GetClusterCACertPoolFunc } -// OnExpiredCertFunc is the type of a function that is called when a new downstream connection is +// ReissueCertFunc is the type of a function that is called when a new downstream connection is // accepted by the gateway but cannot be proxied because the cert used by the gateway has expired. // // Handling of the connection is blocked until the function returns. -type OnExpiredCertFunc func(context.Context, *Gateway) error +type ReissueCertFunc func(context.Context, *Gateway) (tls.Certificate, error) // CheckAndSetDefaults checks and sets the defaults func (c *Config) CheckAndSetDefaults() error { diff --git a/lib/teleterm/gateway/local_proxy_middleware.go b/lib/teleterm/gateway/db_middleware.go similarity index 89% rename from lib/teleterm/gateway/local_proxy_middleware.go rename to lib/teleterm/gateway/db_middleware.go index b163ae99d347c..1f62ae1ce7478 100644 --- a/lib/teleterm/gateway/local_proxy_middleware.go +++ b/lib/teleterm/gateway/db_middleware.go @@ -27,7 +27,7 @@ import ( "github.com/gravitational/teleport/lib/tlsca" ) -type localProxyMiddleware struct { +type dbMiddleware struct { onExpiredCert func(context.Context) error log *logrus.Entry dbRoute tlsca.RouteToDatabase @@ -39,7 +39,7 @@ type localProxyMiddleware struct { // // In the future, DBCertChecker is going to be extended so that it's used by both tsh and Connect // and this middleware will be removed. -func (m *localProxyMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy, conn net.Conn) error { +func (m *dbMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy, conn net.Conn) error { err := lp.CheckDBCerts(m.dbRoute) if err == nil { return nil @@ -58,6 +58,6 @@ func (m *localProxyMiddleware) OnNewConnection(ctx context.Context, lp *alpn.Loc // OnStart is a noop. client.DBCertChecker.OnStart checks cert validity too. However in Connect // there's no flow which would allow the user to create a local proxy without valid // certs. -func (m *localProxyMiddleware) OnStart(context.Context, *alpn.LocalProxy) error { +func (m *dbMiddleware) OnStart(context.Context, *alpn.LocalProxy) error { return nil } diff --git a/lib/teleterm/gateway/local_proxy_middleware_test.go b/lib/teleterm/gateway/db_middleware_test.go similarity index 97% rename from lib/teleterm/gateway/local_proxy_middleware_test.go rename to lib/teleterm/gateway/db_middleware_test.go index 03617d60ff50a..59ce57b7c96b0 100644 --- a/lib/teleterm/gateway/local_proxy_middleware_test.go +++ b/lib/teleterm/gateway/db_middleware_test.go @@ -34,7 +34,7 @@ import ( "github.com/gravitational/teleport/lib/utils/cert" ) -func TestLocalProxyMiddleware_OnNewConnection(t *testing.T) { +func Test_dbMiddleware_OnNewConnection(t *testing.T) { testCert, err := cert.GenerateSelfSignedCert([]string{"localhost"}) require.NoError(t, err) tlsCert, err := keys.X509KeyPair(testCert.Cert, testCert.PrivateKey) @@ -103,7 +103,7 @@ func TestLocalProxyMiddleware_OnNewConnection(t *testing.T) { hasCalledOnExpiredCert := false - middleware := &localProxyMiddleware{ + middleware := &dbMiddleware{ onExpiredCert: func(context.Context) error { hasCalledOnExpiredCert = true return nil diff --git a/lib/teleterm/gateway/gateway.go b/lib/teleterm/gateway/gateway.go index 4287b916d2af2..380ad8db0f08a 100644 --- a/lib/teleterm/gateway/gateway.go +++ b/lib/teleterm/gateway/gateway.go @@ -28,12 +28,9 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" - "github.com/gravitational/teleport/api/utils/keys" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" alpn "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/teleterm/api/uri" - "github.com/gravitational/teleport/lib/tlsca" - "github.com/gravitational/teleport/lib/utils" ) // New creates an instance of Gateway. It starts a listener on the specified port but it doesn't @@ -70,38 +67,28 @@ func New(cfg Config) (*Gateway, error) { cfg.LocalPort = port - tlsCert, err := keys.LoadX509KeyPair(cfg.CertPath, cfg.KeyPath) - if err != nil { - return nil, trace.Wrap(err) - } - - if err := checkCertSubject(tlsCert, cfg.RouteToDatabase()); err != nil { - return nil, trace.Wrap(err, - "database certificate check failed, try restarting the database connection") - } - localProxyConfig := alpn.LocalProxyConfig{ InsecureSkipVerify: cfg.Insecure, RemoteProxyAddr: cfg.WebProxyAddr, Listener: listener, ParentContext: closeContext, - Certs: []tls.Certificate{tlsCert}, Clock: cfg.Clock, ALPNConnUpgradeRequired: cfg.TLSRoutingConnUpgradeRequired, } - localProxyMiddleware := &localProxyMiddleware{ + localProxyMiddleware := &dbMiddleware{ log: cfg.Log, dbRoute: cfg.RouteToDatabase(), } - if cfg.OnExpiredCert != nil { + if cfg.ReissueCert != nil { localProxyConfig.Middleware = localProxyMiddleware } localProxy, err := alpn.NewLocalProxy(localProxyConfig, alpn.WithDatabaseProtocol(cfg.Protocol), alpn.WithClusterCAsIfConnUpgrade(closeContext, cfg.RootClusterCACertPoolFunc), + alpn.WithClientCerts(cfg.Cert), ) if err != nil { return nil, trace.Wrap(err) @@ -114,10 +101,15 @@ func New(cfg Config) (*Gateway, error) { localProxy: localProxy, } - if cfg.OnExpiredCert != nil { + if cfg.ReissueCert != nil { localProxyMiddleware.onExpiredCert = func(ctx context.Context) error { - err := cfg.OnExpiredCert(ctx, gateway) - return trace.Wrap(err) + tlsCert, err := cfg.ReissueCert(ctx, gateway) + if err != nil { + return trace.Wrap(err) + } + + localProxy.SetCerts([]tls.Certificate{tlsCert}) + return nil } } @@ -236,53 +228,6 @@ func (g *Gateway) CLICommand() (*api.GatewayCLICommand, error) { }, nil } -// RouteToDatabase returns tlsca.RouteToDatabase based on the config of the gateway. -// -// The tlsca.RouteToDatabase.Database field is skipped, as it's an optional field and gateways can -// change their Config.TargetSubresourceName at any moment. -func (g *Gateway) RouteToDatabase() tlsca.RouteToDatabase { - return g.cfg.RouteToDatabase() -} - -// ReloadCert loads the key pair from cfg.CertPath & cfg.KeyPath and updates the cert of the running -// local proxy. This is typically done after the cert is reissued and saved to disk. -// -// In the future, we're probably going to make this method accept the cert as an arg rather than -// reading from disk. -func (g *Gateway) ReloadCert() error { - g.cfg.Log.Debug("Reloading cert") - - tlsCert, err := keys.LoadX509KeyPair(g.cfg.CertPath, g.cfg.KeyPath) - if err != nil { - return trace.Wrap(err) - } - - if err := checkCertSubject(tlsCert, g.RouteToDatabase()); err != nil { - return trace.Wrap(err, - "database certificate check failed, try restarting the database connection") - } - - g.localProxy.SetCerts([]tls.Certificate{tlsCert}) - - return nil -} - -// checkCertSubject checks if the cert subject matches the expected db route. -// -// Database certs are scoped per database server but not per database user or database name. -// It might happen that after we save the cert but before we load it, another process obtains a -// cert for another db user. -// -// Before using the cert for the proxy, we have to perform this check. -func checkCertSubject(tlsCert tls.Certificate, dbRoute tlsca.RouteToDatabase) error { - cert, err := utils.TLSCertLeaf(tlsCert) - if err != nil { - return trace.Wrap(err) - } - - return trace.Wrap(alpn.CheckCertSubject(cert, dbRoute)) -} - // Gateway describes local proxy that creates a gateway to the remote Teleport resource. // // Gateway is not safe for concurrent use in itself. However, all access to gateways is gated by diff --git a/lib/teleterm/gateway/gateway_test.go b/lib/teleterm/gateway/gateway_test.go index 3419313a9e0e2..87f9876dbf014 100644 --- a/lib/teleterm/gateway/gateway_test.go +++ b/lib/teleterm/gateway/gateway_test.go @@ -60,7 +60,7 @@ func TestGatewayStart(t *testing.T) { hs.Close() }) - keyPairPaths := gatewaytest.MustGenAndSaveCert(t, tlsca.Identity{ + tlsCert := gatewaytest.MustGenDBCert(t, tlsca.Identity{ Username: "alice", Groups: []string{"test-group"}, RouteToDatabase: tlsca.RouteToDatabase{ @@ -76,8 +76,7 @@ func TestGatewayStart(t *testing.T) { TargetURI: uri.NewClusterURI("bar").AppendDB("foo").String(), TargetUser: "alice", Protocol: defaults.ProtocolPostgres, - CertPath: keyPairPaths.CertPath, - KeyPath: keyPairPaths.KeyPath, + Cert: tlsCert, Insecure: true, WebProxyAddr: hs.Listener.Addr().String(), CLICommandProvider: mockCLICommandProvider{}, @@ -180,7 +179,7 @@ func createGateway(t *testing.T, tcpPortAllocator TCPPortAllocator) *Gateway { hs.Close() }) - keyPairPaths := gatewaytest.MustGenAndSaveCert(t, tlsca.Identity{ + tlsCert := gatewaytest.MustGenDBCert(t, tlsca.Identity{ Username: "alice", Groups: []string{"test-group"}, RouteToDatabase: tlsca.RouteToDatabase{ @@ -196,8 +195,7 @@ func createGateway(t *testing.T, tcpPortAllocator TCPPortAllocator) *Gateway { TargetURI: uri.NewClusterURI("bar").AppendDB("foo").String(), TargetUser: "alice", Protocol: defaults.ProtocolPostgres, - CertPath: keyPairPaths.CertPath, - KeyPath: keyPairPaths.KeyPath, + Cert: tlsCert, Insecure: true, WebProxyAddr: hs.Listener.Addr().String(), CLICommandProvider: mockCLICommandProvider{}, diff --git a/lib/teleterm/gatewaytest/helpers.go b/lib/teleterm/gatewaytest/helpers.go index e8e7be6d0a17e..3efe0f0cb258f 100644 --- a/lib/teleterm/gatewaytest/helpers.go +++ b/lib/teleterm/gatewaytest/helpers.go @@ -23,7 +23,6 @@ import ( "encoding/pem" "fmt" "net" - "os" "testing" "time" @@ -140,16 +139,9 @@ func (m *MockListener) RealAddr() net.Addr { return m.realListener.Addr() } -type KeyPairPaths struct { - CertPath string - KeyPath string -} - -func MustGenAndSaveCert(t *testing.T, identity tlsca.Identity) KeyPairPaths { +func MustGenDBCert(t *testing.T, identity tlsca.Identity) tls.Certificate { t.Helper() - dir := t.TempDir() - ca := mustGenCACert(t) tlsCert := mustGenCertSignedWithCA(t, ca, identity) @@ -157,32 +149,12 @@ func MustGenAndSaveCert(t *testing.T, identity tlsca.Identity) KeyPairPaths { privateKey, ok := tlsCert.PrivateKey.(*rsa.PrivateKey) require.True(t, ok, "Failed to cast tlsCert.PrivateKey") - // Save the cert. - - certFile, err := os.CreateTemp(dir, "cert") - require.NoError(t, err) - pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tlsCert.Certificate[0]}) - - _, err = certFile.Write(pemCert) - require.NoError(t, err) - require.NoError(t, certFile.Close()) - - // Save the private key. - - keyFile, err := os.CreateTemp(dir, "key") - require.NoError(t, err) - pemPrivateKey := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) - _, err = keyFile.Write(pemPrivateKey) + tlsCert, err := tls.X509KeyPair(pemCert, pemPrivateKey) require.NoError(t, err) - require.NoError(t, keyFile.Close()) - - return KeyPairPaths{ - CertPath: certFile.Name(), - KeyPath: keyFile.Name(), - } + return tlsCert } func mustGenCACert(t *testing.T) *tlsca.CertAuthority { diff --git a/lib/tlsca/ca.go b/lib/tlsca/ca.go index bddf8d2c689fc..4270c212e84f1 100644 --- a/lib/tlsca/ca.go +++ b/lib/tlsca/ca.go @@ -249,6 +249,24 @@ func (r RouteToDatabase) String() string { r.ServiceName, r.Protocol, r.Username, r.Database) } +// CheckCertSubject checks if the route to the database from the cert matches +// the provided route in terms of username and database (if present). +func (r RouteToDatabase) CheckCertSubject(cert *x509.Certificate) error { + identity, err := FromSubject(cert.Subject, cert.NotAfter) + if err != nil { + return trace.Wrap(err) + } + if r.Username != "" && r.Username != identity.RouteToDatabase.Username { + return trace.Errorf("certificate subject is for user %s, but need %s", + identity.RouteToDatabase.Username, r.Username) + } + if r.Database != "" && r.Database != identity.RouteToDatabase.Database { + return trace.Errorf("certificate subject is for database name %s, but need %s", + identity.RouteToDatabase.Database, r.Database) + } + return nil +} + // DeviceExtensions holds device-aware extensions for the identity. type DeviceExtensions struct { // DeviceID is the trusted device identifier.