From ccadfabb345139f6320861752872e7454b0feea0 Mon Sep 17 00:00:00 2001 From: enisdenjo Date: Wed, 11 Sep 2024 09:38:15 -0700 Subject: [PATCH] feat(gw): header propagation plugin --- .changeset/nervous-kangaroos-smell.md | 6 + packages/fusion/runtime/src/utils.ts | 3 + .../serve-runtime/src/createGatewayRuntime.ts | 5 + packages/serve-runtime/src/index.ts | 2 +- .../src/plugins/useForwardHeaders.ts | 28 ---- .../src/plugins/usePropagateHeaders.ts | 81 ++++++++++ packages/serve-runtime/src/types.ts | 6 + .../tests/propagateHeaders.spec.ts | 142 ++++++++++++++++++ .../tests/useForwardHeaders.spec.ts | 126 ---------------- 9 files changed, 244 insertions(+), 155 deletions(-) create mode 100644 .changeset/nervous-kangaroos-smell.md delete mode 100644 packages/serve-runtime/src/plugins/useForwardHeaders.ts create mode 100644 packages/serve-runtime/src/plugins/usePropagateHeaders.ts create mode 100644 packages/serve-runtime/tests/propagateHeaders.spec.ts delete mode 100644 packages/serve-runtime/tests/useForwardHeaders.spec.ts diff --git a/.changeset/nervous-kangaroos-smell.md b/.changeset/nervous-kangaroos-smell.md new file mode 100644 index 000000000000..e12c311b1cbe --- /dev/null +++ b/.changeset/nervous-kangaroos-smell.md @@ -0,0 +1,6 @@ +--- +'@graphql-mesh/fusion-runtime': patch +'@graphql-mesh/serve-runtime': patch +--- + +Header Propagation diff --git a/packages/fusion/runtime/src/utils.ts b/packages/fusion/runtime/src/utils.ts index 5c60d17b1331..779473065d96 100644 --- a/packages/fusion/runtime/src/utils.ts +++ b/packages/fusion/runtime/src/utils.ts @@ -124,6 +124,8 @@ function getTransportExecutor({ ); } +export const subgraphNameByExecutionRequest = new WeakMap(); + /** * This function creates a executor factory that uses the transport packages, * and wraps them with the hooks @@ -145,6 +147,7 @@ export function getOnSubgraphExecute({ }) { const subgraphExecutorMap = new Map(); return function onSubgraphExecute(subgraphName: string, executionRequest: ExecutionRequest) { + subgraphNameByExecutionRequest.set(executionRequest, subgraphName); let executor: Executor = subgraphExecutorMap.get(subgraphName); // If the executor is not initialized yet, initialize it if (executor == null) { diff --git a/packages/serve-runtime/src/createGatewayRuntime.ts b/packages/serve-runtime/src/createGatewayRuntime.ts index cf6b17a5428f..87e2783d670d 100644 --- a/packages/serve-runtime/src/createGatewayRuntime.ts +++ b/packages/serve-runtime/src/createGatewayRuntime.ts @@ -70,6 +70,7 @@ import { useCompleteSubscriptionsOnSchemaChange } from './plugins/useCompleteSub import { useContentEncoding } from './plugins/useContentEncoding.js'; import { useCustomAgent } from './plugins/useCustomAgent.js'; import { useFetchDebug } from './plugins/useFetchDebug.js'; +import { usePropagateHeaders } from './plugins/usePropagateHeaders.js'; import { useRequestId } from './plugins/useRequestId.js'; import { useSubgraphExecuteDebug } from './plugins/useSubgraphExecuteDebug.js'; import { useUpstreamCancel } from './plugins/useUpstreamCancel.js'; @@ -819,6 +820,10 @@ export function createGatewayRuntime = Reco extraPlugins.push(useHmacUpstreamSignature(config.hmacSignature)); } + if (config.propagateHeaders) { + extraPlugins.push(usePropagateHeaders(config.propagateHeaders)); + } + const yoga = createYoga({ fetchAPI: config.fetchAPI, logging: logger, diff --git a/packages/serve-runtime/src/index.ts b/packages/serve-runtime/src/index.ts index 0faaaf506299..e52b47732f47 100644 --- a/packages/serve-runtime/src/index.ts +++ b/packages/serve-runtime/src/index.ts @@ -3,7 +3,7 @@ export * from './types.js'; export * from './plugins/useCustomFetch.js'; export * from './plugins/useStaticFiles.js'; export * from './getProxyExecutor.js'; -export * from './plugins/useForwardHeaders.js'; +export * from './plugins/usePropagateHeaders.js'; export * from '@whatwg-node/disposablestack'; export type { ResolveUserFn, ValidateUserFn } from '@envelop/generic-auth'; export * from '@graphql-mesh/hmac-upstream-signature'; diff --git a/packages/serve-runtime/src/plugins/useForwardHeaders.ts b/packages/serve-runtime/src/plugins/useForwardHeaders.ts deleted file mode 100644 index c4a7bcb7b052..000000000000 --- a/packages/serve-runtime/src/plugins/useForwardHeaders.ts +++ /dev/null @@ -1,28 +0,0 @@ -import type { GatewayPlugin } from '../types'; - -export interface ForwardHeadersPluginOptions { - headerNames: string[]; -} - -export function useForwardHeaders(headerNames: string[]): GatewayPlugin { - return { - onFetch({ options, setOptions, context }) { - if (context.request?.headers) { - const forwardedHeaders: Record = {}; - for (const headerName of headerNames) { - const headerValue = context.request.headers.get(headerName); - if (headerValue) { - forwardedHeaders[headerName] = headerValue; - } - } - setOptions({ - ...options, - headers: { - ...forwardedHeaders, - ...options.headers, - }, - }); - } - }, - }; -} diff --git a/packages/serve-runtime/src/plugins/usePropagateHeaders.ts b/packages/serve-runtime/src/plugins/usePropagateHeaders.ts new file mode 100644 index 000000000000..da10241be584 --- /dev/null +++ b/packages/serve-runtime/src/plugins/usePropagateHeaders.ts @@ -0,0 +1,81 @@ +import { mapMaybePromise } from '@envelop/core'; +import { subgraphNameByExecutionRequest } from '@graphql-mesh/fusion-runtime'; +import type { TransportEntry } from '@graphql-mesh/transport-common'; +import type { OnFetchHookDone } from '@graphql-mesh/types'; +import type { MaybePromise } from '@graphql-tools/utils'; +import type { GatewayPlugin } from '../types'; + +interface FromClientToSubgraphsPayload { + request: Request; + subgraphName: string; +} + +interface FromSubgraphsToClientPayload { + response: Response; + subgraphName: string; +} + +export interface PropagateHeadersOpts { + fromClientToSubgraphs?: ( + payload: FromClientToSubgraphsPayload, + ) => Record | void | Promise | void>; + fromSubgraphsToClient?: ( + payload: FromSubgraphsToClientPayload, + ) => Record | void | Promise | void>; +} + +export function usePropagateHeaders(opts: PropagateHeadersOpts): GatewayPlugin { + const resHeadersByRequest = new WeakMap>(); + return { + onFetch({ executionRequest, context, options, setOptions }) { + const subgraphName = subgraphNameByExecutionRequest.get(executionRequest); + if (subgraphName != null) { + let job: Promise | void; + if (opts.fromClientToSubgraphs) { + job = mapMaybePromise( + opts.fromClientToSubgraphs({ + request: context.request, + subgraphName, + }), + headers => + setOptions({ + ...options, + headers: { + ...headers, + ...options.headers, + }, + }), + ); + } + return mapMaybePromise(job, (): OnFetchHookDone => { + if (opts.fromSubgraphsToClient) { + return function onFetchDone({ response }) { + return mapMaybePromise( + opts.fromSubgraphsToClient({ + response, + subgraphName, + }), + headers => { + if (headers) { + resHeadersByRequest.set(context.request, headers); + } + }, + ); + }; + } + }); + } + }, + onResponse({ response, request }) { + const headers = resHeadersByRequest.get(request); + if (headers) { + for (const key in headers) { + const value = headers[key]; + if (value) { + response.headers.set(key, value); + } + } + } + }, + }; +} diff --git a/packages/serve-runtime/src/types.ts b/packages/serve-runtime/src/types.ts index debf3ee23d4b..76bc1f808d73 100644 --- a/packages/serve-runtime/src/types.ts +++ b/packages/serve-runtime/src/types.ts @@ -13,6 +13,7 @@ import type { useGenericAuth } from '@envelop/generic-auth'; import type { Transports, UnifiedGraphPlugin } from '@graphql-mesh/fusion-runtime'; import type { HMACUpstreamSignatureOptions } from '@graphql-mesh/hmac-upstream-signature'; import type useMeshResponseCache from '@graphql-mesh/plugin-response-cache'; +import type { PropagateHeadersOpts } from '@graphql-mesh/serve-cli'; import type { TransportEntry } from '@graphql-mesh/transport-common'; import type { KeyValueCache, @@ -473,4 +474,9 @@ interface GatewayConfigBase> { * Enable WebHooks handling */ webhooks?: boolean; + + /** + * Header Propagation + */ + propagateHeaders?: PropagateHeadersOpts; } diff --git a/packages/serve-runtime/tests/propagateHeaders.spec.ts b/packages/serve-runtime/tests/propagateHeaders.spec.ts new file mode 100644 index 000000000000..84fd66beeb58 --- /dev/null +++ b/packages/serve-runtime/tests/propagateHeaders.spec.ts @@ -0,0 +1,142 @@ +import { createSchema, createYoga, type Plugin } from 'graphql-yoga'; +import { useCustomFetch } from '@graphql-mesh/serve-runtime'; +import { createGatewayRuntime } from '../src/createGatewayRuntime'; + +describe('usePropagateHeaders', () => { + describe('From Client to the Subgraphs', () => { + const requestTrackerPlugin = { + onParams: jest.fn((() => {}) as Plugin['onParams']), + }; + const upstream = createYoga({ + schema: createSchema({ + typeDefs: /* GraphQL */ ` + type Query { + hello: String + } + `, + resolvers: { + Query: { + hello: () => 'world', + }, + }, + }), + plugins: [requestTrackerPlugin], + logging: !!process.env.DEBUG, + }); + beforeEach(() => { + requestTrackerPlugin.onParams.mockClear(); + }); + it('forwards specified headers', async () => { + await using serveRuntime = createGatewayRuntime({ + proxy: { + endpoint: 'http://localhost:4001/graphql', + }, + propagateHeaders: { + fromClientToSubgraphs({ request }) { + return { + 'x-my-header': request.headers.get('x-my-header'), + 'x-my-other': request.headers.get('x-my-other'), + }; + }, + fromSubgraphsToClient({ response }) { + return { + 'set-cookies': response.headers.get('set-cookies'), + }; + }, + }, + plugins: () => [useCustomFetch(upstream.fetch)], + logging: !!process.env.DEBUG, + }); + const response = await serveRuntime.fetch('http://localhost:4000/graphql', { + method: 'POST', + headers: { + 'x-my-header': 'my-value', + 'x-my-other': 'other-value', + 'x-extra-header': 'extra-value', + 'content-type': 'application/json', + }, + body: JSON.stringify({ + query: /* GraphQL */ ` + query { + hello + } + `, + extensions: { + randomThing: 'randomValue', + }, + }), + }); + + const resJson = await response.json(); + expect(resJson).toEqual({ + data: { + hello: 'world', + }, + }); + + // The first call is for the introspection + expect(requestTrackerPlugin.onParams).toHaveBeenCalledTimes(2); + const onParamsPayload = requestTrackerPlugin.onParams.mock.calls[1][0]; + // Do not pass extensions + expect(onParamsPayload.params.extensions).toBeUndefined(); + const headersObj = Object.fromEntries(onParamsPayload.request.headers.entries()); + expect(headersObj['x-my-header']).toBe('my-value'); + expect(headersObj['x-my-other']).toBe('other-value'); + expect(headersObj['x-extra-header']).toBeUndefined(); + }); + it("forwards specified headers but doesn't override the provided headers", async () => { + await using serveRuntime = createGatewayRuntime({ + logging: !!process.env.DEBUG, + proxy: { + endpoint: 'http://localhost:4001/graphql', + headers: { + 'x-my-header': 'my-value', + 'x-extra-header': 'extra-value', + }, + }, + propagateHeaders: { + fromClientToSubgraphs({ request }) { + return { + 'x-my-header': request.headers.get('x-my-header')!, + 'x-my-other': request.headers.get('x-my-other')!, + }; + }, + }, + plugins: () => [useCustomFetch(upstream.fetch)], + maskedErrors: false, + }); + const response = await serveRuntime.fetch('http://localhost:4000/graphql', { + method: 'POST', + headers: { + 'content-type': 'application/json', + 'x-my-header': 'my-new-value', + 'x-my-other': 'other-value', + }, + body: JSON.stringify({ + query: /* GraphQL */ ` + query { + hello + } + `, + }), + }); + + const resJson = await response.json(); + expect(resJson).toEqual({ + data: { + hello: 'world', + }, + }); + + // The first call is for the introspection + expect(requestTrackerPlugin.onParams).toHaveBeenCalledTimes(2); + const onParamsPayload = requestTrackerPlugin.onParams.mock.calls[1][0]; + // Do not pass extensions + expect(onParamsPayload.params.extensions).toBeUndefined(); + const headersObj = Object.fromEntries(onParamsPayload.request.headers.entries()); + expect(headersObj['x-my-header']).toBe('my-value'); + expect(headersObj['x-extra-header']).toBe('extra-value'); + expect(headersObj['x-my-other']).toBe('other-value'); + }); + }); +}); diff --git a/packages/serve-runtime/tests/useForwardHeaders.spec.ts b/packages/serve-runtime/tests/useForwardHeaders.spec.ts deleted file mode 100644 index 463d9c008f27..000000000000 --- a/packages/serve-runtime/tests/useForwardHeaders.spec.ts +++ /dev/null @@ -1,126 +0,0 @@ -import { createSchema, createYoga, type Plugin } from 'graphql-yoga'; -import { useCustomFetch } from '@graphql-mesh/serve-runtime'; -import { createGatewayRuntime } from '../src/createGatewayRuntime'; -import { useForwardHeaders } from '../src/plugins/useForwardHeaders'; - -describe('useForwardHeaders', () => { - const requestTrackerPlugin = { - onParams: jest.fn((() => {}) as Plugin['onParams']), - }; - const upstream = createYoga({ - schema: createSchema({ - typeDefs: /* GraphQL */ ` - type Query { - hello: String - } - `, - resolvers: { - Query: { - hello: () => 'world', - }, - }, - }), - plugins: [requestTrackerPlugin], - logging: !!process.env.DEBUG, - }); - beforeEach(() => { - requestTrackerPlugin.onParams.mockClear(); - }); - it('forwards specified headers', async () => { - await using serveRuntime = createGatewayRuntime({ - proxy: { - endpoint: 'http://localhost:4001/graphql', - }, - plugins: () => [ - useCustomFetch(upstream.fetch), - useForwardHeaders(['x-my-header', 'x-my-other']), - ], - logging: !!process.env.DEBUG, - }); - const response = await serveRuntime.fetch('http://localhost:4000/graphql', { - method: 'POST', - headers: { - 'x-my-header': 'my-value', - 'x-my-other': 'other-value', - 'x-extra-header': 'extra-value', - 'content-type': 'application/json', - }, - body: JSON.stringify({ - query: /* GraphQL */ ` - query { - hello - } - `, - extensions: { - randomThing: 'randomValue', - }, - }), - }); - - const resJson = await response.json(); - expect(resJson).toEqual({ - data: { - hello: 'world', - }, - }); - - // The first call is for the introspection - expect(requestTrackerPlugin.onParams).toHaveBeenCalledTimes(2); - const onParamsPayload = requestTrackerPlugin.onParams.mock.calls[1][0]; - // Do not pass extensions - expect(onParamsPayload.params.extensions).toBeUndefined(); - const headersObj = Object.fromEntries(onParamsPayload.request.headers.entries()); - expect(headersObj['x-my-header']).toBe('my-value'); - expect(headersObj['x-my-other']).toBe('other-value'); - expect(headersObj['x-extra-header']).toBeUndefined(); - }); - it("forwards specified headers but doesn't override the provided headers", async () => { - await using serveRuntime = createGatewayRuntime({ - logging: !!process.env.DEBUG, - proxy: { - endpoint: 'http://localhost:4001/graphql', - headers: { - 'x-my-header': 'my-value', - 'x-extra-header': 'extra-value', - }, - }, - plugins: () => [ - useCustomFetch(upstream.fetch), - useForwardHeaders(['x-my-header', 'x-my-other']), - ], - maskedErrors: false, - }); - const response = await serveRuntime.fetch('http://localhost:4000/graphql', { - method: 'POST', - headers: { - 'content-type': 'application/json', - 'x-my-header': 'my-new-value', - 'x-my-other': 'other-value', - }, - body: JSON.stringify({ - query: /* GraphQL */ ` - query { - hello - } - `, - }), - }); - - const resJson = await response.json(); - expect(resJson).toEqual({ - data: { - hello: 'world', - }, - }); - - // The first call is for the introspection - expect(requestTrackerPlugin.onParams).toHaveBeenCalledTimes(2); - const onParamsPayload = requestTrackerPlugin.onParams.mock.calls[1][0]; - // Do not pass extensions - expect(onParamsPayload.params.extensions).toBeUndefined(); - const headersObj = Object.fromEntries(onParamsPayload.request.headers.entries()); - expect(headersObj['x-my-header']).toBe('my-value'); - expect(headersObj['x-extra-header']).toBe('extra-value'); - expect(headersObj['x-my-other']).toBe('other-value'); - }); -});