diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index ab019707551b3..954784f1b27b8 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -34,6 +34,10 @@ import ( "github.com/gravitational/teleport/lib/utils/aws" ) +// OnNewConnectionFunc is a callback triggered when a new downstream connection is +// accepted by the local proxy. +type OnNewConnectionFunc func(lp *LocalProxy, conn net.Conn) + // LocalProxy allows upgrading incoming connection to TLS where custom TLS values are set SNI ALPN and // updated connection is forwarded to remote ALPN SNI teleport proxy service. type LocalProxy struct { @@ -71,6 +75,11 @@ type LocalProxyConfig struct { Certs []tls.Certificate // AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification. AWSCredentials *credentials.Credentials + // OnNewConnection is a callback triggered when a new downstream connection + // is accepted by the local proxy. + // + // Note that the callback blocks handling of the connection. + OnNewConnection OnNewConnectionFunc } // CheckAndSetDefaults verifies the constraints for LocalProxyConfig. @@ -128,6 +137,11 @@ func (l *LocalProxy) Start(ctx context.Context) error { log.WithError(err).Errorf("Failed to accept client connection.") return trace.Wrap(err) } + + if l.cfg.OnNewConnection != nil { + l.cfg.OnNewConnection(l, conn) + } + go func() { if err := l.handleDownstreamConnection(ctx, conn, l.cfg.SNI); err != nil { if utils.IsOKNetworkError(err) { diff --git a/tool/tsh/db.go b/tool/tsh/db.go index 2f49c901f2533..211b0bb505586 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -18,6 +18,7 @@ package main import ( "context" + "crypto/x509" "encoding/base64" "fmt" "io" @@ -34,6 +35,7 @@ import ( "github.com/gravitational/teleport/lib/client/db/dbcmd" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -598,6 +600,12 @@ type localProxyConfig struct { // it's always true for Snowflake database. Value is copied here to not modify // cli arguments directly. localProxyTunnel bool + // onNewConnection is a callback triggered when the ALPN local proxy + // accepts a new connection. Note that this callback always provides the + // database certificate in addition to the parameters in the original + // alpnproxy.OnNewConnectionFunc, regardless of whether the database + // certificate is provided to the ALPN local proxy or not. + onNewConnection func(dbCert *x509.Certificate, lp *alpnproxy.LocalProxy, conn net.Conn) } // prepareLocalProxyOptions created localProxyOpts needed to create local proxy from localProxyConfig. @@ -620,6 +628,17 @@ func prepareLocalProxyOptions(arg *localProxyConfig) (localProxyOpts, error) { keyFile: keyFile, } + if arg.onNewConnection != nil { + dbCert, err := certFromPath(arg.profile.DatabaseCertPathForCluster(arg.cliConf.SiteName, arg.routeToDatabase.ServiceName)) + if err != nil { + return localProxyOpts{}, trace.Wrap(err) + } + + opts.onNewConnection = func(lp *alpnproxy.LocalProxy, conn net.Conn) { + arg.onNewConnection(dbCert, lp, conn) + } + } + // For SQL Server connections, local proxy must be configured with the // client certificate that will be used to route connections. if arg.routeToDatabase.Protocol == defaults.ProtocolSQLServer { @@ -818,11 +837,7 @@ func dbInfoHasChanged(cf *CLIConf, certPath string) (bool, error) { return false, nil } - buff, err := os.ReadFile(certPath) - if err != nil { - return false, trace.Wrap(err) - } - cert, err := tlsca.ParseCertificatePEM(buff) + cert, err := certFromPath(certPath) if err != nil { return false, trace.Wrap(err) } @@ -842,6 +857,20 @@ func dbInfoHasChanged(cf *CLIConf, certPath string) (bool, error) { return false, nil } +// certFromPath parses the PEM-encoded certificate from the provided path. Note +// that this function expects only one certificate in the file. +func certFromPath(path string) (*x509.Certificate, error) { + buff, err := os.ReadFile(path) + if err != nil { + return nil, trace.Wrap(err) + } + cert, err := tlsca.ParseCertificatePEM(buff) + if err != nil { + return nil, trace.Wrap(err) + } + return cert, nil +} + // isMFADatabaseAccessRequired calls the IsMFARequired endpoint in order to get from user roles if access to the database // requires MFA. func isMFADatabaseAccessRequired(cf *CLIConf, tc *client.TeleportClient, database *tlsca.RouteToDatabase) (bool, error) { diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 765a51002d2df..8c1f6f1ade78d 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -355,6 +355,15 @@ func onProxyCommandDB(cf *CLIConf) error { } }() + onNewConnection := func(dbCert *x509.Certificate, lp *alpnproxy.LocalProxy, conn net.Conn) { + if time.Now().After(dbCert.NotAfter) { + fmt.Fprintln(cf.Stdout()) + fmt.Fprintln(cf.Stdout(), "Your database session is expired. Please restart the local proxy.") + lp.Close() + return + } + } + proxyOpts, err := prepareLocalProxyOptions(&localProxyConfig{ cliConf: cf, teleportClient: client, @@ -362,6 +371,7 @@ func onProxyCommandDB(cf *CLIConf) error { routeToDatabase: routeToDatabase, listener: listener, localProxyTunnel: cf.LocalProxyTunnel, + onNewConnection: onNewConnection, }) if err != nil { return trace.Wrap(err) @@ -421,12 +431,13 @@ func onProxyCommandDB(cf *CLIConf) error { } type localProxyOpts struct { - proxyAddr string - listener net.Listener - protocols []alpncommon.Protocol - insecure bool - certFile string - keyFile string + proxyAddr string + listener net.Listener + protocols []alpncommon.Protocol + insecure bool + certFile string + keyFile string + onNewConnection alpnproxy.OnNewConnectionFunc } // protocol returns the first protocol or string if configuration doesn't contain any protocols. @@ -458,6 +469,7 @@ func mkLocalProxy(ctx context.Context, opts localProxyOpts) (*alpnproxy.LocalPro ParentContext: ctx, SNI: address.Host(), Certs: certs, + OnNewConnection: opts.onNewConnection, }) if err != nil { return nil, trace.Wrap(err)