diff --git a/api/client/alpn_conn_upgrade.go b/api/client/alpn_conn_upgrade.go index a700021581ec6..76cf524950156 100644 --- a/api/client/alpn_conn_upgrade.go +++ b/api/client/alpn_conn_upgrade.go @@ -113,7 +113,11 @@ func isUnadvertisedALPNError(err error) bool { // OverwriteALPNConnUpgradeRequirementByEnv overwrites ALPN connection upgrade // requirement by environment variable. // -// TODO(greedy52) DELETE in 15.0 +// TODO(greedy52) DELETE in ??. Note that this toggle was planned to be deleted +// in 15.0 when the feature exits preview. However, many users still rely on +// this manual toggle as IsALPNConnUpgradeRequired cannot detect many +// situations where connection upgrade is required. This can be deleted once +// IsALPNConnUpgradeRequired is improved. func OverwriteALPNConnUpgradeRequirementByEnv(addr string) (bool, bool) { envValue := os.Getenv(defaults.TLSRoutingConnUpgradeEnvVar) if envValue == "" { @@ -184,8 +188,6 @@ func newALPNConnUpgradeDialer(dialer ContextDialer, tlsConfig *tls.Config, withP // DialContext implements ContextDialer func (d *alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - logrus.Debugf("ALPN connection upgrade for %v.", addr) - tlsConn, err := tlsutils.TLSDial(ctx, d.dialer, network, addr, d.tlsConfig.Clone()) if err != nil { return nil, trace.Wrap(err) @@ -210,14 +212,28 @@ func (d *alpnConnUpgradeDialer) upgradeType() string { return constants.WebAPIConnUpgradeTypeALPN } -func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (net.Conn, error) { +func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, alpnUpgradeType string) (net.Conn, error) { req, err := http.NewRequest(http.MethodGet, api.String(), nil) if err != nil { return nil, trace.Wrap(err) } - req.Header.Add(constants.WebAPIConnUpgradeHeader, upgradeType) - req.Header.Add(constants.WebAPIConnUpgradeTeleportHeader, upgradeType) + challengeKey, err := generateWebSocketChallengeKey() + if err != nil { + return nil, trace.Wrap(err) + } + + // Prefer "websocket". + if useConnUpgradeMode.useWebSocket() { + applyWebSocketUpgradeHeaders(req, alpnUpgradeType, challengeKey) + } + + // Append "legacy" custom upgrade type. + // TODO(greedy52) DELETE in 17.0 + if useConnUpgradeMode.useLegacy() { + req.Header.Add(constants.WebAPIConnUpgradeHeader, alpnUpgradeType) + req.Header.Add(constants.WebAPIConnUpgradeTeleportHeader, alpnUpgradeType) + } // Set "Connection" header to meet RFC spec: // https://datatracker.ietf.org/doc/html/rfc2616#section-14.42 @@ -229,7 +245,7 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (n // require this header to be set to complete the upgrade flow. The header // must be set on both the upgrade request here and the 101 Switching // Protocols response from the server. - req.Header.Add(constants.WebAPIConnUpgradeConnectionHeader, constants.WebAPIConnUpgradeConnectionType) + req.Header.Set(constants.WebAPIConnUpgradeConnectionHeader, constants.WebAPIConnUpgradeConnectionType) // Send the request and check if upgrade is successful. if err = req.Write(conn); err != nil { @@ -246,15 +262,44 @@ func upgradeConnThroughWebAPI(conn net.Conn, api url.URL, upgradeType string) (n return nil, trace.NotImplemented( "connection upgrade call to %q with upgrade type %v failed with status code %v. Please upgrade the server and try again.", constants.WebAPIConnUpgrade, - upgradeType, + alpnUpgradeType, resp.StatusCode, ) } return nil, trace.BadParameter("failed to switch Protocols %v", resp.StatusCode) } - if upgradeType == constants.WebAPIConnUpgradeTypeALPNPing { + // Handle WebSocket. + if resp.Header.Get(constants.WebAPIConnUpgradeHeader) == constants.WebAPIConnUpgradeTypeWebSocket { + if err := checkWebSocketUpgradeResponse(resp, alpnUpgradeType, challengeKey); err != nil { + return nil, trace.Wrap(err) + } + + logrus.WithField("hostname", api.Host).Debug("Performing ALPN WebSocket connection upgrade.") + return newWebSocketALPNClientConn(conn), nil + } + + // Handle "legacy". + // TODO(greedy52) DELETE in 17.0. + logrus.WithField("hostname", api.Host).Debug("Performing ALPN legacy connection upgrade.") + if alpnUpgradeType == constants.WebAPIConnUpgradeTypeALPNPing { return pingconn.New(conn), nil } return conn, nil } + +type connUpgradeMode string + +func (m connUpgradeMode) useWebSocket() bool { + // Use WebSocket as long as it's not legacy only. + return strings.ToLower(string(m)) != "legacy" +} + +func (m connUpgradeMode) useLegacy() bool { + // Use legacy as long as it's not WebSocket only. + return strings.ToLower(string(m)) != "websocket" +} + +var ( + useConnUpgradeMode connUpgradeMode = connUpgradeMode(os.Getenv(defaults.TLSRoutingConnUpgradeModeEnvVar)) +) diff --git a/api/client/alpn_conn_upgrade_test.go b/api/client/alpn_conn_upgrade_test.go index 2d0c8fda0d208..6fae34d15b56c 100644 --- a/api/client/alpn_conn_upgrade_test.go +++ b/api/client/alpn_conn_upgrade_test.go @@ -20,6 +20,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/base64" "errors" "net" "net/http" @@ -28,6 +29,7 @@ import ( "testing" "time" + "github.com/gobwas/ws" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -164,12 +166,23 @@ func TestALPNConnUpgradeDialer(t *testing.T) { wantError bool }{ { - name: "connection upgrade", - serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")), + // TODO(greedy52) DELETE in 17.0 + name: "connection upgrade (legacy)", + serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")), }, { - name: "connection upgrade with ping", - serverHandler: mockConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")), + // TODO(greedy52) DELETE in 17.0 + name: "connection upgrade with ping (legacy)", + serverHandler: mockLegacyConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")), + withPing: true, + }, + { + name: "connection upgrade (WebSocket)", + serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPN, []byte("hello")), + }, + { + name: "connection upgrade with ping (WebSocket)", + serverHandler: mockWebSocketConnUpgradeHandler(t, constants.WebAPIConnUpgradeTypeALPNPing, []byte("hello")), withPing: true, }, { @@ -230,11 +243,27 @@ func TestALPNConnUpgradeDialer(t *testing.T) { } func mustReadConnData(t *testing.T, conn net.Conn, wantText string) { - data := make([]byte, len(wantText)*2) + t.Helper() + + require.NotEmpty(t, wantText) + + // Use a small buffer. + bufferSize := len(wantText) - 1 + data := make([]byte, bufferSize) n, err := conn.Read(data) require.NoError(t, err) - require.Len(t, wantText, n) - require.Equal(t, wantText, string(data[:n])) + require.Equal(t, bufferSize, n) + actualText := string(data) + + // Now read it again to get the full text. This tests + // websocketALPNClientConn.readBuffer is implemented correctly. + data = make([]byte, bufferSize) + n, err = conn.Read(data) + require.NoError(t, err) + require.Equal(t, 1, n) + actualText += string(data[:1]) + + require.Equal(t, wantText, actualText) } type mockALPNServer struct { @@ -291,15 +320,15 @@ func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNSe return m } -// mockConnUpgradeHandler mocks the server side implementation to handle an -// upgrade request and sends back some data inside the tunnel. -func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler { +// mockLegacyConnUpgradeHandler mocks the server side implementation to handle +// an upgrade request and sends back some data inside the tunnel. +func mockLegacyConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler { t.Helper() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path) - require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeHeader)) - require.Equal(t, upgradeType, r.Header.Get(constants.WebAPIConnUpgradeTeleportHeader)) + require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), upgradeType) + require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeTeleportHeader), upgradeType) require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader)) hj, ok := w.(http.Hijacker) @@ -334,6 +363,49 @@ func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http }) } +// mockWebSocketConnUpgradeHandler mocks the server side implementation to handle +// a WebSocket upgrade request and sends back some data inside the tunnel. +func mockWebSocketConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler { + t.Helper() + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, constants.WebAPIConnUpgrade, r.URL.Path) + require.Contains(t, r.Header.Values(constants.WebAPIConnUpgradeHeader), "websocket") + require.Equal(t, constants.WebAPIConnUpgradeConnectionType, r.Header.Get(constants.WebAPIConnUpgradeConnectionHeader)) + require.Equal(t, upgradeType, r.Header.Get("Sec-Websocket-Protocol")) + require.Equal(t, "13", r.Header.Get("Sec-Websocket-Version")) + + challengeKey := r.Header.Get("Sec-Websocket-Key") + challengeKeyDecoded, err := base64.StdEncoding.DecodeString(challengeKey) + require.NoError(t, err) + require.Len(t, challengeKeyDecoded, 16) + + hj, ok := w.(http.Hijacker) + require.True(t, ok) + + conn, _, err := hj.Hijack() + require.NoError(t, err) + defer conn.Close() + + // Upgrade response. + response := &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + response.Header.Set("Upgrade", "websocket") + response.Header.Set("Sec-WebSocket-Protocol", upgradeType) + response.Header.Set("Sec-WebSocket-Accept", computeWebSocketAcceptKey(challengeKey)) + require.NoError(t, response.Write(conn)) + + // Upgraded. + frame := ws.NewFrame(ws.OpBinary, true, write) + frame.Header.Masked = true + require.NoError(t, ws.WriteFrame(conn, frame)) + }) +} + func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) { t.Helper() @@ -350,3 +422,43 @@ func mustStartForwardProxy(t *testing.T) (*testhelpers.ProxyHandler, *url.URL) { go http.Serve(listener, handler) return handler, url } + +func Test_connUpgradeMode(t *testing.T) { + tests := []struct { + envVarValue string + wantUseWebSocket require.BoolAssertionFunc + wantUseLegacy require.BoolAssertionFunc + }{ + { + envVarValue: "", + wantUseWebSocket: require.True, + wantUseLegacy: require.True, + }, + { + envVarValue: "WebSocket", + wantUseWebSocket: require.True, + wantUseLegacy: require.False, + }, + { + envVarValue: "websocket", + wantUseWebSocket: require.True, + wantUseLegacy: require.False, + }, + { + envVarValue: "legacy", + wantUseWebSocket: require.False, + wantUseLegacy: require.True, + }, + { + envVarValue: "default", + wantUseWebSocket: require.True, + wantUseLegacy: require.True, + }, + } + + for _, test := range tests { + mode := connUpgradeMode(test.envVarValue) + test.wantUseWebSocket(t, mode.useWebSocket()) + test.wantUseLegacy(t, mode.useLegacy()) + } +} diff --git a/api/client/alpn_websocket.go b/api/client/alpn_websocket.go new file mode 100644 index 0000000000000..3894d5598c32c --- /dev/null +++ b/api/client/alpn_websocket.go @@ -0,0 +1,162 @@ +/* +Copyright 2024 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "io" + "net" + "net/http" + "sync" + + "github.com/gobwas/ws" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/constants" +) + +func applyWebSocketUpgradeHeaders(req *http.Request, alpnUpgradeType, challengeKey string) { + // Set standard WebSocket upgrade type. + req.Header.Add(constants.WebAPIConnUpgradeHeader, constants.WebAPIConnUpgradeTypeWebSocket) + + // Set "Connection" header to meet RFC spec: + // https://datatracker.ietf.org/doc/html/rfc2616#section-14.42 + // Quote: "the upgrade keyword MUST be supplied within a Connection header + // field (section 14.10) whenever Upgrade is present in an HTTP/1.1 + // message." + req.Header.Set(constants.WebAPIConnUpgradeConnectionHeader, constants.WebAPIConnUpgradeConnectionType) + + // Set alpnUpgradeType as sub protocol. + req.Header.Set(websocketHeaderKeyProtocol, alpnUpgradeType) + req.Header.Set(websocketHeaderKeyVersion, "13") + req.Header.Set(websocketHeaderKeyChallengeKey, challengeKey) +} + +func computeWebSocketAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write([]byte(websocketAcceptKeyMagicString)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateWebSocketChallengeKey() (string, error) { + // Quote from https://www.rfc-editor.org/rfc/rfc6455: + // + // A |Sec-WebSocket-Key| header field with a base64-encoded (see Section 4 + // of [RFC4648]) value that, when decoded, is 16 bytes in length. + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", trace.Wrap(err) + } + return base64.StdEncoding.EncodeToString(p), nil +} + +func checkWebSocketUpgradeResponse(resp *http.Response, alpnUpgradeType, challengeKey string) error { + if alpnUpgradeType != resp.Header.Get(websocketHeaderKeyProtocol) { + return trace.BadParameter("WebSocket handshake failed: Sec-WebSocket-Protocol does not match") + } + if computeWebSocketAcceptKey(challengeKey) != resp.Header.Get(websocketHeaderKeyAccept) { + return trace.BadParameter("WebSocket handshake failed: invalid Sec-WebSocket-Accept") + } + return nil +} + +type websocketALPNClientConn struct { + net.Conn + readBuffer []byte + readMutex sync.Mutex + writeMutex sync.Mutex +} + +func newWebSocketALPNClientConn(conn net.Conn) *websocketALPNClientConn { + return &websocketALPNClientConn{ + Conn: conn, + } +} + +func (c *websocketALPNClientConn) Read(b []byte) (int, error) { + c.readMutex.Lock() + defer c.readMutex.Unlock() + + n, err := c.readLocked(b) + return n, trace.Wrap(err) +} + +func (c *websocketALPNClientConn) readLocked(b []byte) (int, error) { + if len(c.readBuffer) > 0 { + n := copy(b, c.readBuffer) + if n < len(c.readBuffer) { + c.readBuffer = c.readBuffer[n:] + } else { + c.readBuffer = nil + } + return n, nil + } + + for { + frame, err := ws.ReadFrame(c.Conn) + if err != nil { + return 0, trace.Wrap(err) + } + + switch frame.Header.OpCode { + case ws.OpClose: + return 0, io.EOF + case ws.OpPing: + pong := ws.NewPongFrame(frame.Payload) + if err := c.writeFrame(pong); err != nil { + return 0, trace.Wrap(err) + } + case ws.OpBinary: + c.readBuffer = frame.Payload + return c.readLocked(b) + } + } +} + +func (c *websocketALPNClientConn) Write(b []byte) (int, error) { + frame := ws.NewFrame(ws.OpBinary, true, b) + return len(b), trace.Wrap(c.writeFrame(frame)) +} + +func (c *websocketALPNClientConn) writeFrame(frame ws.Frame) error { + c.writeMutex.Lock() + defer c.writeMutex.Unlock() + // By RFC standard, all client frames must be masked: + // https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 + frame.Header.Masked = true + return trace.Wrap(ws.WriteFrame(c.Conn, frame)) +} + +const ( + websocketHeaderKeyProtocol = "Sec-WebSocket-Protocol" + websocketHeaderKeyVersion = "Sec-WebSocket-Version" + websocketHeaderKeyChallengeKey = "Sec-WebSocket-Key" + websocketHeaderKeyAccept = "Sec-WebSocket-Accept" + + // websocketAcceptKeyMagicString is the magic string used for computing + // the accept key during WebSocket handshake. + // + // RFC reference: + // https://www.rfc-editor.org/rfc/rfc6455 + // + // Server side uses gorilla: + // https://github.com/gorilla/websocket/blob/dcea2f088ce10b1b0722c4eb995a4e145b5e9047/util.go#L17-L24 + websocketAcceptKeyMagicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +) diff --git a/api/client/apln_websocket_test.go b/api/client/apln_websocket_test.go new file mode 100644 index 0000000000000..71140df9118c1 --- /dev/null +++ b/api/client/apln_websocket_test.go @@ -0,0 +1,102 @@ +/* +Copyright 2024 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "net" + "testing" + "time" + + "github.com/gobwas/ws" + "github.com/stretchr/testify/require" +) + +func Test_websocketALPNClientConn(t *testing.T) { + clientRawConn, serverRawConn := net.Pipe() + t.Cleanup(func() { + clientRawConn.Close() + serverRawConn.Close() + }) + + clientConn := newWebSocketALPNClientConn(clientRawConn) + + t.Run("Read", func(t *testing.T) { + wait := make(chan struct{}, 1) + + // Send a ping and some text from server. + go func() { + require.NoError(t, ws.WriteFrame(serverRawConn, ws.NewPingFrame([]byte("foo")))) + frame, err := ws.ReadFrame(serverRawConn) + require.NoError(t, err) + require.Equal(t, ws.OpPong, frame.Header.OpCode) + require.NoError(t, ws.WriteFrame(serverRawConn, ws.NewBinaryFrame([]byte("hello client")))) + wait <- struct{}{} + }() + + mustReadWebsocketALPNClientConn(t, clientConn, "hello c") + mustReadWebsocketALPNClientConn(t, clientConn, "lient") + + <-wait + }) + + t.Run("Write", func(t *testing.T) { + wait := make(chan struct{}, 1) + text := "hello server" + + go func() { + n, err := clientConn.Write([]byte(text)) + require.NoError(t, err) + require.Equal(t, len(text), n) + wait <- struct{}{} + }() + + wantFrame := ws.NewBinaryFrame([]byte(text)) + wantFrame.Header.Masked = true + + actualFrame, err := ws.ReadFrame(serverRawConn) + require.NoError(t, err) + require.Equal(t, wantFrame, actualFrame) + + <-wait + }) +} + +func mustReadWebsocketALPNClientConn(t *testing.T, conn *websocketALPNClientConn, wantText string) { + t.Helper() + + actualTextChan := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + readBuff := make([]byte, len(wantText)) + _, err := conn.Read(readBuff) + if err != nil { + errChan <- err + } else { + actualTextChan <- string(readBuff) + } + }() + + select { + case actualText := <-actualTextChan: + require.Equal(t, wantText, actualText) + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + require.Fail(t, "timed out waiting for %v from Read", wantText) + } +} diff --git a/api/constants/constants.go b/api/constants/constants.go index 6234bb64b554c..b9ac64bcb6748 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -408,6 +408,8 @@ const ( MaxAssumeStartDuration = time.Hour * 24 * 7 ) +// Constants for TLS routing connection upgrade. See RFD for more details: +// https://github.com/gravitational/teleport/blob/master/rfd/0123-tls-routing-behind-layer7-lb.md const ( // WebAPIConnUpgrade is the HTTP web API to make the connection upgrade // call. @@ -431,6 +433,8 @@ const ( // long-lived connections alive as L7 LB usually ignores TCP keepalives and // has very short idle timeouts. WebAPIConnUpgradeTypeALPNPing = "alpn-ping" + // WebAPIConnUpgradeTypeWebSocket is the standard upgrade type for WebSocket. + WebAPIConnUpgradeTypeWebSocket = "websocket" // WebAPIConnUpgradeConnectionHeader is the standard header that controls // whether the network connection stays open after the current transaction // finishes. diff --git a/api/defaults/defaults.go b/api/defaults/defaults.go index 6a1b233fe3f44..a4a34da93dc22 100644 --- a/api/defaults/defaults.go +++ b/api/defaults/defaults.go @@ -176,6 +176,21 @@ const ( // =yes,=no // 0,=1 // - // TODO(greedy52) DELETE IN 15.0 + // TODO(greedy52) DELETE in ??. Note that this toggle was planned to be + // deleted in 15.0 when the feature exits preview. However, many users + // still rely on this manual toggle as IsALPNConnUpgradeRequired cannot + // detect many situations where connection upgrade is required. This can be + // deleted once IsALPNConnUpgradeRequired is improved. TLSRoutingConnUpgradeEnvVar = "TELEPORT_TLS_ROUTING_CONN_UPGRADE" + + // TLSRoutingConnUpgradeModeEnvVar overwrites the upgrade mode used when + // performing connection upgrades by the clients: + // - "websocket": client only requests "websocket" in the "Upgrade" header. + // - "legacy": client only requests legacy "alpn"/"alpn-ping" in the + // "Upgrade" header. + // - "", "default", or any other value than above: client sends both + // WebSocket and legacy in the "Upgrade" header. + // + // TODO(greedy52) DELETE in 17.0 + TLSRoutingConnUpgradeModeEnvVar = "TELEPORT_TLS_ROUTING_CONN_UPGRADE_MODE" ) diff --git a/api/go.mod b/api/go.mod index 3b1848be2eb0b..e12120765db8f 100644 --- a/api/go.mod +++ b/api/go.mod @@ -5,6 +5,7 @@ go 1.21 require ( github.com/coreos/go-semver v0.3.1 github.com/go-piv/piv-go v1.11.0 + github.com/gobwas/ws v1.3.0 github.com/gogo/protobuf v1.3.2 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.5.0 @@ -39,6 +40,8 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.3.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect diff --git a/api/go.sum b/api/go.sum index 226ae53f90ed9..5080dfc307e0f 100644 --- a/api/go.sum +++ b/api/go.sum @@ -47,6 +47,12 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-piv/piv-go v1.11.0 h1:5vAaCdRTFSIW4PeqMbnsDlUZ7odMYWnHBDGdmtU/Zhg= github.com/go-piv/piv-go v1.11.0/go.mod h1:NZ2zmjVkfFaL/CF8cVQ/pXdXtuj110zEKGdJM6fJZZM= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0= +github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= diff --git a/go.mod b/go.mod index a6f686b77b6db..91fc66de4906a 100644 --- a/go.mod +++ b/go.mod @@ -89,6 +89,7 @@ require ( github.com/go-piv/piv-go v1.11.0 github.com/go-resty/resty/v2 v2.11.0 github.com/go-webauthn/webauthn v0.10.0 + github.com/gobwas/ws v1.3.0 github.com/gocql/gocql v1.6.0 github.com/gofrs/flock v0.8.1 github.com/gogo/protobuf v1.3.2 // replaced @@ -328,6 +329,8 @@ require ( github.com/go-webauthn/x v0.1.6 // indirect github.com/gobuffalo/flect v1.0.2 // indirect github.com/gobwas/glob v0.2.3 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect github.com/golang-jwt/jwt/v5 v5.2.0 // indirect diff --git a/go.sum b/go.sum index 041da60adf419..c782811a7ae0d 100644 --- a/go.sum +++ b/go.sum @@ -651,6 +651,12 @@ github.com/gobuffalo/packr/v2 v2.8.3 h1:xE1yzvnO56cUC0sTpKR3DIbxZgB54AftTFMhB2XE github.com/gobuffalo/packr/v2 v2.8.3/go.mod h1:0SahksCVcx4IMnigTjiFuyldmTrdTctXsOdiU5KwbKc= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0= +github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= diff --git a/lib/web/conn_upgrade.go b/lib/web/conn_upgrade.go index 7b2bd583c7c7a..3cb10f64b3ff5 100644 --- a/lib/web/conn_upgrade.go +++ b/lib/web/conn_upgrade.go @@ -23,11 +23,16 @@ import ( "io" "net" "net/http" + "slices" + "sync" "time" + "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" + "github.com/sirupsen/logrus" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/utils/pingconn" "github.com/gravitational/teleport/lib/defaults" @@ -41,16 +46,18 @@ func (h *Handler) selectConnectionUpgrade(r *http.Request) (string, ConnectionHa r.Header.Values(constants.WebAPIConnUpgradeTeleportHeader), r.Header.Values(constants.WebAPIConnUpgradeHeader)..., ) - for _, upgradeType := range upgrades { - switch upgradeType { - case constants.WebAPIConnUpgradeTypeALPNPing: - return upgradeType, h.upgradeALPNWithPing, nil - case constants.WebAPIConnUpgradeTypeALPN: - return upgradeType, h.upgradeALPN, nil - } - } - return "", nil, trace.NotFound("unsupported upgrade types: %v", upgrades) + // Prefer WebSocket when multiple types are provided. + switch { + case slices.Contains(upgrades, constants.WebAPIConnUpgradeTypeWebSocket): + return constants.WebAPIConnUpgradeTypeWebSocket, h.upgradeALPN, nil + case slices.Contains(upgrades, constants.WebAPIConnUpgradeTypeALPNPing): + return constants.WebAPIConnUpgradeTypeALPNPing, h.upgradeALPNWithPing, nil + case slices.Contains(upgrades, constants.WebAPIConnUpgradeTypeALPN): + return constants.WebAPIConnUpgradeTypeALPN, h.upgradeALPN, nil + default: + return "", nil, trace.NotFound("unsupported upgrade types: %v", upgrades) + } } // connectionUpgrade handles connection upgrades. @@ -60,6 +67,11 @@ func (h *Handler) connectionUpgrade(w http.ResponseWriter, r *http.Request, p ht return nil, trace.Wrap(err) } + if upgradeType == constants.WebAPIConnUpgradeTypeWebSocket { + return h.upgradeALPNWebSocket(w, r, upgradeHandler) + } + + // TODO(greedy52) DELETE legacy upgrade in 17.0. hj, ok := w.(http.Hijacker) if !ok { return nil, trace.BadParameter("failed to hijack connection") @@ -84,6 +96,48 @@ func (h *Handler) connectionUpgrade(w http.ResponseWriter, r *http.Request, p ht return nil, nil } +func (h *Handler) upgradeALPNWebSocket(w http.ResponseWriter, r *http.Request, upgradeHandler ConnectionHandler) (interface{}, error) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: []string{ + constants.WebAPIConnUpgradeTypeALPN, + constants.WebAPIConnUpgradeTypeALPNPing, + }, + } + wsConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.log.WithError(err).Debug("Failed to upgrade weboscket.") + return nil, trace.Wrap(err) + } + defer wsConn.Close() + + logrus.WithField("protocol", wsConn.Subprotocol()).Trace("Received WebSocket upgrade.") + + conn := newWebSocketALPNServerConn(wsConn) + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + switch wsConn.Subprotocol() { + case constants.WebAPIConnUpgradeTypeALPNPing: + // Starts native WebSocket ping for "alpn-ping". + go h.startPing(ctx, conn) + case constants.WebAPIConnUpgradeTypeALPN: + // Nothing to do + default: + // Just close the connection. Upgrader hijacks the connection so no + // point returning an error. + h.log.WithField("client-protocols", websocket.Subprotocols(r)). + Debug("Unknown or empty WebSocket subprotocol.") + return nil, nil + } + + if err := upgradeHandler(ctx, conn); err != nil && !utils.IsOKNetworkError(err) { + // Upgrader hijacks the connection so no point returning an error here. + h.log.WithError(err).WithField("protocol", wsConn.Subprotocol()).Errorf("Failed to handle WebSocket upgrade request.") + } + return nil, nil +} + func (h *Handler) upgradeALPN(ctx context.Context, conn net.Conn) error { if h.cfg.ALPNHandler == nil { return trace.BadParameter("missing ALPNHandler") @@ -113,7 +167,11 @@ func (h *Handler) upgradeALPNWithPing(ctx context.Context, conn net.Conn) error return h.upgradeALPN(ctx, pingConn) } -func (h *Handler) startPing(ctx context.Context, pingConn *pingconn.PingConn) { +type pingWriter interface { + WritePing() error +} + +func (h *Handler) startPing(ctx context.Context, pingConn pingWriter) { ticker := time.NewTicker(defaults.ProxyPingInterval) defer ticker.Stop() for { @@ -176,3 +234,95 @@ func (conn *waitConn) Close() error { conn.cancel() return trace.Wrap(err) } + +type websocketALPNServerConn struct { + *websocket.Conn + readBuffer []byte + readError error + readMutex sync.Mutex + writeMutex sync.Mutex +} + +func newWebSocketALPNServerConn(wsConn *websocket.Conn) *websocketALPNServerConn { + return &websocketALPNServerConn{ + Conn: wsConn, + } +} + +func (c *websocketALPNServerConn) convertError(err error) error { + if isOKWebsocketCloseError(err) { + return io.EOF + } + return err +} + +func (c *websocketALPNServerConn) Read(b []byte) (int, error) { + c.readMutex.Lock() + defer c.readMutex.Unlock() + + n, err := c.readLocked(b) + return n, trace.Wrap(err) +} + +func (c *websocketALPNServerConn) readLocked(b []byte) (int, error) { + // Stop reading if any previous read err. + if c.readError != nil { + return 0, trace.Wrap(c.readError) + } + + if len(c.readBuffer) > 0 { + n := copy(b, c.readBuffer) + if n < len(c.readBuffer) { + c.readBuffer = c.readBuffer[n:] + } else { + c.readBuffer = nil + } + return n, nil + } + + for { + messageType, data, err := c.Conn.ReadMessage() + if err != nil { + c.readError = c.convertError(err) + return 0, trace.Wrap(c.readError) + } + + switch messageType { + case websocket.CloseMessage: + return 0, nil + case websocket.BinaryMessage: + c.readBuffer = data + return c.readLocked(b) + case websocket.PongMessage: + // Receives Pong as response to Ping. Nothing to do. + } + } +} + +func (c *websocketALPNServerConn) Write(b []byte) (n int, err error) { + c.writeMutex.Lock() + defer c.writeMutex.Unlock() + if err := c.Conn.WriteMessage(websocket.BinaryMessage, b); err != nil { + return 0, trace.Wrap(c.convertError(err)) + } + return len(b), nil +} + +func (c *websocketALPNServerConn) WritePing() error { + c.writeMutex.Lock() + defer c.writeMutex.Unlock() + + // Send some identifier with Ping. Note that we are not validating the Pong + // response. + err := c.Conn.WriteMessage(websocket.PingMessage, []byte(teleport.ComponentTeleport)) + return trace.Wrap(c.convertError(err)) +} + +func (c *websocketALPNServerConn) SetDeadline(t time.Time) error { + c.writeMutex.Lock() + defer c.writeMutex.Unlock() + return trace.NewAggregate( + c.Conn.SetReadDeadline(t), + c.Conn.SetWriteDeadline(t), + ) +} diff --git a/lib/web/conn_upgrade_test.go b/lib/web/conn_upgrade_test.go index d77edfdf21fdf..f5da9e8cafbfc 100644 --- a/lib/web/conn_upgrade_test.go +++ b/lib/web/conn_upgrade_test.go @@ -29,6 +29,7 @@ import ( "testing" "time" + "github.com/gobwas/ws" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -87,32 +88,54 @@ func TestHandlerConnectionUpgrade(t *testing.T) { tests := []struct { name string - inputUpgradeHeaderKey string - inputUpgradeType string + inputRequest *http.Request + expectUpgradeType string checkHandlerError func(error) bool checkClientConnString func(*testing.T, net.Conn, string) }{ { name: "unsupported type", - inputUpgradeType: "unsupported-protocol", + inputRequest: makeConnUpgradeRequest(t, "", "unsupported-protocol", expectedIP), checkHandlerError: trace.IsNotFound, }, { - name: "upgraded to ALPN", - inputUpgradeType: constants.WebAPIConnUpgradeTypeALPN, + // TODO(greedy52) DELETE in 17.0 + name: "upgraded to ALPN (legacy)", + inputRequest: makeConnUpgradeRequest(t, "", constants.WebAPIConnUpgradeTypeALPN, expectedIP), + expectUpgradeType: constants.WebAPIConnUpgradeTypeALPN, checkClientConnString: mustReadClientConnString, }, { - name: "upgraded to ALPN with Ping", - inputUpgradeType: constants.WebAPIConnUpgradeTypeALPNPing, + // TODO(greedy52) DELETE in 17.0 + name: "upgraded to ALPN with Ping (legacy)", + inputRequest: makeConnUpgradeRequest(t, "", constants.WebAPIConnUpgradeTypeALPNPing, expectedIP), + expectUpgradeType: constants.WebAPIConnUpgradeTypeALPNPing, checkClientConnString: mustReadClientPingConnString, }, { name: "upgraded to ALPN with Teleport-specific header", - inputUpgradeHeaderKey: constants.WebAPIConnUpgradeTeleportHeader, - inputUpgradeType: constants.WebAPIConnUpgradeTypeALPN, + inputRequest: makeConnUpgradeRequest(t, constants.WebAPIConnUpgradeTeleportHeader, constants.WebAPIConnUpgradeTypeALPN, expectedIP), + expectUpgradeType: constants.WebAPIConnUpgradeTypeALPN, checkClientConnString: mustReadClientConnString, }, + { + name: "upgraded to WebSocket", + inputRequest: makeConnUpgradeWebSocketRequest(t, constants.WebAPIConnUpgradeTypeALPN, expectedIP), + expectUpgradeType: constants.WebAPIConnUpgradeTypeWebSocket, + checkClientConnString: mustReadClientWebSocketConnString, + }, + { + name: "upgraded to WebSocket with ping", + inputRequest: makeConnUpgradeWebSocketRequest(t, constants.WebAPIConnUpgradeTypeALPNPing, expectedIP), + expectUpgradeType: constants.WebAPIConnUpgradeTypeWebSocket, + checkClientConnString: mustReadClientWebSocketConnString, + }, + { + name: "unsupported WebSocket sub-protocol", + inputRequest: makeConnUpgradeWebSocketRequest(t, "unsupported-protocol", expectedIP), + expectUpgradeType: constants.WebAPIConnUpgradeTypeWebSocket, + checkClientConnString: mustReadClientWebSocketClosed, + }, } for _, test := range tests { @@ -123,7 +146,6 @@ func TestHandlerConnectionUpgrade(t *testing.T) { // serverConn will be hijacked. w := newResponseWriterHijacker(nil, serverConn) - r := makeConnUpgradeRequest(t, test.inputUpgradeHeaderKey, test.inputUpgradeType, expectedIP) // Serve the handler with XForwardedFor middleware to set IPs. handlerErrChan := make(chan error, 1) @@ -133,7 +155,7 @@ func TestHandlerConnectionUpgrade(t *testing.T) { handlerErrChan <- err }) - NewXForwardedForMiddleware(connUpgradeHandler).ServeHTTP(w, r) + NewXForwardedForMiddleware(connUpgradeHandler).ServeHTTP(w, test.inputRequest) }() select { @@ -146,7 +168,7 @@ func TestHandlerConnectionUpgrade(t *testing.T) { } case <-w.hijackedCtx.Done(): - mustReadSwitchProtocolsResponse(t, r, clientConn, test.inputUpgradeType) + mustReadSwitchProtocolsResponse(t, test.inputRequest, clientConn, test.expectUpgradeType) test.checkClientConnString(t, clientConn, expectedPayload) case <-time.After(5 * time.Second): @@ -170,6 +192,25 @@ func makeConnUpgradeRequest(t *testing.T, upgradeHeaderKey, upgradeType, xForwar return r } +func makeConnUpgradeWebSocketRequest(t *testing.T, alpnUpgradeType, xForwardedFor string) *http.Request { + t.Helper() + + r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil) + require.NoError(t, err) + + // Append "legacy" upgrade. This tests whether the handler prefers "websocket". + r.Header.Add(constants.WebAPIConnUpgradeHeader, alpnUpgradeType) + r.Header.Add("X-Forwarded-For", xForwardedFor) + + // Add WebSocket headers + r.Header.Add(constants.WebAPIConnUpgradeHeader, "websocket") + r.Header.Add(constants.WebAPIConnUpgradeConnectionHeader, "upgrade") + r.Header.Set("Sec-Websocket-Protocol", alpnUpgradeType) + r.Header.Set("Sec-Websocket-Version", "13") + r.Header.Set("Sec-Websocket-Key", "MTIzNDU2Nzg5MDEyMzQ1Ng==") + return r +} + func mustReadSwitchProtocolsResponse(t *testing.T, r *http.Request, clientConn net.Conn, upgradeType string) { t.Helper() @@ -180,8 +221,10 @@ func mustReadSwitchProtocolsResponse(t *testing.T, r *http.Request, clientConn n io.Copy(io.Discard, resp.Body) _ = resp.Body.Close() + if upgradeType != "websocket" { + require.Equal(t, upgradeType, resp.Header.Get(constants.WebAPIConnUpgradeTeleportHeader)) + } require.Equal(t, upgradeType, resp.Header.Get(constants.WebAPIConnUpgradeHeader)) - require.Equal(t, upgradeType, resp.Header.Get(constants.WebAPIConnUpgradeTeleportHeader)) require.Equal(t, constants.WebAPIConnUpgradeConnectionType, resp.Header.Get(constants.WebAPIConnUpgradeConnectionHeader)) require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) } @@ -200,6 +243,32 @@ func mustReadClientPingConnString(t *testing.T, clientConn net.Conn, expectedPay mustReadClientConnString(t, pingconn.New(clientConn), expectedPayload) } +func mustReadClientWebSocketConnString(t *testing.T, clientConn net.Conn, expectedPayload string) { + t.Helper() + + for { + frame, err := ws.ReadFrame(clientConn) + require.NoError(t, err) + + switch frame.Header.OpCode { + case ws.OpBinary: + require.Equal(t, expectedPayload, string(frame.Payload)) + return + case ws.OpPing: + continue + default: + require.Fail(t, "does not expect WebSocket frame %v", frame) + } + } +} + +func mustReadClientWebSocketClosed(t *testing.T, clientConn net.Conn, expectedPayload string) { + t.Helper() + + _, err := ws.ReadFrame(clientConn) + require.True(t, utils.IsOKNetworkError(err)) +} + // responseWriterHijacker is a mock http.ResponseWriter that also serves a // net.Conn for http.Hijacker. type responseWriterHijacker struct { @@ -226,5 +295,9 @@ func newResponseWriterHijacker(w http.ResponseWriter, conn net.Conn) *responseWr func (h *responseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { h.hijackedCtxCancel() - return h.conn, nil, nil + // buf is used by gorilla websocket upgrader. + reader := bufio.NewReaderSize(nil, 10) + writer := bufio.NewWriter(h.conn) + buf := bufio.NewReadWriter(reader, writer) + return h.conn, buf, nil } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 3c40c7a18c15c..e596de09f242f 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -1095,6 +1095,14 @@ func (t *WSStream) writeError(msg string) { } } +func isOKWebsocketCloseError(err error) bool { + return websocket.IsCloseError(err, + websocket.CloseAbnormalClosure, + websocket.CloseGoingAway, + websocket.CloseNormalClosure, + ) +} + func (t *WSStream) processMessages(ctx context.Context) { defer func() { t.close() @@ -1108,8 +1116,7 @@ func (t *WSStream) processMessages(ctx context.Context) { default: ty, bytes, err := t.ws.ReadMessage() if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || - websocket.IsCloseError(err, websocket.CloseAbnormalClosure, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || isOKWebsocketCloseError(err) { return }