diff --git a/lib/teleterm/clusters/cluster_databases.go b/lib/teleterm/clusters/cluster_databases.go index 59b5f9b2dc701..91a22ebac4374 100644 --- a/lib/teleterm/clusters/cluster_databases.go +++ b/lib/teleterm/clusters/cluster_databases.go @@ -155,12 +155,12 @@ func (c *Cluster) GetDatabases(ctx context.Context, r *api.GetDatabasesRequest) } // ReissueDBCerts issues new certificates for specific DB access -func (c *Cluster) ReissueDBCerts(ctx context.Context, user string, db types.Database) error { +func (c *Cluster) ReissueDBCerts(ctx context.Context, routeToDatabase tlsca.RouteToDatabase) error { // When generating certificate for MongoDB access, database username must // be encoded into it. This is required to be able to tell which database // user to authenticate the connection as. - if db.GetProtocol() == libdefaults.ProtocolMongoDB && user == "" { - return trace.BadParameter("please provide the database user name using --db-user flag") + if routeToDatabase.Protocol == libdefaults.ProtocolMongoDB && routeToDatabase.Username == "" { + return trace.BadParameter("the username must be present for MongoDB connections") } err := addMetadataToRetryableError(ctx, func() error { @@ -177,9 +177,9 @@ func (c *Cluster) ReissueDBCerts(ctx context.Context, user string, db types.Data err = c.clusterClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ RouteToCluster: c.clusterClient.SiteName, RouteToDatabase: proto.RouteToDatabase{ - ServiceName: db.GetName(), - Protocol: db.GetProtocol(), - Username: user, + ServiceName: routeToDatabase.ServiceName, + Protocol: routeToDatabase.Protocol, + Username: routeToDatabase.Username, }, AccessRequests: c.status.ActiveRequests.AccessRequests, }) @@ -194,11 +194,7 @@ func (c *Cluster) ReissueDBCerts(ctx context.Context, user string, db types.Data } // Update the database-specific connection profile file. - err = dbprofile.Add(ctx, c.clusterClient, tlsca.RouteToDatabase{ - ServiceName: db.GetName(), - Protocol: db.GetProtocol(), - Username: user, - }, c.status) + err = dbprofile.Add(ctx, c.clusterClient, routeToDatabase, c.status) if err != nil { return trace.Wrap(err) } diff --git a/lib/teleterm/clusters/cluster_gateways.go b/lib/teleterm/clusters/cluster_gateways.go index a1558501cd713..ce2da150d4d1d 100644 --- a/lib/teleterm/clusters/cluster_gateways.go +++ b/lib/teleterm/clusters/cluster_gateways.go @@ -19,9 +19,10 @@ package clusters import ( "context" - "github.com/gravitational/teleport/lib/teleterm/gateway" - "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/teleterm/gateway" + "github.com/gravitational/teleport/lib/tlsca" ) type CreateGatewayParams struct { @@ -36,6 +37,7 @@ type CreateGatewayParams struct { LocalPort string CLICommandProvider gateway.CLICommandProvider TCPPortAllocator gateway.TCPPortAllocator + OnExpiredCert gateway.OnExpiredCertFunc } // CreateGateway creates a gateway @@ -45,7 +47,13 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) return nil, trace.Wrap(err) } - if err := c.ReissueDBCerts(ctx, params.TargetUser, db); err != nil { + routeToDatabase := tlsca.RouteToDatabase{ + ServiceName: db.GetName(), + Protocol: db.GetProtocol(), + Username: params.TargetUser, + } + + if err := c.ReissueDBCerts(ctx, routeToDatabase); err != nil { return nil, trace.Wrap(err) } @@ -60,9 +68,11 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) CertPath: c.status.DatabaseCertPathForCluster(c.clusterClient.SiteName, db.GetName()), Insecure: c.clusterClient.InsecureSkipVerify, WebProxyAddr: c.clusterClient.WebProxyAddr, - Log: c.Log.WithField("gateway", params.TargetURI), + Log: c.Log, CLICommandProvider: params.CLICommandProvider, TCPPortAllocator: params.TCPPortAllocator, + OnExpiredCert: params.OnExpiredCert, + Clock: c.clock, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index 314a6bd4fc01f..909dce5d41565 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -317,7 +317,7 @@ func (s *Service) SetGatewayLocalPort(gatewayURI, localPort string) (*gateway.Ga return oldGateway, nil } - newGateway, err := gateway.NewWithLocalPort(*oldGateway, localPort) + newGateway, err := gateway.NewWithLocalPort(oldGateway, localPort) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/teleterm/daemon/daemon_test.go b/lib/teleterm/daemon/daemon_test.go index 57137163a6e95..ea9d60bb8bf29 100644 --- a/lib/teleterm/daemon/daemon_test.go +++ b/lib/teleterm/daemon/daemon_test.go @@ -64,7 +64,9 @@ func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters. return nil, trace.Wrap(err) } m.t.Cleanup(func() { - gateway.Close() + if err := gateway.Close(); err != nil { + m.t.Logf("Ignoring error from gateway.Close() during cleanup, it appears the gateway was already closed. The error was: %s", err) + } }) return gateway, nil diff --git a/lib/teleterm/gateway/config.go b/lib/teleterm/gateway/config.go index 2c103afa3a930..60fc022462c16 100644 --- a/lib/teleterm/gateway/config.go +++ b/lib/teleterm/gateway/config.go @@ -17,15 +17,18 @@ limitations under the License. package gateway import ( + "context" "runtime" + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/teleterm/api/uri" - "github.com/gravitational/trace" - - "github.com/google/uuid" - "github.com/sirupsen/logrus" + "github.com/gravitational/teleport/lib/tlsca" ) // Config describes gateway configuration @@ -39,7 +42,7 @@ type Config struct { // TargetUser is the target user name TargetUser string // TargetSubresourceName points at a subresource of the remote resource, for example a database - // name on a database server. + // name on a database server. It is used only for generating the CLI command. TargetSubresourceName string // Port is the gateway port @@ -63,8 +66,21 @@ type Config struct { // TCPPortAllocator creates listeners on the given ports. This interface lets us avoid occupying // hardcoded ports in tests. TCPPortAllocator TCPPortAllocator + // Clock is used by Gateway.localProxy to check cert expiration. + Clock clockwork.Clock + // OnExpiredCert is called when a new downstream connection is accepted by the + // gateway but cannot be proxied because the cert used by the gateway has expired. + // + // Handling of the connection is blocked until OnExpiredCert returns. + OnExpiredCert OnExpiredCertFunc } +// OnExpiredCertFunc is the type of a function that is called when a new downstream connection is +// accepted by the gateway but cannot be proxied because the cert used by the gateway has expired. +// +// Handling of the connection is blocked until the function returns. +type OnExpiredCertFunc func(context.Context, *Gateway) error + // CheckAndSetDefaults checks and sets the defaults func (c *Config) CheckAndSetDefaults() error { if c.URI.String() == "" { @@ -84,9 +100,14 @@ func (c *Config) CheckAndSetDefaults() error { } if c.Log == nil { - c.Log = logrus.WithField("gateway", c.URI.String()) + c.Log = logrus.NewEntry(logrus.StandardLogger()) } + c.Log = c.Log.WithFields(logrus.Fields{ + "resource": c.TargetURI, + "gateway": c.URI.String(), + }) + if c.TargetName == "" { return trace.BadParameter("missing target name") } @@ -103,5 +124,21 @@ func (c *Config) CheckAndSetDefaults() error { c.TCPPortAllocator = NetTCPPortAllocator{} } + if c.Clock == nil { + c.Clock = clockwork.NewRealClock() + } + return nil } + +// RouteToDatabase returns tlsca.RouteToDatabase based on the config of the gateway. +// +// The tlsca.RouteToDatabase.Database field is skipped, as it's an optional field and gateways can +// change their Config.TargetSubresourceName at any moment. +func (c *Config) RouteToDatabase() tlsca.RouteToDatabase { + return tlsca.RouteToDatabase{ + ServiceName: c.TargetName, + Protocol: c.Protocol, + Username: c.TargetUser, + } +} diff --git a/lib/teleterm/gateway/gateway.go b/lib/teleterm/gateway/gateway.go index 3f9840fc91a99..c5ec8c0806913 100644 --- a/lib/teleterm/gateway/gateway.go +++ b/lib/teleterm/gateway/gateway.go @@ -65,6 +65,8 @@ func New(cfg Config) (*Gateway, error) { return nil, trace.Wrap(err) } + cfg.LocalPort = port + protocol, err := alpncommon.ToALPNProtocol(cfg.Protocol) if err != nil { return nil, trace.Wrap(err) @@ -80,7 +82,7 @@ func New(cfg Config) (*Gateway, error) { return nil, trace.Wrap(err) } - localProxy, err := alpn.NewLocalProxy(alpn.LocalProxyConfig{ + localProxyConfig := alpn.LocalProxyConfig{ InsecureSkipVerify: cfg.Insecure, RemoteProxyAddr: cfg.WebProxyAddr, Protocols: []alpncommon.Protocol{protocol}, @@ -88,13 +90,23 @@ func New(cfg Config) (*Gateway, error) { ParentContext: closeContext, SNI: address.Host(), Certs: []tls.Certificate{tlsCert}, - }) + Clock: cfg.Clock, + } + + localProxyMiddleware := &localProxyMiddleware{ + log: cfg.Log, + dbRoute: cfg.RouteToDatabase(), + } + + if cfg.OnExpiredCert != nil { + localProxyConfig.Middleware = localProxyMiddleware + } + + localProxy, err := alpn.NewLocalProxy(localProxyConfig) if err != nil { return nil, trace.Wrap(err) } - cfg.LocalPort = port - gateway := &Gateway{ cfg: &cfg, closeContext: closeContext, @@ -102,13 +114,20 @@ func New(cfg Config) (*Gateway, error) { localProxy: localProxy, } + if cfg.OnExpiredCert != nil { + localProxyMiddleware.onExpiredCert = func(ctx context.Context) error { + err := cfg.OnExpiredCert(ctx, gateway) + return trace.Wrap(err) + } + } + ok = true return gateway, nil } // NewWithLocalPort initializes a copy of an existing gateway which has all config fields identical // to the existing gateway with the exception of the local port. -func NewWithLocalPort(gateway Gateway, port string) (*Gateway, error) { +func NewWithLocalPort(gateway *Gateway, port string) (*Gateway, error) { if port == gateway.LocalPort() { return nil, trace.BadParameter("port is already set to %s", port) } @@ -207,6 +226,24 @@ func (g *Gateway) CLICommand() (string, error) { return cliCommand, nil } +// ReloadCert loads the key pair from cfg.CertPath & cfg.KeyPath and updates the cert of the running +// local proxy. This is typically done after the cert is reissued and saved to disk. +// +// In the future, we're probably going to make this method accept the cert as an arg rather than +// reading from disk. +func (g *Gateway) ReloadCert() error { + g.cfg.Log.Debug("Reloading cert") + + tlsCert, err := keys.LoadX509KeyPair(g.cfg.CertPath, g.cfg.KeyPath) + if err != nil { + return trace.Wrap(err) + } + + g.localProxy.SetCerts([]tls.Certificate{tlsCert}) + + return nil +} + // Gateway describes local proxy that creates a gateway to the remote Teleport resource. // // Gateway is not safe for concurrent use in itself. However, all access to gateways is gated by diff --git a/lib/teleterm/gateway/gateway_test.go b/lib/teleterm/gateway/gateway_test.go index b55b14cbc0e7e..35515c91665ec 100644 --- a/lib/teleterm/gateway/gateway_test.go +++ b/lib/teleterm/gateway/gateway_test.go @@ -67,7 +67,11 @@ func TestGatewayStart(t *testing.T) { }, ) require.NoError(t, err) - t.Cleanup(func() { gateway.Close() }) + t.Cleanup(func() { + if err := gateway.Close(); err != nil { + t.Logf("Ignoring error from gateway.Close() during cleanup, it appears the gateway was already closed. The error was: %s", err) + } + }) gatewayAddress := net.JoinHostPort(gateway.LocalAddress(), gateway.LocalPort()) require.NotEmpty(t, gateway.LocalPort()) @@ -91,7 +95,7 @@ func TestNewWithLocalPortStartsListenerOnNewPortIfPortIsFree(t *testing.T) { tcpPortAllocator := gatewaytest.MockTCPPortAllocator{} oldGateway := createAndServeGateway(t, &tcpPortAllocator) - newGateway, err := NewWithLocalPort(*oldGateway, "12345") + newGateway, err := NewWithLocalPort(oldGateway, "12345") require.NoError(t, err) require.Equal(t, "12345", newGateway.LocalPort()) require.Equal(t, oldGateway.URI(), newGateway.URI()) @@ -109,7 +113,7 @@ func TestNewWithLocalPortReturnsErrorIfNewPortIsOccupied(t *testing.T) { tcpPortAllocator := gatewaytest.MockTCPPortAllocator{PortsInUse: []string{"12345"}} gateway := createAndServeGateway(t, &tcpPortAllocator) - _, err := NewWithLocalPort(*gateway, "12345") + _, err := NewWithLocalPort(gateway, "12345") require.ErrorContains(t, err, "address already in use") } @@ -119,7 +123,7 @@ func TestNewWithLocalPortReturnsErrorIfNewPortEqualsOldPort(t *testing.T) { port := gateway.LocalPort() expectedErrMessage := fmt.Sprintf("port is already set to %s", port) - _, err := NewWithLocalPort(*gateway, port) + _, err := NewWithLocalPort(gateway, port) require.True(t, trace.IsBadParameter(err), "Expected err to be a BadParameter error") require.ErrorContains(t, err, expectedErrMessage) } @@ -171,7 +175,9 @@ func serveGatewayAndBlockUntilItAcceptsConnections(t *testing.T, gateway *Gatewa serveErr <- err }() t.Cleanup(func() { - gateway.Close() + if err := gateway.Close(); err != nil { + t.Logf("Ignoring error from gateway.Close() during cleanup, it appears the gateway was already closed. The error was: %s", err) + } require.NoError(t, <-serveErr, "Gateway %s returned error from Serve()", gateway.URI()) }) diff --git a/lib/teleterm/gateway/local_proxy_middleware.go b/lib/teleterm/gateway/local_proxy_middleware.go new file mode 100644 index 0000000000000..9db185bba5d60 --- /dev/null +++ b/lib/teleterm/gateway/local_proxy_middleware.go @@ -0,0 +1,65 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gateway + +import ( + "context" + "crypto/x509" + "errors" + "net" + + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + + alpn "github.com/gravitational/teleport/lib/srv/alpnproxy" + "github.com/gravitational/teleport/lib/tlsca" +) + +type localProxyMiddleware struct { + onExpiredCert func(context.Context) error + log *logrus.Entry + dbRoute tlsca.RouteToDatabase +} + +// OnNewConnection calls m.onExpiredCert if the cert used by the local proxy has expired. +// This is a very basic reimplementation of client.DBCertChecker.OnNewConnection. DBCertChecker +// supports per-session MFA while for now Connect needs to just check for expired certs. +// +// In the future, DBCertChecker is going to be extended so that it's used by both tsh and Connect +// and this middleware will be removed. +func (m *localProxyMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy, conn net.Conn) error { + m.log.Debug("Checking local proxy certs") + + err := lp.CheckDBCerts(m.dbRoute) + if err == nil { + return nil + } + + // Return early and don't fire onExpiredCert if certs are invalid but not due to expiry. + if !errors.As(err, &x509.CertificateInvalidError{}) { + return trace.Wrap(err) + } + + m.log.WithError(err).Debug("Gateway certificates have expired") + + return trace.Wrap(m.onExpiredCert(ctx)) +} + +// OnStart is a noop. client.DBCertChecker.OnStart checks cert validity too. However in Connect +// there's no flow which would allow the user to create a local proxy without valid +// certs. +func (m *localProxyMiddleware) OnStart(context.Context, *alpn.LocalProxy) error { + return nil +} diff --git a/lib/teleterm/gateway/local_proxy_middleware_test.go b/lib/teleterm/gateway/local_proxy_middleware_test.go new file mode 100644 index 0000000000000..8419ae9a338fb --- /dev/null +++ b/lib/teleterm/gateway/local_proxy_middleware_test.go @@ -0,0 +1,128 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gateway + +import ( + "context" + "crypto/tls" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/defaults" + alpn "github.com/gravitational/teleport/lib/srv/alpnproxy" + alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" +) + +func TestLocalProxyMiddleware_OnNewConnection(t *testing.T) { + cert, err := utils.GenerateSelfSignedCert([]string{"localhost"}) + require.NoError(t, err) + tlsCert, err := keys.X509KeyPair(cert.Cert, cert.PrivateKey) + require.NoError(t, err) + x509cert, err := utils.TLSCertToX509(tlsCert) + require.NoError(t, err) + + clockAfterCertExpiry := clockwork.NewFakeClockAt(x509cert.NotAfter) + clockAfterCertExpiry.Advance(time.Hour*4 + time.Minute*20) + + clockBeforeCertExpiry := clockwork.NewFakeClockAt(x509cert.NotBefore) + clockBeforeCertExpiry.Advance(time.Hour*4 + time.Minute*20) + + validDbRoute := tlsca.RouteToDatabase{ + Protocol: defaults.ProtocolPostgres, + ServiceName: "foo-database-server", + } + + tests := []struct { + name string + clock clockwork.Clock + dbRoute tlsca.RouteToDatabase + expectation func(t *testing.T, onNewConnectionErr error, hasCalledOnExpiredCert bool) + }{ + { + name: "With expired cert", + clock: clockAfterCertExpiry, + dbRoute: validDbRoute, + expectation: func(t *testing.T, onNewConnectionErr error, hasCalledOnExpiredCert bool) { + require.NoError(t, onNewConnectionErr) + require.True(t, hasCalledOnExpiredCert, + "Expected the onExpiredCert callback to be called by the middleware") + }, + }, + { + name: "With active cert with subject not matching dbRoute", + clock: clockBeforeCertExpiry, + dbRoute: tlsca.RouteToDatabase{ + Protocol: defaults.ProtocolPostgres, + ServiceName: "foo-database-server", + Username: "bar", + Database: "quux", + }, + expectation: func(t *testing.T, onNewConnectionErr error, hasCalledOnExpiredCert bool) { + require.Error(t, onNewConnectionErr) + require.False(t, hasCalledOnExpiredCert, + "Expected the onExpiredCert callback to not be called by the middleware") + }, + }, + { + name: "With valid cert", + clock: clockBeforeCertExpiry, + dbRoute: validDbRoute, + expectation: func(t *testing.T, onNewConnectionErr error, hasCalledOnExpiredCert bool) { + require.NoError(t, onNewConnectionErr) + require.False(t, hasCalledOnExpiredCert, + "Expected the onExpiredCert callback to not be called by the middleware") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hasCalledOnExpiredCert := false + + middleware := &localProxyMiddleware{ + onExpiredCert: func(context.Context) error { + hasCalledOnExpiredCert = true + return nil + }, + log: logrus.WithField(trace.Component, "middleware"), + dbRoute: tt.dbRoute, + } + + localProxy, err := alpn.NewLocalProxy(alpn.LocalProxyConfig{ + RemoteProxyAddr: "localhost", + Protocols: []alpncommon.Protocol{alpncommon.ProtocolHTTP}, + ParentContext: ctx, + Clock: tt.clock, + }) + require.NoError(t, err) + + localProxy.SetCerts([]tls.Certificate{tlsCert}) + + err = middleware.OnNewConnection(ctx, localProxy, nil /* net.Conn, not used by middleware */) + tt.expectation(t, err, hasCalledOnExpiredCert) + }) + } +}