Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions integration/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package proxy
import (
"bytes"
"context"
"crypto/x509"
"net"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -778,6 +779,7 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {
Protocols: []alpncommon.Protocol{alpncommon.ProtocolMySQL},
InsecureSkipVerify: true,
Middleware: libclient.NewDBCertChecker(tc, routeToDatabase, fakeClock),
Clock: fakeClock,
})

client, err := mysql.MakeTestClientWithoutTLS(lp.GetAddr(), routeToDatabase)
Expand All @@ -790,19 +792,15 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {

// Disconnect.
require.NoError(t, client.Close())
certs := lp.GetCerts()
require.NotEmpty(t, certs)
cert1, err := utils.TLSCertToX509(certs[0])
require.NoError(t, err)
// sanity check that cert equality check works
require.Equal(t, cert1, cert1, "cert should be equal to itself")

// mock db cert expiration (as far as the middleware thinks anyway)
// Unfortunately, mocking cert expiration by advancing a fake clock
// does not cause an invalid certificate error even if no cert renewal is done by the middleware,
// because TLS handshakes are done with real system time.
require.Greater(t, cert1.NotAfter, fakeClock.Now())
fakeClock.Advance(cert1.NotAfter.Sub(fakeClock.Now()) + time.Second)
// advance the fake clock and verify that the local proxy thinks its cert expired.
fakeClock.Advance(time.Hour * 48)
err = lp.CheckDBCerts(routeToDatabase)
require.Error(t, err)
var x509Err x509.CertificateInvalidError
require.ErrorAs(t, err, &x509Err)
require.Equal(t, x509Err.Reason, x509.Expired)
require.Contains(t, x509Err.Detail, "is after")

// Open a new connection
client, err = mysql.MakeTestClientWithoutTLS(lp.GetAddr(), routeToDatabase)
Expand All @@ -815,11 +813,6 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {

// Disconnect.
require.NoError(t, client.Close())
certs = lp.GetCerts()
require.NotEmpty(t, certs)
cert2, err := utils.TLSCertToX509(certs[0])
require.NoError(t, err)
require.NotEqual(t, cert1, cert2, "cert should have been renewed by middleware")
})
}

Expand Down
34 changes: 1 addition & 33 deletions lib/client/local_proxy_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"time"
Expand All @@ -32,7 +31,6 @@ import (
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
)

// DBCertChecker is a middleware that ensures that the local proxy has valid TLS database certs.
Expand Down Expand Up @@ -71,39 +69,9 @@ func (c *DBCertChecker) OnStart(ctx context.Context, lp *alpnproxy.LocalProxy) e
return trace.Wrap(c.ensureValidCerts(ctx, lp))
}

// checkCerts checks if the local proxy TLS certs are configured, not expired, and match the db route.
func (c *DBCertChecker) checkCerts(lp *alpnproxy.LocalProxy) error {
log.Debug("checking local proxy database certs")
certs := lp.GetCerts()
if len(certs) == 0 {
return trace.Wrap(trace.NotFound("local proxy has no TLS certificates configured"))
}
cert, err := utils.TLSCertToX509(certs[0])
if err != nil {
return trace.Wrap(err)
}
err = utils.VerifyCertificateExpiry(cert, c.clock)
if err != nil {
return trace.Wrap(err)
}
identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter)
if err != nil {
return trace.Wrap(err)
}
if c.dbRoute.Username != "" && c.dbRoute.Username != identity.RouteToDatabase.Username {
msg := fmt.Sprintf("certificate subject is for user %s, but need %s", identity.RouteToDatabase.Username, c.dbRoute.Username)
return trace.Wrap(errors.New(msg))
}
if c.dbRoute.Database != "" && c.dbRoute.Database != identity.RouteToDatabase.Database {
msg := fmt.Sprintf("certificate subject is for database name %s, but need %s", identity.RouteToDatabase.Database, c.dbRoute.Database)
return trace.Wrap(errors.New(msg))
}
return nil
}

