Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
094fa43
ALPN connect test improvements
greedy52 Mar 19, 2023
7e848cf
fix typos
greedy52 Mar 21, 2023
2c9b116
remove extra period
greedy52 Mar 21, 2023
54019f3
simplify error check
greedy52 Mar 28, 2023
d98dd80
moving things over
greedy52 Mar 28, 2023
f3399cb
tsh dials
greedy52 Mar 28, 2023
9ace6ca
reverse tunnel
greedy52 Mar 29, 2023
1b4f4c5
fix auth connect
greedy52 Mar 29, 2023
c688918
move ping
greedy52 Mar 29, 2023
cf113f2
Merge branch 'master' of github.com:gravitational/teleport into STeve…
greedy52 Mar 29, 2023
1928f65
add ssh support
greedy52 Mar 30, 2023
b6d6c8a
add HTTP client support
greedy52 Mar 30, 2023
1d5760a
Move ALPN dialer, ALPN conn upgrade, Ping conn to api
greedy52 Mar 30, 2023
83abe8b
Merge branch 'STeve/21870_move_to_api' into STeve/21870_ssh_auth_reve…
greedy52 Mar 30, 2023
a5de0d9
beatify
greedy52 Mar 30, 2023
29f493a
add test
greedy52 Apr 1, 2023
7a8a309
beautify round 2
greedy52 Apr 2, 2023
511b9d0
fix timeout
greedy52 Apr 2, 2023
12f9f71
Merge branch 'master' of github.com:gravitational/teleport into STeve…
greedy52 Apr 6, 2023
ff3c869
Implement alpn-ping upgrade for reversetunnel and ssh
greedy52 Apr 6, 2023
922fc8b
clean up
greedy52 Apr 6, 2023
5c576c0
fix proxy test
greedy52 Apr 7, 2023
b1c76ee
Merge branch 'master' of github.com:gravitational/teleport into STeve…
greedy52 Apr 7, 2023
a3e4735
minor refactor
greedy52 Apr 12, 2023
e619b94
remove WebProxyAddr
greedy52 Apr 13, 2023
9ada770
require IsALPNConnUpgradeRequiredFunc
greedy52 Apr 13, 2023
4e0c056
add tlsRoutingWithConnUpgradeConnect
greedy52 Apr 13, 2023
a3d664f
fix lint
greedy52 Apr 13, 2023
e6b4825
simplify
greedy52 Apr 14, 2023
fd185ad
remove debug log and change unknown upgrade type to 404
greedy52 Apr 14, 2023
9165f12
Force new proxy client to use web proxy when TLS routing is enabled
greedy52 Apr 14, 2023
6521050
Merge branch 'master' of github.com:gravitational/teleport into STeve…
greedy52 Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 100 additions & 9 deletions api/client/alpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,28 @@ package client
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"strings"
"time"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client/webclient"
"github.com/gravitational/teleport/api/constants"
)

// GetClusterCAsFunc is a function to fetch cluster CAs.
type GetClusterCAsFunc func(ctx context.Context) (*x509.CertPool, error)

// ClusterCAsFromCertPool returns a GetClusterCAsFunc with provided static cert
// pool.
func ClusterCAsFromCertPool(cas *x509.CertPool) GetClusterCAsFunc {
return func(_ context.Context) (*x509.CertPool, error) {
return cas, nil
}
}

// ALPNDialerConfig is the config for ALPNDialer.
type ALPNDialerConfig struct {
// KeepAlivePeriod defines period between keep alives.
Expand All @@ -35,12 +51,17 @@ type ALPNDialerConfig struct {
TLSConfig *tls.Config
// ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required.
ALPNConnUpgradeRequired bool
// GetClusterCAs is an optional callback function to fetch cluster
// CAs when connection upgrade is required. If not provided, it's assumed
// the proper CAs are already present in TLSConfig.
GetClusterCAs GetClusterCAsFunc
}

// ALPNDialer is a ContextDialer that dials a connection to the Proxy Service
// with ALPN and SNI configured in the provided TLSConfig. An ALPN connection
// upgrade is also performed at the initial connection, if an upgrade is
// required.
// required. If the negotiated protocol is a Ping protocol, it will return the
// de-multiplexed connection without the Ping.
type ALPNDialer struct {
cfg ALPNDialerConfig
}
Expand All @@ -52,25 +73,73 @@ func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer {
}
}

