diff --git a/router-tests/websocket_test.go b/router-tests/websocket_test.go index 334e2a2646..5eb16bacb5 100644 --- a/router-tests/websocket_test.go +++ b/router-tests/websocket_test.go @@ -847,6 +847,80 @@ func TestWebSockets(t *testing.T) { xEnv.WaitForSubscriptionCount(0, time.Second*5) }) }) + t.Run("can use auth context in header expressions for subgraph requests", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), []authentication.JWKSConfig{toJWKSConfig(authServer.JWKSURL(), time.Second*5)}) + authOptions := authentication.HttpHeaderAuthenticatorOptions{ + Name: JwksName, + TokenDecoder: tokenDecoder, + } + authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions) + require.NoError(t, err) + authenticators := []authentication.Authenticator{authenticator} + + headerRules := config.HeaderRules{ + All: &config.GlobalHeaderRule{ + Request: []*config.RequestHeaderRule{ + { + Operation: config.HeaderRuleOperationSet, + Name: "x-authenticated", + Expression: "request.auth.isAuthenticated ? '1' : '0'", + }, + { + Operation: config.HeaderRuleOperationSet, + Name: "x-favorite-animal", + Expression: "request.auth.claims.favorite_animal", + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithHeaderRules(headerRules), + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + Subgraphs: testenv.SubgraphsConfig{ + GlobalMiddleware: func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("x-authenticated") + require.Equal(t, "1", authHeader) + + favoriteAnimalHeader := r.Header.Get("x-favorite-animal") + require.Equal(t, "bear", favoriteAnimalHeader) + next.ServeHTTP(w, r) + }) + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.Token(map[string]any{ + "favorite_animal": "bear", + }) + require.NoError(t, err) + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + conn := xEnv.InitGraphQLWebSocketConnection(header, nil, nil) + err = testenv.WSWriteJSON(t, conn, &testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp } }"}`), + }) + require.NoError(t, err) + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "next", res.Type) + require.Equal(t, "1", res.ID) + require.NoError(t, conn.Close()) + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + + }) t.Run("subscription with header propagation sse subgraph post", func(t *testing.T) { t.Parallel() diff --git a/router/core/websocket.go b/router/core/websocket.go index a2ff249fd7..fd618f50c6 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -985,6 +985,10 @@ func (h *WebSocketConnectionHandler) executeSubscription(registration *Subscript w: nil, r: registration.clientRequest, }) + + if origCtx := getRequestContext(h.request.Context()); origCtx != nil { + reqContext.expressionContext = *origCtx.expressionContext.Clone() + } resolveCtx = resolveCtx.WithContext(withRequestContext(h.ctx, reqContext)) if h.graphqlHandler.authorizer != nil { resolveCtx = WithAuthorizationExtension(resolveCtx)