diff --git a/api/client/alpn_conn_upgrade.go b/api/client/alpn_conn_upgrade.go index d08b1fe7968e4..e80639663f147 100644 --- a/api/client/alpn_conn_upgrade.go +++ b/api/client/alpn_conn_upgrade.go @@ -173,9 +173,10 @@ func isALPNConnUpgradeRequiredByEnv(addr, envValue string) bool { // alpnConnUpgradeDialer makes an "HTTP" upgrade call to the Proxy Service then // tunnels the connection with this connection upgrade. type alpnConnUpgradeDialer struct { - dialer ContextDialer - tlsConfig *tls.Config - withPing bool + dialer ContextDialer + tlsConfig *tls.Config + withPing bool + useLegacyMode bool } // newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer. @@ -184,6 +185,8 @@ func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config, withP dialer: dialer, tlsConfig: tlsConfig, withPing: withPing, + // Only use "legacy" mode when it's explicitly set by the env var. + useLegacyMode: strings.ToLower(os.Getenv(defaults.TLSRoutingConnUpgradeModeEnvVar)) == "legacy", } } @@ -199,7 +202,7 @@ func (d *alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr s Path: constants.WebAPIConnUpgrade, } - conn, err := upgradeConnThroughWebAPI(tlsConn, upgradeURL, d.upgradeType()) + conn, err := upgradeConnThroughWebAPI(tlsConn, upgradeURL, d.upgradeType(), d.useLegacyMode) if err != nil { return nil, trace.NewAggregate(tlsConn.Close(), err) } @@ -213,7 +216,7 @@ func (d *alpnConnUpgradeDialer) upgradeType() string { return constants.WebAPIConnUpgradeTypeALPN } -func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string) (net.Conn, error) { +func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string, useLegacyMode bool) (net.Conn, error) { req, err := http.NewRequest(http.MethodGet, api.String(), nil) if err != nil { return nil, trace.Wrap(err) @@ -224,16 +227,12 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string return nil, trace.Wrap(err) } - // Prefer "websocket". - if useConnUpgradeMode.useWebSocket() { - applyWebSocketUpgradeHeaders(req, alpnUpgradeType, challengeKey) - } - - // Append "legacy" custom upgrade type. - // TODO(greedy52) DELETE in 17.0 - if useConnUpgradeMode.useLegacy() { + // Only set one mode at a time. + if useLegacyMode { req.Header.Add(constants.WebAPIConnUpgradeHeader, alpnUpgradeType) req.Header.Add(constants.WebAPIConnUpgradeTeleportHeader, alpnUpgradeType) + } else { + applyWebSocketUpgradeHeaders(req, alpnUpgradeType, challengeKey) } // Set "Connection" header to meet RFC spec: @@ -282,26 +281,9 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string } // Handle "legacy". - // TODO(greedy52) DELETE in 17.0. logger.DebugContext(req.Context(), "Performing ALPN legacy connection upgrade.") if alpnUpgradeType == constants.WebAPIConnUpgradeTypeALPNPing { return pingconn.New(conn), nil } return conn, nil } - -type connUpgradeMode string - -func (m connUpgradeMode) useWebSocket() bool { - // Use WebSocket as long as it's not legacy only. - return strings.ToLower(string(m)) != "legacy" -} - -func (m connUpgradeMode) useLegacy() bool { - // Use legacy as long as it's not WebSocket only. - return strings.ToLower(string(m)) != "websocket" -} - -var ( - useConnUpgradeMode connUpgradeMode = connUpgradeMode(os.Getenv(defaults.TLSRoutingConnUpgradeModeEnvVar)) -) diff --git a/api/client/alpn_conn_upgrade_test.go b/api/client/alpn_conn_upgrade_test.go index 6fae34d15b56c..437a8a115c6a0 100644 --- a/api/client/alpn_conn_upgrade_test.go +++ b/api/client/alpn_conn_upgrade_test.go @@ -163,16 +163,17 @@ func TestALPNConnUpgradeDialer(t *testing.T) { name string serverHandler http.Handler withPing bool + useLegacyMode bool wantError bool }{ { - // TODO(greedy52) DELETE in 17.0 name: "connection upgrade (legacy)", + useLegacyMode: true, serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")), }, { - // TODO(greedy52) DELETE in 17.0 name: "connection upgrade with ping (legacy)", + useLegacyMode: true, serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")), withPing: true, }, @@ -210,6 +211,8 @@ func TestALPNConnUpgradeDialer(t *testing.T) { t.Run("direct", func(t *testing.T) { dialer := newALPNConnUpgradeDialer(directDialer, tlsConfig, test.withPing) + dialer.(*alpnConnUpgradeDialer).useLegacyMode = test.useLegacyMode + conn, err := dialer.DialContext(ctx, "tcp", addr.Host) if test.wantError { require.Error(t, err) @@ -227,6 +230,8 @@ func TestALPNConnUpgradeDialer(t *testing.T) { proxyURLDialer := newProxyURLDialer(forwardProxyURL, directDialer) dialer := newALPNConnUpgradeDialer(proxyURLDialer, tlsConfig, test.withPing) + dialer.(*alpnConnUpgradeDialer).useLegacyMode = test.useLegacyMode + conn, err := dialer.DialContext(ctx, "tcp", addr.Host) if test.wantError { require.Error(t, err) @@ -422,43 +427,3 @@ func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) { go http.Serve(listener, handler) return handler, url } - -func Test_connUpgradeMode(t *testing.T) { - tests := []struct { - envVarValue string - wantUseWebSocket require.BoolAssertionFunc - wantUseLegacy require.BoolAssertionFunc - }{ - { - envVarValue: "", - wantUseWebSocket: require.True, - wantUseLegacy: require.True, - }, - { - envVarValue: "WebSocket", - wantUseWebSocket: require.True, - wantUseLegacy: require.False, - }, - { - envVarValue: "websocket", - wantUseWebSocket: require.True, - wantUseLegacy: require.False, - }, - { - envVarValue: "legacy", - wantUseWebSocket: require.False, - wantUseLegacy: require.True, - }, - { - envVarValue: "default", - wantUseWebSocket: require.True, - wantUseLegacy: require.True, - }, - } - - for _, test := range tests { - mode := connUpgradeMode(test.envVarValue) - test.wantUseWebSocket(t, mode.useWebSocket()) - test.wantUseLegacy(t, mode.useLegacy()) - } -}