// DialContext implements ContextDialer.
func (d ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
func (d *ALPNDialer) shouldUpdateTLSConfig() bool {
return d.shouldUpdateServerName() || d.shouldGetClusterCAs()
}

// shouldUpdateServerName returns true if ServerName is not in the provided TLS
// config. It will default to the host of the dialing address.
func (d *ALPNDialer) shouldUpdateServerName() bool {
return d.cfg.TLSConfig.ServerName == ""
}

// shouldGetClusterCAs returns true if RootCAs of the provided TLS config needs
// to be set to the Teleport cluster CAs.
//
// When Teleport Proxy is behind a L7 load balancer, the load balancer
// usually terminates TLS with public certs, and the Proxy is usually in
// private subnets with self-signed web certs. During the connection
// upgrade flow for TLS Routing, instead of serving these self-signed web
// certs, the TLS Routing handler at the Proxy server will present the
// Cluster CAs so clients here can still verify the server.
func (d *ALPNDialer) shouldGetClusterCAs() bool {
return d.cfg.ALPNConnUpgradeRequired && d.cfg.TLSConfig.RootCAs == nil && d.cfg.GetClusterCAs != nil
}
Comment thread
greedy52 marked this conversation as resolved.

func (d *ALPNDialer) getTLSConfig(ctx context.Context, addr string) (*tls.Config, error) {
if d.cfg.TLSConfig == nil {
return nil, trace.BadParameter("missing TLS config")
}
if !d.shouldUpdateTLSConfig() {
return d.cfg.TLSConfig, nil
}
Comment thread
greedy52 marked this conversation as resolved.

dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout, WithTLSConfig(d.cfg.TLSConfig))
if d.cfg.ALPNConnUpgradeRequired {
dialer = newALPNConnUpgradeDialer(dialer, &tls.Config{
InsecureSkipVerify: d.cfg.TLSConfig.InsecureSkipVerify,
})
var err error
tlsConfig := d.cfg.TLSConfig.Clone()
if d.shouldGetClusterCAs() {
tlsConfig.RootCAs, err = d.cfg.GetClusterCAs(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
}
Comment thread
greedy52 marked this conversation as resolved.
if d.shouldUpdateServerName() {
tlsConfig.ServerName, _, err = webclient.ParseHostPort(addr)
if err != nil {
return nil, trace.Wrap(err)
}
}
return tlsConfig, nil
}

// DialContext implements ContextDialer.
func (d *ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
tlsConfig, err := d.getTLSConfig(ctx, addr)
if err != nil {
return nil, trace.Wrap(err)
}

dialer := NewDialer(ctx, d.cfg.DialTimeout, d.cfg.DialTimeout,
WithInsecureSkipVerify(d.cfg.TLSConfig.InsecureSkipVerify),
WithALPNConnUpgrade(d.cfg.ALPNConnUpgradeRequired),
WithALPNConnUpgradePing(shouldALPNConnUpgradeWithPing(tlsConfig)),
)

conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, trace.Wrap(err)
}

tlsConn := tls.Client(conn, d.cfg.TLSConfig)
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
defer tlsConn.Close()
return nil, trace.Wrap(err)
Expand All @@ -91,3 +160,25 @@ func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (*tls.Conn
}
return tlsConn, nil
}

// IsALPNPingProtocol checks if the provided protocol is suffixed with Ping.
func IsALPNPingProtocol(protocol string) bool {
return strings.HasSuffix(protocol, constants.ALPNSNIProtocolPingSuffix)
}

// shouldALPNConnUpgradeWithPing returns true if Ping wrapper is required
// during connection upgrade.
func shouldALPNConnUpgradeWithPing(config *tls.Config) bool {
for _, proto := range config.NextProtos {
switch proto {
// Server usually sends SSH keepalives or HTTP2 pings every five
// minutes for reverse tunnel and SSH connections. Load balancers
// usually have a shorter idle timeout. Thus wrapping the connection
// with Ping protocol at the connection upgrade layer to keepalive.
case constants.ALPNSNIProtocolReverseTunnel,
Comment thread
greedy52 marked this conversation as resolved.
constants.ALPNSNIProtocolSSH:
return true
}
}
return false
}
50 changes: 32 additions & 18 deletions api/client/alpn_conn_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/pingconn"
)

// IsALPNConnUpgradeRequired returns true if a tunnel is required through a HTTP
Expand Down Expand Up @@ -145,18 +146,20 @@ func isALPNConnUpgradeRequiredByEnv(addr, envValue string) bool {
type alpnConnUpgradeDialer struct {
dialer ContextDialer
tlsConfig *tls.Config
withPing bool
}

// newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer.
func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config) ContextDialer {
func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config, withPing bool) ContextDialer {
return &alpnConnUpgradeDialer{
dialer: dialer,
tlsConfig: tlsConfig,
withPing: withPing,
}
}

// DialContext implements ContextDialer
func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
func (d *alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
logrus.Debugf("ALPN connection upgrade for %v.", addr)

conn, err := d.dialer.DialContext(ctx, network, addr)
Expand All @@ -181,47 +184,58 @@ func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr st
}

tlsConn := tls.Client(conn, cfg)

err = upgradeConnThroughWebAPI(tlsConn, url.URL{
upgradeURL := url.URL{
Host: addr,
Scheme: "https",
Path: constants.WebAPIConnUpgrade,
})
}

conn, err = upgradeConnThroughWebAPI(tlsConn, upgradeURL, d.upgradeType())
if err != nil {
defer tlsConn.Close()
return nil, trace.Wrap(err)
return nil, trace.NewAggregate(tlsConn.Close(), err)
}
return tlsConn, nil
return conn, nil
}

