diff --git a/api/client/alpn.go b/api/client/alpn.go index 15ac2451194cc..91a8d968b64d6 100644 --- a/api/client/alpn.go +++ b/api/client/alpn.go @@ -19,12 +19,28 @@ package client import ( "context" "crypto/tls" + "crypto/x509" "net" + "strings" "time" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/api/constants" ) +// GetClusterCAsFunc is a function to fetch cluster CAs. +type GetClusterCAsFunc func(ctx context.Context) (*x509.CertPool, error) + +// ClusterCAsFromCertPool returns a GetClusterCAsFunc with provided static cert +// pool. +func ClusterCAsFromCertPool(cas *x509.CertPool) GetClusterCAsFunc { + return func(_ context.Context) (*x509.CertPool, error) { + return cas, nil + } +} + // ALPNDialerConfig is the config for ALPNDialer. type ALPNDialerConfig struct { // KeepAlivePeriod defines period between keep alives. @@ -35,12 +51,17 @@ type ALPNDialerConfig struct { TLSConfig *tls.Config // ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required. ALPNConnUpgradeRequired bool + // GetClusterCAs is an optional callback function to fetch cluster + // CAs when connection upgrade is required. If not provided, it's assumed + // the proper CAs are already present in TLSConfig. + GetClusterCAs GetClusterCAsFunc } // ALPNDialer is a ContextDialer that dials a connection to the Proxy Service // with ALPN and SNI configured in the provided TLSConfig. An ALPN connection // upgrade is also performed at the initial connection, if an upgrade is -// required. +// required. If the negotiated protocol is a Ping protocol, it will return the +// de-multiplexed connection without the Ping. type ALPNDialer struct { cfg ALPNDialerConfig } @@ -52,25 +73,73 @@ func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer { } } -// DialContext implements ContextDialer. -func (d ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { +func (d *ALPNDialer) shouldUpdateTLSConfig() bool { + return d.shouldUpdateServerName() || d.shouldGetClusterCAs() +} + +// shouldUpdateServerName returns true if ServerName is not in the provided TLS +// config. It will default to the host of the dialing address. +func (d *ALPNDialer) shouldUpdateServerName() bool { + return d.cfg.TLSConfig.ServerName == "" +} + +// shouldGetClusterCAs returns true if RootCAs of the provided TLS config needs +// to be set to the Teleport cluster CAs. +// +// When Teleport Proxy is behind a L7 load balancer, the load balancer +// usually terminates TLS with public certs, and the Proxy is usually in +// private subnets with self-signed web certs. During the connection +// upgrade flow for TLS Routing, instead of serving these self-signed web +// certs, the TLS Routing handler at the Proxy server will present the +// Cluster CAs so clients here can still verify the server. +func (d *ALPNDialer) shouldGetClusterCAs() bool { + return d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil +} + +func (d *ALPNDialer) getTLSConfig(ctx context.Context, addr string) (*tls.Config, error) { if d.cfg.TLSConfig == nil { return nil, trace.BadParameter("missing TLS config") } + if !d.shouldUpdateTLSConfig() { + return d.cfg.TLSConfig, nil + } - dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, WithTLSConfig(d.cfg.TLSConfig)) - if d.cfg.ALPNConnUpgradeRequired { - dialer = newALPNConnUpgradeDialer(dialer, &tls.Config{ - InsecureSkipVerify: d.cfg.TLSConfig.InsecureSkipVerify, - }) + var err error + tlsConfig := d.cfg.TLSConfig.Clone() + if d.shouldGetClusterCAs() { + tlsConfig.RootCAs, err = d.cfg.GetClusterCAs(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + } + if d.shouldUpdateServerName() { + tlsConfig.ServerName, _, err = webclient.ParseHostPort(addr) + if err != nil { + return nil, trace.Wrap(err) + } + } + return tlsConfig, nil +} + +// DialContext implements ContextDialer. +func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + tlsConfig, err := d.getTLSConfig(ctx, addr) + if err != nil { + return nil, trace.Wrap(err) } + dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, + WithInsecureSkipVerify(d.cfg.TLSConfig.InsecureSkipVerify), + WithALPNConnUpgrade(d.cfg.ALPNConnUpgradeRequired), + WithALPNConnUpgradePing(shouldALPNConnUpgradeWithPing(tlsConfig)), + ) + conn, err := dialer.DialContext(ctx, network, addr) if err != nil { return nil, trace.Wrap(err) } - tlsConn := tls.Client(conn, d.cfg.TLSConfig) + tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.HandshakeContext(ctx); err != nil { defer tlsConn.Close() return nil, trace.Wrap(err) @@ -91,3 +160,25 @@ func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (*tls.Conn } return tlsConn, nil } + +// IsALPNPingProtocol checks if the provided protocol is suffixed with Ping. +func IsALPNPingProtocol(protocol string) bool { + return strings.HasSuffix(protocol, constants.ALPNSNIProtocolPingSuffix) +} + +// shouldALPNConnUpgradeWithPing returns true if Ping wrapper is required +// during connection upgrade. +func shouldALPNConnUpgradeWithPing(config *tls.Config) bool { + for _, proto := range config.NextProtos { + switch proto { + // Server usually sends SSH keepalives or HTTP2 pings every five + // minutes for reverse tunnel and SSH connections. Load balancers + // usually have a shorter idle timeout. Thus wrapping the connection + // with Ping protocol at the connection upgrade layer to keepalive. + case constants.ALPNSNIProtocolReverseTunnel, + constants.ALPNSNIProtocolSSH: + return true + } + } + return false +} diff --git a/api/client/alpn_conn_upgrade.go b/api/client/alpn_conn_upgrade.go index c82f91dd7ff63..c96dec445b0cf 100644 --- a/api/client/alpn_conn_upgrade.go +++ b/api/client/alpn_conn_upgrade.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/api/utils/pingconn" ) // IsALPNConnUpgradeRequired returns true if a tunnel is required through a HTTP @@ -145,18 +146,20 @@ func isALPNConnUpgradeRequiredByEnv(addr, envValue string) bool { type alpnConnUpgradeDialer struct { dialer ContextDialer tlsConfig *tls.Config + withPing bool } // newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer. -func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config) ContextDialer { +func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config, withPing bool) ContextDialer { return &alpnConnUpgradeDialer{ dialer: dialer, tlsConfig: tlsConfig, + withPing: withPing, } } // DialContext implements ContextDialer -func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { +func (d *alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { logrus.Debugf("ALPN connection upgrade for %v.", addr) conn, err := d.dialer.DialContext(ctx, network, addr) @@ -181,47 +184,58 @@ func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr st } tlsConn := tls.Client(conn, cfg) - - err = upgradeConnThroughWebAPI(tlsConn, url.URL{ + upgradeURL := url.URL{ Host: addr, Scheme: "https", Path: constants.WebAPIConnUpgrade, - }) + } + + conn, err = upgradeConnThroughWebAPI(tlsConn, upgradeURL, d.upgradeType()) if err != nil { - defer tlsConn.Close() - return nil, trace.Wrap(err) + return nil, trace.NewAggregate(tlsConn.Close(), err) } - return tlsConn, nil + return conn, nil } -func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error { +func (d *alpnConnUpgradeDialer) upgradeType() string { + if d.withPing { + return constants.WebAPIConnUpgradeTypeALPNPing + } + return constants.WebAPIConnUpgradeTypeALPN +} + +func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (net.Conn, error) { req, err := http.NewRequest(http.MethodGet, api.String(), nil) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - // For now, only "alpn" is supported. - req.Header.Add(constants.WebAPIConnUpgradeHeader, constants.WebAPIConnUpgradeTypeALPN) + req.Header.Add(constants.WebAPIConnUpgradeHeader, upgradeType) // Send the request and check if upgrade is successful. if err = req.Write(conn); err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } resp, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } defer resp.Body.Close() if http.StatusSwitchingProtocols != resp.StatusCode { if http.StatusNotFound == resp.StatusCode { - return trace.NotImplemented( - "connection upgrade call to %q failed with status code %v. Please upgrade the server and try again.", + return nil, trace.NotImplemented( + "connection upgrade call to %q with upgrade type %v failed with status code %v. Please upgrade the server and try again.", constants.WebAPIConnUpgrade, + upgradeType, resp.StatusCode, ) } - return trace.BadParameter("failed to switch Protocols %v", resp.StatusCode) + return nil, trace.BadParameter("failed to switch Protocols %v", resp.StatusCode) + } + + if upgradeType == constants.WebAPIConnUpgradeTypeALPNPing { + return pingconn.New(conn), nil } - return nil + return conn, nil } diff --git a/api/client/alpn_conn_upgrade_test.go b/api/client/alpn_conn_upgrade_test.go index 4c02b3ebfebb6..a52a412134a8a 100644 --- a/api/client/alpn_conn_upgrade_test.go +++ b/api/client/alpn_conn_upgrade_test.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/fixtures" + "github.com/gravitational/teleport/api/utils/pingconn" ) func TestIsALPNConnUpgradeRequired(t *testing.T) { @@ -123,42 +124,58 @@ func TestIsALPNConnUpgradeRequiredByEnv(t *testing.T) { func TestALPNConnUpgradeDialer(t *testing.T) { t.Parallel() - t.Run("connection upgraded", func(t *testing.T) { - ctx := context.Background() - - server := httptest.NewTLSServer(mockConnUpgradeHandler(t, "alpn", []byte("hello"))) - t.Cleanup(server.Close) - addr, err := url.Parse(server.URL) - require.NoError(t, err) - pool := x509.NewCertPool() - pool.AddCert(server.Certificate()) - - tlsConfig := &tls.Config{RootCAs: pool} - preDialer := NewDialer(ctx, 0, 5*time.Second) - dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) - conn, err := dialer.DialContext(ctx, "tcp", addr.Host) - require.NoError(t, err) - - data := make([]byte, 100) - n, err := conn.Read(data) - require.NoError(t, err) - require.Equal(t, string(data[:n]), "hello") - }) - - t.Run("connection upgrade API not found", func(t *testing.T) { - ctx := context.Background() - - server := httptest.NewTLSServer(http.NotFoundHandler()) - t.Cleanup(server.Close) - addr, err := url.Parse(server.URL) - require.NoError(t, err) + tests := []struct { + name string + serverHandler http.Handler + withPing bool + wantError bool + }{ + { + name: "connection upgrade", + serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")), + }, + { + name: "connection upgrade with ping", + serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")), + withPing: true, + }, + { + name: "connection upgrade API not found", + serverHandler: http.NotFoundHandler(), + wantError: true, + }, + } - tlsConfig := &tls.Config{InsecureSkipVerify: true} - preDialer := NewDialer(ctx, 0, 5*time.Second) - dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig) - _, err = dialer.DialContext(ctx, "tcp", addr.Host) - require.Error(t, err) - }) + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + server := httptest.NewTLSServer(test.serverHandler) + t.Cleanup(server.Close) + addr, err := url.Parse(server.URL) + require.NoError(t, err) + pool := x509.NewCertPool() + pool.AddCert(server.Certificate()) + + tlsConfig := &tls.Config{RootCAs: pool} + preDialer := newDirectDialer(0, 5*time.Second) + dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig, test.withPing) + conn, err := dialer.DialContext(ctx, "tcp", addr.Host) + if test.wantError { + require.Error(t, err) + return + } + require.NoError(t, err) + defer conn.Close() + + data := make([]byte, 100) + n, err := conn.Read(data) + require.NoError(t, err) + require.Equal(t, string(data[:n]), "hello") + }) + } } type mockALPNServer struct { @@ -218,6 +235,8 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe // mockConnUpgradeHandler mocks the server side implementation to handle an // upgrade request and sends back some data inside the tunnel. func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler { + t.Helper() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path) require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeHeader)) @@ -238,7 +257,18 @@ func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http require.NoError(t, response.Write(conn)) // Upgraded. - _, err = conn.Write(write) - require.NoError(t, err) + switch upgradeType { + case constants.WebAPIConnUpgradeTypeALPNPing: + // Wrap conn with Ping and write some pings. + pingConn := pingconn.New(conn) + pingConn.WritePing() + _, err = pingConn.Write(write) + require.NoError(t, err) + pingConn.WritePing() + + default: + _, err = conn.Write(write) + require.NoError(t, err) + } }) } diff --git a/api/client/alpn_test.go b/api/client/alpn_test.go new file mode 100644 index 0000000000000..92c8ddaedc1b7 --- /dev/null +++ b/api/client/alpn_test.go @@ -0,0 +1,115 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "crypto/tls" + "crypto/x509" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func TestALPNDialer_getTLSConfig(t *testing.T) { + t.Parallel() + cas := x509.NewCertPool() + + tests := []struct { + name string + input ALPNDialerConfig + wantTLSConfig *tls.Config + wantError bool + }{ + { + name: "missing tls config", + input: ALPNDialerConfig{}, + wantError: true, + }, + { + name: "no update", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{ + ServerName: "example.com", + }, + }, + wantTLSConfig: &tls.Config{ + ServerName: "example.com", + }, + }, + { + name: "no update when upgrade required", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{ + ServerName: "example.com", + RootCAs: cas, + }, + ALPNConnUpgradeRequired: true, + }, + wantTLSConfig: &tls.Config{ + ServerName: "example.com", + RootCAs: cas, + }, + }, + { + name: "name updated", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{}, + }, + wantTLSConfig: &tls.Config{ + ServerName: "example.com", + }, + }, + { + name: "get cas failed", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{}, + ALPNConnUpgradeRequired: true, + GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) { + return nil, trace.AccessDenied("fail it") + }, + }, + wantError: true, + }, + { + name: "cas updated", + input: ALPNDialerConfig{ + TLSConfig: &tls.Config{}, + ALPNConnUpgradeRequired: true, + GetClusterCAs: ClusterCAsFromCertPool(cas), + }, + wantTLSConfig: &tls.Config{ + ServerName: "example.com", + RootCAs: cas, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dialer := NewALPNDialer(test.input).(*ALPNDialer) + tlsConfig, err := dialer.getTLSConfig(context.Background(), "example.com:443") + if test.wantError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, test.wantTLSConfig, tlsConfig) + } + }) + } +} diff --git a/api/client/client.go b/api/client/client.go index f86e420aca234..081e713ebb8f5 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -151,9 +151,16 @@ func newClient(cfg Config, dialer ContextDialer, tlsConfig *tls.Config) *Client } // connectInBackground connects the client to the server in the background. -// The client will use the first credentials and the given dialer. If -// no dialer is given, the first address will be used. This address must -// be an auth server address. +// +// The client will use the first credentials and the given dialer. +// +// If no dialer is given, the first address will be used. If no +// ALPNSNIAuthDialClusterName is given, this address must be an auth server +// address. +// +// If ALPNSNIAuthDialClusterName is given, the address is expected to be a web +// proxy address and the client will connect auth through the web proxy server +// using TLS Routing. func connectInBackground(ctx context.Context, cfg Config) (*Client, error) { tlsConfig, err := cfg.Credentials[0].TLSConfig() if err != nil { @@ -259,7 +266,7 @@ func connect(ctx context.Context, cfg Config) (*Client, error) { addr: addr, }) if sshConfig != nil { - for _, cf := range []connectFunc{proxyConnect, tunnelConnect, tlsRoutingConnect} { + for _, cf := range []connectFunc{proxyConnect, tunnelConnect, tlsRoutingConnect, tlsRoutingWithConnUpgradeConnect} { syncConnect(ctx, cf, connectParams{ cfg: cfg, tlsConfig: tlsConfig, @@ -325,9 +332,14 @@ type ( } ) -// authConnect connects to the Teleport Auth Server directly. +// authConnect connects to the Teleport Auth Server directly or through Proxy. func authConnect(ctx context.Context, params connectParams) (*Client, error) { - dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithTLSConfig(params.tlsConfig)) + dialer := NewDialer(ctx, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, + WithInsecureSkipVerify(params.cfg.InsecureAddressDiscovery), + WithALPNConnUpgrade(params.cfg.ALPNConnUpgradeRequired), + WithALPNConnUpgradePing(true), // Use Ping protocol for long-lived connections. + ) + clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v as an auth server", params.addr) @@ -340,7 +352,7 @@ func tunnelConnect(ctx context.Context, params connectParams) (*Client, error) { if params.sshConfig == nil { return nil, trace.BadParameter("must provide ssh client config") } - dialer := newTunnelDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithTLSConfig(params.tlsConfig)) + dialer := newTunnelDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, WithInsecureSkipVerify(params.cfg.InsecureAddressDiscovery)) clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v as a reverse tunnel proxy", params.addr) @@ -353,7 +365,7 @@ func proxyConnect(ctx context.Context, params connectParams) (*Client, error) { if params.sshConfig == nil { return nil, trace.BadParameter("must provide ssh client config") } - dialer := NewProxyDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, params.addr, params.cfg.InsecureAddressDiscovery, WithTLSConfig(params.tlsConfig)) + dialer := NewProxyDialer(*params.sshConfig, params.cfg.KeepAlivePeriod, params.cfg.DialTimeout, params.addr, params.cfg.InsecureAddressDiscovery, WithInsecureSkipVerify(params.cfg.InsecureAddressDiscovery)) clt := newClient(params.cfg, dialer, params.tlsConfig) if err := clt.dialGRPC(ctx, params.addr); err != nil { return nil, trace.Wrap(err, "failed to connect to addr %v as a web proxy", params.addr) @@ -374,6 +386,20 @@ func tlsRoutingConnect(ctx context.Context, params connectParams) (*Client, erro return clt, nil } +// tlsRoutingWithConnUpgradeConnect connects to the Teleport Auth Server +// through the proxy using TLS Routing with ALPN connection upgrade. +func tlsRoutingWithConnUpgradeConnect(ctx context.Context, params connectParams) (*Client, error) { + if params.sshConfig == nil { + return nil, trace.BadParameter("must provide ssh client config") + } + dialer := newTLSRoutingWithConnUpgradeDialer(*params.sshConfig, params) + clt := newClient(params.cfg, dialer, params.tlsConfig) + if err := clt.dialGRPC(ctx, params.addr); err != nil { + return nil, trace.Wrap(err, "failed to connect to addr %v with TLS Routing with ALPN connection upgrade dialer", params.addr) + } + return clt, nil +} + // dialerConnect connects to the Teleport Auth Server through a custom dialer. // The dialer must provide the address in a custom ContextDialerFunc function. func dialerConnect(ctx context.Context, params connectParams) (*Client, error) { @@ -525,6 +551,17 @@ type Config struct { CircuitBreakerConfig breaker.Config // Context is the base context to use for dialing. If not provided context.Background is used Context context.Context + // ALPNConnUpgradeRequired indicates that ALPN connection upgrades are + // required for making TLS Routing requests. + // + // In DialInBackground mode without a Dialer, a valid value must be + // provided as it's assumed that the caller knows the context if connection + // upgrades are required for TLS Routing. + // + // In default mode, this value is optional as some of the connect methods + // will perform necessary tests to decide if connection upgrade is + // required. + ALPNConnUpgradeRequired bool } // CheckAndSetDefaults checks and sets default config values. @@ -558,7 +595,6 @@ func (c *Config) CheckAndSetDefaults() error { if !c.DialInBackground { c.DialOpts = append(c.DialOpts, grpc.WithBlock()) } - return nil } diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index e8e5d544b2884..693ac9ebd44d3 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -19,7 +19,9 @@ package client import ( "context" "crypto/tls" + "crypto/x509" "net" + "net/url" "time" "github.com/gravitational/trace" @@ -34,6 +36,44 @@ import ( "github.com/gravitational/teleport/api/utils/sshutils" ) +type dialConfig struct { + tlsConfig *tls.Config + // alpnConnUpgradeRequired specifies if ALPN connection upgrade is + // required. + alpnConnUpgradeRequired bool + // alpnConnUpgradeWithPing specifies if Ping is required during ALPN + // connection upgrade. This is only effective when alpnConnUpgradeRequired + // is true. + alpnConnUpgradeWithPing bool +} + +// WithInsecureSkipVerify specifies if dialing insecure when using an HTTPS proxy. +func WithInsecureSkipVerify(insecure bool) DialOption { + return func(cfg *dialProxyConfig) { + cfg.tlsConfig = &tls.Config{ + InsecureSkipVerify: insecure, + } + } +} + +// WithALPNConnUpgrade specifies if ALPN connection upgrade is required. +func WithALPNConnUpgrade(alpnConnUpgradeRequired bool) DialOption { + return func(cfg *dialProxyConfig) { + cfg.alpnConnUpgradeRequired = alpnConnUpgradeRequired + } +} + +// WithALPNConnUpgradePing specifies if Ping is required during ALPN connection +// upgrade. This is only effective when alpnConnUpgradeRequired is true. +func WithALPNConnUpgradePing(alpnConnUpgradeWithPing bool) DialOption { + return func(cfg *dialProxyConfig) { + cfg.alpnConnUpgradeWithPing = alpnConnUpgradeWithPing + } +} + +// DialOption allows setting options as functional arguments to api.NewDialer. +type DialOption func(cfg *dialConfig) + // ContextDialer represents network dialer interface that uses context type ContextDialer interface { // DialContext is a function that dials the specified address @@ -56,6 +96,12 @@ func newDirectDialer(keepAlivePeriod, dialTimeout time.Duration) *net.Dialer { } } +func newProxyURLDialer(proxyURL *url.URL, dialer *net.Dialer, opts ...DialProxyOption) ContextDialer { + return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + return DialProxyWithDialer(ctx, proxyURL, addr, dialer, opts...) + }) +} + // tracedDialer ensures that the provided ContextDialerFunc is given a context // which contains tracing information. In the event that a grpc dial occurs without // a grpc.WithBlock dialing option, the context provided to the dial function will @@ -78,12 +124,29 @@ func tracedDialer(ctx context.Context, fn ContextDialerFunc) ContextDialerFunc { // NewDialer makes a new dialer that connects to an Auth server either directly or via an HTTP proxy, depending // on the environment. -func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, opts ...DialProxyOption) ContextDialer { +func NewDialer(ctx context.Context, keepAlivePeriod, dialTimeout time.Duration, opts ...DialOption) ContextDialer { + var cfg dialConfig + for _, opt := range opts { + opt(&cfg) + } + return tracedDialer(ctx, func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer := newDirectDialer(keepAlivePeriod, dialTimeout) + netDialer := newDirectDialer(keepAlivePeriod, dialTimeout) + + // Base direct dialer. + var dialer ContextDialer = netDialer + + // Wrap with proxy URL dialer if proxy URL is detected. if proxyURL := utils.GetProxyURL(addr); proxyURL != nil { - return DialProxyWithDialer(ctx, proxyURL, addr, dialer, opts...) + dialer = newProxyURLDialer(proxyURL, netDialer, opts...) + } + + // Wrap with alpnConnUpgradeDialer if upgrade is required for TLS Routing. + if cfg.alpnConnUpgradeRequired { + dialer = newALPNConnUpgradeDialer(dialer, cfg.tlsConfig, cfg.alpnConnUpgradeWithPing) } + + // Dial. return dialer.DialContext(ctx, network, addr) }) } @@ -112,6 +175,15 @@ func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dura }) } +// GRPCContextDialer converts a ContextDialer to a function used for +// grpc.WithContextDialer. +func GRPCContextDialer(dialer ContextDialer) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + conn, err := dialer.DialContext(ctx, "tcp", addr) + return conn, trace.Wrap(err) + } +} + // newTunnelDialer makes a dialer to connect to an Auth server through the SSH reverse tunnel on the proxy. func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, opts ...DialProxyOption) ContextDialer { dialer := newDirectDialer(keepAlivePeriod, dialTimeout) @@ -184,6 +256,56 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeou }) } +// newTLSRoutingWithConnUpgradeDialer makes a reverse tunnel TLS Routing dialer +// through the web proxy with ALPN connection upgrade. +func newTLSRoutingWithConnUpgradeDialer(ssh ssh.ClientConfig, params connectParams) ContextDialer { + return ContextDialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + insecure := params.cfg.InsecureAddressDiscovery + resp, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: params.addr, + Insecure: insecure, + }) + if err != nil { + return nil, trace.Wrap(err) + } + if !resp.Proxy.TLSRoutingEnabled { + return nil, trace.NotImplemented("TLS routing is not enabled") + } + + host, _, err := webclient.ParseHostPort(params.addr) + if err != nil { + return nil, trace.Wrap(err) + } + conn, err := DialALPN(ctx, params.addr, ALPNDialerConfig{ + DialTimeout: params.cfg.DialTimeout, + KeepAlivePeriod: params.cfg.KeepAlivePeriod, + TLSConfig: &tls.Config{ + NextProtos: []string{constants.ALPNSNIProtocolReverseTunnel}, + InsecureSkipVerify: insecure, + ServerName: host, + }, + ALPNConnUpgradeRequired: IsALPNConnUpgradeRequired(params.addr, insecure), + GetClusterCAs: func(_ context.Context) (*x509.CertPool, error) { + tlsConfig, err := params.cfg.Credentials[0].TLSConfig() + if err != nil { + return nil, trace.Wrap(err) + } + return tlsConfig.RootCAs, nil + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + sconn, err := sshConnect(ctx, conn, ssh, params.cfg.DialTimeout, params.addr) + if err != nil { + return nil, trace.Wrap(err) + } + return sconn, nil + }) +} + // sshConnect upgrades the underling connection to ssh and connects to the Auth service. func sshConnect(ctx context.Context, conn net.Conn, ssh ssh.ClientConfig, dialTimeout time.Duration, addr string) (net.Conn, error) { ssh.Timeout = dialTimeout diff --git a/api/client/proxy.go b/api/client/proxy.go index 2c222f0b0dee6..500217fd9941e 100644 --- a/api/client/proxy.go +++ b/api/client/proxy.go @@ -29,12 +29,10 @@ import ( "golang.org/x/net/proxy" ) -type dialProxyConfig struct { - tlsConfig *tls.Config -} +type dialProxyConfig = dialConfig // DialProxyOption allows setting options as functional arguments to DialProxy. -type DialProxyOption func(cfg *dialProxyConfig) +type DialProxyOption = DialOption // WithTLSConfig provides the dialer with the TLS config to use when using an // HTTPS proxy. diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index b0299c8fa97bf..678e4a441264c 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -82,6 +82,11 @@ type ClientConfig struct { DialTimeout time.Duration // DialOpts define options for dialing the client connection. DialOpts []grpc.DialOption + // ALPNConnUpgradeRequired indicates that ALPN connection upgrades are + // required for making TLS routing requests. + ALPNConnUpgradeRequired bool + // InsecureSkipVerify is an option to skip HTTPS cert check + InsecureSkipVerify bool // The below items are intended to be used by tests to connect without mTLS. // The gRPC transport credentials to use when establishing the connection to proxy. @@ -108,7 +113,6 @@ func (c *ClientConfig) CheckAndSetDefaults() error { if c.DialTimeout <= 0 { c.DialTimeout = defaults.DefaultIOTimeout } - if c.TLSConfig != nil { c.clientCreds = func() client.Credentials { return client.LoadTLS(c.TLSConfig.Clone()) @@ -293,11 +297,18 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error dialCtx, cancel := context.WithTimeout(ctx, cfg.DialTimeout) defer cancel() + // Dial web proxy with TLS Routing. + addr := cfg.ProxySSHAddress + if cfg.TLSRoutingEnabled { + addr = cfg.ProxyWebAddress + } + c := &clusterName{} conn, err := grpc.DialContext( dialCtx, - cfg.ProxySSHAddress, - append(cfg.DialOpts, + addr, + append([]grpc.DialOption{ + grpc.WithContextDialer(newDialerForGRPCClient(ctx, cfg)), grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: cfg.creds(), clusterName: c}), grpc.WithChainUnaryInterceptor( append(cfg.UnaryInterceptors, @@ -311,7 +322,7 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error metadata.StreamClientInterceptor, )..., ), - )..., + }, cfg.DialOpts...)..., ) if err != nil { return nil, trace.Wrap(err) @@ -336,6 +347,14 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error }, nil } +func newDialerForGRPCClient(ctx context.Context, cfg *ClientConfig) func(context.Context, string) (net.Conn, error) { + return client.GRPCContextDialer(client.NewDialer(ctx, defaults.DefaultIdleTimeout, cfg.DialTimeout, + client.WithInsecureSkipVerify(cfg.InsecureSkipVerify), + client.WithALPNConnUpgrade(cfg.ALPNConnUpgradeRequired), + client.WithALPNConnUpgradePing(true), // Use Ping protocol for long-lived connections. + )) +} + // teleportAuthority is the extension set by the server // which contains the name of the cluster it is in. const teleportAuthority = "x-teleport-authority" @@ -445,6 +464,7 @@ func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config Credentials: []client.Credentials{c.cfg.clientCreds()}, ALPNSNIAuthDialClusterName: cluster, CircuitBreakerConfig: breaker.NoopBreakerConfig(), + ALPNConnUpgradeRequired: c.cfg.ALPNConnUpgradeRequired, } case c.sshClient != nil: return client.Config{ diff --git a/api/constants/constants.go b/api/constants/constants.go index 47b3b6c513c29..1d58d56ba69ad 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -319,6 +319,10 @@ const ( ALPNSNIAuthProtocol = "teleport-auth@" // ALPNSNIProtocolReverseTunnel is TLS ALPN protocol value used to indicate Proxy reversetunnel protocol. ALPNSNIProtocolReverseTunnel = "teleport-reversetunnel" + // ALPNSNIProtocolSSH is the TLS ALPN protocol value used to indicate Proxy SSH protocol. + ALPNSNIProtocolSSH = "teleport-proxy-ssh" + // ALPNSNIProtocolPingSuffix is TLS ALPN suffix used to wrap connections with Ping. + ALPNSNIProtocolPingSuffix = "-ping" ) const ( @@ -413,4 +417,12 @@ const ( // WebAPIConnUpgradeTypeALPN is a connection upgrade type that specifies // the upgraded connection should be handled by the ALPN handler. WebAPIConnUpgradeTypeALPN = "alpn" + // WebAPIConnUpgradeTypeALPNPing is a connection upgrade type that + // specifies the upgraded connection should be handled by the ALPN handler + // wrapped with the Ping protocol. + // + // This should be used when the tunneled TLS Routing protocol cannot keep + // long-lived connections alive as L7 LB usually ignores TCP keepalives and + // has very short idle timeouts. + WebAPIConnUpgradeTypeALPNPing = "alpn-ping" ) diff --git a/api/utils/slices.go b/api/utils/slices.go index ef0bdc08cf2b4..61dc124200718 100644 --- a/api/utils/slices.go +++ b/api/utils/slices.go @@ -47,7 +47,7 @@ func JoinStrings[T ~string](elems []T, sep string) T { // Deduplicate deduplicates list of comparable values. func Deduplicate[T comparable](in []T) []T { - if len(in) == 0 { + if len(in) <= 1 { return in } out := make([]T, 0, len(in)) @@ -63,7 +63,7 @@ func Deduplicate[T comparable](in []T) []T { // DeduplicateAny deduplicates list of any values with compare function. func DeduplicateAny[T any](in []T, compare func(T, T) bool) []T { - if len(in) == 0 { + if len(in) <= 1 { return in } out := make([]T, 0, len(in)) diff --git a/integration/helpers/instance.go b/integration/helpers/instance.go index 7ab15f6d1128f..2d9a36361a8cb 100644 --- a/integration/helpers/instance.go +++ b/integration/helpers/instance.go @@ -1255,6 +1255,8 @@ type ClientConfig struct { Stderr io.Writer // Stdout overrides standard output for the session Stdout io.Writer + // ALBAddr is the address to a local server that simulates a layer 7 load balancer. + ALBAddr string } // NewClientWithCreds creates client with credentials @@ -1280,12 +1282,16 @@ func (i *TeleInstance) NewUnauthenticatedClient(cfg ClientConfig) (tc *client.Te var webProxyAddr string var sshProxyAddr string - if cfg.Proxy == nil { - webProxyAddr = i.Web - sshProxyAddr = i.SSHProxy - } else { + switch { + case cfg.Proxy != nil: webProxyAddr = cfg.Proxy.WebAddr sshProxyAddr = cfg.Proxy.SSHAddr + case cfg.ALBAddr != "": + webProxyAddr = cfg.ALBAddr + sshProxyAddr = cfg.ALBAddr + default: + webProxyAddr = i.Web + sshProxyAddr = i.SSHProxy } fwdAgentMode := client.ForwardAgentNo @@ -1294,25 +1300,26 @@ func (i *TeleInstance) NewUnauthenticatedClient(cfg ClientConfig) (tc *client.Te } cconf := &client.Config{ - Username: cfg.Login, - Host: cfg.Host, - HostPort: cfg.Port, - HostLogin: cfg.Login, - InsecureSkipVerify: true, - KeysDir: keyDir, - SiteName: cfg.Cluster, - ForwardAgent: fwdAgentMode, - Labels: cfg.Labels, - WebProxyAddr: webProxyAddr, - SSHProxyAddr: sshProxyAddr, - InteractiveCommand: cfg.Interactive, - TLSRoutingEnabled: i.IsSinglePortSetup, - Tracer: tracing.NoopProvider().Tracer("test"), - EnableEscapeSequences: cfg.EnableEscapeSequences, - Stderr: cfg.Stderr, - Stdin: cfg.Stdin, - Stdout: cfg.Stdout, - NonInteractive: true, + Username: cfg.Login, + Host: cfg.Host, + HostPort: cfg.Port, + HostLogin: cfg.Login, + InsecureSkipVerify: true, + KeysDir: keyDir, + SiteName: cfg.Cluster, + ForwardAgent: fwdAgentMode, + Labels: cfg.Labels, + WebProxyAddr: webProxyAddr, + SSHProxyAddr: sshProxyAddr, + InteractiveCommand: cfg.Interactive, + TLSRoutingEnabled: i.IsSinglePortSetup, + TLSRoutingConnUpgradeRequired: cfg.ALBAddr != "", + Tracer: tracing.NoopProvider().Tracer("test"), + EnableEscapeSequences: cfg.EnableEscapeSequences, + Stderr: cfg.Stderr, + Stdin: cfg.Stdin, + Stdout: cfg.Stdout, + NonInteractive: true, } // JumpHost turns on jump host mode diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index dbe210ba3e086..7e1dff01fbd0e 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -33,6 +33,7 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jackc/pgconn" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" @@ -398,6 +399,28 @@ func withTrustedCluster() proxySuiteOptionsFunc { } } +// withTrustedClusterBehindALB creates a local server that simulates a layer 7 +// LB and puts it infront of the root cluster when the leaf connects through +// the reverse tunnel. +func withTrustedClusterBehindALB() proxySuiteOptionsFunc { + return func(options *suiteOptions) { + originalSetup := options.updateRoleMappingFunc + + options.updateRoleMappingFunc = func(t *testing.T, suite *Suite) { + t.Helper() + + if originalSetup != nil { + originalSetup(t, suite) + } + require.NotNil(t, options.trustedCluster) + + albProxy := mustStartMockALBProxy(t, suite.root.Config.Proxy.WebAddr.Addr) + options.trustedCluster.SetProxyAddress(albProxy.Addr().String()) + options.trustedCluster.SetReverseTunnelAddress(albProxy.Addr().String()) + } + } +} + func mustRunPostgresQuery(t *testing.T, client *pgconn.PgConn) { result, err := client.Exec(context.Background(), "select 1").ReadAll() require.NoError(t, err) @@ -474,7 +497,7 @@ func mustCreateKubeConfigFile(t *testing.T, config clientcmdapi.Config) string { func mustCreateListener(t *testing.T) net.Listener { t.Helper() - listener, err := net.Listen("tcp", ":0") + listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(func() { @@ -594,7 +617,7 @@ type mockAWSALBProxy struct { cert tls.Certificate } -func (m *mockAWSALBProxy) serve(ctx context.Context, t *testing.T) { +func (m *mockAWSALBProxy) serve(ctx context.Context) { for { select { case <-ctx.Done(): @@ -604,26 +627,33 @@ func (m *mockAWSALBProxy) serve(ctx context.Context, t *testing.T) { conn, err := m.Accept() if err != nil { - if utils.IsOKNetworkError(err) { - return - } - require.NoError(t, err) + logrus.WithError(err).Debugf("Failed to accept conn.") return } go func() { + defer conn.Close() + // Handshake with incoming client and drops ALPN. downstreamConn := tls.Server(conn, &tls.Config{ Certificates: []tls.Certificate{m.cert}, }) - require.NoError(t, downstreamConn.HandshakeContext(ctx)) + + // api.Client may try different connection methods. Just close the + // connection when something goes wrong. + if err := downstreamConn.HandshakeContext(ctx); err != nil { + logrus.WithError(err).Debugf("Failed to handshake.") + return + } // Make a connection to the proxy server with ALPN protos. upstreamConn, err := tls.Dial("tcp", m.proxyAddr, &tls.Config{ InsecureSkipVerify: true, }) - require.NoError(t, err) - + if err != nil { + logrus.WithError(err).Debugf("Failed to dial upstream.") + return + } utils.ProxyConn(ctx, downstreamConn, upstreamConn) }() } @@ -640,7 +670,7 @@ func mustStartMockALBProxy(t *testing.T, proxyAddr string) *mockAWSALBProxy { Listener: mustCreateListener(t), cert: mustCreateSelfSignedCert(t), } - go m.serve(ctx, t) + go m.serve(ctx) return m } diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index 16638ba9e18c7..32aa114a59d35 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -69,6 +69,7 @@ func TestALPNSNIProxyMultiCluster(t *testing.T) { secondClusterPortSetup helpers.InstanceListenerSetupFunc disableALPNListenerOnRoot bool disableALPNListenerOnLeaf bool + testALPNConnUpgrade bool }{ { name: "StandardAndOnePortSetupMasterALPNDisabled", @@ -85,17 +86,20 @@ func TestALPNSNIProxyMultiCluster(t *testing.T) { name: "TwoClusterOnePortSetup", mainClusterPortSetup: helpers.SingleProxyPortSetup, secondClusterPortSetup: helpers.SingleProxyPortSetup, + testALPNConnUpgrade: true, }, { name: "OnePortAndStandardListenerSetupLeafALPNDisabled", mainClusterPortSetup: helpers.SingleProxyPortSetup, secondClusterPortSetup: helpers.StandardListenerSetup, disableALPNListenerOnLeaf: true, + testALPNConnUpgrade: true, }, { name: "OnePortAndStandardListenerSetup", mainClusterPortSetup: helpers.SingleProxyPortSetup, secondClusterPortSetup: helpers.StandardListenerSetup, + testALPNConnUpgrade: true, }, } @@ -132,6 +136,31 @@ func TestALPNSNIProxyMultiCluster(t *testing.T) { Host: helpers.Loopback, Port: helpers.Port(t, suite.leaf.SSH), }) + + if tc.testALPNConnUpgrade { + t.Run("ALPN conn upgrade", func(t *testing.T) { + // Make a mock ALB which points to the Teleport Proxy Service. + albProxy := mustStartMockALBProxy(t, suite.root.Config.Proxy.WebAddr.Addr) + + // Run command in root through ALB address. + suite.mustConnectToClusterAndRunSSHCommand(t, helpers.ClientConfig{ + Login: username, + Cluster: suite.root.Secrets.SiteName, + Host: helpers.Loopback, + Port: helpers.Port(t, suite.root.SSH), + ALBAddr: albProxy.Addr().String(), + }) + + // Run command in leaf through ALB address. + suite.mustConnectToClusterAndRunSSHCommand(t, helpers.ClientConfig{ + Login: username, + Cluster: suite.leaf.Secrets.SiteName, + Host: helpers.Loopback, + Port: helpers.Port(t, suite.leaf.SSH), + ALBAddr: albProxy.Addr().String(), + }) + }) + } }) } } @@ -144,6 +173,7 @@ func TestALPNSNIProxyTrustedClusterNode(t *testing.T) { secondClusterListenerSetup helpers.InstanceListenerSetupFunc disableALPNListenerOnRoot bool disableALPNListenerOnLeaf bool + extraSuiteOptions []proxySuiteOptionsFunc }{ { name: "StandardAndOnePortSetupMasterALPNDisabled", @@ -172,6 +202,12 @@ func TestALPNSNIProxyTrustedClusterNode(t *testing.T) { mainClusterListenerSetup: helpers.SingleProxyPortSetup, secondClusterListenerSetup: helpers.StandardListenerSetup, }, + { + name: "TrustedClusterBehindALB", + mainClusterListenerSetup: helpers.SingleProxyPortSetup, + secondClusterListenerSetup: helpers.SingleProxyPortSetup, + extraSuiteOptions: []proxySuiteOptionsFunc{withTrustedClusterBehindALB()}, + }, } for _, tc := range testCase { t.Run(tc.name, func(t *testing.T) { @@ -180,7 +216,7 @@ func TestALPNSNIProxyTrustedClusterNode(t *testing.T) { username := helpers.MustGetCurrentUser(t).Username - suite := newSuite(t, + opts := []proxySuiteOptionsFunc{ withRootClusterConfig(rootClusterStandardConfig(t)), withLeafClusterConfig(leafClusterStandardConfig(t)), withRootClusterListeners(tc.mainClusterListenerSetup), @@ -189,7 +225,8 @@ func TestALPNSNIProxyTrustedClusterNode(t *testing.T) { withLeafClusterRoles(newRole(t, "auxdevs", username)), withRootAndLeafTrustedClusterReset(), withTrustedCluster(), - ) + } + suite := newSuite(t, append(opts, tc.extraSuiteOptions...)...) nodeHostname := "clusterauxnode" suite.addNodeToLeafCluster(t, "clusterauxnode") @@ -1174,16 +1211,63 @@ func TestALPNProxyAuthClientConnectWithUserIdentity(t *testing.T) { identity := client.LoadIdentityFile(identityFilePath) require.NoError(t, err) - tc, err := client.New(context.Background(), client.Config{ - Addrs: []string{rc.Web}, - Credentials: []client.Credentials{identity}, - InsecureAddressDiscovery: true, - }) - require.NoError(t, err) + // Make a mock ALB which points to the Teleport Proxy Service. Then + // client can point to this ALB instead. + albProxy := mustStartMockALBProxy(t, rc.Web) - resp, err := tc.Ping(context.Background()) - require.NoError(t, err) - require.Equal(t, rc.Secrets.SiteName, resp.ClusterName) + tests := []struct { + name string + clientConfig client.Config + }{ + { + name: "sync connect to Proxy", + clientConfig: client.Config{ + Addrs: []string{rc.Web}, + Credentials: []client.Credentials{identity}, + InsecureAddressDiscovery: true, + }, + }, + { + name: "sync connect to Proxy behind ALB", + clientConfig: client.Config{ + Addrs: []string{albProxy.Addr().String()}, + Credentials: []client.Credentials{identity}, + InsecureAddressDiscovery: true, + }, + }, + { + name: "background connect to Proxy", + clientConfig: client.Config{ + Addrs: []string{rc.Web}, + Credentials: []client.Credentials{identity}, + InsecureAddressDiscovery: true, + DialInBackground: true, + ALPNSNIAuthDialClusterName: cfg.ClusterName, + }, + }, + { + name: "background connect to Proxy behind ALB", + clientConfig: client.Config{ + Addrs: []string{albProxy.Addr().String()}, + Credentials: []client.Credentials{identity}, + InsecureAddressDiscovery: true, + DialInBackground: true, + ALPNSNIAuthDialClusterName: cfg.ClusterName, + ALPNConnUpgradeRequired: true, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tc, err := client.New(context.Background(), test.clientConfig) + require.NoError(t, err) + + resp, err := tc.Ping(context.Background()) + require.NoError(t, err) + require.Equal(t, rc.Secrets.SiteName, resp.ClusterName) + }) + } } // TestALPNProxyDialProxySSHWithoutInsecureMode tests dialing to the localhost with teleport-proxy-ssh diff --git a/lib/auth/authclient/authclient.go b/lib/auth/authclient/authclient.go index c76ca1f3e4b3a..00a4997a5155e 100644 --- a/lib/auth/authclient/authclient.go +++ b/lib/auth/authclient/authclient.go @@ -106,6 +106,7 @@ func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) { ClientConfig: cfg.SSH, Log: cfg.Log, InsecureSkipTLSVerify: cfg.TLS.InsecureSkipVerify, + ClusterCAs: cfg.TLS.RootCAs, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 3e89bfae29f8f..70efc5c7309b7 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -99,9 +99,12 @@ func NewClient(cfg client.Config, params ...roundtrip.ClientParam) (*Client, err if len(cfg.Addrs) == 0 { return nil, trace.BadParameter("no addresses to dial") } - contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, client.WithTLSConfig(httpTLS)) httpDialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) { for _, addr := range cfg.Addrs { + contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, + client.WithInsecureSkipVerify(httpTLS.InsecureSkipVerify), + client.WithALPNConnUpgrade(cfg.ALPNConnUpgradeRequired), + ) conn, err = contextDialer.DialContext(ctx, network, addr) if err == nil { return conn, nil diff --git a/lib/client/api.go b/lib/client/api.go index 464d08ef7c757..d93b4d02a21f7 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2773,7 +2773,9 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (*ClusterClient, clt, err := makeProxySSHClient(ctx, tc, config) return clt, trace.Wrap(err) }), - SSHConfig: cfg.ClientConfig, + SSHConfig: cfg.ClientConfig, + ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, + InsecureSkipVerify: tc.InsecureSkipVerify, }) if err != nil { return nil, trace.Wrap(err) @@ -3030,7 +3032,11 @@ func makeProxySSHClientWithTLSWrapper(ctx context.Context, tc *TeleportClient, s } tlsConfig.NextProtos = []string{string(alpncommon.ProtocolProxySSH)} - dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(tlsConfig)) + dialer := proxy.DialerFromEnvironment(tc.Config.WebProxyAddr, proxy.WithALPNDialer(client.ALPNDialerConfig{ + TLSConfig: tlsConfig, + ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, + DialTimeout: sshConfig.Timeout, + })) return dialer.Dial(ctx, "tcp", proxyAddr, sshConfig) } @@ -4617,6 +4623,18 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste return kubeproto.NewKubeServiceClient(clt.GetConnection()), nil } +// IsALPNConnUpgradeRequiredForWebProxy returns true if connection upgrade is +// required for provided addr. The provided address must be a web proxy +// address. +func (tc *TeleportClient) IsALPNConnUpgradeRequiredForWebProxy(proxyAddr string) bool { + // Use cached value. + if proxyAddr == tc.WebProxyAddr { + return tc.TLSRoutingConnUpgradeRequired + } + // Do a test for other proxy addresses. + return client.IsALPNConnUpgradeRequired(proxyAddr, tc.InsecureSkipVerify) +} + // RootClusterCACertPool returns a *x509.CertPool with the root cluster CA. func (tc *TeleportClient) RootClusterCACertPool(ctx context.Context) (*x509.CertPool, error) { _, span := tc.Tracer.Start( diff --git a/lib/client/client.go b/lib/client/client.go index 0d8660b486164..fb7b1f03f09b0 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1114,6 +1114,7 @@ func (proxy *ProxyClient) ConnectToAuthServiceThroughALPNSNIProxy(ctx context.Co }, ALPNSNIAuthDialClusterName: clusterName, CircuitBreakerConfig: breaker.NoopBreakerConfig(), + ALPNConnUpgradeRequired: proxy.teleportClient.IsALPNConnUpgradeRequiredForWebProxy(proxyAddr), }) if err != nil { return nil, trace.Wrap(err) @@ -1219,6 +1220,7 @@ func (proxy *ProxyClient) NewTracingClient(ctx context.Context, clusterName stri case proxy.teleportClient.TLSRoutingEnabled: clientConfig.Addrs = []string{proxy.teleportClient.WebProxyAddr} clientConfig.ALPNSNIAuthDialClusterName = clusterName + clientConfig.ALPNConnUpgradeRequired = proxy.teleportClient.TLSRoutingConnUpgradeRequired default: clientConfig.Dialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (net.Conn, error) { return proxy.dialAuthServer(ctx, clusterName) diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 645f984e312bc..19df3bb4c6c74 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -19,6 +19,7 @@ package reversetunnel import ( "context" "crypto/tls" + "crypto/x509" "errors" "io" "net" @@ -31,6 +32,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" @@ -323,7 +325,7 @@ func (p *AgentPool) updateRuntimeConfig(ctx context.Context) error { return trace.Wrap(err) } - p.runtimeConfig.update(netConfig) + p.runtimeConfig.update(ctx, netConfig, p.Resolver) p.log.Debugf("Runtime config: tunnel_strategy: %v connection_count: %v", p.runtimeConfig.tunnelStrategyType, p.runtimeConfig.connectionCount) return nil @@ -480,18 +482,7 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease options := []proxy.DialerOptionFunc{proxy.WithInsecureSkipTLSVerify(lib.IsInsecureDevMode())} if p.runtimeConfig.useALPNRouting() { - tlsConfig := &tls.Config{ - NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, - } - - if p.runtimeConfig.useReverseTunnelV2() { - tlsConfig.NextProtos = []string{ - string(alpncommon.ProtocolReverseTunnelV2), - string(alpncommon.ProtocolReverseTunnel), - } - } - - options = append(options, proxy.WithALPNDialer(tlsConfig)) + options = append(options, proxy.WithALPNDialer(p.runtimeConfig.alpnDialerConfig(p.getClusterCAs))) } dialer := &agentDialer{ @@ -524,6 +515,11 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease return agent, nil } +func (p *AgentPool) getClusterCAs(_ context.Context) (*x509.CertPool, error) { + clusterCAs, _, err := auth.ClientCertPool(p.AccessPoint, p.Cluster, types.HostCA) + return clusterCAs, trace.Wrap(err) +} + // Wait blocks until the pool context is stopped. func (p *AgentPool) Wait() { if p == nil { @@ -587,6 +583,9 @@ type agentPoolRuntimeConfig struct { // isRemoteCluster forces the agent pool to connect to all proxies // regardless of the configured tunnel strategy. isRemoteCluster bool + // tlsRoutingConnUpgradeRequired indicates that ALPN connection upgrades + // are required for making TLS routing requests. + tlsRoutingConnUpgradeRequired bool // remoteTLSRoutingEnabled caches a remote clusters tls routing setting. This helps prevent // proxy endpoint stagnation where an even numbers of proxies are hidden behind a round robin @@ -638,16 +637,35 @@ func (c *agentPoolRuntimeConfig) getConnectionCount() int { return c.connectionCount } -// useReverseTunnelV2 returns true if reverse tunnel should be used. -func (c *agentPoolRuntimeConfig) useReverseTunnelV2() bool { - c.mu.RLock() - defer c.mu.RUnlock() +// useReverseTunnelV2Locked returns true if reverse tunnel should be used. +func (c *agentPoolRuntimeConfig) useReverseTunnelV2Locked() bool { if c.isRemoteCluster { return false } return c.tunnelStrategyType == types.ProxyPeering } +// alpnDialerConfig creates a config for ALPN dialer. +func (c *agentPoolRuntimeConfig) alpnDialerConfig(getClusterCAs client.GetClusterCAsFunc) client.ALPNDialerConfig { + c.mu.RLock() + defer c.mu.RUnlock() + + protocols := []alpncommon.Protocol{alpncommon.ProtocolReverseTunnel} + if c.useReverseTunnelV2Locked() { + protocols = []alpncommon.Protocol{alpncommon.ProtocolReverseTunnelV2, alpncommon.ProtocolReverseTunnel} + } + + return client.ALPNDialerConfig{ + TLSConfig: &tls.Config{ + NextProtos: alpncommon.ProtocolsToString(protocols), + InsecureSkipVerify: lib.IsInsecureDevMode(), + }, + KeepAlivePeriod: c.keepAliveInterval, + ALPNConnUpgradeRequired: c.tlsRoutingConnUpgradeRequired, + GetClusterCAs: getClusterCAs, + } +} + // useALPNRouting returns true agents should connect using alpn routing. func (c *agentPoolRuntimeConfig) useALPNRouting() bool { c.mu.RLock() @@ -711,13 +729,18 @@ func (c *agentPoolRuntimeConfig) updateRemote(ctx context.Context, addr *utils.N c.lastRemotePing = &now c.remoteTLSRoutingEnabled = tlsRoutingEnabled + if c.remoteTLSRoutingEnabled { + c.tlsRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(addr.Addr, lib.IsInsecureDevMode()) + logrus.Debugf("ALPN upgrade required for remote %v: %v", addr.Addr, c.tlsRoutingConnUpgradeRequired) + } return nil } -func (c *agentPoolRuntimeConfig) update(netConfig types.ClusterNetworkingConfig) { +func (c *agentPoolRuntimeConfig) update(ctx context.Context, netConfig types.ClusterNetworkingConfig, resolver Resolver) { c.mu.Lock() defer c.mu.Unlock() + oldProxyListenerMode := c.proxyListenerMode c.keepAliveInterval = netConfig.GetKeepAliveInterval() c.proxyListenerMode = netConfig.GetProxyListenerMode() @@ -736,6 +759,15 @@ func (c *agentPoolRuntimeConfig) update(netConfig types.ClusterNetworkingConfig) if c.connectionCount <= 0 { c.connectionCount = defaultAgentConnectionCount } + + if c.proxyListenerMode == types.ProxyListenerMode_Multiplex && oldProxyListenerMode != c.proxyListenerMode { + addr, _, err := resolver(ctx) + if err == nil { + c.tlsRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(addr.Addr, lib.IsInsecureDevMode()) + } else { + logrus.WithError(err).Warnf("Faield to resolve addr.") + } + } } // Make sure ServerHandlerToListener implements both interfaces. diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index 0a25c705966c1..fb0431c205eff 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -19,6 +19,7 @@ package reversetunnel import ( "context" "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" @@ -31,6 +32,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/sshutils" @@ -62,12 +64,17 @@ type TunnelAuthDialerConfig struct { Log logrus.FieldLogger // InsecureSkipTLSVerify is whether to skip certificate validation. InsecureSkipTLSVerify bool + // ClusterCAs contains cluster CAs. + ClusterCAs *x509.CertPool } func (c *TunnelAuthDialerConfig) CheckAndSetDefaults() error { if c.Resolver == nil { return trace.BadParameter("missing tunnel address resolver") } + if c.ClusterCAs == nil { + return trace.BadParameter("missing cluster CAs") + } return nil } @@ -91,8 +98,14 @@ func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Co } if mode == types.ProxyListenerMode_Multiplex { - opts = append(opts, proxy.WithALPNDialer(&tls.Config{ - NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, + opts = append(opts, proxy.WithALPNDialer(client.ALPNDialerConfig{ + TLSConfig: &tls.Config{ + NextProtos: []string{string(alpncommon.ProtocolReverseTunnel)}, + InsecureSkipVerify: t.InsecureSkipTLSVerify, + }, + DialTimeout: t.ClientConfig.Timeout, + ALPNConnUpgradeRequired: client.IsALPNConnUpgradeRequired(addr.Addr, t.InsecureSkipTLSVerify), + GetClusterCAs: client.ClusterCAsFromCertPool(t.ClusterCAs), })) } diff --git a/lib/service/connect.go b/lib/service/connect.go index 2b5141af2502e..d10354bdd5ffd 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -1171,6 +1171,7 @@ func (process *TeleportProcess) newClientThroughTunnel(addr string, tlsConfig *t ClientConfig: sshConfig, Log: process.log, InsecureSkipTLSVerify: lib.IsInsecureDevMode(), + ClusterCAs: tlsConfig.RootCAs, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/utils/proxy/proxy.go b/lib/utils/proxy/proxy.go index 464665bf1c151..441b78394188e 100644 --- a/lib/utils/proxy/proxy.go +++ b/lib/utils/proxy/proxy.go @@ -18,7 +18,6 @@ package proxy import ( "context" - "crypto/tls" "net" "net/url" "time" @@ -32,7 +31,6 @@ import ( "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" apiutils "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/lib/utils" ) var log = logrus.WithFields(logrus.Fields{ @@ -61,24 +59,7 @@ func (d directDial) dialALPNWithDeadline(ctx context.Context, network string, ad ctx, span := tracing.DefaultProvider().Tracer("dialer").Start(ctx, "directDial/dialALPNWithDeadline") defer span.End() - dialer := &net.Dialer{ - Timeout: config.Timeout, - } - address, err := utils.ParseAddr(addr) - if err != nil { - return nil, trace.Wrap(err) - } - conf, err := d.getTLSConfig(address) - if err != nil { - return nil, trace.Wrap(err) - } - - tlsDialer := tls.Dialer{ - NetDialer: dialer, - Config: conf, - } - - tlsConn, err := tlsDialer.DialContext(ctx, network, addr) + tlsConn, err := d.alpnDialer.DialContext(ctx, network, addr) if err != nil { return nil, trace.Wrap(err) } @@ -95,28 +76,13 @@ type Dialer interface { } type directDial struct { - // insecure is whether to skip certificate validation. - insecure bool - // tlsRoutingEnabled indicates that proxy is running in TLSRouting mode. - tlsRoutingEnabled bool - // tlsConfig is the TLS config to use. - tlsConfig *tls.Config -} - -// getTLSConfig configures the dialers TLS config for a specified address. -func (d directDial) getTLSConfig(addr *utils.NetAddr) (*tls.Config, error) { - if d.tlsConfig == nil { - return nil, trace.BadParameter("TLS config was nil") - } - tlsConfig := d.tlsConfig.Clone() - tlsConfig.ServerName = addr.Host() - tlsConfig.InsecureSkipVerify = d.insecure - return tlsConfig, nil + // alpnDialer is the dialer used for TLS routing. + alpnDialer apiclient.ContextDialer } // Dial calls ssh.Dial directly. func (d directDial) Dial(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { - if d.tlsRoutingEnabled { + if d.alpnDialer != nil { client, err := d.dialALPNWithDeadline(ctx, network, addr, config) if err != nil { return nil, trace.Wrap(err) @@ -132,31 +98,15 @@ func (d directDial) Dial(ctx context.Context, network string, addr string, confi // DialTimeout acts like Dial but takes a timeout. func (d directDial) DialTimeout(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) { + if d.alpnDialer != nil { + conn, err := d.alpnDialer.DialContext(ctx, network, address) + return conn, trace.Wrap(err) + } + dialer := &net.Dialer{ Timeout: timeout, } - if d.tlsRoutingEnabled { - addr, err := utils.ParseAddr(address) - if err != nil { - return nil, trace.Wrap(err) - } - conf, err := d.getTLSConfig(addr) - if err != nil { - return nil, trace.Wrap(err) - } - - tlsDialer := tls.Dialer{ - NetDialer: dialer, - Config: conf, - } - - tlsConn, err := tlsDialer.DialContext(ctx, "tcp", address) - if err != nil { - return nil, trace.Wrap(err) - } - return tlsConn, nil - } conn, err := dialer.DialContext(ctx, network, address) if err != nil { return nil, trace.Wrap(err) @@ -169,38 +119,8 @@ type proxyDial struct { proxyURL *url.URL // insecure is whether to skip certificate validation. insecure bool - // tlsRoutingEnabled indicates that proxy is running in TLSRouting mode. - tlsRoutingEnabled bool - // tlsConfig is the TLS config to use. - tlsConfig *tls.Config -} - -// getTLSConfig configures the dialers TLS config for a specified address. -func (d proxyDial) getTLSConfig(addr *utils.NetAddr) (*tls.Config, error) { - if d.tlsConfig == nil { - return nil, trace.BadParameter("TLS config was nil") - } - tlsConfig := d.tlsConfig.Clone() - tlsConfig.ServerName = addr.Host() - tlsConfig.InsecureSkipVerify = d.insecure - return tlsConfig, nil -} - -// getTLSConfigForProxy configures the dialer's TLS config for the HTTPS proxy -// address. If the proxy is HTTP, a nil error and nil config are returned. -func (d proxyDial) getTLSConfigForProxy() (*tls.Config, error) { - if d.proxyURL.Scheme != "https" { - return nil, nil - } - netAddr, err := utils.ParseAddr(d.proxyURL.String()) - if err != nil { - return nil, trace.Wrap(err) - } - tlsConfig, err := d.getTLSConfig(netAddr) - if err != nil { - return nil, trace.Wrap(err) - } - return tlsConfig, nil + // alpnDialer is the dialer used for TLS routing. + alpnDialer apiclient.ContextDialer } // DialTimeout acts like Dial but takes a timeout. @@ -212,62 +132,33 @@ func (d proxyDial) DialTimeout(ctx context.Context, network, address string, tim ctx = timeoutCtx } - tlsConfig, err := d.getTLSConfigForProxy() - if err != nil { - return nil, trace.Wrap(err) + // ALPN dialer handles proxy URL internally. + if d.alpnDialer != nil { + tlsConn, err := d.alpnDialer.DialContext(ctx, network, address) + return tlsConn, trace.Wrap(err) } - conn, err := apiclient.DialProxy(ctx, d.proxyURL, address, apiclient.WithTLSConfig(tlsConfig)) + conn, err := apiclient.DialProxy(ctx, d.proxyURL, address, apiclient.WithInsecureSkipVerify(d.insecure)) if err != nil { return nil, trace.Wrap(err) } - if d.tlsRoutingEnabled { - address, err := utils.ParseAddr(address) - if err != nil { - return nil, trace.Wrap(err) - } - conf, err := d.getTLSConfig(address) - if err != nil { - return nil, trace.Wrap(err) - } - tlsConn := tls.Client(conn, conf) - if err = tlsConn.HandshakeContext(ctx); err != nil { - conn.Close() - return nil, trace.Wrap(err) - } - conn = tlsConn - } return conn, nil } // Dial first connects to a proxy, then uses the connection to establish a new // SSH connection. func (d proxyDial) Dial(ctx context.Context, network string, addr string, config *ssh.ClientConfig) (*tracessh.Client, error) { - tlsConfig, err := d.getTLSConfigForProxy() - if err != nil { - return nil, trace.Wrap(err) - } // Build a proxy connection first. - pconn, err := apiclient.DialProxy(ctx, d.proxyURL, addr, apiclient.WithTLSConfig(tlsConfig)) + pconn, err := d.DialTimeout(ctx, network, addr, config.Timeout) if err != nil { return nil, trace.Wrap(err) } + if config.Timeout > 0 { if err := pconn.SetReadDeadline(time.Now().Add(config.Timeout)); err != nil { return nil, trace.Wrap(err) } } - if d.tlsRoutingEnabled { - address, err := utils.ParseAddr(addr) - if err != nil { - return nil, trace.Wrap(err) - } - conf, err := d.getTLSConfig(address) - if err != nil { - return nil, trace.Wrap(err) - } - pconn = tls.Client(pconn, conf) - } // Do the same as ssh.Dial but pass in proxy connection. c, chans, reqs, err := tracessh.NewClientConn(ctx, pconn, addr, config) @@ -285,20 +176,17 @@ func (d proxyDial) Dial(ctx context.Context, network string, addr string, config type dialerOptions struct { // insecureSkipTLSVerify is whether to skip certificate validation. insecureSkipTLSVerify bool - // tlsRoutingEnabled indicates that proxy is running in TLSRouting mode. - tlsRoutingEnabled bool - // tlsConfig is the TLS config to use for TLS routing. - tlsConfig *tls.Config + // alpnDialer is the dialer used for TLS routing. + alpnDialer apiclient.ContextDialer } // DialerOptionFunc allows setting options as functional arguments to DialerFromEnvironment type DialerOptionFunc func(options *dialerOptions) // WithALPNDialer creates a dialer that allows to Teleport running in single-port mode. -func WithALPNDialer(tlsConfig *tls.Config) DialerOptionFunc { +func WithALPNDialer(alpnDialerConfig apiclient.ALPNDialerConfig) DialerOptionFunc { return func(options *dialerOptions) { - options.tlsRoutingEnabled = true - options.tlsConfig = tlsConfig + options.alpnDialer = apiclient.NewALPNDialer(alpnDialerConfig) } } @@ -327,17 +215,14 @@ func DialerFromEnvironment(addr string, opts ...DialerOptionFunc) Dialer { if proxyURL == nil { log.Debugf("No proxy set in environment, returning direct dialer.") return directDial{ - tlsConfig: options.tlsConfig, - tlsRoutingEnabled: options.tlsRoutingEnabled, - insecure: options.insecureSkipTLSVerify, + alpnDialer: options.alpnDialer, } } log.Debugf("Found proxy %q in environment, returning proxy dialer.", proxyURL) return proxyDial{ - proxyURL: proxyURL, - insecure: options.insecureSkipTLSVerify, - tlsRoutingEnabled: options.tlsRoutingEnabled, - tlsConfig: options.tlsConfig, + proxyURL: proxyURL, + insecure: options.insecureSkipTLSVerify, + alpnDialer: options.alpnDialer, } } diff --git a/lib/web/conn_upgrade.go b/lib/web/conn_upgrade.go index cdfe3dae0556a..fdb03332b041a 100644 --- a/lib/web/conn_upgrade.go +++ b/lib/web/conn_upgrade.go @@ -21,11 +21,14 @@ import ( "io" "net" "net/http" + "time" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/pingconn" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" ) @@ -35,12 +38,14 @@ func (h *Handler) selectConnectionUpgrade(r *http.Request) (string, ConnectionHa upgrades := r.Header.Values(constants.WebAPIConnUpgradeHeader) for _, upgradeType := range upgrades { switch upgradeType { + case constants.WebAPIConnUpgradeTypeALPNPing: + return upgradeType, h.upgradeALPNWithPing, nil case constants.WebAPIConnUpgradeTypeALPN: return upgradeType, h.upgradeALPN, nil } } - return "", nil, trace.BadParameter("unsupported upgrade types: %v", upgrades) + return "", nil, trace.NotFound("unsupported upgrade types: %v", upgrades) } // connectionUpgrade handles connection upgrades. @@ -88,6 +93,40 @@ func (h *Handler) upgradeALPN(ctx context.Context, conn net.Conn) error { return h.cfg.ALPNHandler(ctx, waitConn) } +func (h *Handler) upgradeALPNWithPing(ctx context.Context, conn net.Conn) error { + if h.cfg.ALPNHandler == nil { + return trace.BadParameter("missing ALPNHandler") + } + + pingConn := pingconn.New(conn) + + // Cancel ping background goroutine when connection is closed. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go h.startPing(ctx, pingConn) + + return h.upgradeALPN(ctx, pingConn) +} + +func (h *Handler) startPing(ctx context.Context, pingConn *pingconn.PingConn) { + ticker := time.NewTicker(defaults.ProxyPingInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := pingConn.WritePing() + if err != nil { + if !utils.IsOKNetworkError(err) { + h.log.WithError(err).Warn("Failed to write ping message") + } + return + } + } + } +} + func writeUpgradeResponse(w io.Writer, upgradeType string) error { header := make(http.Header) header.Add(constants.WebAPIConnUpgradeHeader, upgradeType) diff --git a/lib/web/conn_upgrade_test.go b/lib/web/conn_upgrade_test.go index c7234e56af516..76b566beeeace 100644 --- a/lib/web/conn_upgrade_test.go +++ b/lib/web/conn_upgrade_test.go @@ -29,6 +29,9 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/pingconn" ) func TestWriteUpgradeResponse(t *testing.T) { @@ -78,7 +81,7 @@ func TestHandlerConnectionUpgrade(t *testing.T) { r.Header.Add("Upgrade", "unsupported-protocol") _, err = h.connectionUpgrade(httptest.NewRecorder(), r, nil) - require.True(t, trace.IsBadParameter(err)) + require.True(t, trace.IsNotFound(err)) }) t.Run("upgraded to ALPN", func(t *testing.T) { @@ -86,35 +89,53 @@ func TestHandlerConnectionUpgrade(t *testing.T) { defer serverConn.Close() defer clientConn.Close() - r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil) - require.NoError(t, err) - r.Header.Add("Upgrade", "alpn") - - go func() { - // serverConn will be hijacked. - w := newResponseWriterHijacker(nil, serverConn) - _, err := h.connectionUpgrade(w, r, nil) - require.NoError(t, err) - }() + sendConnUpgradeRequest(t, h, constants.WebAPIConnUpgradeTypeALPN, serverConn, clientConn) - // Verify clientConn receives http.StatusSwitchingProtocols. - clientConnReader := bufio.NewReader(clientConn) - resp, err := http.ReadResponse(clientConnReader, r) + // Verify clientConn receives data sent by Config.ALPNHandler. + receive, err := bufio.NewReader(clientConn).ReadString(byte('@')) require.NoError(t, err) + require.Equal(t, expectedPayload, receive) + }) - // Always drain/close the body. - io.Copy(io.Discard, resp.Body) - _ = resp.Body.Close() + t.Run("upgraded to ALPN with Ping", func(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() - require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + sendConnUpgradeRequest(t, h, constants.WebAPIConnUpgradeTypeALPNPing, serverConn, clientConn) - // Verify clientConn receives data sent by Config.ALPNHandler. - receive, err := clientConnReader.ReadString(byte('@')) + // Verify ping-wrapped clientConn receives data sent by Config.ALPNHandler. + receive, err := bufio.NewReader(pingconn.New(clientConn)).ReadString(byte('@')) require.NoError(t, err) require.Equal(t, expectedPayload, receive) }) } +func sendConnUpgradeRequest(t *testing.T, h *Handler, upgradeType string, serverConn, clientConn net.Conn) { + t.Helper() + + r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil) + require.NoError(t, err) + r.Header.Add("Upgrade", upgradeType) + + go func() { + // serverConn will be hijacked. + w := newResponseWriterHijacker(nil, serverConn) + _, err := h.connectionUpgrade(w, r, nil) + require.NoError(t, err) + }() + + // Verify clientConn receives http.StatusSwitchingProtocols. + resp, err := http.ReadResponse(bufio.NewReader(clientConn), r) + require.NoError(t, err) + + // Always drain/close the body. + io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) +} + // responseWriterHijacker is a mock http.ResponseWriter that also serves a // net.Conn for http.Hijacker. type responseWriterHijacker struct { diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index e445e5cef8569..0089e10ca70d3 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -40,9 +40,9 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" - apiutils "github.com/gravitational/teleport/api/utils" libclient "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/db/dbcmd" "github.com/gravitational/teleport/lib/defaults" @@ -234,53 +234,31 @@ func dialSSHProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxy // if sp.tlsRouting is true, remoteProxyAddr is the ALPN listener port. // if it is false, then remoteProxyAddr is the SSH proxy port. remoteProxyAddr := net.JoinHostPort(sp.proxyHost, sp.proxyPort) - httpsProxy := apiutils.GetProxyURL(remoteProxyAddr) - pool, err := tc.LocalAgent().ClientCertPool(sp.clusterName) - if err != nil { - return nil, trace.Wrap(err) - } - - // If HTTPS_PROXY is configured, we need to open a TCP connection via - // the specified HTTPS Proxy, otherwise, we can just open a plain TCP - // connection. - var tcpConn net.Conn - if httpsProxy != nil { - httpProxyTLSConfig := &tls.Config{ - RootCAs: pool, - InsecureSkipVerify: tc.InsecureSkipVerify, - ServerName: httpsProxy.Hostname(), - } - tcpConn, err = client.DialProxy(ctx, httpsProxy, remoteProxyAddr, client.WithTLSConfig(httpProxyTLSConfig)) - if err != nil { - return nil, trace.Wrap(err) - } - } else { - tcpConn, err = (&net.Dialer{}).DialContext(ctx, "tcp", remoteProxyAddr) + var dialer client.ContextDialer + switch { + case sp.tlsRouting: + pool, err := tc.LocalAgent().ClientCertPool(sp.clusterName) if err != nil { return nil, trace.Wrap(err) } - } - // If TLS routing is not enabled, just return the TCP connection - if !sp.tlsRouting { - return tcpConn, nil - } + dialer = client.NewALPNDialer(client.ALPNDialerConfig{ + TLSConfig: &tls.Config{ + RootCAs: pool, + NextProtos: []string{string(alpncommon.ProtocolProxySSH)}, + InsecureSkipVerify: tc.InsecureSkipVerify, + ServerName: sp.proxyHost, + }, + ALPNConnUpgradeRequired: tc.IsALPNConnUpgradeRequiredForWebProxy(remoteProxyAddr), + }) - // Otherwise, we need to upgrade the TCP connection to a TLS connection. - tlsConfig := &tls.Config{ - RootCAs: pool, - NextProtos: []string{string(alpncommon.ProtocolProxySSH)}, - InsecureSkipVerify: tc.InsecureSkipVerify, - ServerName: sp.proxyHost, - } - tlsConn := tls.Client(tcpConn, tlsConfig) - if err := tlsConn.HandshakeContext(ctx); err != nil { - tlsConn.Close() - return nil, trace.Wrap(err) + default: + dialer = client.NewDialer(ctx, apidefaults.DefaultIdleTimeout, apidefaults.DefaultIOTimeout, client.WithInsecureSkipVerify(tc.InsecureSkipVerify)) } - return tlsConn, nil + conn, err := dialer.DialContext(ctx, "tcp", remoteProxyAddr) + return conn, trace.Wrap(err) } func proxySubsystemName(userHost, cluster string) string {