Skip to content

Commit

Permalink
feat: Use waitUntil for Report call if available (#1838)
Browse files Browse the repository at this point in the history
This attaches a `waitUntil` function to the SDK context and uses it on the call to Report if it is available. I've added the lookup of the `waitUntil` function on Vercel in Arcjet core, since Vercel supports many frameworks.

Closes #884
  • Loading branch information
blaine-arcjet authored Oct 22, 2024
1 parent 07e68dc commit 2851021
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 3 deletions.
33 changes: 33 additions & 0 deletions arcjet/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,35 @@ function toString(value: unknown) {
return "<unsupported value>";
}

// This is the Symbol that Vercel defines in their infrastructure to access the
// Context (where available). The Context can contain the `waitUntil` function.
// https://github.com/vercel/vercel/blob/930d7fb892dc26f240f2b950d963931c45e1e661/packages/functions/src/get-context.ts#L6
const SYMBOL_FOR_REQ_CONTEXT = Symbol.for("@vercel/request-context");

type WaitUntil = (promise: Promise<unknown>) => void;

function lookupWaitUntil(): WaitUntil | undefined {
const fromSymbol: typeof globalThis & {
[SYMBOL_FOR_REQ_CONTEXT]?: unknown;
} = globalThis;
if (
typeof fromSymbol[SYMBOL_FOR_REQ_CONTEXT] === "object" &&
fromSymbol[SYMBOL_FOR_REQ_CONTEXT] !== null &&
"get" in fromSymbol[SYMBOL_FOR_REQ_CONTEXT] &&
typeof fromSymbol[SYMBOL_FOR_REQ_CONTEXT].get === "function"
) {
const vercelCtx = fromSymbol[SYMBOL_FOR_REQ_CONTEXT].get();
if (
typeof vercelCtx === "object" &&
vercelCtx !== null &&
"waitUntil" in vercelCtx &&
typeof vercelCtx.waitUntil === "function"
) {
return vercelCtx.waitUntil;
}
}
}

function toAnalyzeRequest(request: Partial<ArcjetRequestDetails>) {
const headers: Record<string, string> = {};
if (typeof request.headers !== "undefined") {
Expand Down Expand Up @@ -584,6 +613,7 @@ export type ExtraProps<Rules> = Rules extends []
export type ArcjetAdapterContext = {
[key: string]: unknown;
getBody(): Promise<string | undefined>;
waitUntil?: (promise: Promise<unknown>) => void;
};

/**
Expand Down Expand Up @@ -1239,10 +1269,13 @@ export default function arcjet<
? [...options.characteristics]
: [];

const waitUntil = lookupWaitUntil();

const baseContext = {
key,
log,
characteristics,
waitUntil,
...ctx,
};

Expand Down
10 changes: 7 additions & 3 deletions protocol/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ export function createClient(options: ClientOptions): Client {

log.debug("Report request to %s", baseUrl);

// We use the promise API directly to avoid returning a promise from this function so execution can't be paused with `await`
// TODO(#884): Leverage `waitUntil` if the function is attached to the context
client
// We use the promise API directly to avoid returning a promise from this
// function so execution can't be paused with `await`
const reportPromise = client
.report(reportRequest, {
headers: { Authorization: `Bearer ${context.key}` },
timeoutMs: 2_000, // 2 seconds
Expand All @@ -177,6 +177,10 @@ export function createClient(options: ClientOptions): Client {
.catch((err: unknown) => {
log.info("Encountered problem sending report: %s", errorMessage(err));
});

if (typeof context.waitUntil === "function") {
context.waitUntil(reportPromise);
}
},
});
}
1 change: 1 addition & 0 deletions protocol/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -797,4 +797,5 @@ export type ArcjetContext = {
log: ArcjetLogger;
characteristics: string[];
getBody: () => Promise<string | undefined>;
waitUntil?: (promise: Promise<unknown>) => void;
};
54 changes: 54 additions & 0 deletions protocol/test/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,60 @@ describe("createClient", () => {
expect(decision.isAllowed()).toBe(true);
});

test("calling `report` will use `waitUntil` if available", async () => {
const [promise, resolve] = deferred();

const key = "test-key";
const fingerprint =
"fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c";
const context = {
key,
fingerprint,
runtime: "test",
log,
characteristics: [],
getBody: () => Promise.resolve(undefined),
waitUntil: jest.fn((promise: Promise<unknown>) => {
promise.then(() => resolve());
}),
};
const details = {
ip: "172.100.1.1",
method: "GET",
protocol: "http",
host: "example.com",
path: "/",
headers: new Headers([["User-Agent", "curl/8.1.2"]]),
extra: {
"extra-test": "extra-test-value",
},
email: "[email protected]",
};

const router = {
report: () => {
return new ReportResponse({});
},
};

const client = createClient({
...defaultRemoteClientOptions,
transport: createRouterTransport(({ service }) => {
service(DecideService, router);
}),
});
const decision = new ArcjetAllowDecision({
ttl: 0,
reason: new ArcjetTestReason(),
results: [],
});
client.report(context, details, decision, []);

await promise;

expect(context.waitUntil).toHaveBeenCalledTimes(1);
});

test("calling `report` will make RPC call with ALLOW decision", async () => {
const key = "test-key";
const fingerprint =
Expand Down

0 comments on commit 2851021

Please sign in to comment.