func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error {
func (d *alpnConnUpgradeDialer) upgradeType() string {
if d.withPing {
return constants.WebAPIConnUpgradeTypeALPNPing
}
return constants.WebAPIConnUpgradeTypeALPN
}

func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (net.Conn, error) {
req, err := http.NewRequest(http.MethodGet, api.String(), nil)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}

// For now, only "alpn" is supported.
req.Header.Add(constants.WebAPIConnUpgradeHeader, constants.WebAPIConnUpgradeTypeALPN)
req.Header.Add(constants.WebAPIConnUpgradeHeader, upgradeType)

// Send the request and check if upgrade is successful.
if err = req.Write(conn); err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}
defer resp.Body.Close()

if http.StatusSwitchingProtocols != resp.StatusCode {
if http.StatusNotFound == resp.StatusCode {
return trace.NotImplemented(
"connection upgrade call to %q failed with status code %v. Please upgrade the server and try again.",
return nil, trace.NotImplemented(
"connection upgrade call to %q with upgrade type %v failed with status code %v. Please upgrade the server and try again.",
constants.WebAPIConnUpgrade,
upgradeType,
resp.StatusCode,
)
}
return trace.BadParameter("failed to switch Protocols %v", resp.StatusCode)
return nil, trace.BadParameter("failed to switch Protocols %v", resp.StatusCode)
}

if upgradeType == constants.WebAPIConnUpgradeTypeALPNPing {
return pingconn.New(conn), nil
}
return nil
return conn, nil
}
104 changes: 67 additions & 37 deletions api/client/alpn_conn_upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/fixtures"
"github.com/gravitational/teleport/api/utils/pingconn"
)

func TestIsALPNConnUpgradeRequired(t *testing.T) {
Expand Down Expand Up @@ -123,42 +124,58 @@ func TestIsALPNConnUpgradeRequiredByEnv(t *testing.T) {
func TestALPNConnUpgradeDialer(t *testing.T) {
t.Parallel()

t.Run("connection upgraded", func(t *testing.T) {
ctx := context.Background()

server := httptest.NewTLSServer(mockConnUpgradeHandler(t, "alpn", []byte("hello")))
t.Cleanup(server.Close)
addr, err := url.Parse(server.URL)
require.NoError(t, err)
pool := x509.NewCertPool()
pool.AddCert(server.Certificate())

tlsConfig := &tls.Config{RootCAs: pool}
preDialer := NewDialer(ctx, 0, 5*time.Second)
dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig)
conn, err := dialer.DialContext(ctx, "tcp", addr.Host)
require.NoError(t, err)

data := make([]byte, 100)
n, err := conn.Read(data)
require.NoError(t, err)
require.Equal(t, string(data[:n]), "hello")
})

t.Run("connection upgrade API not found", func(t *testing.T) {
ctx := context.Background()

server := httptest.NewTLSServer(http.NotFoundHandler())
t.Cleanup(server.Close)
addr, err := url.Parse(server.URL)
require.NoError(t, err)
tests := []struct {
name string
serverHandler http.Handler
withPing bool
wantError bool
}{
{
name: "connection upgrade",
serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
},
{
name: "connection upgrade with ping",
serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
withPing: true,
},
{
name: "connection upgrade API not found",
serverHandler: http.NotFoundHandler(),
wantError: true,
},
}

tlsConfig := &tls.Config{InsecureSkipVerify: true}
preDialer := NewDialer(ctx, 0, 5*time.Second)
dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig)
_, err = dialer.DialContext(ctx, "tcp", addr.Host)
require.Error(t, err)
})
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()

server := httptest.NewTLSServer(test.serverHandler)
t.Cleanup(server.Close)
addr, err := url.Parse(server.URL)
require.NoError(t, err)
pool := x509.NewCertPool()
pool.AddCert(server.Certificate())

tlsConfig := &tls.Config{RootCAs: pool}
preDialer := newDirectDialer(0, 5*time.Second)
dialer := newALPNConnUpgradeDialer(preDialer, tlsConfig, test.withPing)
conn, err := dialer.DialContext(ctx, "tcp", addr.Host)
if test.wantError {
require.Error(t, err)
return
}
require.NoError(t, err)
defer conn.Close()

data := make([]byte, 100)
n, err := conn.Read(data)
require.NoError(t, err)
require.Equal(t, string(data[:n]), "hello")
})
}
}

type mockALPNServer struct {
Expand Down Expand Up @@ -218,6 +235,8 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe
// mockConnUpgradeHandler mocks the server side implementation to handle an
// upgrade request and sends back some data inside the tunnel.
func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
t.Helper()

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path)
require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeHeader))
Expand All @@ -238,7 +257,18 @@ func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http
require.NoError(t, response.Write(conn))

// Upgraded.
_, err = conn.Write(write)
require.NoError(t, err)
switch upgradeType {
case constants.WebAPIConnUpgradeTypeALPNPing:
// Wrap conn with Ping and write some pings.
pingConn := pingconn.New(conn)
pingConn.WritePing()
_, err = pingConn.Write(write)
require.NoError(t, err)
pingConn.WritePing()

default:
_, err = conn.Write(write)
require.NoError(t, err)
}
})
}
Loading