diff --git a/integration/helpers/instance.go b/integration/helpers/instance.go index c1ea4000ea71e..94aa57e3cd1f4 100644 --- a/integration/helpers/instance.go +++ b/integration/helpers/instance.go @@ -28,6 +28,12 @@ import ( "testing" "time" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" @@ -46,11 +52,6 @@ import ( "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" ) const ( @@ -1214,8 +1215,7 @@ func (i *TeleInstance) NewClientWithCreds(cfg ClientConfig, creds UserCreds) (tc return clt, nil } -// NewUnauthenticatedClient returns a fully configured and pre-authenticated client -// (pre-authenticated with server CAs and signed session key) +// NewUnauthenticatedClient returns a fully configured and un-authenticated client func (i *TeleInstance) NewUnauthenticatedClient(cfg ClientConfig) (tc *client.TeleportClient, err error) { keyDir, err := os.MkdirTemp(i.Config.DataDir, "tsh") if err != nil { @@ -1274,7 +1274,12 @@ func (i *TeleInstance) NewClient(cfg ClientConfig) (*client.TeleportClient, erro if err != nil { return nil, trace.Wrap(err) } + return i.AddClientCredentials(tc, cfg) +} +// AddClientCredentials adds authenticated credentials to a client. +// (server CAs and signed session key). +func (i *TeleInstance) AddClientCredentials(tc *client.TeleportClient, cfg ClientConfig) (*client.TeleportClient, error) { // Generate certificates for the user simulating login. creds, err := GenerateUserCreds(UserCredsRequest{ Process: i.Process, diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index 6b1fac0e0efb7..ac48f9497538e 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -29,6 +29,7 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -41,6 +42,7 @@ import ( "github.com/gravitational/teleport/integration/kube" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth/testauthority" + libclient "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/srv/alpnproxy" @@ -754,6 +756,72 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) { require.NoError(t, client.Close()) }) }) + + t.Run("authenticated tunnel with cert renewal", func(t *testing.T) { + // get a teleport client + tc, err := pack.Root.Cluster.NewClient(helpers.ClientConfig{ + Login: pack.Root.User.GetName(), + Cluster: pack.Root.Cluster.Secrets.SiteName, + }) + require.NoError(t, err) + routeToDatabase := tlsca.RouteToDatabase{ + ServiceName: pack.Root.MysqlService.Name, + Protocol: pack.Root.MysqlService.Protocol, + Username: "root", + } + // inject a fake clock into the middleware so we can control when it thinks certs have expired + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // configure local proxy without certs but with cert checking/reissuing middleware + // local proxy middleware should fetch a DB cert when the local proxy starts + lp := mustStartALPNLocalProxyWithConfig(t, alpnproxy.LocalProxyConfig{ + RemoteProxyAddr: pack.Root.Cluster.SSHProxy, + Protocols: []alpncommon.Protocol{alpncommon.ProtocolMySQL}, + InsecureSkipVerify: true, + Middleware: libclient.NewDBCertChecker(tc, routeToDatabase, fakeClock), + }) + + client, err := mysql.MakeTestClientWithoutTLS(lp.GetAddr(), routeToDatabase) + require.NoError(t, err) + + // Execute a query. + result, err := client.Execute("select 1") + require.NoError(t, err) + require.Equal(t, mysql.TestQueryResponse, result) + + // Disconnect. + require.NoError(t, client.Close()) + certs := lp.GetCerts() + require.NotEmpty(t, certs) + cert1, err := utils.TLSCertToX509(certs[0]) + require.NoError(t, err) + // sanity check that cert equality check works + require.Equal(t, cert1, cert1, "cert should be equal to itself") + + // mock db cert expiration (as far as the middleware thinks anyway) + // Unfortunately, mocking cert expiration by advancing a fake clock + // does not cause an invalid certificate error even if no cert renewal is done by the middleware, + // because TLS handshakes are done with real system time. + require.Greater(t, cert1.NotAfter, fakeClock.Now()) + fakeClock.Advance(cert1.NotAfter.Sub(fakeClock.Now()) + time.Second) + + // Open a new connection + client, err = mysql.MakeTestClientWithoutTLS(lp.GetAddr(), routeToDatabase) + require.NoError(t, err) + + // Execute a query. + result, err = client.Execute("select 1") + require.NoError(t, err) + require.Equal(t, mysql.TestQueryResponse, result) + + // Disconnect. + require.NoError(t, client.Close()) + certs = lp.GetCerts() + require.NotEmpty(t, certs) + cert2, err := utils.TLSCertToX509(certs[0]) + require.NoError(t, err) + require.NotEqual(t, cert1, cert2, "cert should have been renewed by middleware") + }) } // TestALPNSNIProxyAppAccess tests application access via ALPN SNI proxy service. diff --git a/lib/client/api.go b/lib/client/api.go index ddf1ce1b29bba..6c4f73fd81521 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1720,7 +1720,7 @@ func (tc *TeleportClient) ReissueUserCerts(ctx context.Context, cachePolicy Cert // (according to RBAC), IssueCertsWithMFA will: // - for SSH certs, return the existing Key from the keystore. // - for TLS certs, fall back to ReissueUserCerts. -func (tc *TeleportClient) IssueUserCertsWithMFA(ctx context.Context, params ReissueParams) (*Key, error) { +func (tc *TeleportClient) IssueUserCertsWithMFA(ctx context.Context, params ReissueParams, applyOpts func(opts *PromptMFAChallengeOpts)) (*Key, error) { ctx, span := tc.Tracer.Start( ctx, "teleportClient/IssueUserCertsWithMFA", @@ -1737,7 +1737,7 @@ func (tc *TeleportClient) IssueUserCertsWithMFA(ctx context.Context, params Reis return proxyClient.IssueUserCertsWithMFA( ctx, params, func(ctx context.Context, proxyAddr string, c *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - return tc.PromptMFAChallenge(ctx, proxyAddr, c, nil /* applyOpts */) + return tc.PromptMFAChallenge(ctx, proxyAddr, c, applyOpts) }) } diff --git a/lib/client/api_login_test.go b/lib/client/api_login_test.go index fcb148dbc4d8b..ff22aac7806e0 100644 --- a/lib/client/api_login_test.go +++ b/lib/client/api_login_test.go @@ -264,6 +264,7 @@ func TestTeleportClient_PromptMFAChallenge(t *testing.T) { challenge := &proto.MFAAuthenticateChallenge{} customizedOpts := &client.PromptMFAChallengeOpts{ + HintBeforePrompt: "some hint explaining the imminent prompt", PromptDevicePrefix: "llama", Quiet: true, AllowStdinHijack: true, diff --git a/lib/client/local_proxy_middleware.go b/lib/client/local_proxy_middleware.go new file mode 100644 index 0000000000000..9d49a4e5bee6e --- /dev/null +++ b/lib/client/local_proxy_middleware.go @@ -0,0 +1,161 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/srv/alpnproxy" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" +) + +// DBCertChecker is a middleware that ensures that the local proxy has valid TLS database certs. +type DBCertChecker struct { + // tc is a TeleportClient used to reissue certificates when necessary. + tc *TeleportClient + // dbRoute contains database routing information. + dbRoute tlsca.RouteToDatabase + // Clock specifies the time provider. Will be used to override the time anchor + // for TLS certificate verification. + // Defaults to real clock if unspecified + clock clockwork.Clock +} + +func NewDBCertChecker(tc *TeleportClient, dbRoute tlsca.RouteToDatabase, clock clockwork.Clock) alpnproxy.LocalProxyMiddleware { + if clock == nil { + clock = clockwork.NewRealClock() + } + return &DBCertChecker{ + tc: tc, + dbRoute: dbRoute, + clock: clock, + } +} + +var _ alpnproxy.LocalProxyMiddleware = (*DBCertChecker)(nil) + +// OnNewConnection is a callback triggered when a new downstream connection is +// accepted by the local proxy. +func (c *DBCertChecker) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error { + return trace.Wrap(c.ensureValidCerts(ctx, lp)) +} + +// OnStart is a callback triggered when the local proxy starts. +func (c *DBCertChecker) OnStart(ctx context.Context, lp *alpnproxy.LocalProxy) error { + return trace.Wrap(c.ensureValidCerts(ctx, lp)) +} + +// checkCerts checks if the local proxy TLS certs are configured, not expired, and match the db route. +func (c *DBCertChecker) checkCerts(lp *alpnproxy.LocalProxy) error { + log.Debug("checking local proxy database certs") + certs := lp.GetCerts() + if len(certs) == 0 { + return trace.Wrap(trace.NotFound("local proxy has no TLS certificates configured")) + } + cert, err := utils.TLSCertToX509(certs[0]) + if err != nil { + return trace.Wrap(err) + } + err = utils.VerifyCertificateExpiry(cert, c.clock) + if err != nil { + return trace.Wrap(err) + } + identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + if err != nil { + return trace.Wrap(err) + } + if c.dbRoute.Username != "" && c.dbRoute.Username != identity.RouteToDatabase.Username { + msg := fmt.Sprintf("certificate subject is for user %s, but need %s", identity.RouteToDatabase.Username, c.dbRoute.Username) + return trace.Wrap(errors.New(msg)) + } + if c.dbRoute.Database != "" && c.dbRoute.Database != identity.RouteToDatabase.Database { + msg := fmt.Sprintf("certificate subject is for database name %s, but need %s", identity.RouteToDatabase.Database, c.dbRoute.Database) + return trace.Wrap(errors.New(msg)) + } + return nil +} + +// ensureValidCerts ensures that the local proxy is configured with valid certs. +func (c *DBCertChecker) ensureValidCerts(ctx context.Context, lp *alpnproxy.LocalProxy) error { + if err := c.checkCerts(lp); err != nil { + log.WithError(err).Debug("local proxy tunnel certificates need to be reissued") + } else { + return nil + } + return trace.Wrap(c.renewCerts(ctx, lp)) +} + +// renewCerts attempts to renew the database certs for the local proxy. +func (c *DBCertChecker) renewCerts(ctx context.Context, lp *alpnproxy.LocalProxy) error { + var accessRequests []string + if profile, err := StatusCurrent(c.tc.HomePath, c.tc.WebProxyAddr, ""); err != nil { + log.WithError(err).Warn("unable to load profile, requesting database certs without access requests") + } else { + accessRequests = profile.ActiveRequests.AccessRequests + } + + hint := fmt.Sprintf("MFA is required to access database %q", c.dbRoute.ServiceName) + var key *Key + if err := RetryWithRelogin(ctx, c.tc, func() error { + newKey, err := c.tc.IssueUserCertsWithMFA(ctx, ReissueParams{ + RouteToCluster: c.tc.SiteName, + RouteToDatabase: proto.RouteToDatabase{ + ServiceName: c.dbRoute.ServiceName, + Protocol: c.dbRoute.Protocol, + Username: c.dbRoute.Username, + Database: c.dbRoute.Database, + }, + AccessRequests: accessRequests, + }, func(opts *PromptMFAChallengeOpts) { + opts.HintBeforePrompt = hint + }) + key = newKey + return trace.Wrap(err) + }); err != nil { + return trace.Wrap(err) + } + + dbCert, ok := key.DBTLSCerts[c.dbRoute.ServiceName] + if !ok { + return trace.NotFound("database '%v' TLS cert missing", c.dbRoute.ServiceName) + } + tlsCert, err := keys.X509KeyPair(dbCert, key.PrivateKeyPEM()) + if err != nil { + return trace.Wrap(err) + } + x509cert, err := x509.ParseCertificate(tlsCert.Certificate[0]) + if err != nil { + return trace.Wrap(err) + } + certTTL := x509cert.NotAfter.Sub(c.clock.Now()).Round(time.Minute) + fmt.Printf("Database certificate renewed: valid until %s [valid for %v]\n", + x509cert.NotAfter.Format(time.RFC3339), certTTL) + lp.SetCerts([]tls.Certificate{tlsCert}) + return nil +} diff --git a/lib/client/mfa.go b/lib/client/mfa.go index cf27d997a092d..84b932031ac06 100644 --- a/lib/client/mfa.go +++ b/lib/client/mfa.go @@ -50,6 +50,10 @@ func (p *mfaPrompt) PromptPIN() (string, error) { // PromptMFAChallengeOpts groups optional settings for PromptMFAChallenge. type PromptMFAChallengeOpts struct { + // HintBeforePrompt is an optional hint message to print before an MFA prompt. + // It is used to provide context about why the user is being prompted where it may + // not be obvious. + HintBeforePrompt string // PromptDevicePrefix is an optional prefix printed before "security key" or // "device". It is used to emphasize between different kinds of devices, like // registered vs new. @@ -105,6 +109,10 @@ func PromptMFAChallenge(ctx context.Context, c *proto.MFAAuthenticateChallenge, if opts == nil { opts = &PromptMFAChallengeOpts{} } + writer := os.Stderr + if opts.HintBeforePrompt != "" { + fmt.Fprintln(writer, opts.HintBeforePrompt) + } promptDevicePrefix := opts.PromptDevicePrefix quiet := opts.Quiet @@ -175,7 +183,7 @@ func PromptMFAChallenge(ctx context.Context, c *proto.MFAAuthenticateChallenge, msg = fmt.Sprintf("Enter an OTP code from a %sdevice", promptDevicePrefix) } - otp, err := prompt.Password(otpCtx, os.Stderr, prompt.Stdin(), msg) + otp, err := prompt.Password(otpCtx, writer, prompt.Stdin(), msg) if err != nil { respC <- response{kind: kind, err: err} return @@ -202,7 +210,7 @@ func PromptMFAChallenge(ctx context.Context, c *proto.MFAAuthenticateChallenge, defer wg.Done() log.Debugf("WebAuthn: prompting devices with origin %q", origin) - prompt := wancli.NewDefaultPrompt(ctx, os.Stderr) + prompt := wancli.NewDefaultPrompt(ctx, writer) prompt.SecondTouchMessage = fmt.Sprintf("Tap your %ssecurity key to complete login", promptDevicePrefix) switch { case quiet: diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index f636b56453407..395f762d39a58 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -72,6 +72,17 @@ type LocalProxyConfig struct { RootCAs *x509.CertPool // ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required. ALPNConnUpgradeRequired bool + // Middleware provides callback functions to the local proxy. + Middleware LocalProxyMiddleware +} + +// LocalProxyMiddleware provides callback functions for LocalProxy. +type LocalProxyMiddleware interface { + // OnNewConnection is a callback triggered when a new downstream connection is + // accepted by the local proxy. + OnNewConnection(ctx context.Context, lp *LocalProxy, conn net.Conn) error + // OnStart is a callback triggered when the local proxy starts. + OnStart(ctx context.Context, lp *LocalProxy) error } // CheckAndSetDefaults verifies the constraints for LocalProxyConfig. @@ -114,6 +125,12 @@ func NewLocalProxy(cfg LocalProxyConfig) (*LocalProxy, error) { // Start starts the LocalProxy. func (l *LocalProxy) Start(ctx context.Context) error { + if l.cfg.Middleware != nil { + err := l.cfg.Middleware.OnStart(ctx, l) + if err != nil { + return trace.Wrap(err) + } + } for { select { case <-ctx.Done(): @@ -129,6 +146,14 @@ func (l *LocalProxy) Start(ctx context.Context) error { log.WithError(err).Errorf("Failed to accept client connection.") return trace.Wrap(err) } + + if l.cfg.Middleware != nil { + if err := l.cfg.Middleware.OnNewConnection(ctx, l, conn); err != nil { + log.WithError(err).Errorf("Middleware failed to handle new connection.") + continue + } + } + go func() { if err := l.handleDownstreamConnection(ctx, conn); err != nil { if utils.IsOKNetworkError(err) { @@ -156,7 +181,7 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC NextProtos: l.cfg.GetProtocols(), InsecureSkipVerify: l.cfg.InsecureSkipVerify, ServerName: l.cfg.SNI, - Certificates: l.cfg.Certs, + Certificates: l.GetCerts(), RootCAs: l.cfg.RootCAs, }, }) @@ -191,7 +216,7 @@ func (l *LocalProxy) StartAWSAccessProxy(ctx context.Context) error { NextProtos: l.cfg.GetProtocols(), InsecureSkipVerify: l.cfg.InsecureSkipVerify, ServerName: l.cfg.SNI, - Certificates: l.cfg.Certs, + Certificates: l.GetCerts(), }, } proxy := &httputil.ReverseProxy{ @@ -221,3 +246,11 @@ func (l *LocalProxy) StartAWSAccessProxy(ctx context.Context) error { } return nil } + +func (l *LocalProxy) GetCerts() []tls.Certificate { + return l.cfg.Certs +} + +func (l *LocalProxy) SetCerts(certs []tls.Certificate) { + l.cfg.Certs = certs +} diff --git a/lib/srv/alpnproxy/local_proxy_test.go b/lib/srv/alpnproxy/local_proxy_test.go index cde2db0b8c692..c3b38c70f8e85 100644 --- a/lib/srv/alpnproxy/local_proxy_test.go +++ b/lib/srv/alpnproxy/local_proxy_test.go @@ -19,9 +19,11 @@ package alpnproxy import ( "bytes" "context" + "net" "net/http" "net/http/httptest" "net/url" + "sync" "testing" "time" @@ -135,6 +137,115 @@ func TestHandleAWSAccessS3Signing(t *testing.T) { require.NoError(t, err) } +type mockMiddlewareCounter struct { + sync.Mutex + recvStateChange chan struct{} + connCount int + startCount int +} + +func newMockMiddlewareCounter() *mockMiddlewareCounter { + return &mockMiddlewareCounter{ + recvStateChange: make(chan struct{}, 1), + } +} + +func (m *mockMiddlewareCounter) onStateChange() { + select { + case m.recvStateChange <- struct{}{}: + default: + } +} + +func (m *mockMiddlewareCounter) OnNewConnection(_ context.Context, _ *LocalProxy, _ net.Conn) error { + m.Lock() + defer m.Unlock() + m.connCount++ + m.onStateChange() + return nil +} + +func (m *mockMiddlewareCounter) OnStart(_ context.Context, _ *LocalProxy) error { + m.Lock() + defer m.Unlock() + m.startCount++ + m.onStateChange() + return nil +} + +func (m *mockMiddlewareCounter) waitForCounts(t *testing.T, wantStartCount int, wantConnCount int) { + timer := time.NewTimer(time.Second * 3) + for { + var ( + startCount int + connCount int + ) + m.Lock() + startCount = m.startCount + connCount = m.connCount + m.Unlock() + if startCount == wantStartCount && connCount == wantConnCount { + return + } + + select { + case <-m.recvStateChange: + continue + case <-timer.C: + require.FailNow(t, + "timeout waiting for middleware state change", + "have startCount=%d connCount=%d, want startCount=%d connCount=%d", + startCount, connCount, wantStartCount, wantConnCount) + } + } +} + +var _ LocalProxyMiddleware = (*mockMiddlewareCounter)(nil) + +func TestMiddleware(t *testing.T) { + m := newMockMiddlewareCounter() + hs := httptest.NewTLSServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})) + lp, err := NewLocalProxy(LocalProxyConfig{ + Listener: mustCreateLocalListener(t), + RemoteProxyAddr: hs.Listener.Addr().String(), + Protocols: []common.Protocol{common.ProtocolHTTP}, + ParentContext: context.Background(), + InsecureSkipVerify: true, + Middleware: m, + }) + require.NoError(t, err) + t.Cleanup(func() { + err := lp.Close() + require.NoError(t, err) + hs.Close() + }) + + m.waitForCounts(t, 0, 0) + go func() { + err := lp.Start(context.Background()) + require.NoError(t, err) + }() + + // ensure that OnStart middleware is called when the proxy starts + m.waitForCounts(t, 1, 0) + url := url.URL{ + Scheme: "http", + Host: lp.GetAddr(), + Path: "/", + } + + pr := bytes.NewReader([]byte("payload content")) + req, err := http.NewRequest(http.MethodGet, url.String(), pr) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + // ensure that OnNewConnection middleware is called when a new connection is made to the proxy + m.waitForCounts(t, 1, 1) +} + func createAWSAccessProxySuite(t *testing.T, cred *credentials.Credentials) *LocalProxy { hs := httptest.NewTLSServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})) diff --git a/lib/tlsca/ca.go b/lib/tlsca/ca.go index 36d0b52d289ec..4e8a5e338c718 100644 --- a/lib/tlsca/ca.go +++ b/lib/tlsca/ca.go @@ -218,9 +218,9 @@ type RouteToDatabase struct { } // String returns string representation of the database routing struct. -func (d RouteToDatabase) String() string { +func (r RouteToDatabase) String() string { return fmt.Sprintf("Database(Service=%v, Protocol=%v, Username=%v, Database=%v)", - d.ServiceName, d.Protocol, d.Username, d.Database) + r.ServiceName, r.Protocol, r.Username, r.Database) } // GetRouteToApp returns application routing data. If missing, returns an error. diff --git a/lib/utils/certs.go b/lib/utils/certs.go index b0f121b453bd7..d4ca5ec554ea5 100644 --- a/lib/utils/certs.go +++ b/lib/utils/certs.go @@ -19,6 +19,7 @@ import ( "crypto/ecdsa" "crypto/rand" "crypto/rsa" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" @@ -147,14 +148,14 @@ func VerifyCertificateExpiry(c *x509.Certificate, clock clockwork.Clock) error { return x509.CertificateInvalidError{ Cert: c, Reason: x509.Expired, - Detail: fmt.Sprintf("current time %s is before %s", now.Format(time.RFC3339), c.NotBefore.Format(time.RFC3339)), + Detail: fmt.Sprintf("current time %s is before %s", now.UTC().Format(time.RFC3339), c.NotBefore.UTC().Format(time.RFC3339)), } } if now.After(c.NotAfter) { return x509.CertificateInvalidError{ Cert: c, Reason: x509.Expired, - Detail: fmt.Sprintf("current time %s is after %s", now.Format(time.RFC3339), c.NotAfter.Format(time.RFC3339)), + Detail: fmt.Sprintf("current time %s is after %s", now.UTC().Format(time.RFC3339), c.NotAfter.UTC().Format(time.RFC3339)), } } return nil @@ -285,4 +286,13 @@ func NewCertPoolFromPath(path string) (*x509.CertPool, error) { return pool, nil } +// TLSCertToX509 is a helper function that converts a tls.Certificate into an *x509.Certificate +func TLSCertToX509(cert tls.Certificate) (*x509.Certificate, error) { + if len(cert.Certificate) < 1 { + return nil, trace.NotFound("invalid certificate length") + } + x509cert, err := x509.ParseCertificate(cert.Certificate[0]) + return x509cert, trace.Wrap(err) +} + const pemBlockCertificate = "CERTIFICATE" diff --git a/tool/tsh/db.go b/tool/tsh/db.go index c1adbcbb6225a..524a68bb1fbeb 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -275,18 +275,26 @@ func onDatabaseLogin(cf *CLIConf) error { return trace.Wrap(dbConnectTemplate.Execute(cf.Stdout(), templateData)) } -func databaseLogin(cf *CLIConf, tc *client.TeleportClient, db tlsca.RouteToDatabase) error { - log.Debugf("Fetching database access certificate for %s on cluster %v.", db, tc.SiteName) +// checkAndSetDBRouteDefaults checks the database route and sets defaults for certificate generation. +func checkAndSetDBRouteDefaults(r *tlsca.RouteToDatabase) error { // When generating certificate for MongoDB access, database username must // be encoded into it. This is required to be able to tell which database // user to authenticate the connection as. - if db.Protocol == defaults.ProtocolMongoDB && db.Username == "" { + if r.Protocol == defaults.ProtocolMongoDB && r.Username == "" { return trace.BadParameter("please provide the database user name using --db-user flag") } - if db.Protocol == defaults.ProtocolRedis && db.Username == "" { + if r.Protocol == defaults.ProtocolRedis && r.Username == "" { // Default to "default" in the same way as Redis does. We need the username to check access on our side. // ref: https://redis.io/commands/auth - db.Username = defaults.DefaultRedisUsername + r.Username = defaults.DefaultRedisUsername + } + return nil +} + +func databaseLogin(cf *CLIConf, tc *client.TeleportClient, db tlsca.RouteToDatabase) error { + log.Debugf("Fetching database access certificate for %s on cluster %v.", db, tc.SiteName) + if err := checkAndSetDBRouteDefaults(&db); err != nil { + return trace.Wrap(err) } profile, err := client.StatusCurrent(cf.HomePath, cf.Proxy, cf.IdentityFileIn) @@ -310,7 +318,7 @@ func databaseLogin(cf *CLIConf, tc *client.TeleportClient, db tlsca.RouteToDatab Database: db.Database, }, AccessRequests: profile.ActiveRequests.AccessRequests, - }) + }, nil /*applyOpts*/) return trace.Wrap(err) }); err != nil { return trace.Wrap(err) @@ -639,23 +647,30 @@ type localProxyConfig struct { } // prepareLocalProxyOptions created localProxyOpts needed to create local proxy from localProxyConfig. -func prepareLocalProxyOptions(arg *localProxyConfig) (localProxyOpts, error) { - // If user requested no client auth, open an authenticated tunnel using - // client cert/key of the database. +func prepareLocalProxyOptions(arg *localProxyConfig) (*localProxyOpts, error) { certFile := arg.cliConf.LocalProxyCertFile keyFile := arg.cliConf.LocalProxyKeyFile - if certFile == "" && arg.localProxyTunnel { - certFile = arg.profile.DatabaseCertPathForCluster(arg.cliConf.SiteName, arg.routeToDatabase.ServiceName) + if arg.routeToDatabase.Protocol == defaults.ProtocolSQLServer || (arg.localProxyTunnel && certFile == "") { + // For SQL Server connections, local proxy must be configured with the + // client certificate that will be used to route connections. + certFile = arg.profile.DatabaseCertPathForCluster(arg.teleportClient.SiteName, arg.routeToDatabase.ServiceName) keyFile = arg.profile.KeyPath() } + certs, err := mkLocalProxyCerts(certFile, keyFile) + if err != nil { + if !arg.localProxyTunnel { + return nil, trace.Wrap(err) + } + // local proxy with tunnel monitors its certs, so it's ok if a cert file can't be loaded. + certs = nil + } - opts := localProxyOpts{ + opts := &localProxyOpts{ proxyAddr: arg.teleportClient.WebProxyAddr, listener: arg.listener, protocols: []common.Protocol{common.Protocol(arg.routeToDatabase.Protocol)}, insecure: arg.cliConf.InsecureSkipVerify, - certFile: certFile, - keyFile: keyFile, + certs: certs, alpnConnUpgradeRequired: alpnproxy.IsALPNConnUpgradeRequired(arg.teleportClient.WebProxyAddr, arg.cliConf.InsecureSkipVerify), } @@ -664,16 +679,17 @@ func prepareLocalProxyOptions(arg *localProxyConfig) (localProxyOpts, error) { if opts.alpnConnUpgradeRequired { profileCAs, err := utils.NewCertPoolFromPath(arg.profile.CACertPathForCluster(arg.rootClusterName)) if err != nil { - return localProxyOpts{}, trace.Wrap(err) + return nil, trace.Wrap(err) } opts.rootCAs = profileCAs } - // For SQL Server connections, local proxy must be configured with the - // client certificate that will be used to route connections. - if arg.routeToDatabase.Protocol == defaults.ProtocolSQLServer { - opts.certFile = arg.profile.DatabaseCertPathForCluster(arg.teleportClient.SiteName, arg.routeToDatabase.ServiceName) - opts.keyFile = arg.profile.KeyPath() + if arg.localProxyTunnel { + dbRoute := *arg.routeToDatabase + if err := checkAndSetDBRouteDefaults(&dbRoute); err != nil { + return nil, trace.Wrap(err) + } + opts.middleware = client.NewDBCertChecker(arg.teleportClient, dbRoute, nil) } // To set correct MySQL server version DB proxy needs additional protocol. @@ -682,7 +698,7 @@ func prepareLocalProxyOptions(arg *localProxyConfig) (localProxyOpts, error) { var err error arg.database, err = getDatabase(arg.cliConf, arg.teleportClient, arg.routeToDatabase.ServiceName) if err != nil { - return localProxyOpts{}, trace.Wrap(err) + return nil, trace.Wrap(err) } } @@ -807,6 +823,13 @@ func getDatabase(cf *CLIConf, tc *client.TeleportClient, dbName string) (types.D } func needDatabaseRelogin(cf *CLIConf, tc *client.TeleportClient, database *tlsca.RouteToDatabase, profile *client.ProfileStatus) (bool, error) { + if cf.LocalProxyTunnel { + // Don't login to database here if local proxy tunnel is enabled. + // When local proxy tunnel is enabled, the local proxy will check if DB login is needed when + // it starts and on each new connection. + return false, nil + } + found := false activeDatabases, err := profile.DatabasesForCluster(tc.SiteName) if err != nil { @@ -842,7 +865,7 @@ func needDatabaseRelogin(cf *CLIConf, tc *client.TeleportClient, database *tlsca } // maybeDatabaseLogin checks if cert is still valid or DB connection requires -// MFA. If yes trigger db login logic. +// MFA, and that client is not requesting an authenticated local proxy tunnel. If yes trigger db login logic. func maybeDatabaseLogin(cf *CLIConf, tc *client.TeleportClient, profile *client.ProfileStatus, db *tlsca.RouteToDatabase) error { reloginNeeded, err := needDatabaseRelogin(cf, tc, db, profile) if err != nil { diff --git a/tool/tsh/kube.go b/tool/tsh/kube.go index 7ec5a0f9144c2..b7221a324b9f3 100644 --- a/tool/tsh/kube.go +++ b/tool/tsh/kube.go @@ -157,7 +157,7 @@ func (c *kubeJoinCommand) run(cf *CLIConf) error { k, err = tc.IssueUserCertsWithMFA(cf.Context, client.ReissueParams{ RouteToCluster: cluster, KubernetesCluster: kubeCluster, - }) + }, nil /*applyOpts*/) return trace.Wrap(err) }) @@ -600,7 +600,7 @@ func (c *kubeCredentialsCommand) run(cf *CLIConf) error { k, err = tc.IssueUserCertsWithMFA(cf.Context, client.ReissueParams{ RouteToCluster: c.teleportCluster, KubernetesCluster: c.kubeCluster, - }) + }, nil /*applyOpts*/) return err }) if err != nil { diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 9e97b3f155f54..7bfbdafb77f2a 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -484,10 +484,10 @@ type localProxyOpts struct { listener net.Listener protocols []alpncommon.Protocol insecure bool - certFile string - keyFile string + certs []tls.Certificate rootCAs *x509.CertPool alpnConnUpgradeRequired bool + middleware alpnproxy.LocalProxyMiddleware } // protocol returns the first protocol or string if configuration doesn't contain any protocols. @@ -498,7 +498,7 @@ func (l *localProxyOpts) protocol() string { return string(l.protocols[0]) } -func mkLocalProxy(ctx context.Context, opts localProxyOpts) (*alpnproxy.LocalProxy, error) { +func mkLocalProxy(ctx context.Context, opts *localProxyOpts) (*alpnproxy.LocalProxy, error) { alpnProtocol, err := alpncommon.ToALPNProtocol(opts.protocol()) if err != nil { return nil, trace.Wrap(err) @@ -507,10 +507,6 @@ func mkLocalProxy(ctx context.Context, opts localProxyOpts) (*alpnproxy.LocalPro if err != nil { return nil, trace.Wrap(err) } - certs, err := mkLocalProxyCerts(opts.certFile, opts.keyFile) - if err != nil { - return nil, trace.Wrap(err) - } protocols := append([]alpncommon.Protocol{alpnProtocol}, opts.protocols...) if alpncommon.HasPingSupport(alpnProtocol) { @@ -524,9 +520,10 @@ func mkLocalProxy(ctx context.Context, opts localProxyOpts) (*alpnproxy.LocalPro Listener: opts.listener, ParentContext: ctx, SNI: address.Host(), - Certs: certs, + Certs: opts.certs, RootCAs: opts.rootCAs, ALPNConnUpgradeRequired: opts.alpnConnUpgradeRequired, + Middleware: opts.middleware, }) if err != nil { return nil, trace.Wrap(err) @@ -690,10 +687,7 @@ func loadAppCertificate(tc *libclient.TeleportClient, appName string) (tls.Certi // getTLSCertExpireTime returns the certificate NotAfter time. func getTLSCertExpireTime(cert tls.Certificate) (time.Time, error) { - if len(cert.Certificate) < 1 { - return time.Time{}, trace.NotFound("invalid certificate length") - } - x509cert, err := x509.ParseCertificate(cert.Certificate[0]) + x509cert, err := utils.TLSCertToX509(cert) if err != nil { return time.Time{}, trace.Wrap(err) }