Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 12 additions & 30 deletions api/client/alpn_conn_upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ func isALPNConnUpgradeRequiredByEnv(addr, envValue string) bool {
// alpnConnUpgradeDialer makes an "HTTP" upgrade call to the Proxy Service then
// tunnels the connection with this connection upgrade.
type alpnConnUpgradeDialer struct {
dialer ContextDialer
tlsConfig *tls.Config
withPing bool
dialer ContextDialer
tlsConfig *tls.Config
withPing bool
useLegacyMode bool
}

// newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer.
Expand All @@ -184,6 +185,8 @@ func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config, withP
dialer: dialer,
tlsConfig: tlsConfig,
withPing: withPing,
// Only use "legacy" mode when it's explicitly set by the env var.
useLegacyMode: strings.ToLower(os.Getenv(defaults.TLSRoutingConnUpgradeModeEnvVar)) == "legacy",
}
}

Expand All @@ -199,7 +202,7 @@ func (d *alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr s
Path: constants.WebAPIConnUpgrade,
}

conn, err := upgradeConnThroughWebAPI(tlsConn, upgradeURL, d.upgradeType())
conn, err := upgradeConnThroughWebAPI(tlsConn, upgradeURL, d.upgradeType(), d.useLegacyMode)
if err != nil {
return nil, trace.NewAggregate(tlsConn.Close(), err)
}
Expand All @@ -213,7 +216,7 @@ func (d *alpnConnUpgradeDialer) upgradeType() string {
return constants.WebAPIConnUpgradeTypeALPN
}

func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string) (net.Conn, error) {
func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string, useLegacyMode bool) (net.Conn, error) {
req, err := http.NewRequest(http.MethodGet, api.String(), nil)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -224,16 +227,12 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string
return nil, trace.Wrap(err)
}

// Prefer "websocket".
if useConnUpgradeMode.useWebSocket() {
applyWebSocketUpgradeHeaders(req, alpnUpgradeType, challengeKey)
}

// Append "legacy" custom upgrade type.
// TODO(greedy52) DELETE in 17.0
if useConnUpgradeMode.useLegacy() {
// Only set one mode at a time.
if useLegacyMode {
req.Header.Add(constants.WebAPIConnUpgradeHeader, alpnUpgradeType)
req.Header.Add(constants.WebAPIConnUpgradeTeleportHeader, alpnUpgradeType)
} else {
applyWebSocketUpgradeHeaders(req, alpnUpgradeType, challengeKey)
}

// Set "Connection" header to meet RFC spec:
Expand Down Expand Up @@ -282,26 +281,9 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string
}

// Handle "legacy".
// TODO(greedy52) DELETE in 17.0.
logger.DebugContext(req.Context(), "Performing ALPN legacy connection upgrade.")
if alpnUpgradeType == constants.WebAPIConnUpgradeTypeALPNPing {
return pingconn.New(conn), nil
}
return conn, nil
}

type connUpgradeMode string

func (m connUpgradeMode) useWebSocket() bool {
// Use WebSocket as long as it's not legacy only.
return strings.ToLower(string(m)) != "legacy"
}

func (m connUpgradeMode) useLegacy() bool {
// Use legacy as long as it's not WebSocket only.
return strings.ToLower(string(m)) != "websocket"
}

var (
useConnUpgradeMode connUpgradeMode = connUpgradeMode(os.Getenv(defaults.TLSRoutingConnUpgradeModeEnvVar))
)
49 changes: 7 additions & 42 deletions api/client/alpn_conn_upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,17 @@ func TestALPNConnUpgradeDialer(t *testing.T) {
name string
serverHandler http.Handler
withPing bool
useLegacyMode bool
wantError bool
}{
{
// TODO(greedy52) DELETE in 17.0
name: "connection upgrade (legacy)",
useLegacyMode: true,
serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")),
},
{
// TODO(greedy52) DELETE in 17.0
name: "connection upgrade with ping (legacy)",
useLegacyMode: true,
serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")),
withPing: true,
},
Expand Down Expand Up @@ -210,6 +211,8 @@ func TestALPNConnUpgradeDialer(t *testing.T) {

t.Run("direct", func(t *testing.T) {
dialer := newALPNConnUpgradeDialer(directDialer, tlsConfig, test.withPing)
dialer.(*alpnConnUpgradeDialer).useLegacyMode = test.useLegacyMode

conn, err := dialer.DialContext(ctx, "tcp", addr.Host)
if test.wantError {
require.Error(t, err)
Expand All @@ -227,6 +230,8 @@ func TestALPNConnUpgradeDialer(t *testing.T) {

proxyURLDialer := newProxyURLDialer(forwardProxyURL, directDialer)
dialer := newALPNConnUpgradeDialer(proxyURLDialer, tlsConfig, test.withPing)
dialer.(*alpnConnUpgradeDialer).useLegacyMode = test.useLegacyMode

conn, err := dialer.DialContext(ctx, "tcp", addr.Host)
if test.wantError {
require.Error(t, err)
Expand Down Expand Up @@ -422,43 +427,3 @@ func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) {
go http.Serve(listener, handler)
return handler, url
}

func Test_connUpgradeMode(t *testing.T) {
tests := []struct {
envVarValue string
wantUseWebSocket require.BoolAssertionFunc
wantUseLegacy require.BoolAssertionFunc
}{
{
envVarValue: "",
wantUseWebSocket: require.True,
wantUseLegacy: require.True,
},
{
envVarValue: "WebSocket",
wantUseWebSocket: require.True,
wantUseLegacy: require.False,
},
{
envVarValue: "websocket",
wantUseWebSocket: require.True,
wantUseLegacy: require.False,
},
{
envVarValue: "legacy",
wantUseWebSocket: require.False,
wantUseLegacy: require.True,
},
{
envVarValue: "default",
wantUseWebSocket: require.True,
wantUseLegacy: require.True,
},
}

for _, test := range tests {
mode := connUpgradeMode(test.envVarValue)
test.wantUseWebSocket(t, mode.useWebSocket())
test.wantUseLegacy(t, mode.useLegacy())
}
}