From 25dfcdfb281d5f577439f0698c86833f11b95eb8 Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Wed, 29 May 2024 13:04:34 -0700 Subject: [PATCH 1/2] feat!: Allow ArcjetContext to be extended via new argument to core `protect()` --- arcjet-bun/index.ts | 2 +- arcjet-next/index.ts | 2 +- arcjet-node/index.ts | 2 +- arcjet-sveltekit/index.ts | 2 +- arcjet/README.md | 6 +++- arcjet/index.ts | 23 ++++++++++--- arcjet/test/index.edge.test.ts | 31 +++++++++-------- arcjet/test/index.node.test.ts | 63 ++++++++++++++++++---------------- protocol/index.ts | 1 + 9 files changed, 79 insertions(+), 53 deletions(-) diff --git a/arcjet-bun/index.ts b/arcjet-bun/index.ts index d487e3806..cdff494a3 100644 --- a/arcjet-bun/index.ts +++ b/arcjet-bun/index.ts @@ -197,7 +197,7 @@ function withClient( props ?? {}, ) as ArcjetRequest>; - return aj.protect(req); + return aj.protect({}, req); }, handler( fetch: ( diff --git a/arcjet-next/index.ts b/arcjet-next/index.ts index deffbaec5..ee8bd5dcd 100644 --- a/arcjet-next/index.ts +++ b/arcjet-next/index.ts @@ -300,7 +300,7 @@ function withClient( ExtraProps >; - return aj.protect(req); + return aj.protect({}, req); }, }); } diff --git a/arcjet-node/index.ts b/arcjet-node/index.ts index 79933996e..5c0f90707 100644 --- a/arcjet-node/index.ts +++ b/arcjet-node/index.ts @@ -212,7 +212,7 @@ function withClient( ExtraProps >; - return aj.protect(req); + return aj.protect({}, req); }, }); } diff --git a/arcjet-sveltekit/index.ts b/arcjet-sveltekit/index.ts index 23e4fb786..a05fef206 100644 --- a/arcjet-sveltekit/index.ts +++ b/arcjet-sveltekit/index.ts @@ -215,7 +215,7 @@ function withClient( ExtraProps >; - return aj.protect(req); + return aj.protect({}, req); }, }); } diff --git a/arcjet/README.md b/arcjet/README.md index 1b25ef2a8..cf10b5dad 100644 --- a/arcjet/README.md +++ b/arcjet/README.md @@ -62,6 +62,10 @@ const server = http.createServer(async function ( req: http.IncomingMessage, res: http.ServerResponse, ) { + // Any sort of additional context that might want to be included for the + // execution of `protect()`. This is mostly only useful for writing adapters. + const ctx = {}; + // Construct an object with Arcjet request details const path = new URL(req.url || "", `http://${req.headers.host}`); const details = { @@ -71,7 +75,7 @@ const server = http.createServer(async function ( path: path.pathname, }; - const decision = await aj.protect(details); + const decision = await aj.protect(ctx, details); console.log(decision); if (decision.isDenied()) { diff --git a/arcjet/index.ts b/arcjet/index.ts index 3a2b04cec..9431ee810 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -602,6 +602,14 @@ export type ExtraProps = Rules extends [] ? UnionToIntersection> : never; +/** + * Additional context that can be provided by adapters. + * + * Among other things, this could include the Arcjet API Key if it were only + * available in a runtime handler or IP details provided by a platform. + */ +export type ArcjetAdapterContext = Record; + /** * @property {string} ip - The IP address of the client. * @property {string} method - The HTTP method of the request. @@ -1064,10 +1072,14 @@ export interface Arcjet { * Make a decision about how to handle a request. This will analyze the * request locally where possible and call the Arcjet decision API. * + * @param {ArcjetAdapterContext} ctx - Additional context for this function call. * @param {ArcjetRequest} request - Details about the {@link ArcjetRequest} that Arcjet needs to make a decision. * @returns An {@link ArcjetDecision} indicating Arcjet's decision about the request. */ - protect(request: ArcjetRequest): Promise; + protect( + ctx: ArcjetAdapterContext, + request: ArcjetRequest, + ): Promise; /** * Augments the client with another rule. Useful for varying rules based on @@ -1112,6 +1124,7 @@ export default function arcjet< async function protect( rules: ArcjetRule[], + ctx: ArcjetAdapterContext, request: ArcjetRequest, ) { // This goes against the type definition above, but users might call @@ -1149,7 +1162,7 @@ export default function arcjet< logger.debug("fingerprint (%s): %s", runtime(), fingerprint); logger.timeEnd("fingerprint"); - const context: ArcjetContext = { key, fingerprint }; + const context: ArcjetContext = { key, ...ctx, fingerprint }; if (rules.length < 1) { // TODO(#607): Error if no rules configured after deprecation period @@ -1372,9 +1385,10 @@ export default function arcjet< return withRule(rule); }, async protect( + ctx: ArcjetContext, request: ArcjetRequest>, ): Promise { - return protect(rules, request); + return protect(rules, ctx, request); }, }); } @@ -1387,9 +1401,10 @@ export default function arcjet< return withRule(rule); }, async protect( + ctx: ArcjetContext, request: ArcjetRequest>, ): Promise { - return protect(rootRules, request); + return protect(rootRules, ctx, request); }, }); } diff --git a/arcjet/test/index.edge.test.ts b/arcjet/test/index.edge.test.ts index 25fb19f23..9486f6ae4 100644 --- a/arcjet/test/index.edge.test.ts +++ b/arcjet/test/index.edge.test.ts @@ -69,20 +69,23 @@ describe("Arcjet: Env = Edge runtime", () => { const aj2 = aj.withRule(foobarbaz()); - const decision = await aj2.protect({ - abc: 123, - requested: 1, - email: "", - ip: "", - method: "", - protocol: "", - host: "", - path: "", - headers: new Headers(), - extra: {}, - userId: "user123", - foobar: 123, - }); + const decision = await aj2.protect( + {}, + { + abc: 123, + requested: 1, + email: "", + ip: "", + method: "", + protocol: "", + host: "", + path: "", + headers: new Headers(), + extra: {}, + userId: "user123", + foobar: 123, + }, + ); expect(decision.isErrored()).toBe(false); }); diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index 3c4d539b7..5508c6b30 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -3440,7 +3440,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(decision.conclusion).toEqual("DENY"); expect(allowed.validate).toHaveBeenCalledTimes(1); @@ -3469,7 +3469,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(decision.conclusion).toEqual("ALLOW"); }); @@ -3522,7 +3522,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(decision.conclusion).toEqual("ERROR"); }); @@ -3556,7 +3556,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(decision.conclusion).toEqual("DENY"); expect(denied.validate).toHaveBeenCalledTimes(1); @@ -3599,7 +3599,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), @@ -3652,7 +3652,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), @@ -3711,7 +3711,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), @@ -3762,7 +3762,7 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(request); + const _ = await aj.protect({}, request); expect(client.report).toHaveBeenCalledTimes(0); expect(client.decide).toHaveBeenCalledTimes(1); // TODO: Validate correct `ruleResults` are sent with `decide` when available @@ -3803,7 +3803,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), @@ -3857,7 +3857,7 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(request); + const _ = await aj.protect({}, request); expect(client.report).toHaveBeenCalledTimes(1); expect(client.report).toHaveBeenCalledWith( expect.objectContaining(context), @@ -3908,7 +3908,7 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(request); + const _ = await aj.protect({}, request); expect(client.decide).toHaveBeenCalledTimes(0); }); @@ -3946,7 +3946,7 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(request); + const _ = await aj.protect({}, request); expect(client.report).toHaveBeenCalledTimes(0); expect(client.decide).toHaveBeenCalledTimes(1); @@ -3995,7 +3995,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(decision.isErrored()).toBe(false); @@ -4004,7 +4004,7 @@ describe("SDK", () => { expect(decision.conclusion).toEqual("DENY"); - const decision2 = await aj.protect(request); + const decision2 = await aj.protect({}, request); expect(decision2.isErrored()).toBe(false); expect(client.decide).toHaveBeenCalledTimes(1); @@ -4062,7 +4062,7 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(request); + const _ = await aj.protect({}, request); expect(client.report).toHaveBeenCalledTimes(0); expect(client.decide).toHaveBeenCalledTimes(1); @@ -4111,7 +4111,7 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(request); + const _ = await aj.protect({}, request); expect(errorLogSpy).toHaveBeenCalledTimes(1); expect(errorLogSpy).toHaveBeenCalledWith( @@ -4163,7 +4163,7 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(request); + const _ = await aj.protect({}, request); expect(errorLogSpy).toHaveBeenCalledTimes(1); expect(errorLogSpy).toHaveBeenCalledWith( @@ -4201,14 +4201,14 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(decision.isDenied()).toBe(false); expect(client.decide).toBeCalledTimes(1); expect(client.report).toBeCalledTimes(1); - const decision2 = await aj.protect(request); + const decision2 = await aj.protect({}, request); expect(decision2.isDenied()).toBe(false); @@ -4252,7 +4252,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(decision.isErrored()).toBe(false); @@ -4304,7 +4304,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(request); + const decision = await aj.protect({}, request); expect(decision.isErrored()).toBe(true); @@ -4393,15 +4393,18 @@ describe("Arcjet: Env = Serverless Node runtime on Vercel", () => { rules: [rateLimit(config)], client, }); - const decision = await aj.protect({ - ip, - method, - protocol, - host, - path, - headers, - "extra-test": "extra-test-value", - }); + const decision = await aj.protect( + {}, + { + ip, + method, + protocol, + host, + path, + headers, + "extra-test": "extra-test-value", + }, + ); // If this fails, check the console an error related to the args passed to // the mocked decide service method above. diff --git a/protocol/index.ts b/protocol/index.ts index 2b4276b1c..000914680 100644 --- a/protocol/index.ts +++ b/protocol/index.ts @@ -745,6 +745,7 @@ export interface ArcjetShieldRule extends ArcjetRule { } export type ArcjetContext = { + [key: string]: unknown; key: string; fingerprint: string; }; From e69d8578aa2f2741b68d0a213995f37ec8499cbd Mon Sep 17 00:00:00 2001 From: Blaine Bublitz Date: Wed, 29 May 2024 14:01:35 -0700 Subject: [PATCH 2/2] add test for overriding key --- arcjet/test/index.node.test.ts | 58 ++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index 5508c6b30..3878acaf5 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -4274,6 +4274,64 @@ describe("SDK", () => { ); }); + test("overrides `key` with custom context", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const key = "test-key"; + const context = { + key, + fingerprint: + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + }; + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: new Headers([["User-Agent", "Mozilla/5.0"]]), + "extra-test": "extra-test-value", + }; + + const rule = testRuleRemote(); + + const aj = arcjet({ + key, + rules: [[rule]], + client, + }); + + const decision = await aj.protect({ key: "overridden-key" }, request); + + expect(decision.isErrored()).toBe(false); + + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.objectContaining({ ...context, key: "overridden-key" }), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), + [rule], + ); + }); + test("reports and returns an ERROR decision if a `client.decide()` fails", async () => { const client = { decide: jest.fn(async () => {