Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"version": 1,
"body": "subscription { currentTime { unixTime } }"
}
231 changes: 231 additions & 0 deletions router-tests/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"time"

"go.uber.org/zap"
"go.uber.org/zap/zapcore"

"github.com/buger/jsonparser"
"github.com/gorilla/websocket"
Expand Down Expand Up @@ -2069,6 +2070,236 @@ func TestWebSockets(t *testing.T) {
xEnv.WaitForSubscriptionCount(0, time.Second*5)
})
})

t.Run("unknown operation gets rejected when safelist is enabled", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithPersistedOperationsConfig(config.PersistedOperationsConfig{
Safelist: config.SafelistConfiguration{Enabled: true},
}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"graphql-client-name": "my-client"}`))
err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 1) { id } }"}`),
})
require.NoError(t, err)
var res testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res)
require.NoError(t, err)
require.Equal(t, "error", res.Type)
require.Equal(t, "1", res.ID)
require.JSONEq(t, `[{"message":"operation '9a41d21da2823195ad42c11d51e9ad3345824abdabf567b3615a235843a1fcc7' for client 'my-client' not found"}]`,
string(res.Payload))
Comment thread
endigma marked this conversation as resolved.

require.NoError(t, conn.Close())
})
})

t.Run("known hash passes when safelist is enabled", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithPersistedOperationsConfig(config.PersistedOperationsConfig{
Safelist: config.SafelistConfiguration{Enabled: true},
}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"graphql-client-name": "my-client"}`))
err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{"extensions":{"persistedQuery":{"version":1,"sha256Hash":"6e94d99132b544a0d7522696a7d35643d56a26c7b8c2e0df29e2b9935636628c"}}}`),
})
require.NoError(t, err)
var res testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res)
require.NoError(t, err)
require.Equal(t, "next", res.Type)
require.Equal(t, "1", res.ID)
require.Contains(t, string(res.Payload), `"data"`)
require.Contains(t, string(res.Payload), `"currentTime"`)

require.NoError(t, conn.Close())
})
})

t.Run("unknown operation gets logged when log_unknown is enabled", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithPersistedOperationsConfig(config.PersistedOperationsConfig{
LogUnknown: true,
Safelist: config.SafelistConfiguration{Enabled: false},
}),
},
LogObservation: testenv.LogObservationConfig{
Enabled: true,
LogLevel: zapcore.InfoLevel,
},
}, func(t *testing.T, xEnv *testenv.Environment) {
conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"graphql-client-name": "my-client"}`))
err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp } }"}`),
})
require.NoError(t, err)

var res testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res)
require.NoError(t, err)
require.Equal(t, "next", res.Type)
require.Equal(t, "1", res.ID)
require.Contains(t, string(res.Payload), `"data"`)

// Verify the warning was logged
logEntries := xEnv.Observer().FilterMessageSnippet("Unknown persisted operation found").All()
require.Len(t, logEntries, 1)
requestContext := logEntries[0].ContextMap()
require.Contains(t, requestContext["query"], "subscription { currentTime { unixTime timeStamp } }")
require.Equal(t, "8ad544bda5b2ad7a59481e31fb6fa62705fd072b20fdaadba4f3908d01f2c132", requestContext["sha256Hash"])

require.NoError(t, conn.Close())
})
})

t.Run("operation with matching hash passes", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) {
conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"graphql-client-name": "my-client"}`))

err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{
"query": "subscription { currentTime { unixTime } }",
"extensions": {
"persistedQuery": {
"version": 1,
"sha256Hash": "6e94d99132b544a0d7522696a7d35643d56a26c7b8c2e0df29e2b9935636628c"
}
}
}`),
})
require.NoError(t, err)

var res testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res)
require.NoError(t, err)
require.Equal(t, "next", res.Type)
require.Equal(t, "1", res.ID)
require.Contains(t, string(res.Payload), `"data"`)
require.Contains(t, string(res.Payload), `"currentTime"`)

require.NoError(t, conn.Close())
})
})

t.Run("operation with mismatched hash is rejected", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) {
conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"graphql-client-name": "my-client"}`))

err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{
"query": "subscription { currentTime { unixTime } }",
"extensions": {
"persistedQuery": {
"version": 1,
"sha256Hash": "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
}
}
}`),
})
require.NoError(t, err)

var res testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res)
require.NoError(t, err)
require.Equal(t, "error", res.Type)
require.Equal(t, "1", res.ID)
require.Contains(t, string(res.Payload), "persistedQuery sha256 hash does not match query body")

require.NoError(t, conn.Close())
})
})

