Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions router/core/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,17 @@ type WebSocketConnectionHandler struct {
forwardUpgradeRequestHeaders bool
forwardUpgradeRequestQueryParams bool
forwardInitialPayload bool

hashHeadersWhitelist []string
}

func NewWebsocketConnectionHandler(ctx context.Context, opts WebSocketConnectionHandlerOptions) *WebSocketConnectionHandler {
var hashHeadersWhitelist []string

if opts.Config != nil {
hashHeadersWhitelist = opts.Config.HashHeadersWhitelist
}

return &WebSocketConnectionHandler{
ctx: ctx,
operationProcessor: opts.OperationProcessor,
Expand All @@ -580,6 +588,7 @@ func NewWebsocketConnectionHandler(ctx context.Context, opts WebSocketConnection
forwardUpgradeRequestHeaders: opts.Config != nil && opts.Config.ForwardUpgradeHeaders,
forwardUpgradeRequestQueryParams: opts.Config != nil && opts.Config.ForwardUpgradeQueryParams,
forwardInitialPayload: opts.Config != nil && opts.Config.ForwardInitialPayload,
hashHeadersWhitelist: hashHeadersWhitelist,
}
}

Expand Down Expand Up @@ -795,12 +804,15 @@ func (h *WebsocketHandler) HandleMessage(handler *WebSocketConnectionHandler, ms

func (h *WebSocketConnectionHandler) Initialize() (err error) {
h.logger.Debug("Websocket connection", zap.String("protocol", h.protocol.Subprotocol()))

// Initialize the protocol and get the initial payload
h.initialPayload, err = h.protocol.Initialize()
if err != nil {
h.logger.Error("Initializing websocket connection", zap.Error(err))
_ = h.requestError(fmt.Errorf("error initializing session"))
return err
}

if h.forwardUpgradeRequestQueryParams {
query := h.r.URL.Query()
if len(query) != 0 {
Expand All @@ -810,24 +822,37 @@ func (h *WebSocketConnectionHandler) Initialize() (err error) {
}
}
}

if h.forwardUpgradeRequestHeaders {
header := make(http.Header, len(h.r.Header))
header := make(http.Header)
for k, v := range h.r.Header {
if h.ignoreHeader(k) {
continue
if !h.ignoreHeader(k) && h.shouldHashHeader(k) {
header[k] = v
}
header[k] = v
}
if len(header) > 0 {
h.upgradeRequestHeaders, err = json.Marshal(header)
}
if err != nil {
return err
if err != nil {
return err
}
}
}

return nil
}

func (h *WebSocketConnectionHandler) shouldHashHeader(header string) bool {
if len(h.hashHeadersWhitelist) == 0 {
return true // If no whitelist is provided, allow all headers by default
}
for _, allowed := range h.hashHeadersWhitelist {
if header == allowed {
return true
}
}
return false
}

func (h *WebSocketConnectionHandler) ignoreHeader(k string) bool {
switch k {
case "Sec-Websocket-Protocol",
Expand Down
186 changes: 186 additions & 0 deletions router/core/websocket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package core

import (
"bufio"
"bytes"
"context"
"encoding/json"
"net"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/wundergraph/cosmo/router/internal/wsproto"
"github.com/wundergraph/cosmo/router/pkg/config"
"go.uber.org/zap"
)

type MockConn struct {
Comment thread
glower marked this conversation as resolved.
net.Conn
readBuf bytes.Buffer
writeBuf bytes.Buffer
}

func (c *MockConn) Read(b []byte) (n int, err error) {
return c.readBuf.Read(b)
}

func (c *MockConn) Write(b []byte) (n int, err error) {
return c.writeBuf.Write(b)
}

func newMockConn() *MockConn {
return &MockConn{}
}

func newTestWebSocketConnectionHandler(opts WebSocketConnectionHandlerOptions) *WebSocketConnectionHandler {
return NewWebsocketConnectionHandler(context.Background(), opts)
}

func newTestWebSocketConnectionHandlerOptions() WebSocketConnectionHandlerOptions {
return WebSocketConnectionHandlerOptions{
Config: &config.WebSocketConfiguration{
ForwardUpgradeHeaders: true,
},
Logger: zap.NewNop(),
Connection: &wsConnectionWrapper{
conn: newMockConn(),
rw: bufio.NewReadWriter(bufio.NewReader(newMockConn()), bufio.NewWriter(newMockConn())),
},
Request: &http.Request{
Header: http.Header{
"Test-Header": {"value1"},
"Sec-Websocket-Key": {"key"},
"Sec-Websocket-Version": {"13"},
"Sec-Websocket-Protocol": {"graphql-ws"},
},
},
}
}

type mockProtocol struct{}

func (p *mockProtocol) Subprotocol() string {
return "mock-protocol"
}

func (p *mockProtocol) Initialize() (json.RawMessage, error) {
return json.RawMessage(`{}`), nil
}

func (p *mockProtocol) ReadMessage() (*wsproto.Message, error) {
return nil, nil
}

func (p *mockProtocol) WriteGraphQLData(id string, data json.RawMessage, extensions json.RawMessage) error {
return nil
}

func (p *mockProtocol) WriteGraphQLErrors(id string, errors json.RawMessage, extensions json.RawMessage) error {
return nil
}

func (p *mockProtocol) Pong(msg *wsproto.Message) error {
return nil
}

func (p *mockProtocol) Done(id string) error {
return nil
}

func TestShouldHashHeader(t *testing.T) {
opts := newTestWebSocketConnectionHandlerOptions()
handler := newTestWebSocketConnectionHandler(opts)

tests := []struct {
name string
whitelist []string
header string
expectedShouldHash bool
}{
{
name: "No whitelist specified, allow all headers",
whitelist: []string{},
header: "Test-Header",
expectedShouldHash: true,
},
{
name: "No whitelist specified, allow all headers (unknown header)",
whitelist: []string{},
header: "Unknown-Header",
expectedShouldHash: true,
},
{
name: "Whitelist specified, header in whitelist",
whitelist: []string{"Test-Header", "Allowed-Header"},
header: "Test-Header",
expectedShouldHash: true,
},
{
name: "Whitelist specified, header not in whitelist",
whitelist: []string{"Test-Header", "Allowed-Header"},
header: "Unknown-Header",
expectedShouldHash: false,
},
{
name: "Whitelist specified, another header in whitelist",
whitelist: []string{"Test-Header", "Allowed-Header"},
header: "Allowed-Header",
expectedShouldHash: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler.hashHeadersWhitelist = tt.whitelist
result := handler.shouldHashHeader(tt.header)
assert.Equal(t, tt.expectedShouldHash, result)
})
}
}

func TestWebSocketConnectionHandler_Initialize(t *testing.T) {
opts := newTestWebSocketConnectionHandlerOptions()
handler := newTestWebSocketConnectionHandler(opts)

// Mock protocol to avoid actual network operations
handler.protocol = &mockProtocol{}

tests := []struct {
name string
whitelist []string
expectedHeaders []string
unexpectedHeaders []string
}{
{
name: "No whitelist specified, include all headers except ignored",
whitelist: []string{},
expectedHeaders: []string{"Test-Header"},
unexpectedHeaders: []string{"Sec-Websocket-Key", "Sec-Websocket-Version"},
},
{
name: "Whitelist specified, include only whitelisted headers",
whitelist: []string{"Test-Header"},
expectedHeaders: []string{"Test-Header"},
unexpectedHeaders: []string{"Sec-Websocket-Key", "Sec-Websocket-Version", "Unknown-Header"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler.hashHeadersWhitelist = tt.whitelist

err := handler.Initialize()
assert.NoError(t, err)
headers := string(handler.upgradeRequestHeaders)

for _, expectedHeader := range tt.expectedHeaders {
assert.Contains(t, headers, expectedHeader)
}

for _, unexpectedHeader := range tt.unexpectedHeaders {
assert.NotContains(t, headers, unexpectedHeader)
}
})
}
}
2 changes: 2 additions & 0 deletions router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ type WebSocketConfiguration struct {
ForwardUpgradeQueryParams bool `yaml:"forward_upgrade_query_params" default:"true" envconfig:"WEBSOCKETS_FORWARD_UPGRADE_QUERY_PARAMS"`
// ForwardInitialPayload true if the Router should forward the initial payload of a Subscription Request to the Subgraph
ForwardInitialPayload bool `yaml:"forward_initial_payload" default:"true" envconfig:"WEBSOCKETS_FORWARD_INITIAL_PAYLOAD"`
// HashHeadersWhitelist is a list of headers that should be hashed during the WebSocket upgrade process
HashHeadersWhitelist []string `yaml:"hash_headers_whitelist" envconfig:"WEBSOCKETS_HASH_HEADERS_WHITELIST"`
}

type AnonymizeIpConfiguration struct {
Expand Down
7 changes: 7 additions & 0 deletions router/pkg/config/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@
"type": "boolean",
"default": true,
"description": "Forward the initial payload in the extensions payload when starting a subscription on a Subgraph. The default value is true."
},
"hash_headers_whitelist": {
"type": "array",
"description": "A list of headers that should be hashed during the WebSocket upgrade process.",
"items": {
"type": "string"
}
}
}
},
Expand Down
5 changes: 4 additions & 1 deletion router/pkg/config/fixtures/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,7 @@ websocket:
handler_path: /absinthe/socket
forward_initial_payload: true
forward_upgrade_headers: true
forward_upgrade_query_params: true
forward_upgrade_query_params: true
hash_headers_whitelist:
- "Test-Header"
- "Another-Header"
3 changes: 2 additions & 1 deletion router/pkg/config/testdata/config_defaults.json
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@
},
"ForwardUpgradeHeaders": true,
"ForwardUpgradeQueryParams": true,
"ForwardInitialPayload": true
"ForwardInitialPayload": true,
"HashHeadersWhitelist": []
},
"SubgraphErrorPropagation": {
"Enabled": false,
Expand Down
6 changes: 5 additions & 1 deletion router/pkg/config/testdata/config_full.json
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,11 @@
},
"ForwardUpgradeHeaders": true,
"ForwardUpgradeQueryParams": true,
"ForwardInitialPayload": true
"ForwardInitialPayload": true,
"HashHeadersWhitelist": [
"Test-Header",
"Another-Header"
]
},
"SubgraphErrorPropagation": {
"Enabled": false,
Expand Down