// ensureValidCerts ensures that the local proxy is configured with valid certs.
func (c *DBCertChecker) ensureValidCerts(ctx context.Context, lp *alpnproxy.LocalProxy) error {
if err := c.checkCerts(lp); err != nil {
if err := lp.CheckDBCerts(c.dbRoute); err != nil {
log.WithError(err).Debug("local proxy tunnel certificates need to be reissued")
} else {
return nil
Expand Down
13 changes: 10 additions & 3 deletions lib/srv/alpnproxy/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ func mustGenSelfSignedCert(t *testing.T) *tlsca.CertAuthority {

type signOptions struct {
identity tlsca.Identity
clock clockwork.Clock
}

func withIdentity(identity tlsca.Identity) signOptionsFunc {
Expand All @@ -170,11 +171,18 @@ func withIdentity(identity tlsca.Identity) signOptionsFunc {
}
}

func withClock(clock clockwork.Clock) signOptionsFunc {
return func(o *signOptions) {
o.clock = clock
}
}

type signOptionsFunc func(o *signOptions)

func mustGenCertSignedWithCA(t *testing.T, ca *tlsca.CertAuthority, opts ...signOptionsFunc) tls.Certificate {
options := signOptions{
identity: tlsca.Identity{Username: "test-user"},
clock: clockwork.NewRealClock(),
}

for _, opt := range opts {
Expand All @@ -187,12 +195,11 @@ func mustGenCertSignedWithCA(t *testing.T, ca *tlsca.CertAuthority, opts ...sign
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

clock := clockwork.NewRealClock()
tlsCert, err := ca.GenerateCertificate(tlsca.CertificateRequest{
Clock: clock,
Clock: options.clock,
PublicKey: privateKey.Public(),
Subject: subj,
NotAfter: clock.Now().UTC().Add(time.Minute),
NotAfter: options.clock.Now().UTC().Add(time.Minute),
DNSNames: []string{"localhost", "*.localhost"},
})
require.NoError(t, err)
Expand Down
76 changes: 66 additions & 10 deletions lib/srv/alpnproxy/local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ import (
"net"
"net/http"
"net/http/httputil"
"sync"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/aws"
)
Expand All @@ -40,6 +43,7 @@ type LocalProxy struct {
cfg LocalProxyConfig
context context.Context
cancel context.CancelFunc
certsMu sync.RWMutex
}

// LocalProxyConfig is configuration for LocalProxy.
Expand Down Expand Up @@ -74,6 +78,10 @@ type LocalProxyConfig struct {
ALPNConnUpgradeRequired bool
// Middleware provides callback functions to the local proxy.
Middleware LocalProxyMiddleware
// Clock is used to override time in tests.
Clock clockwork.Clock
// Log is the Logger.
Log logrus.FieldLogger
}

// LocalProxyMiddleware provides callback functions for LocalProxy.
Expand All @@ -97,6 +105,12 @@ func (cfg *LocalProxyConfig) CheckAndSetDefaults() error {
if cfg.ParentContext == nil {
return trace.BadParameter("missing parent context")
}
if cfg.Clock == nil {
cfg.Clock = clockwork.NewRealClock()
}
if cfg.Log == nil {
cfg.Log = logrus.WithField(trace.Component, "localproxy")
}
return nil
}

Expand Down Expand Up @@ -144,15 +158,15 @@ func (l *LocalProxy) Start(ctx context.Context) error {
if utils.IsOKNetworkError(err) {
return nil
}
log.WithError(err).Errorf("Failed to accept client connection.")
l.cfg.Log.WithError(err).Error("Failed to accept client connection.")
return trace.Wrap(err)
}

if l.cfg.Middleware != nil {
if err := l.cfg.Middleware.OnNewConnection(ctx, l, conn); err != nil {
log.WithError(err).Error("Middleware failed to handle client connection.")
l.cfg.Log.WithError(err).Error("Middleware failed to handle client connection.")
if err := conn.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) {
log.WithError(err).Debug("Failed to close client connection.")
l.cfg.Log.WithError(err).Debug("Failed to close client connection.")
}
continue
}
Expand All @@ -163,7 +177,7 @@ func (l *LocalProxy) Start(ctx context.Context) error {
if utils.IsOKNetworkError(err) {
return
}
log.WithError(err).Errorf("Failed to handle connection.")
l.cfg.Log.WithError(err).Error("Failed to handle connection.")
}
}()
}
Expand All @@ -185,7 +199,7 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC
NextProtos: l.cfg.GetProtocols(),
InsecureSkipVerify: l.cfg.InsecureSkipVerify,
ServerName: l.cfg.SNI,
Certificates: l.GetCerts(),
Certificates: l.getCerts(),
RootCAs: l.cfg.RootCAs,
},
})
Expand All @@ -196,7 +210,7 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC

var upstreamConn net.Conn = tlsConn
if common.IsPingProtocol(common.Protocol(tlsConn.ConnectionState().NegotiatedProtocol)) {
log.Debug("Using ping connection")
l.cfg.Log.Debug("Using ping connection")
upstreamConn = NewPingConn(tlsConn)
}

Expand All @@ -220,7 +234,7 @@ func (l *LocalProxy) StartAWSAccessProxy(ctx context.Context) error {
NextProtos: l.cfg.GetProtocols(),
InsecureSkipVerify: l.cfg.InsecureSkipVerify,
ServerName: l.cfg.SNI,
Certificates: l.GetCerts(),
Certificates: l.getCerts(),
},
}
proxy := &httputil.ReverseProxy{
Expand All @@ -232,7 +246,7 @@ func (l *LocalProxy) StartAWSAccessProxy(ctx context.Context) error {
}
err := http.Serve(l.cfg.Listener, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if err := aws.VerifyAWSSignature(req, l.cfg.AWSCredentials); err != nil {
log.WithError(err).Errorf("AWS signature verification failed.")
l.cfg.Log.WithError(err).Error("AWS signature verification failed.")
rw.WriteHeader(http.StatusForbidden)
return
}
Expand All @@ -251,10 +265,52 @@ func (l *LocalProxy) StartAWSAccessProxy(ctx context.Context) error {
return nil
}

func (l *LocalProxy) GetCerts() []tls.Certificate {
// getCerts returns the local proxy's configured TLS certificates.
// For thread-safety, it is important that the returned slice and its contents are not be mutated by callers,
// therefore this method is not exported.
func (l *LocalProxy) getCerts() []tls.Certificate {
l.certsMu.RLock()
defer l.certsMu.RUnlock()
Comment thread
GavinFrazar marked this conversation as resolved.
return l.cfg.Certs
}

// CheckDBCerts checks the proxy certificates for expiration and that the cert subject matches a database route.
func (l *LocalProxy) CheckDBCerts(dbRoute tlsca.RouteToDatabase) error {
l.cfg.Log.Debug("checking local proxy database certs")
l.certsMu.RLock()
defer l.certsMu.RUnlock()
if len(l.cfg.Certs) == 0 {
return trace.NotFound("local proxy has no TLS certificates configured")
}
cert, err := utils.TLSCertToX509(l.cfg.Certs[0])
if err != nil {
return trace.Wrap(err)
}

// Check for cert expiration.
if err := utils.VerifyCertificateExpiry(cert, l.cfg.Clock); err != nil {
return trace.Wrap(err)
}

// Check the subject matches.
identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter)
if err != nil {
return trace.Wrap(err)
}
if dbRoute.Username != "" && dbRoute.Username != identity.RouteToDatabase.Username {
return trace.Errorf("certificate subject is for user %s, but need %s",
identity.RouteToDatabase.Username, dbRoute.Username)
}
if dbRoute.Database != "" && dbRoute.Database != identity.RouteToDatabase.Database {
return trace.Errorf("certificate subject is for database name %s, but need %s",
identity.RouteToDatabase.Database, dbRoute.Database)
}
return nil
}

// SetCerts sets the local proxy's configured TLS certificates.
func (l *LocalProxy) SetCerts(certs []tls.Certificate) {
l.certsMu.Lock()
defer l.certsMu.Unlock()
l.cfg.Certs = certs
}
Loading