t.Run("cache poisoning is tried but prevented", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
ApqConfig: config.AutomaticPersistedQueriesConfig{
Enabled: true,
},
}, func(t *testing.T, xEnv *testenv.Environment) {
conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"graphql-client-name": "my-client"}`))

query1 := "subscription { employeeUpdated(employeeID: 3) { id } }"
hashOfQuery2 := "6e94d99132b544a0d7522696a7d35643d56a26c7b8c2e0df29e2b9935636628c"

err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: fmt.Appendf(nil, `{
"query": "%s",
"extensions": {
"persistedQuery": {
"version": 1,
"sha256Hash": "%s"
}
}
}`, query1, hashOfQuery2),
})
require.NoError(t, err)

var res1 testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res1)
require.NoError(t, err)
require.Equal(t, "error", res1.Type)
require.Equal(t, "1", res1.ID)
require.Contains(t, string(res1.Payload), "persistedQuery sha256 hash does not match query body")

// even though we got an error we challenge that and still try
// out if our malicious query was cached with hash2.
err = testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "2",
Type: "subscribe",
Payload: fmt.Appendf(nil, `{
"extensions": {
"persistedQuery": {
"version": 1,
"sha256Hash": "%s"
}
}
}`, hashOfQuery2),
})
require.NoError(t, err)
Comment thread
endigma marked this conversation as resolved.

var res2 testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res2)
require.NoError(t, err)
require.Equal(t, "next", res2.Type)
require.Equal(t, "2", res2.ID)

// we expect the response to look like what we asked for via query2
require.Contains(t, string(res2.Payload), "currentTime")
require.Contains(t, string(res2.Payload), "unixTime")

require.NoError(t, conn.Close())
})
})

Comment thread
endigma marked this conversation as resolved.
}

func TestFlakyWebSockets(t *testing.T) {
Expand Down
48 changes: 46 additions & 2 deletions router/core/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/wundergraph/graphql-go-tools/v2/pkg/netpoll"

"github.com/wundergraph/cosmo/router/internal/expr"
"github.com/wundergraph/cosmo/router/internal/persistedoperation"
"github.com/wundergraph/cosmo/router/internal/wsproto"
"github.com/wundergraph/cosmo/router/pkg/authentication"
"github.com/wundergraph/cosmo/router/pkg/config"
Expand Down Expand Up @@ -836,11 +837,40 @@ func (h *WebSocketConnectionHandler) parseAndPlan(registration *SubscriptionRegi
isApq bool
)

if operationKit.parsedOperation.IsPersistedOperation {
skipParse, isApq, err = operationKit.FetchPersistedOperation(h.ctx, h.clientInfo)
if h.shouldComputeOperationSha256(operationKit) {
err = operationKit.ComputeOperationSha256()
if err != nil {
return nil, nil, err
}

// Ensure if operation has both hash and query, that the hash matches the query
if operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.HasHash() && operationKit.parsedOperation.Request.Query != "" {
if operationKit.parsedOperation.Sha256Hash != operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash {
return nil, nil, errors.New("persistedQuery sha256 hash does not match query body")
}
}

if h.operationBlocker.safelistEnabled || h.operationBlocker.logUnknownOperationsEnabled {
// Set the request hash to the parsed hash, to see if it matches a persisted operation
operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery = &GraphQLRequestExtensionsPersistedQuery{
Sha256Hash: operationKit.parsedOperation.Sha256Hash,
}
}
}

if operationKit.parsedOperation.IsPersistedOperation || h.operationBlocker.safelistEnabled || h.operationBlocker.logUnknownOperationsEnabled {
skipParse, isApq, err = operationKit.FetchPersistedOperation(h.ctx, h.clientInfo)
if err != nil {
var poNotFoundErr *persistedoperation.PersistentOperationNotFoundError
if h.operationBlocker.logUnknownOperationsEnabled && errors.As(err, &poNotFoundErr) {
h.logger.Warn("Unknown persisted operation found", zap.String("query", operationKit.parsedOperation.Request.Query), zap.String("sha256Hash", poNotFoundErr.Sha256Hash))
if h.operationBlocker.safelistEnabled {
return nil, nil, err
}
} else {
return nil, nil, err
}
}
}

// If the persistent operation is already in the cache, we skip the parse step
Expand Down Expand Up @@ -1212,6 +1242,20 @@ func (h *WebSocketConnectionHandler) ignoreHeader(k string) bool {
return h.forwardUpgradeHeaders.withStaticAllowList || h.forwardUpgradeHeaders.withRegexAllowList
}

func (h *WebSocketConnectionHandler) shouldComputeOperationSha256(operationKit *OperationKit) bool {
hasPersistedHash := operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.HasHash()

if hasPersistedHash && operationKit.parsedOperation.Request.Query != "" {
return true
}

if !hasPersistedHash && (h.operationBlocker.safelistEnabled || h.operationBlocker.logUnknownOperationsEnabled) {
return true
}

return false
}

func (h *WebSocketConnectionHandler) Complete(rw *websocketResponseWriter) {
h.subscriptions.Delete(rw.id)
err := rw.protocol.Complete(rw.id)
Expand Down
Loading