diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 2d1e5f64de..11f3fab684 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -2345,6 +2345,44 @@ func (e *Environment) GraphQLWebsocketDialWithRetry(header http.Header, query ur return nil, nil, err } +// GraphQLWebsocketDialWithCompressionRetry is like GraphQLWebsocketDialWithRetry but enables +// permessage-deflate compression negotiation on the client side. +func (e *Environment) GraphQLWebsocketDialWithCompressionRetry(header http.Header, query url.Values) (*websocket.Conn, *http.Response, error) { + dialer := websocket.Dialer{ + Subprotocols: []string{"graphql-transport-ws"}, + EnableCompression: true, + } + + waitBetweenRetriesInMs := rand.Intn(10) + timeToSleep := time.Duration(waitBetweenRetriesInMs) * time.Millisecond + + var err error + + for i := 0; i <= maxSocketRetries; i++ { + urlStr := e.GraphQLWebSocketSubscriptionURL() + if query != nil { + urlStr += "?" + query.Encode() + } + conn, resp, err := dialer.Dial(urlStr, header) + + if resp != nil && err == nil { + return conn, resp, err + } + + if errors.Is(err, websocket.ErrBadHandshake) { + return conn, resp, err + } + + // Make sure that on the final attempt we won't wait + if i != maxSocketRetries { + time.Sleep(timeToSleep) + timeToSleep *= 2 + } + } + + return nil, nil, err +} + func (e *Environment) InitGraphQLWebSocketConnection(header http.Header, query url.Values, initialPayload json.RawMessage) *websocket.Conn { conn, _, err := e.GraphQLWebsocketDialWithRetry(header, query) require.NoError(e.t, err) @@ -2362,6 +2400,25 @@ func (e *Environment) InitGraphQLWebSocketConnection(header http.Header, query u return conn } +// InitGraphQLWebSocketConnectionWithCompression initializes a WebSocket connection with +// permessage-deflate compression negotiation enabled on the client side. +func (e *Environment) InitGraphQLWebSocketConnectionWithCompression(header http.Header, query url.Values, initialPayload json.RawMessage) (*websocket.Conn, *http.Response) { + conn, resp, err := e.GraphQLWebsocketDialWithCompressionRetry(header, query) + require.NoError(e.t, err) + e.t.Cleanup(func() { + _ = conn.Close() + }) + err = conn.WriteJSON(WebSocketMessage{ + Type: "connection_init", + Payload: initialPayload, + }) + require.NoError(e.t, err) + var ack WebSocketMessage + require.NoError(e.t, ReadAndCheckJSON(e.t, conn, &ack)) + require.Equal(e.t, "connection_ack", ack.Type) + return conn, resp +} + func (e *Environment) GraphQLSubscriptionOverSSE(ctx context.Context, request GraphQLRequest, handler func(data string)) { req, err := e.newGraphQLRequestOverGET(e.GraphQLRequestURL(), request) if err != nil { diff --git a/router-tests/websocket_test.go b/router-tests/websocket_test.go index 169b9fbc05..f824aa1c35 100644 --- a/router-tests/websocket_test.go +++ b/router-tests/websocket_test.go @@ -2392,6 +2392,127 @@ func TestWebSockets(t *testing.T) { }) }) + t.Run("compression enabled on server and client", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Use the compression-enabled dialer + conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil) + + // Check that compression was negotiated via the Sec-WebSocket-Extensions header + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + + // Verify the connection works correctly with compression + err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"{ employees { id } }"}`), + }) + require.NoError(t, err) + + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "next", res.Type) + require.Equal(t, "1", res.ID) + require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload)) + + var complete testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &complete) + require.NoError(t, err) + require.Equal(t, "complete", complete.Type) + require.Equal(t, "1", complete.ID) + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) + + t.Run("compression disabled on server but enabled on client", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = false + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Use the compression-enabled dialer, but server has compression disabled + conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil) + + // Check that compression was NOT negotiated + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.NotContains(t, extensions, "permessage-deflate", "Expected compression NOT to be negotiated when disabled on server") + + // Verify the connection still works correctly without compression + err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"{ employees { id } }"}`), + }) + require.NoError(t, err) + + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "next", res.Type) + require.Equal(t, "1", res.ID) + require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload)) + + var complete testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &complete) + require.NoError(t, err) + require.Equal(t, "complete", complete.Type) + require.Equal(t, "1", complete.ID) + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) + + t.Run("compression with custom level", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 9 // Best compression + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil) + + // Check that compression was negotiated + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + + // Run a subscription query to verify it works with max compression + err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"{ employees { id details { forename surname } } }"}`), + }) + require.NoError(t, err) + + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "next", res.Type) + require.Equal(t, "1", res.ID) + require.Contains(t, string(res.Payload), "forename") + require.Contains(t, string(res.Payload), "surname") + + var complete testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &complete) + require.NoError(t, err) + require.Equal(t, "complete", complete.Type) + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) + } func TestFlakyWebSockets(t *testing.T) { diff --git a/router/core/websocket.go b/router/core/websocket.go index 94aa65f75e..0534f7a4b9 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -2,10 +2,12 @@ package core import ( "bytes" + "compress/flate" "context" "encoding/json" "errors" "fmt" + "io" "net" "net/http" "regexp" @@ -17,6 +19,7 @@ import ( "github.com/buger/jsonparser" "github.com/go-chi/chi/v5/middleware" "github.com/gobwas/ws" + "github.com/gobwas/ws/wsflate" "github.com/gobwas/ws/wsutil" "github.com/gorilla/websocket" "github.com/tidwall/gjson" @@ -87,6 +90,13 @@ func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions disableVariablesRemapping: opts.DisableVariablesRemapping, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, } + if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.Compression.Enabled { + handler.compressionEnabled = true + handler.compressionLevel = opts.WebSocketConfiguration.Compression.Level + if handler.compressionLevel < 1 || handler.compressionLevel > 9 { + handler.compressionLevel = flate.DefaultCompression + } + } if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.AbsintheProtocol.Enabled { handler.absintheHandlerEnabled = true handler.absintheHandlerPath = opts.WebSocketConfiguration.AbsintheProtocol.HandlerPath @@ -156,13 +166,19 @@ type wsConnectionWrapper struct { mu sync.Mutex readTimeout time.Duration writeTimeout time.Duration + + // Compression fields + compressionEnabled bool + compressionLevel int } -func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration) *wsConnectionWrapper { +func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration, compressionEnabled bool, compressionLevel int) *wsConnectionWrapper { return &wsConnectionWrapper{ - conn: conn, - readTimeout: readTimeout, - writeTimeout: writeTimeout, + conn: conn, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + compressionEnabled: compressionEnabled, + compressionLevel: compressionLevel, } } @@ -175,9 +191,51 @@ func (c *wsConnectionWrapper) ReadJSON(v any) error { } } - text, err := wsutil.ReadClientText(c.conn) - if err != nil { - return err + var text []byte + var err error + + if c.compressionEnabled { + // Read frames directly and handle compression + controlHandler := wsutil.ControlFrameHandler(c.conn, ws.StateServerSide) + for { + frame, err := ws.ReadFrame(c.conn) + if err != nil { + return err + } + + // Unmask client frames + if frame.Header.Masked { + ws.Cipher(frame.Payload, frame.Header.Mask, 0) + } + + if frame.Header.OpCode.IsControl() { + if err := controlHandler(frame.Header, bytes.NewReader(frame.Payload)); err != nil { + return err + } + continue + } + + if frame.Header.OpCode == ws.OpText || frame.Header.OpCode == ws.OpBinary { + // Check if frame is compressed (RSV1 bit set) + isCompressed, err := wsflate.IsCompressed(frame.Header) + if err != nil { + return err + } + if isCompressed { + frame, err = wsflate.DecompressFrame(frame) + if err != nil { + return err + } + } + text = frame.Payload + break + } + } + } else { + text, err = wsutil.ReadClientText(c.conn) + if err != nil { + return err + } } return json.Unmarshal(text, v) @@ -195,6 +253,10 @@ func (c *wsConnectionWrapper) WriteText(text string) error { } } + if c.compressionEnabled { + return c.writeCompressed([]byte(text)) + } + return wsutil.WriteServerText(c.conn, []byte(text)) } @@ -213,9 +275,32 @@ func (c *wsConnectionWrapper) WriteJSON(v any) error { } } + if c.compressionEnabled { + return c.writeCompressed(data) + } + return wsutil.WriteServerText(c.conn, data) } +// writeCompressed writes data with compression. Must be called with the mutex held. +func (c *wsConnectionWrapper) writeCompressed(data []byte) error { + var buf bytes.Buffer + writer := wsflate.NewWriter(&buf, func(w io.Writer) wsflate.Compressor { + fw, _ := flate.NewWriter(w, c.compressionLevel) + return fw + }) + if _, err := writer.Write(data); err != nil { + return err + } + if err := writer.Flush(); err != nil { + return err + } + + frame := ws.NewFrame(ws.OpText, true, buf.Bytes()) + frame.Header.Rsv = ws.Rsv(true, false, false) // Set RSV1 bit for compression + return ws.WriteFrame(c.conn, frame) +} + func (c *wsConnectionWrapper) WriteCloseFrame(code ws.StatusCode, reason string) error { c.mu.Lock() defer c.mu.Unlock() @@ -267,6 +352,9 @@ type WebsocketHandler struct { disableVariablesRemapping bool apolloCompatibilityFlags config.ApolloCompatibilityFlags + + compressionEnabled bool + compressionLevel int } func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.Request) { @@ -309,7 +397,29 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R return false }, } + + // Configure permessage-deflate compression if enabled + var compressionNegotiated bool + var ext wsflate.Extension + if h.compressionEnabled { + ext = wsflate.Extension{ + Parameters: wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + }, + } + upgrader.Negotiate = ext.Negotiate + } + c, _, _, err := upgrader.Upgrade(r, w) + + // Check if compression was negotiated + if h.compressionEnabled && err == nil { + if _, accepted := ext.Accepted(); accepted { + compressionNegotiated = true + } + } + if err != nil { requestLogger.Warn("Websocket upgrade", zap.Error(err)) _ = c.Close() @@ -325,7 +435,7 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R // After successful upgrade, we can't write to the response writer anymore // because it's hijacked by the websocket connection - conn := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout) + conn := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout, compressionNegotiated, h.compressionLevel) protocol, err := wsproto.NewProtocol(subProtocol, conn) if err != nil { requestLogger.Error("Create websocket protocol", zap.Error(err)) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 4e430627c0..3f343f9d8e 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -722,6 +722,16 @@ type WebSocketConfiguration struct { Authentication WebSocketAuthenticationConfiguration `yaml:"authentication,omitempty"` // SetClientInfoFromInitialPayload configuration for the WebSocket Connection ClientInfoFromInitialPayload WebSocketClientInfoFromInitialPayloadConfiguration `yaml:"client_info_from_initial_payload"` + // Compression configuration for WebSocket per-message compression (permessage-deflate) + Compression WebSocketCompressionConfiguration `yaml:"compression,omitempty"` +} + +// WebSocketCompressionConfiguration configures permessage-deflate compression for WebSocket connections +type WebSocketCompressionConfiguration struct { + // Enabled enables permessage-deflate compression for WebSocket connections + Enabled bool `yaml:"enabled" envDefault:"false" env:"WEBSOCKETS_COMPRESSION_ENABLED"` + // Level is the compression level (1-9, where 1 is fastest and 9 is best compression) + Level int `yaml:"level" envDefault:"6" env:"WEBSOCKETS_COMPRESSION_LEVEL"` } type WebSocketClientInfoFromInitialPayloadConfiguration struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 0b5fc698c3..e258486b70 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -560,6 +560,25 @@ } } } + }, + "compression": { + "type": "object", + "description": "Configuration for WebSocket per-message compression (permessage-deflate extension).", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false, + "description": "Enable permessage-deflate compression for WebSocket connections. When enabled, the server will negotiate compression with clients that support it. The default value is false." + }, + "level": { + "type": "integer", + "default": 6, + "minimum": 1, + "maximum": 9, + "description": "The compression level (1-9). Level 1 is fastest with least compression, level 9 provides best compression but is slowest. The default value is 6." + } + } } } }, diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index 515d08f45d..1e025d74d3 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -404,6 +404,9 @@ websocket: export_token: enabled: true header_key: 'Authorization' + compression: + enabled: true + level: 6 storage_providers: file_system: diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index ce91d417e7..925899d240 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -438,6 +438,10 @@ "NameTargetHeader": "graphql-client-name", "VersionTargetHeader": "graphql-client-version" } + }, + "Compression": { + "Enabled": false, + "Level": 6 } }, "SubgraphErrorPropagation": { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 02064e09d5..c502a20bb3 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -818,6 +818,10 @@ "NameTargetHeader": "graphql-client-name", "VersionTargetHeader": "graphql-client-version" } + }, + "Compression": { + "Enabled": true, + "Level": 6 } }, "SubgraphErrorPropagation": {