diff --git a/router-tests/go.mod b/router-tests/go.mod index 28d13da9e5..1c5da767d7 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -7,6 +7,7 @@ require ( github.com/MicahParks/jwkset v0.11.0 github.com/buger/jsonparser v1.1.1 github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 + github.com/coder/websocket v1.8.13 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 @@ -58,7 +59,6 @@ require ( github.com/cep21/circuit/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cilium/ebpf v0.16.0 // indirect - github.com/coder/websocket v1.8.13 // indirect github.com/containerd/cgroups/v3 v3.0.2 // indirect github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 6d1d884ed3..875850d370 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -2357,11 +2357,7 @@ type GraphQLError struct { const maxSocketRetries = 5 -func (e *Environment) GraphQLWebsocketDialWithRetry(header http.Header, query url.Values) (*websocket.Conn, *http.Response, error) { - dialer := websocket.Dialer{ - Subprotocols: []string{"graphql-transport-ws"}, - } - +func (e *Environment) graphQLWebsocketDialWithRetry(dialer websocket.Dialer, header http.Header, query url.Values) (*websocket.Conn, *http.Response, error) { waitBetweenRetriesInMs := rand.Intn(10) timeToSleep := time.Duration(waitBetweenRetriesInMs) * time.Millisecond @@ -2392,6 +2388,21 @@ func (e *Environment) GraphQLWebsocketDialWithRetry(header http.Header, query ur return nil, nil, err } +func (e *Environment) GraphQLWebsocketDialWithRetry(header http.Header, query url.Values) (*websocket.Conn, *http.Response, error) { + return e.graphQLWebsocketDialWithRetry(websocket.Dialer{ + Subprotocols: []string{"graphql-transport-ws"}, + }, header, query) +} + +// GraphQLWebsocketDialWithCompressionRetry is like GraphQLWebsocketDialWithRetry but enables +// permessage-deflate compression negotiation on the client side. +func (e *Environment) GraphQLWebsocketDialWithCompressionRetry(header http.Header, query url.Values) (*websocket.Conn, *http.Response, error) { + return e.graphQLWebsocketDialWithRetry(websocket.Dialer{ + Subprotocols: []string{"graphql-transport-ws"}, + EnableCompression: true, + }, header, query) +} + func (e *Environment) InitGraphQLWebSocketConnection(header http.Header, query url.Values, initialPayload json.RawMessage) *websocket.Conn { conn, _, err := e.GraphQLWebsocketDialWithRetry(header, query) require.NoError(e.t, err) @@ -2409,6 +2420,25 @@ func (e *Environment) InitGraphQLWebSocketConnection(header http.Header, query u return conn } +// InitGraphQLWebSocketConnectionWithCompression initializes a WebSocket connection with +// permessage-deflate compression negotiation enabled on the client side. +func (e *Environment) InitGraphQLWebSocketConnectionWithCompression(header http.Header, query url.Values, initialPayload json.RawMessage) (*websocket.Conn, *http.Response) { + conn, resp, err := e.GraphQLWebsocketDialWithCompressionRetry(header, query) + require.NoError(e.t, err) + e.t.Cleanup(func() { + _ = conn.Close() + }) + err = conn.WriteJSON(WebSocketMessage{ + Type: "connection_init", + Payload: initialPayload, + }) + require.NoError(e.t, err) + var ack WebSocketMessage + require.NoError(e.t, ReadAndCheckJSON(e.t, conn, &ack)) + require.Equal(e.t, "connection_ack", ack.Type) + return conn, resp +} + func (e *Environment) GraphQLSubscriptionOverSSE(ctx context.Context, request GraphQLRequest, handler func(data string)) { req, err := e.newGraphQLRequestOverGET(e.GraphQLRequestURL(), request) if err != nil { diff --git a/router-tests/websocket_test.go b/router-tests/websocket_test.go index 169b9fbc05..d82cc3f31d 100644 --- a/router-tests/websocket_test.go +++ b/router-tests/websocket_test.go @@ -1,6 +1,7 @@ package integration import ( + "context" "crypto/sha256" "encoding/json" "errors" @@ -20,6 +21,7 @@ import ( "go.uber.org/zap/zapcore" "github.com/buger/jsonparser" + coderws "github.com/coder/websocket" "github.com/gorilla/websocket" "github.com/hasura/go-graphql-client" "github.com/hasura/go-graphql-client/pkg/jsonutil" @@ -2392,6 +2394,397 @@ func TestWebSockets(t *testing.T) { }) }) + t.Run("compression enabled on server and client", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Use the compression-enabled dialer + conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil) + + // Check that compression was negotiated via the Sec-WebSocket-Extensions header. + // gorilla always requests server_no_context_takeover and client_no_context_takeover, + // so the server must mirror both flags in the response. + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + require.Contains(t, extensions, "server_no_context_takeover", "Expected server to mirror server_no_context_takeover") + require.Contains(t, extensions, "client_no_context_takeover", "Expected server to mirror client_no_context_takeover") + + // Verify the connection works correctly with compression + err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"{ employees { id } }"}`), + }) + require.NoError(t, err) + + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "next", res.Type) + require.Equal(t, "1", res.ID) + require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload)) + + var complete testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &complete) + require.NoError(t, err) + require.Equal(t, "complete", complete.Type) + require.Equal(t, "1", complete.ID) + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) + + t.Run("compression negotiation does not add no_context_takeover when not requested", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + wsURL, err := url.Parse(xEnv.GraphQLWebSocketSubscriptionURL()) + require.NoError(t, err) + + switch wsURL.Scheme { + case "ws": + wsURL.Scheme = "http" + case "wss": + wsURL.Scheme = "https" + default: + t.Fatalf("unexpected websocket scheme: %s", wsURL.Scheme) + } + + req, err := http.NewRequest(http.MethodGet, wsURL.String(), nil) + require.NoError(t, err) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Sec-WebSocket-Protocol", "graphql-transport-ws") + req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate") + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + require.NotContains(t, extensions, "server_no_context_takeover", "Expected server not to add server_no_context_takeover when not requested") + require.NotContains(t, extensions, "client_no_context_takeover", "Expected server not to add client_no_context_takeover when not requested") + }) + }) + + t.Run("compression disabled on server but enabled on client", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = false + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Use the compression-enabled dialer, but server has compression disabled + conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil) + + // Check that compression was NOT negotiated + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.NotContains(t, extensions, "permessage-deflate", "Expected compression NOT to be negotiated when disabled on server") + + // Verify the connection still works correctly without compression + err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"{ employees { id } }"}`), + }) + require.NoError(t, err) + + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "next", res.Type) + require.Equal(t, "1", res.ID) + require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload)) + + var complete testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &complete) + require.NoError(t, err) + require.Equal(t, "complete", complete.Type) + require.Equal(t, "1", complete.ID) + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) + + t.Run("compression enabled on server but client does not support it", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Use a standard dialer WITHOUT compression (the default) + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + + // Verify the connection works correctly without compression even + // though the server has compression enabled. + err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"{ employees { id } }"}`), + }) + require.NoError(t, err) + + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "next", res.Type) + require.Equal(t, "1", res.ID) + require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload)) + + var complete testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &complete) + require.NoError(t, err) + require.Equal(t, "complete", complete.Type) + require.Equal(t, "1", complete.ID) + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) + + t.Run("compression with custom level", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 9 // Best compression + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil) + + // Check that compression was negotiated + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + + // Run a subscription query to verify it works with max compression + err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"{ employees { id details { forename surname } } }"}`), + }) + require.NoError(t, err) + + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "next", res.Type) + require.Equal(t, "1", res.ID) + require.JSONEq(t, `{"data":{"employees":[{"id":1,"details":{"forename":"Jens","surname":"Neuse"}},{"id":2,"details":{"forename":"Dustin","surname":"Deus"}},{"id":3,"details":{"forename":"Stefan","surname":"Avram"}},{"id":4,"details":{"forename":"Björn","surname":"Schwenzer"}},{"id":5,"details":{"forename":"Sergiy","surname":"Petrunin"}},{"id":7,"details":{"forename":"Suvij","surname":"Surya"}},{"id":8,"details":{"forename":"Nithin","surname":"Kumar"}},{"id":10,"details":{"forename":"Eelco","surname":"Wiersma"}},{"id":11,"details":{"forename":"Alexandra","surname":"Neuse"}},{"id":12,"details":{"forename":"David","surname":"Stutt"}}]}}`, string(res.Payload)) + + var complete testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &complete) + require.NoError(t, err) + require.Equal(t, "complete", complete.Type) + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) + + t.Run("compression negotiation includes window bits in response", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + cfg.Compression.ClientMaxWindowBits = 12 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + wsURL, err := url.Parse(xEnv.GraphQLWebSocketSubscriptionURL()) + require.NoError(t, err) + + switch wsURL.Scheme { + case "ws": + wsURL.Scheme = "http" + case "wss": + wsURL.Scheme = "https" + default: + t.Fatalf("unexpected websocket scheme: %s", wsURL.Scheme) + } + + req, err := http.NewRequest(http.MethodGet, wsURL.String(), nil) + require.NoError(t, err) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Sec-WebSocket-Protocol", "graphql-transport-ws") + req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; client_max_window_bits=15") + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + require.Contains(t, extensions, "server_max_window_bits=15", "Expected server_max_window_bits=15 in response") + require.Contains(t, extensions, "client_max_window_bits=12", "Expected client_max_window_bits=12 (server config is more restrictive)") + }) + }) + + t.Run("compression negotiation respects client offering smaller window bits than server default", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + // Server allows up to 15, but the client will offer 10. + cfg.Compression.ClientMaxWindowBits = 15 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + wsURL, err := url.Parse(xEnv.GraphQLWebSocketSubscriptionURL()) + require.NoError(t, err) + + switch wsURL.Scheme { + case "ws": + wsURL.Scheme = "http" + case "wss": + wsURL.Scheme = "https" + default: + t.Fatalf("unexpected websocket scheme: %s", wsURL.Scheme) + } + + req, err := http.NewRequest(http.MethodGet, wsURL.String(), nil) + require.NoError(t, err) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Sec-WebSocket-Protocol", "graphql-transport-ws") + // Client offers a window smaller than the server's configured maximum. + req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; client_max_window_bits=10") + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + require.Contains(t, extensions, "server_max_window_bits=15", "Expected server_max_window_bits=15") + require.Contains(t, extensions, "client_max_window_bits=10", "Expected server to honour the client's smaller window bits") + }) + }) + + t.Run("compression with small client_max_window_bits works end to end", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + cfg.Compression.ClientMaxWindowBits = 9 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // gorilla's Dialer with EnableCompression sends client_max_window_bits (no value), + // so the server responds with its configured value of 9. + conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil) + + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + + // Run a query to verify data round-trips correctly with the smaller window. + err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"{ employees { id } }"}`), + }) + require.NoError(t, err) + + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "next", res.Type) + require.Equal(t, "1", res.ID) + require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload)) + + var complete testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &complete) + require.NoError(t, err) + require.Equal(t, "complete", complete.Type) + require.Equal(t, "1", complete.ID) + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) + + t.Run("compression negotiation omits client_max_window_bits when client does not offer it", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + cfg.Compression.ClientMaxWindowBits = 10 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + wsURL, err := url.Parse(xEnv.GraphQLWebSocketSubscriptionURL()) + require.NoError(t, err) + + switch wsURL.Scheme { + case "ws": + wsURL.Scheme = "http" + case "wss": + wsURL.Scheme = "https" + default: + t.Fatalf("unexpected websocket scheme: %s", wsURL.Scheme) + } + + req, err := http.NewRequest(http.MethodGet, wsURL.String(), nil) + require.NoError(t, err) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Sec-WebSocket-Protocol", "graphql-transport-ws") + // Offer permessage-deflate WITHOUT client_max_window_bits. + req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate") + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + require.Contains(t, extensions, "server_max_window_bits=15", "Expected server_max_window_bits=15 in response") + // RFC 7692 §7.1.2.2: client_max_window_bits must NOT appear when the client didn't offer it. + require.NotContains(t, extensions, "client_max_window_bits", "Expected no client_max_window_bits when client didn't offer it") + }) + }) + } func TestFlakyWebSockets(t *testing.T) { @@ -3067,3 +3460,210 @@ func handleCountEmpSubscription(t *testing.T, wsWriteCh chan<- wsJSONMessage, ws } } } + +// TestWebSocketsCoderClient validates the server's WebSocket handling using the +// coder/websocket library (successor to nhooyr.io/websocket) as an alternative +// to the gorilla/websocket client used in the rest of this file. +func TestWebSocketsCoderClient(t *testing.T) { + t.Run("basic subscription", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, _, err := coderws.Dial(ctx, xEnv.GraphQLWebSocketSubscriptionURL(), &coderws.DialOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.CloseNow() + }) + + // connection_init + err = conn.Write(ctx, coderws.MessageText, []byte(`{"type":"connection_init"}`)) + require.NoError(t, err) + + // connection_ack + _, ackData, err := conn.Read(ctx) + require.NoError(t, err) + var ack testenv.WebSocketMessage + require.NoError(t, json.Unmarshal(ackData, &ack)) + require.Equal(t, "connection_ack", ack.Type) + + // subscribe + subscribePayload, err := json.Marshal(testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: json.RawMessage(`{"query":"{ employees { id } }"}`), + }) + require.NoError(t, err) + err = conn.Write(ctx, coderws.MessageText, subscribePayload) + require.NoError(t, err) + + // next + _, nextData, err := conn.Read(ctx) + require.NoError(t, err) + var next testenv.WebSocketMessage + require.NoError(t, json.Unmarshal(nextData, &next)) + require.Equal(t, "next", next.Type) + require.Equal(t, "1", next.ID) + require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(next.Payload)) + + // complete + _, completeData, err := conn.Read(ctx) + require.NoError(t, err) + var complete testenv.WebSocketMessage + require.NoError(t, json.Unmarshal(completeData, &complete)) + require.Equal(t, "complete", complete.Type) + require.Equal(t, "1", complete.ID) + }) + }) + + t.Run("subscription with compression", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, resp, err := coderws.Dial(ctx, xEnv.GraphQLWebSocketSubscriptionURL(), &coderws.DialOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + CompressionMode: coderws.CompressionContextTakeover, + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.CloseNow() + }) + + // Verify compression was negotiated. + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated") + + // connection_init + err = conn.Write(ctx, coderws.MessageText, []byte(`{"type":"connection_init"}`)) + require.NoError(t, err) + + // connection_ack + _, ackData, err := conn.Read(ctx) + require.NoError(t, err) + var ack testenv.WebSocketMessage + require.NoError(t, json.Unmarshal(ackData, &ack)) + require.Equal(t, "connection_ack", ack.Type) + + // subscribe + subscribePayload, err := json.Marshal(testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: json.RawMessage(`{"query":"{ employees { id } }"}`), + }) + require.NoError(t, err) + err = conn.Write(ctx, coderws.MessageText, subscribePayload) + require.NoError(t, err) + + // next + _, nextData, err := conn.Read(ctx) + require.NoError(t, err) + var next testenv.WebSocketMessage + require.NoError(t, json.Unmarshal(nextData, &next)) + require.Equal(t, "next", next.Type) + require.Equal(t, "1", next.ID) + require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(next.Payload)) + + // complete + _, completeData, err := conn.Read(ctx) + require.NoError(t, err) + var complete testenv.WebSocketMessage + require.NoError(t, json.Unmarshal(completeData, &complete)) + require.Equal(t, "complete", complete.Type) + require.Equal(t, "1", complete.ID) + }) + }) + + t.Run("multiple queries over single context takeover connection", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Compression.Enabled = true + cfg.Compression.Level = 6 + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, resp, err := coderws.Dial(ctx, xEnv.GraphQLWebSocketSubscriptionURL(), &coderws.DialOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + CompressionMode: coderws.CompressionContextTakeover, + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.CloseNow() + }) + + // Verify context takeover was negotiated (no no_context_takeover flags). + extensions := resp.Header.Get("Sec-WebSocket-Extensions") + require.Contains(t, extensions, "permessage-deflate") + require.NotContains(t, extensions, "server_no_context_takeover") + require.NotContains(t, extensions, "client_no_context_takeover") + + // connection_init / connection_ack + err = conn.Write(ctx, coderws.MessageText, []byte(`{"type":"connection_init"}`)) + require.NoError(t, err) + _, ackData, err := conn.Read(ctx) + require.NoError(t, err) + var ack testenv.WebSocketMessage + require.NoError(t, json.Unmarshal(ackData, &ack)) + require.Equal(t, "connection_ack", ack.Type) + + // Helper: run a one-shot query and return the payload. + runQuery := func(id, query, expectedPayload string) { + t.Helper() + sub, err := json.Marshal(testenv.WebSocketMessage{ + ID: id, + Type: "subscribe", + Payload: json.RawMessage(fmt.Sprintf(`{"query":"%s"}`, query)), + }) + require.NoError(t, err) + err = conn.Write(ctx, coderws.MessageText, sub) + require.NoError(t, err) + + // next + _, data, err := conn.Read(ctx) + require.NoError(t, err) + var msg testenv.WebSocketMessage + require.NoError(t, json.Unmarshal(data, &msg)) + require.Equal(t, "next", msg.Type) + require.Equal(t, id, msg.ID) + require.JSONEq(t, expectedPayload, string(msg.Payload)) + + // complete + _, data, err = conn.Read(ctx) + require.NoError(t, err) + var comp testenv.WebSocketMessage + require.NoError(t, json.Unmarshal(data, &comp)) + require.Equal(t, "complete", comp.Type) + require.Equal(t, id, comp.ID) + } + + expectedEmployeeIDs := `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}` + expectedEmployeeDetails := `{"data":{"employees":[{"id":1,"details":{"forename":"Jens","surname":"Neuse"}},{"id":2,"details":{"forename":"Dustin","surname":"Deus"}},{"id":3,"details":{"forename":"Stefan","surname":"Avram"}},{"id":4,"details":{"forename":"Björn","surname":"Schwenzer"}},{"id":5,"details":{"forename":"Sergiy","surname":"Petrunin"}},{"id":7,"details":{"forename":"Suvij","surname":"Surya"}},{"id":8,"details":{"forename":"Nithin","surname":"Kumar"}},{"id":10,"details":{"forename":"Eelco","surname":"Wiersma"}},{"id":11,"details":{"forename":"Alexandra","surname":"Neuse"}},{"id":12,"details":{"forename":"David","surname":"Stutt"}}]}}` + + // Run three queries sequentially on the same connection. + // The server's writeCompressedWithContextTakeover and + // decompressWithContextTakeover paths maintain dictionary + // state across messages. If the dictionary gets corrupted, + // decompression will fail and the reads above will error. + runQuery("1", "{ employees { id } }", expectedEmployeeIDs) + runQuery("2", "{ employees { id details { forename surname } } }", expectedEmployeeDetails) + // Third query repeats the first — exercises dictionary reuse + // with identical content after intervening different content. + runQuery("3", "{ employees { id } }", expectedEmployeeIDs) + }) + }) +} diff --git a/router/core/websocket.go b/router/core/websocket.go index 0aa5ca5588..eb2da71469 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -2,10 +2,12 @@ package core import ( "bytes" + "compress/flate" "context" "encoding/json" "errors" "fmt" + "io" "net" "net/http" "regexp" @@ -16,7 +18,9 @@ import ( "github.com/buger/jsonparser" "github.com/go-chi/chi/v5/middleware" + "github.com/gobwas/httphead" "github.com/gobwas/ws" + "github.com/gobwas/ws/wsflate" "github.com/gobwas/ws/wsutil" "github.com/gorilla/websocket" "github.com/tidwall/gjson" @@ -42,6 +46,17 @@ var ( errClientTerminatedConnection = errors.New("client terminated connection") ) +type compressionMode struct { + enabled bool + level int + serverContextTakeover bool + clientContextTakeover bool + // clientWindowBits is the negotiated LZ77 window size (8-15) used by the + // client for compression. The decompression dictionary is sized as + // 1 << clientWindowBits. Default is 15 (32 KB, the DEFLATE maximum). + clientWindowBits int +} + type WebsocketMiddlewareOptions struct { OperationProcessor *OperationProcessor OperationBlocker *OperationBlocker @@ -87,6 +102,17 @@ func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions disableVariablesRemapping: opts.DisableVariablesRemapping, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, } + if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.Compression.Enabled { + handler.compression.enabled = true + handler.compression.level = opts.WebSocketConfiguration.Compression.Level + if handler.compression.level < 1 || handler.compression.level > 9 { + handler.compression.level = flate.DefaultCompression + } + handler.compression.clientWindowBits = opts.WebSocketConfiguration.Compression.ClientMaxWindowBits + if handler.compression.clientWindowBits < 8 || handler.compression.clientWindowBits > 15 { + handler.compression.clientWindowBits = 15 + } + } if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.AbsintheProtocol.Enabled { handler.absintheHandlerEnabled = true handler.absintheHandlerPath = opts.WebSocketConfiguration.AbsintheProtocol.HandlerPath @@ -156,31 +182,201 @@ type wsConnectionWrapper struct { mu sync.Mutex readTimeout time.Duration writeTimeout time.Duration + + // Compression and takeover mode negotiated for this connection. + compression compressionMode + + // Persistent compression state (only used with context takeover) + compressBuf *bytes.Buffer + compressor *flate.Writer + + // Persistent decompression state (only used with context takeover) + decompressor io.ReadCloser + decompressDict []byte } -func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration) *wsConnectionWrapper { - return &wsConnectionWrapper{ +func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration, compression compressionMode) (*wsConnectionWrapper, error) { + w := &wsConnectionWrapper{ conn: conn, readTimeout: readTimeout, writeTimeout: writeTimeout, + compression: compression, + } + + // Initialize persistent compression state only if context takeover is enabled + if compression.enabled && compression.serverContextTakeover { + w.compressBuf = new(bytes.Buffer) + var err error + w.compressor, err = flate.NewWriter(w.compressBuf, compression.level) + if err != nil { + return nil, fmt.Errorf("failed to create flate compressor: %w", err) + } } + + if compression.enabled && compression.clientContextTakeover { + w.decompressor = flate.NewReader(bytes.NewReader(nil)) + w.decompressDict = make([]byte, 0, 1< 0 { - err := c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) +func (c *wsConnectionWrapper) readDataFrames() ([]byte, ws.OpCode, bool, error) { + // Read frames directly and handle compression, buffering fragmented messages. + controlHandler := wsutil.ControlFrameHandler(c.conn, ws.StateServerSide) + var ( + frame ws.Frame + payload []byte + isCompressed bool + op ws.OpCode + started bool + err error + ) + + for { + frame, err = ws.ReadFrame(c.conn) if err != nil { - return err + return nil, 0, false, err + } + + // RFC 6455 §5.1: all client-to-server frames MUST be masked. + if !frame.Header.Masked { + return nil, 0, false, fmt.Errorf("received unmasked frame (opcode %v)", frame.Header.OpCode) + } + ws.Cipher(frame.Payload, frame.Header.Mask, 0) + + if frame.Header.OpCode.IsControl() { + if err := controlHandler(frame.Header, bytes.NewReader(frame.Payload)); err != nil { + return nil, 0, false, err + } + continue + } + + if !started { + // First data frame must be text or binary. + if frame.Header.OpCode != ws.OpText && frame.Header.OpCode != ws.OpBinary { + continue + } + op = frame.Header.OpCode + started = true + // Per RFC 7692, the RSV1 compression bit is only set on the first frame. + isCompressed, err = wsflate.IsCompressed(frame.Header) + if err != nil { + return nil, 0, false, err + } + } else if frame.Header.OpCode != ws.OpContinuation { + // After the first frame, we expect continuation frames until FIN. + return nil, 0, false, fmt.Errorf("unexpected opcode %v while waiting for continuation", frame.Header.OpCode) + } + + // Buffer the payload from this frame. + payload = append(payload, frame.Payload...) + + // Check if this is the final frame. + if frame.Header.Fin { + return payload, op, isCompressed, nil } } +} - text, err := wsutil.ReadClientText(c.conn) +func (c *wsConnectionWrapper) decompressPayload(payload []byte, op ws.OpCode) ([]byte, error) { + if c.compression.clientContextTakeover { + // Use persistent decompressor with dictionary for context takeover. + return c.decompressWithContextTakeover(payload) + } + + // No context takeover - decompress independently. + frame := ws.NewFrame(op, true, payload) + frame.Header.Rsv = ws.Rsv(true, false, false) + frame, err := wsflate.DecompressFrame(frame) if err != nil { - return err + return nil, err } + return frame.Payload, nil +} - return json.Unmarshal(text, v) +// decompressWithContextTakeover decompresses data using the persistent decompressor, +// maintaining dictionary state across messages for better decompression. +func (c *wsConnectionWrapper) decompressWithContextTakeover(compressed []byte) ([]byte, error) { + // Per RFC 7692, append the DEFLATE tail expected by wsflate reader semantics. + // wsflate uses a 9-byte "read tail" (not just the 4-byte sync marker), which + // avoids premature EOF on some streams. + compressed = append( + compressed, + 0x00, 0x00, 0xff, 0xff, // sync flush marker + 0x01, 0x00, 0x00, 0xff, 0xff, // empty stored block + ) + + // Reset the decompressor to read from the new compressed data, using the accumulated dictionary + if resetter, ok := c.decompressor.(flate.Resetter); ok { + if err := resetter.Reset(bytes.NewReader(compressed), c.decompressDict); err != nil { + return nil, err + } + } + + // Read all decompressed data + var decompressed bytes.Buffer + if _, err := io.Copy(&decompressed, c.decompressor); err != nil { + return nil, err + } + + // Update the dictionary with the decompressed data (keep last 32KB per DEFLATE spec) + c.updateDecompressDict(decompressed.Bytes()) + + return decompressed.Bytes(), nil +} + +// updateDecompressDict updates the decompression dictionary with new data. +// The dictionary is capped to the negotiated client window size (1 << clientWindowBits). +func (c *wsConnectionWrapper) updateDecompressDict(data []byte) { + maxDictSize := 1 << c.compression.clientWindowBits + + if len(data) >= maxDictSize { + c.decompressDict = make([]byte, maxDictSize) + copy(c.decompressDict, data[len(data)-maxDictSize:]) + } else { + c.decompressDict = append(c.decompressDict, data...) + if len(c.decompressDict) > maxDictSize { + c.decompressDict = c.decompressDict[len(c.decompressDict)-maxDictSize:] + } + } } func (c *wsConnectionWrapper) WriteText(text string) error { @@ -195,6 +391,10 @@ func (c *wsConnectionWrapper) WriteText(text string) error { } } + if c.compression.enabled { + return c.writeCompressed([]byte(text)) + } + return wsutil.WriteServerText(c.conn, []byte(text)) } @@ -213,9 +413,72 @@ func (c *wsConnectionWrapper) WriteJSON(v any) error { } } + if c.compression.enabled { + return c.writeCompressed(data) + } + return wsutil.WriteServerText(c.conn, data) } +// writeCompressed writes data with compression. Must be called with the mutex held. +func (c *wsConnectionWrapper) writeCompressed(data []byte) error { + if c.compression.serverContextTakeover { + return c.writeCompressedWithContextTakeover(data) + } + return c.writeCompressedNoContextTakeover(data) +} + +// writeCompressedNoContextTakeover compresses data without preserving dictionary state. +// A fresh flate.Writer is created per message via wsflate.NewWriter to honour +// c.compression.level and produce RFC 7692-compliant Z_SYNC_FLUSH framing +// (wsflate.CompressFrame hardcodes level 9). +func (c *wsConnectionWrapper) writeCompressedNoContextTakeover(data []byte) error { + var buf bytes.Buffer + writer := wsflate.NewWriter(&buf, func(w io.Writer) wsflate.Compressor { + fw, _ := flate.NewWriter(w, c.compression.level) + return fw + }) + if _, err := writer.Write(data); err != nil { + return err + } + if err := writer.Flush(); err != nil { + return err + } + + compressed := buf.Bytes() + frame := ws.NewFrame(ws.OpText, true, compressed) + frame.Header.Rsv = ws.Rsv(true, false, false) // Set RSV1 bit for compression + return ws.WriteFrame(c.conn, frame) +} + +// writeCompressedWithContextTakeover compresses data while preserving dictionary state +// between messages for better compression ratios. +func (c *wsConnectionWrapper) writeCompressedWithContextTakeover(data []byte) error { + // Reset buffer but NOT the compressor - this preserves the dictionary + c.compressBuf.Reset() + + if _, err := c.compressor.Write(data); err != nil { + return err + } + if err := c.compressor.Flush(); err != nil { + return err + } + + // Per RFC 7692, remove the trailing sync marker (0x00 0x00 0xff 0xff) when present. + compressed := c.compressBuf.Bytes() + if len(compressed) >= 4 && + compressed[len(compressed)-4] == 0x00 && + compressed[len(compressed)-3] == 0x00 && + compressed[len(compressed)-2] == 0xff && + compressed[len(compressed)-1] == 0xff { + compressed = compressed[:len(compressed)-4] + } + + frame := ws.NewFrame(ws.OpText, true, compressed) + frame.Header.Rsv = ws.Rsv(true, false, false) // Set RSV1 bit for compression + return ws.WriteFrame(c.conn, frame) +} + func (c *wsConnectionWrapper) WriteCloseFrame(code ws.StatusCode, reason string) error { c.mu.Lock() defer c.mu.Unlock() @@ -233,6 +496,12 @@ func (c *wsConnectionWrapper) WriteCloseFrame(code ws.StatusCode, reason string) func (c *wsConnectionWrapper) Close() error { c.mu.Lock() defer c.mu.Unlock() + if c.compressor != nil { + _ = c.compressor.Close() + } + if c.decompressor != nil { + _ = c.decompressor.Close() + } return c.conn.Close() } @@ -267,6 +536,94 @@ type WebsocketHandler struct { disableVariablesRemapping bool apolloCompatibilityFlags config.ApolloCompatibilityFlags + + compression compressionMode +} + +func (h *WebsocketHandler) configureCompressionNegotiation(upgrader *ws.HTTPUpgrader) *wsflate.Extension { + if !h.compression.enabled { + return nil + } + + ext := &wsflate.Extension{ + Parameters: wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + // Accept any client offer for server_max_window_bits (up to 15). + ServerMaxWindowBits: 15, + }, + } + upgrader.Negotiate = func(opt httphead.Option) (accept httphead.Option, err error) { + accept, err = ext.Negotiate(opt) + if err != nil || accept.Size() == 0 { + return accept, err + } + + params, accepted := ext.Accepted() + if !accepted { + return accept, nil + } + + response := wsflate.Parameters{ + // Mirror no_context_takeover only when explicitly requested by the client. + ServerNoContextTakeover: params.ServerNoContextTakeover, + ClientNoContextTakeover: params.ClientNoContextTakeover, + // Go's compress/flate always uses a 32 KB window (bits=15). + ServerMaxWindowBits: 15, + } + + // RFC 7692 §7.1.2.2: only include client_max_window_bits in the + // response when the client included it in the offer. + if params.ClientMaxWindowBits.Defined() { + configBits := wsflate.WindowBits(h.compression.clientWindowBits) + // Use the more restrictive of the client's offer and server config. + // ClientMaxWindowBits == 1 means "parameter present, no value" — + // the client accepts any server-chosen value. + if params.ClientMaxWindowBits == 1 || configBits < params.ClientMaxWindowBits { + response.ClientMaxWindowBits = configBits + } else { + response.ClientMaxWindowBits = params.ClientMaxWindowBits + } + } + + return response.Option(), nil + } + + return ext +} + +func resolveNegotiatedCompression(base compressionMode, ext *wsflate.Extension, upgradeErr error) compressionMode { + if ext == nil || upgradeErr != nil { + return compressionMode{ + enabled: false, + level: base.level, + } + } + params, accepted := ext.Accepted() + if !accepted { + return compressionMode{ + enabled: false, + level: base.level, + } + } + // Derive the effective client window bits from the negotiation. + // If the client offered client_max_window_bits, use min(offer, config); + // otherwise fall back to the configured default. + clientWindowBits := base.clientWindowBits + if params.ClientMaxWindowBits.Defined() && params.ClientMaxWindowBits > 1 { + if int(params.ClientMaxWindowBits) < clientWindowBits { + clientWindowBits = int(params.ClientMaxWindowBits) + } + } + + // Context takeover remains enabled when no_context_takeover is not requested. + return compressionMode{ + enabled: true, + level: base.level, + serverContextTakeover: !params.ServerNoContextTakeover, + clientContextTakeover: !params.ClientNoContextTakeover, + clientWindowBits: clientWindowBits, + } } func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.Request) { @@ -309,7 +666,12 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R return false }, } + + compressionExt := h.configureCompressionNegotiation(&upgrader) + c, _, _, err := upgrader.Upgrade(r, w) + connectionCompression := resolveNegotiatedCompression(h.compression, compressionExt, err) + if err != nil { requestLogger.Warn("Websocket upgrade", zap.Error(err)) _ = c.Close() @@ -325,7 +687,12 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R // After successful upgrade, we can't write to the response writer anymore // because it's hijacked by the websocket connection - conn := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout) + conn, err := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout, connectionCompression) + if err != nil { + requestLogger.Error("Create websocket connection wrapper", zap.Error(err)) + _ = c.Close() + return + } protocol, err := wsproto.NewProtocol(subProtocol, conn) if err != nil { requestLogger.Error("Create websocket protocol", zap.Error(err)) @@ -802,7 +1169,12 @@ func (h *WebSocketConnectionHandler) requestError(err error) error { return err } h.logger.Warn("Handling websocket connection", zap.Error(err)) - return h.conn.WriteText(err.Error()) + // Keep websocket protocol-compliant on init/read failures. Plain text payloads + // are interpreted as GraphQL messages by clients and cause JSON parse errors. + if closeErr := h.protocol.Close(ws.StatusProtocolError, err.Error()); closeErr != nil { + return closeErr + } + return err } func (h *WebSocketConnectionHandler) writeErrorMessage(operationID string, err error) error { diff --git a/router/core/websocket_test.go b/router/core/websocket_test.go new file mode 100644 index 0000000000..e33f3f8d6c --- /dev/null +++ b/router/core/websocket_test.go @@ -0,0 +1,674 @@ +package core + +import ( + "bytes" + "compress/flate" + "encoding/json" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsflate" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockConn is a mock net.Conn for testing WebSocket operations +type mockConn struct { + readBuf *bytes.Buffer + writeBuf *bytes.Buffer + mu sync.Mutex + closed bool +} + +func newMockConn() *mockConn { + return &mockConn{ + readBuf: new(bytes.Buffer), + writeBuf: new(bytes.Buffer), + } +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + return m.readBuf.Read(b) +} + +func (m *mockConn) Write(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + return m.writeBuf.Write(b) +} + +func (m *mockConn) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { return nil } +func (m *mockConn) RemoteAddr() net.Addr { return nil } +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +// writeFrame writes a WebSocket frame to the mock connection's read buffer +func (m *mockConn) writeFrame(frame ws.Frame) error { + m.mu.Lock() + defer m.mu.Unlock() + return ws.WriteFrame(m.readBuf, frame) +} + +// getWrittenBytes returns the bytes written to the mock connection +func (m *mockConn) getWrittenBytes() []byte { + m.mu.Lock() + defer m.mu.Unlock() + return m.writeBuf.Bytes() +} + +// TestWsConnectionWrapper_NoContextTakeover tests compression without context takeover +func TestWsConnectionWrapper_NoContextTakeover(t *testing.T) { + t.Run("write compressed message without context takeover", func(t *testing.T) { + conn := newMockConn() + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + // Write a test message + testData := map[string]string{"message": "hello world"} + err = wrapper.WriteJSON(testData) + require.NoError(t, err) + + // Read the frame from the mock connection + writtenBytes := conn.getWrittenBytes() + require.NotEmpty(t, writtenBytes) + + // Parse the frame + frame, err := ws.ReadFrame(bytes.NewReader(writtenBytes)) + require.NoError(t, err) + + // Verify RSV1 bit is set (compression) + isCompressed, err := wsflate.IsCompressed(frame.Header) + require.NoError(t, err) + assert.True(t, isCompressed, "Frame should be compressed") + + // Decompress and verify content + decompressed, err := wsflate.DecompressFrame(frame) + require.NoError(t, err) + + var result map[string]string + err = json.Unmarshal(decompressed.Payload, &result) + require.NoError(t, err) + assert.Equal(t, testData, result) + }) + + t.Run("read compressed message without context takeover", func(t *testing.T) { + conn := newMockConn() + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + // Prepare a compressed message + testData := map[string]string{"message": "hello world"} + jsonData, _ := json.Marshal(testData) + + // Compress the data + compressed, err := compressData(jsonData) + require.NoError(t, err) + + // Create a compressed frame (client frame - masked) + frame := ws.NewFrame(ws.OpText, true, compressed) + frame.Header.Rsv = ws.Rsv(true, false, false) + frame.Header.Masked = true + frame.Header.Mask = [4]byte{1, 2, 3, 4} + ws.Cipher(frame.Payload, frame.Header.Mask, 0) + + // Write to mock connection's read buffer + err = conn.writeFrame(frame) + require.NoError(t, err) + + // Read and verify + var result map[string]string + err = wrapper.ReadJSON(&result) + require.NoError(t, err) + assert.Equal(t, testData, result) + }) +} + +// TestWsConnectionWrapper_ContextTakeover tests compression with context takeover +func TestWsConnectionWrapper_ContextTakeover(t *testing.T) { + t.Run("write multiple messages with server context takeover shows compression benefit", func(t *testing.T) { + conn := newMockConn() + // Enable server context takeover + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, serverContextTakeover: true, clientWindowBits: 15}) + require.NoError(t, err) + + // Write multiple similar messages - with context takeover, + // subsequent messages should reference patterns from earlier ones + messages := []map[string]string{ + {"type": "next", "id": "1", "data": "first message with some repeated content"}, + {"type": "next", "id": "2", "data": "second message with some repeated content"}, + {"type": "next", "id": "3", "data": "third message with some repeated content"}, + } + + var compressedSizes []int + + for _, msg := range messages { + conn.writeBuf.Reset() // Clear for each message + + err := wrapper.WriteJSON(msg) + require.NoError(t, err) + + writtenBytes := conn.getWrittenBytes() + compressedSizes = append(compressedSizes, len(writtenBytes)) + } + + // With context takeover, later messages should be smaller + // because they can reference patterns from earlier messages + t.Logf("Compressed sizes with context takeover: %v", compressedSizes) + + // The second and third messages should be smaller than the first + // due to dictionary reuse + assert.Less(t, compressedSizes[1], compressedSizes[0], + "Second message should be smaller due to context takeover") + assert.Less(t, compressedSizes[2], compressedSizes[0], + "Third message should be smaller due to context takeover") + }) + + t.Run("write multiple messages without server context takeover for comparison", func(t *testing.T) { + conn := newMockConn() + // Disable server context takeover + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + messages := []map[string]string{ + {"type": "next", "id": "1", "data": "first message with some repeated content"}, + {"type": "next", "id": "2", "data": "second message with some repeated content"}, + {"type": "next", "id": "3", "data": "third message with some repeated content"}, + } + + var compressedSizes []int + + for _, msg := range messages { + conn.writeBuf.Reset() + + err := wrapper.WriteJSON(msg) + require.NoError(t, err) + + writtenBytes := conn.getWrittenBytes() + compressedSizes = append(compressedSizes, len(writtenBytes)) + } + + t.Logf("Compressed sizes without context takeover: %v", compressedSizes) + + // Without context takeover, all messages should be similar size + // since each is compressed independently + sizeDiff12 := abs(compressedSizes[1] - compressedSizes[0]) + sizeDiff13 := abs(compressedSizes[2] - compressedSizes[0]) + + // Allow some variation but messages should be roughly same size + assert.Less(t, sizeDiff12, 10, "Messages without context takeover should be similar size") + assert.Less(t, sizeDiff13, 10, "Messages without context takeover should be similar size") + }) + + t.Run("read compressed messages without context takeover", func(t *testing.T) { + conn := newMockConn() + // Disable client context takeover - each message compressed independently + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + // Prepare multiple independently compressed messages + messages := []map[string]string{ + {"type": "subscribe", "id": "1"}, + {"type": "subscribe", "id": "2"}, + {"type": "subscribe", "id": "3"}, + } + + for _, msg := range messages { + conn.readBuf.Reset() + + jsonData, _ := json.Marshal(msg) + // Compress without context takeover (independent) + compressed, err := compressData(jsonData) + require.NoError(t, err) + + // Create frame + frame := ws.NewFrame(ws.OpText, true, compressed) + frame.Header.Rsv = ws.Rsv(true, false, false) + frame.Header.Masked = true + frame.Header.Mask = [4]byte{1, 2, 3, 4} + ws.Cipher(frame.Payload, frame.Header.Mask, 0) + + err = conn.writeFrame(frame) + require.NoError(t, err) + + // Read and verify + var result map[string]string + err = wrapper.ReadJSON(&result) + require.NoError(t, err) + assert.Equal(t, msg, result) + } + }) +} + +// TestWsConnectionWrapper_FragmentedFrames tests handling of fragmented WebSocket frames +func TestWsConnectionWrapper_FragmentedFrames(t *testing.T) { + t.Run("read fragmented uncompressed message", func(t *testing.T) { + conn := newMockConn() + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + // Prepare a message that will be sent in fragments + testData := map[string]string{"message": "this is a longer message that will be fragmented"} + jsonData, _ := json.Marshal(testData) + + // Split into 3 fragments + fragmentSize := len(jsonData) / 3 + fragments := [][]byte{ + jsonData[:fragmentSize], + jsonData[fragmentSize : 2*fragmentSize], + jsonData[2*fragmentSize:], + } + + // First fragment (not FIN, OpText) + frame1 := ws.NewFrame(ws.OpText, false, fragments[0]) + frame1.Header.Masked = true + frame1.Header.Mask = [4]byte{1, 2, 3, 4} + ws.Cipher(frame1.Payload, frame1.Header.Mask, 0) + + // Middle fragment (not FIN, OpContinuation) + frame2 := ws.NewFrame(ws.OpContinuation, false, fragments[1]) + frame2.Header.Masked = true + frame2.Header.Mask = [4]byte{5, 6, 7, 8} + ws.Cipher(frame2.Payload, frame2.Header.Mask, 0) + + // Final fragment (FIN, OpContinuation) + frame3 := ws.NewFrame(ws.OpContinuation, true, fragments[2]) + frame3.Header.Masked = true + frame3.Header.Mask = [4]byte{9, 10, 11, 12} + ws.Cipher(frame3.Payload, frame3.Header.Mask, 0) + + // Write all frames + err = conn.writeFrame(frame1) + require.NoError(t, err) + err = conn.writeFrame(frame2) + require.NoError(t, err) + err = conn.writeFrame(frame3) + require.NoError(t, err) + + // Read should reassemble the fragments + var result map[string]string + err = wrapper.ReadJSON(&result) + require.NoError(t, err) + assert.Equal(t, testData, result) + }) + + t.Run("read fragmented compressed message", func(t *testing.T) { + conn := newMockConn() + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + // Prepare and compress a message + testData := map[string]string{"message": "this is a compressed message that will be fragmented"} + jsonData, _ := json.Marshal(testData) + compressed, err := compressData(jsonData) + require.NoError(t, err) + + // Split compressed data into 2 fragments + midPoint := len(compressed) / 2 + fragments := [][]byte{ + compressed[:midPoint], + compressed[midPoint:], + } + + // First fragment (not FIN, OpText, RSV1 set for compression) + frame1 := ws.NewFrame(ws.OpText, false, fragments[0]) + frame1.Header.Rsv = ws.Rsv(true, false, false) // RSV1 only on first frame + frame1.Header.Masked = true + frame1.Header.Mask = [4]byte{1, 2, 3, 4} + ws.Cipher(frame1.Payload, frame1.Header.Mask, 0) + + // Final fragment (FIN, OpContinuation, RSV1 NOT set per RFC 7692) + frame2 := ws.NewFrame(ws.OpContinuation, true, fragments[1]) + frame2.Header.Masked = true + frame2.Header.Mask = [4]byte{5, 6, 7, 8} + ws.Cipher(frame2.Payload, frame2.Header.Mask, 0) + + // Write frames + err = conn.writeFrame(frame1) + require.NoError(t, err) + err = conn.writeFrame(frame2) + require.NoError(t, err) + + // Read should reassemble and decompress + var result map[string]string + err = wrapper.ReadJSON(&result) + require.NoError(t, err) + assert.Equal(t, testData, result) + }) +} + +// TestWsConnectionWrapper_ContextTakeoverDictionary tests dictionary accumulation +func TestWsConnectionWrapper_ContextTakeoverDictionary(t *testing.T) { + t.Run("server context takeover compressor maintains state", func(t *testing.T) { + conn := newMockConn() + // Enable server context takeover + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, serverContextTakeover: true, clientWindowBits: 15}) + require.NoError(t, err) + + // Verify compressor is initialized + assert.NotNil(t, wrapper.compressor, "Compressor should be initialized for context takeover") + assert.NotNil(t, wrapper.compressBuf, "Compress buffer should be initialized") + + // Write a message + err = wrapper.WriteJSON(map[string]string{"test": "data"}) + require.NoError(t, err) + + // Compressor should still be valid (not nil) after use + assert.NotNil(t, wrapper.compressor, "Compressor should persist after use") + }) + + t.Run("client context takeover decompressor is initialized", func(t *testing.T) { + conn := newMockConn() + // Enable client context takeover + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientContextTakeover: true, clientWindowBits: 15}) + require.NoError(t, err) + + // Verify decompressor is initialized + assert.NotNil(t, wrapper.decompressor, "Decompressor should be initialized for context takeover") + assert.NotNil(t, wrapper.decompressDict, "Decompress dictionary should be initialized") + }) + + t.Run("no context takeover does not initialize persistent state", func(t *testing.T) { + conn := newMockConn() + // Disable context takeover + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + // Verify persistent state is not initialized + assert.Nil(t, wrapper.compressor, "Compressor should not be initialized without server context takeover") + assert.Nil(t, wrapper.compressBuf, "Compress buffer should not be initialized") + assert.Nil(t, wrapper.decompressor, "Decompressor should not be initialized without client context takeover") + assert.Nil(t, wrapper.decompressDict, "Decompress dictionary should not be initialized") + }) + + t.Run("decompress with client context takeover handles wsflate read tail framing", func(t *testing.T) { + conn := newMockConn() + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientContextTakeover: true, clientWindowBits: 15}) + require.NoError(t, err) + + var compressBuf bytes.Buffer + compressor, err := flate.NewWriter(&compressBuf, 6) + require.NoError(t, err) + t.Cleanup(func() { + _ = compressor.Close() + }) + + compressMessage := func(msg []byte) []byte { + compressBuf.Reset() + _, err := compressor.Write(msg) + require.NoError(t, err) + require.NoError(t, compressor.Flush()) + + compressed := append([]byte(nil), compressBuf.Bytes()...) + + // Match PMCE framing semantics: sender strips the 0x00 0x00 0xff 0xff tail. + require.GreaterOrEqual(t, len(compressed), 4) + require.Equal(t, []byte{0x00, 0x00, 0xff, 0xff}, compressed[len(compressed)-4:]) + return compressed[:len(compressed)-4] + } + + // Second message is highly repetitive so it should benefit from context takeover. + msg1 := []byte(`{"type":"next","id":"1","payload":"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}`) + msg2 := []byte(`{"type":"next","id":"2","payload":"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}`) + + compressed1 := compressMessage(msg1) + got1, err := wrapper.decompressWithContextTakeover(compressed1) + require.NoError(t, err) + assert.Equal(t, msg1, got1) + + compressed2 := compressMessage(msg2) + got2, err := wrapper.decompressWithContextTakeover(compressed2) + require.NoError(t, err) + assert.Equal(t, msg2, got2) + }) +} + +// TestWsConnectionWrapper_CompressionDisabled tests behavior when compression is disabled +func TestWsConnectionWrapper_CompressionDisabled(t *testing.T) { + t.Run("write uncompressed when compression disabled", func(t *testing.T) { + conn := newMockConn() + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: false, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + testData := map[string]string{"message": "hello world"} + err = wrapper.WriteJSON(testData) + require.NoError(t, err) + + writtenBytes := conn.getWrittenBytes() + require.NotEmpty(t, writtenBytes) + + // Parse the frame + frame, err := ws.ReadFrame(bytes.NewReader(writtenBytes)) + require.NoError(t, err) + + // Verify RSV1 bit is NOT set (no compression) + isCompressed, err := wsflate.IsCompressed(frame.Header) + require.NoError(t, err) + assert.False(t, isCompressed, "Frame should not be compressed") + }) +} + +// TestWsConnectionWrapper_WriteText tests WriteText method +func TestWsConnectionWrapper_WriteText(t *testing.T) { + t.Run("write text with compression", func(t *testing.T) { + conn := newMockConn() + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: true, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + testText := `{"type":"connection_ack"}` + err = wrapper.WriteText(testText) + require.NoError(t, err) + + // Verify frame is compressed + frame, err := ws.ReadFrame(bytes.NewReader(conn.getWrittenBytes())) + require.NoError(t, err) + + isCompressed, err := wsflate.IsCompressed(frame.Header) + require.NoError(t, err) + assert.True(t, isCompressed) + + // Decompress and verify + decompressed, err := wsflate.DecompressFrame(frame) + require.NoError(t, err) + assert.Equal(t, testText, string(decompressed.Payload)) + }) + + t.Run("write text without compression", func(t *testing.T) { + conn := newMockConn() + wrapper, err := newWSConnectionWrapper(conn, 0, 0, compressionMode{enabled: false, level: 6, clientWindowBits: 15}) + require.NoError(t, err) + + testText := `{"type":"connection_ack"}` + err = wrapper.WriteText(testText) + require.NoError(t, err) + + // Verify frame is not compressed + frame, err := ws.ReadFrame(bytes.NewReader(conn.getWrittenBytes())) + require.NoError(t, err) + + isCompressed, err := wsflate.IsCompressed(frame.Header) + require.NoError(t, err) + assert.False(t, isCompressed) + assert.Equal(t, testText, string(frame.Payload)) + }) +} + +// TestResolveNegotiatedCompression tests the resolveNegotiatedCompression function +func TestResolveNegotiatedCompression(t *testing.T) { + base := compressionMode{enabled: true, level: 6, clientWindowBits: 15} + + t.Run("returns disabled when ext is nil", func(t *testing.T) { + result := resolveNegotiatedCompression(base, nil, nil) + assert.False(t, result.enabled) + assert.Equal(t, 6, result.level) + }) + + t.Run("returns disabled when upgrade error occurs", func(t *testing.T) { + ext := &wsflate.Extension{} + result := resolveNegotiatedCompression(base, ext, fmt.Errorf("upgrade failed")) + assert.False(t, result.enabled) + assert.Equal(t, 6, result.level) + }) + + t.Run("returns disabled when compression not accepted by client", func(t *testing.T) { + // Extension exists (server supports compression) but Accepted() returns false + // because the client never offered permessage-deflate. + ext := &wsflate.Extension{ + Parameters: wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + }, + } + // Without calling ext.Negotiate, ext.Accepted() returns false. + result := resolveNegotiatedCompression(base, ext, nil) + assert.False(t, result.enabled, "compression must be disabled when the client did not negotiate it") + assert.Equal(t, 6, result.level) + assert.False(t, result.serverContextTakeover) + assert.False(t, result.clientContextTakeover) + }) + + t.Run("returns enabled with context takeover when accepted without no_context_takeover", func(t *testing.T) { + ext := &wsflate.Extension{ + Parameters: wsflate.Parameters{ + ServerNoContextTakeover: false, + ClientNoContextTakeover: false, + }, + } + // Simulate successful negotiation by calling Negotiate with a valid offer. + offer := wsflate.Parameters{ + ServerNoContextTakeover: false, + ClientNoContextTakeover: false, + }.Option() + _, _ = ext.Negotiate(offer) + + result := resolveNegotiatedCompression(base, ext, nil) + assert.True(t, result.enabled) + assert.Equal(t, 6, result.level) + assert.True(t, result.serverContextTakeover) + assert.True(t, result.clientContextTakeover) + assert.Equal(t, 15, result.clientWindowBits) + }) + + t.Run("returns enabled without context takeover when no_context_takeover negotiated", func(t *testing.T) { + ext := &wsflate.Extension{ + Parameters: wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + }, + } + offer := wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + }.Option() + _, _ = ext.Negotiate(offer) + + result := resolveNegotiatedCompression(base, ext, nil) + assert.True(t, result.enabled) + assert.Equal(t, 6, result.level) + assert.False(t, result.serverContextTakeover) + assert.False(t, result.clientContextTakeover) + assert.Equal(t, 15, result.clientWindowBits) + }) + + t.Run("uses client_max_window_bits when client offers a concrete value", func(t *testing.T) { + ext := &wsflate.Extension{ + Parameters: wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + ServerMaxWindowBits: 15, + }, + } + offer := wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + ClientMaxWindowBits: 12, + }.Option() + _, _ = ext.Negotiate(offer) + + result := resolveNegotiatedCompression(base, ext, nil) + assert.True(t, result.enabled) + assert.Equal(t, 12, result.clientWindowBits, "should use client's offered value when it is more restrictive") + }) + + t.Run("uses server config when client offers larger window bits", func(t *testing.T) { + configBase := compressionMode{enabled: true, level: 6, clientWindowBits: 10} + ext := &wsflate.Extension{ + Parameters: wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + ServerMaxWindowBits: 15, + }, + } + offer := wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + ClientMaxWindowBits: 14, + }.Option() + _, _ = ext.Negotiate(offer) + + result := resolveNegotiatedCompression(configBase, ext, nil) + assert.True(t, result.enabled) + assert.Equal(t, 10, result.clientWindowBits, "should use server config when it is more restrictive") + }) + + t.Run("uses config default when client does not offer client_max_window_bits", func(t *testing.T) { + configBase := compressionMode{enabled: true, level: 6, clientWindowBits: 12} + ext := &wsflate.Extension{ + Parameters: wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + ServerMaxWindowBits: 15, + }, + } + // Offer without client_max_window_bits. + offer := wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + }.Option() + _, _ = ext.Negotiate(offer) + + result := resolveNegotiatedCompression(configBase, ext, nil) + assert.True(t, result.enabled) + assert.Equal(t, 12, result.clientWindowBits, "should fall back to server config when client doesn't offer the param") + }) +} + +// Helper functions + +// abs returns the absolute value of an integer +func abs(x int) int { + if x < 0 { + return -x + } + return x +} + +// compressData compresses data using deflate (without context takeover) +func compressData(data []byte) ([]byte, error) { + var buf bytes.Buffer + writer := wsflate.NewWriter(&buf, func(w io.Writer) wsflate.Compressor { + fw, _ := flate.NewWriter(w, 6) + return fw + }) + if _, err := writer.Write(data); err != nil { + return nil, err + } + if err := writer.Flush(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 5baa086249..ad48b23246 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -752,6 +752,21 @@ type WebSocketConfiguration struct { Authentication WebSocketAuthenticationConfiguration `yaml:"authentication,omitempty"` // SetClientInfoFromInitialPayload configuration for the WebSocket Connection ClientInfoFromInitialPayload WebSocketClientInfoFromInitialPayloadConfiguration `yaml:"client_info_from_initial_payload"` + // Compression configuration for WebSocket per-message compression (permessage-deflate) + Compression WebSocketCompressionConfiguration `yaml:"compression,omitempty"` +} + +// WebSocketCompressionConfiguration configures permessage-deflate compression for WebSocket connections +type WebSocketCompressionConfiguration struct { + // Enabled enables permessage-deflate compression for WebSocket connections + Enabled bool `yaml:"enabled" envDefault:"false" env:"WEBSOCKETS_COMPRESSION_ENABLED"` + // Level is the compression level (1-9, where 1 is fastest and 9 is best compression) + Level int `yaml:"level" envDefault:"6" env:"WEBSOCKETS_COMPRESSION_LEVEL"` + // ClientMaxWindowBits limits the LZ77 sliding window size (8-15) that the client + // may use when compressing messages. Smaller values reduce server memory for + // decompression at the cost of compression ratio. Default is 15 (32 KB). + // Only included in the negotiation response when the client offers the parameter. + ClientMaxWindowBits int `yaml:"client_max_window_bits" envDefault:"15" env:"WEBSOCKETS_COMPRESSION_CLIENT_MAX_WINDOW_BITS"` } type WebSocketClientInfoFromInitialPayloadConfiguration struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index e9335caf1c..baad9edca0 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -599,6 +599,32 @@ } } } + }, + "compression": { + "type": "object", + "description": "Configuration for WebSocket per-message compression (permessage-deflate extension).", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false, + "description": "Enable permessage-deflate compression for WebSocket connections. When enabled, the server will negotiate compression with clients that support it. The default value is false." + }, + "level": { + "type": "integer", + "default": 6, + "minimum": 1, + "maximum": 9, + "description": "The compression level (1-9). Level 1 is fastest with least compression, level 9 provides best compression but is slowest. The default value is 6." + }, + "client_max_window_bits": { + "type": "integer", + "default": 15, + "minimum": 8, + "maximum": 15, + "description": "Limits the LZ77 sliding window size (8-15) that clients may use when compressing messages. Smaller values reduce server memory for decompression at the cost of compression ratio. The parameter is only included in the negotiation response when the client offers it. The default value is 15 (32 KB window)." + } + } } } }, diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index b089b39eac..42e6f3c281 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -411,6 +411,10 @@ websocket: export_token: enabled: true header_key: 'Authorization' + compression: + enabled: true + level: 6 + client_max_window_bits: 15 storage_providers: file_system: diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index f53b401f76..27f47df6b6 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -465,6 +465,11 @@ "NameTargetHeader": "graphql-client-name", "VersionTargetHeader": "graphql-client-version" } + }, + "Compression": { + "Enabled": false, + "Level": 6, + "ClientMaxWindowBits": 15 } }, "SubgraphErrorPropagation": { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 70091fcae3..f71c1fe294 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -857,6 +857,11 @@ "NameTargetHeader": "graphql-client-name", "VersionTargetHeader": "graphql-client-version" } + }, + "Compression": { + "Enabled": true, + "Level": 6, + "ClientMaxWindowBits": 15 } }, "SubgraphErrorPropagation": {