From 28510216334e2b66fc19a7ee51e741fb59a20607 Mon Sep 17 00:00:00 2001 From: blaine-arcjet <146491715+blaine-arcjet@users.noreply.github.com> Date: Tue, 22 Oct 2024 16:31:03 -0700 Subject: [PATCH] feat: Use `waitUntil` for Report call if available (#1838) This attaches a `waitUntil` function to the SDK context and uses it on the call to Report if it is available. I've added the lookup of the `waitUntil` function on Vercel in Arcjet core, since Vercel supports many frameworks. Closes #884 --- arcjet/index.ts | 33 ++++++++++++++++++++++ protocol/client.ts | 10 +++++-- protocol/index.ts | 1 + protocol/test/client.test.ts | 54 ++++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 3 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index 407b52c5a..b9f6f39b6 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -216,6 +216,35 @@ function toString(value: unknown) { return ""; } +// This is the Symbol that Vercel defines in their infrastructure to access the +// Context (where available). The Context can contain the `waitUntil` function. +// https://github.com/vercel/vercel/blob/930d7fb892dc26f240f2b950d963931c45e1e661/packages/functions/src/get-context.ts#L6 +const SYMBOL_FOR_REQ_CONTEXT = Symbol.for("@vercel/request-context"); + +type WaitUntil = (promise: Promise) => void; + +function lookupWaitUntil(): WaitUntil | undefined { + const fromSymbol: typeof globalThis & { + [SYMBOL_FOR_REQ_CONTEXT]?: unknown; + } = globalThis; + if ( + typeof fromSymbol[SYMBOL_FOR_REQ_CONTEXT] === "object" && + fromSymbol[SYMBOL_FOR_REQ_CONTEXT] !== null && + "get" in fromSymbol[SYMBOL_FOR_REQ_CONTEXT] && + typeof fromSymbol[SYMBOL_FOR_REQ_CONTEXT].get === "function" + ) { + const vercelCtx = fromSymbol[SYMBOL_FOR_REQ_CONTEXT].get(); + if ( + typeof vercelCtx === "object" && + vercelCtx !== null && + "waitUntil" in vercelCtx && + typeof vercelCtx.waitUntil === "function" + ) { + return vercelCtx.waitUntil; + } + } +} + function toAnalyzeRequest(request: Partial) { const headers: Record = {}; if (typeof request.headers !== "undefined") { @@ -584,6 +613,7 @@ export type ExtraProps = Rules extends [] export type ArcjetAdapterContext = { [key: string]: unknown; getBody(): Promise; + waitUntil?: (promise: Promise) => void; }; /** @@ -1239,10 +1269,13 @@ export default function arcjet< ? [...options.characteristics] : []; + const waitUntil = lookupWaitUntil(); + const baseContext = { key, log, characteristics, + waitUntil, ...ctx, }; diff --git a/protocol/client.ts b/protocol/client.ts index 6a547192b..e1876f823 100644 --- a/protocol/client.ts +++ b/protocol/client.ts @@ -155,9 +155,9 @@ export function createClient(options: ClientOptions): Client { log.debug("Report request to %s", baseUrl); - // We use the promise API directly to avoid returning a promise from this function so execution can't be paused with `await` - // TODO(#884): Leverage `waitUntil` if the function is attached to the context - client + // We use the promise API directly to avoid returning a promise from this + // function so execution can't be paused with `await` + const reportPromise = client .report(reportRequest, { headers: { Authorization: `Bearer ${context.key}` }, timeoutMs: 2_000, // 2 seconds @@ -177,6 +177,10 @@ export function createClient(options: ClientOptions): Client { .catch((err: unknown) => { log.info("Encountered problem sending report: %s", errorMessage(err)); }); + + if (typeof context.waitUntil === "function") { + context.waitUntil(reportPromise); + } }, }); } diff --git a/protocol/index.ts b/protocol/index.ts index 1ca292757..1e5d366ad 100644 --- a/protocol/index.ts +++ b/protocol/index.ts @@ -797,4 +797,5 @@ export type ArcjetContext = { log: ArcjetLogger; characteristics: string[]; getBody: () => Promise; + waitUntil?: (promise: Promise) => void; }; diff --git a/protocol/test/client.test.ts b/protocol/test/client.test.ts index 72dbdf29a..245c37cdf 100644 --- a/protocol/test/client.test.ts +++ b/protocol/test/client.test.ts @@ -687,6 +687,60 @@ describe("createClient", () => { expect(decision.isAllowed()).toBe(true); }); + test("calling `report` will use `waitUntil` if available", async () => { + const [promise, resolve] = deferred(); + + const key = "test-key"; + const fingerprint = + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c"; + const context = { + key, + fingerprint, + runtime: "test", + log, + characteristics: [], + getBody: () => Promise.resolve(undefined), + waitUntil: jest.fn((promise: Promise) => { + promise.then(() => resolve()); + }), + }; + const details = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "curl/8.1.2"]]), + extra: { + "extra-test": "extra-test-value", + }, + email: "test@example.com", + }; + + const router = { + report: () => { + return new ReportResponse({}); + }, + }; + + const client = createClient({ + ...defaultRemoteClientOptions, + transport: createRouterTransport(({ service }) => { + service(DecideService, router); + }), + }); + const decision = new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + client.report(context, details, decision, []); + + await promise; + + expect(context.waitUntil).toHaveBeenCalledTimes(1); + }); + test("calling `report` will make RPC call with ALLOW decision", async () => { const key = "test-key"; const fingerprint =