Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lib/srv/alpnproxy/local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ import (
"github.com/gravitational/teleport/lib/utils/aws"
)

// OnNewConnectionFunc is a callback triggered when a new downstream connection is
// accepted by the local proxy.
type OnNewConnectionFunc func(lp *LocalProxy, conn net.Conn)

// LocalProxy allows upgrading incoming connection to TLS where custom TLS values are set SNI ALPN and
// updated connection is forwarded to remote ALPN SNI teleport proxy service.
type LocalProxy struct {
Expand Down Expand Up @@ -71,6 +75,11 @@ type LocalProxyConfig struct {
Certs []tls.Certificate
// AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification.
AWSCredentials *credentials.Credentials
// OnNewConnection is a callback triggered when a new downstream connection
// is accepted by the local proxy.
//
// Note that the callback blocks handling of the connection.
Comment on lines +80 to +81
Copy link
Copy Markdown
Contributor

@GavinFrazar GavinFrazar Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this note but I just want to clarify that your intention was to block the LocalProxy.Start for {...} loop entirely or just an individual connection? We are doing the former - but if that was your intention could you explain why not call the callback inside the spawned goroutine just before calling l.handleDownstreamConnection?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. I guess this note is inaccurate. In current use cases, I think we want to block the LocalProxy.Start, meaning when detecting a certain situation by this callback, there is no point in accepting another new connection until the situation is resolved. For the MFA flow, feel free to do anything that works for you.

OnNewConnection OnNewConnectionFunc
}

// CheckAndSetDefaults verifies the constraints for LocalProxyConfig.
Expand Down Expand Up @@ -128,6 +137,11 @@ func (l *LocalProxy) Start(ctx context.Context) error {
log.WithError(err).Errorf("Failed to accept client connection.")
return trace.Wrap(err)
}

if l.cfg.OnNewConnection != nil {
l.cfg.OnNewConnection(l, conn)
}

go func() {
if err := l.handleDownstreamConnection(ctx, conn, l.cfg.SNI); err != nil {
if utils.IsOKNetworkError(err) {
Expand Down
39 changes: 34 additions & 5 deletions tool/tsh/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"context"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
Expand All @@ -34,6 +35,7 @@ import (
"github.com/gravitational/teleport/lib/client/db/dbcmd"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -598,6 +600,12 @@ type localProxyConfig struct {
// it's always true for Snowflake database. Value is copied here to not modify
// cli arguments directly.
localProxyTunnel bool
// onNewConnection is a callback triggered when the ALPN local proxy
// accepts a new connection. Note that this callback always provides the
// database certificate in addition to the parameters in the original
// alpnproxy.OnNewConnectionFunc, regardless of whether the database
// certificate is provided to the ALPN local proxy or not.
onNewConnection func(dbCert *x509.Certificate, lp *alpnproxy.LocalProxy, conn net.Conn)
}

// prepareLocalProxyOptions created localProxyOpts needed to create local proxy from localProxyConfig.
Expand All @@ -620,6 +628,17 @@ func prepareLocalProxyOptions(arg *localProxyConfig) (localProxyOpts, error) {
keyFile: keyFile,
}

if arg.onNewConnection != nil {
dbCert, err := certFromPath(arg.profile.DatabaseCertPathForCluster(arg.cliConf.SiteName, arg.routeToDatabase.ServiceName))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the cert can be reissued without restarting the local proxy ? If yes the dbCert will contains the stale cert data. Reading the cert from disk for each connection is not effective but maybe we can load cert periodically once per some timeout ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the cert can be reissued without restarting the local proxy?

Yes, the cert can be reissued if the user runs other tsh commands (or if Teleport Connect tries to do something). However, the LocalProxy is also using the stale cert with today's implementation.

Do we also want to add the functionality to overwrite the cert LocalProxy is using (without restarting) in this PR? There are many approaches we can do too.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cache the last cert loaded and check if it has expired? If it has expired then prompt user for credentials to reissue the cert.

if err != nil {
return localProxyOpts{}, trace.Wrap(err)
}

opts.onNewConnection = func(lp *alpnproxy.LocalProxy, conn net.Conn) {
arg.onNewConnection(dbCert, lp, conn)
}
}

// For SQL Server connections, local proxy must be configured with the
// client certificate that will be used to route connections.
if arg.routeToDatabase.Protocol == defaults.ProtocolSQLServer {
Expand Down Expand Up @@ -818,11 +837,7 @@ func dbInfoHasChanged(cf *CLIConf, certPath string) (bool, error) {
return false, nil
}

buff, err := os.ReadFile(certPath)
if err != nil {
return false, trace.Wrap(err)
}
cert, err := tlsca.ParseCertificatePEM(buff)
cert, err := certFromPath(certPath)
if err != nil {
return false, trace.Wrap(err)
}
Expand All @@ -842,6 +857,20 @@ func dbInfoHasChanged(cf *CLIConf, certPath string) (bool, error) {
return false, nil
}

// certFromPath parses the PEM-encoded certificate from the provided path. Note
// that this function expects only one certificate in the file.
func certFromPath(path string) (*x509.Certificate, error) {
buff, err := os.ReadFile(path)
if err != nil {
return nil, trace.Wrap(err)
}
cert, err := tlsca.ParseCertificatePEM(buff)
if err != nil {
return nil, trace.Wrap(err)
}
return cert, nil
}

// isMFADatabaseAccessRequired calls the IsMFARequired endpoint in order to get from user roles if access to the database
// requires MFA.
func isMFADatabaseAccessRequired(cf *CLIConf, tc *client.TeleportClient, database *tlsca.RouteToDatabase) (bool, error) {
Expand Down
24 changes: 18 additions & 6 deletions tool/tsh/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,23 @@ func onProxyCommandDB(cf *CLIConf) error {
}
}()

onNewConnection := func(dbCert *x509.Certificate, lp *alpnproxy.LocalProxy, conn net.Conn) {
if time.Now().After(dbCert.NotAfter) {
fmt.Fprintln(cf.Stdout())
fmt.Fprintln(cf.Stdout(), "Your database session is expired. Please restart the local proxy.")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fmt.Fprintln(cf.Stdout(), "Your database session is expired. Please restart the local proxy.")
fmt.Fprintln(cf.Stdout(), "Your database session has expired. Please restart the local proxy.")

I think "has" is more correct here though I'm not 100% sure.

Copy link
Copy Markdown
Contributor

@GavinFrazar GavinFrazar Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should they restart the local proxy? If the database cli is providing the certs and the local proxy is just forwarding, can't they just use tsh db login and it will work? This works without --tunnel already, they just need to know to do it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This message is shown at the end of tsh proxy db command to ask the user to rerun whatever tsh proxy db command they are currently running.

Copy link
Copy Markdown
Contributor

@GavinFrazar GavinFrazar Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see - does tsh proxy db not work with reissued certs if it is started with expired certs? I actually haven't tried that: I only tried waiting for the 1 minute cert to expire

lp.Close()
return
}
}

proxyOpts, err := prepareLocalProxyOptions(&localProxyConfig{
cliConf: cf,
teleportClient: client,
profile: profile,
routeToDatabase: routeToDatabase,
listener: listener,
localProxyTunnel: cf.LocalProxyTunnel,
onNewConnection: onNewConnection,
})
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -421,12 +431,13 @@ func onProxyCommandDB(cf *CLIConf) error {
}

type localProxyOpts struct {
proxyAddr string
listener net.Listener
protocols []alpncommon.Protocol
insecure bool
certFile string
keyFile string
proxyAddr string
listener net.Listener
protocols []alpncommon.Protocol
insecure bool
certFile string
keyFile string
onNewConnection alpnproxy.OnNewConnectionFunc
}

// protocol returns the first protocol or string if configuration doesn't contain any protocols.
Expand Down Expand Up @@ -458,6 +469,7 @@ func mkLocalProxy(ctx context.Context, opts localProxyOpts) (*alpnproxy.LocalPro
ParentContext: ctx,
SNI: address.Host(),
Certs: certs,
OnNewConnection: opts.onNewConnection,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down