From 4950364be1f895fc8bb782950b20623fc8324ceb Mon Sep 17 00:00:00 2001 From: blaine-arcjet <146491715+blaine-arcjet@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:25:38 -0700 Subject: [PATCH] feat!: Separate ArcjetRequest and ArcjetRequestDetails types to accept record of headers (#228) This separates our `ArcjetRequest` type, which users provide to the SDK's `protect()` function, and the `ArcjetRequestDetails`, which is an implementation detail of the protocol. This allows us to support various definitions for headers (and other fields in the future) while still normalizing for consumption by local and remote rules internally. Closes #33 --- arcjet/index.ts | 80 +++-- arcjet/test/index.node.test.ts | 532 ++++++++++++++++++++++----------- protocol/index.ts | 5 +- 3 files changed, 420 insertions(+), 197 deletions(-) diff --git a/arcjet/index.ts b/arcjet/index.ts index 6865ac7e7..36e0b8f55 100644 --- a/arcjet/index.ts +++ b/arcjet/index.ts @@ -168,13 +168,14 @@ type LiteralCheck< | boolean | symbol | bigint, -> = IsNever extends false // Must be wider than `never` - ? [T] extends [LiteralType] // Must be narrower than `LiteralType` - ? [LiteralType] extends [T] // Cannot be wider than `LiteralType` - ? false - : true - : false - : false; +> = + IsNever extends false // Must be wider than `never` + ? [T] extends [LiteralType] // Must be narrower than `LiteralType` + ? [LiteralType] extends [T] // Cannot be wider than `LiteralType` + ? false + : true + : false + : false; type IsStringLiteral = LiteralCheck; export interface RemoteClient { @@ -257,10 +258,12 @@ function toString(value: unknown) { return value ? "true" : "false"; } - return ""; + return ""; } -function extraProps(details: ArcjetRequestDetails): Record { +function extraProps( + details: ArcjetRequest, +): Record { const extra: Map = new Map(); for (const [key, value] of Object.entries(details)) { if (isUnknownRequestProperty(key)) { @@ -315,7 +318,7 @@ export function createRemoteClient( query: details.query, // TODO(#208): Re-add body // body: details.body, - extra: extraProps(details), + extra: details.extra, email: typeof details.email === "string" ? details.email : undefined, }, rules: rules.map(ArcjetRuleToProtocol), @@ -364,7 +367,7 @@ export function createRemoteClient( headers: Object.fromEntries(details.headers.entries()), // TODO(#208): Re-add body // body: details.body, - extra: extraProps(details), + extra: details.extra, email: typeof details.email === "string" ? details.email : undefined, }, decision: ArcjetDecisionToProtocol(decision), @@ -584,20 +587,21 @@ export type Product = ArcjetRule[]; // Note: If a user doesn't provide the object literal to our primitives // directly, we fallback to no required props. They can opt-in by adding the // `as const` suffix to the characteristics array. -type PropsForCharacteristic = IsStringLiteral extends true - ? T extends - | "ip.src" - | "http.host" - | "http.method" - | "http.request.uri.path" - | `http.request.headers["${string}"]` - | `http.request.cookie["${string}"]` - | `http.request.uri.args["${string}"]` - ? {} - : T extends string - ? Record - : never - : {}; +type PropsForCharacteristic = + IsStringLiteral extends true + ? T extends + | "ip.src" + | "http.host" + | "http.method" + | "http.request.uri.path" + | `http.request.headers["${string}"]` + | `http.request.cookie["${string}"]` + | `http.request.uri.args["${string}"]` + ? {} + : T extends string + ? Record + : never + : {}; // Rules can specify they require specific props on an ArcjetRequest type PropsForRule = R extends ArcjetRule ? Props : {}; // We theoretically support an arbitrary amount of rule flattening, @@ -625,7 +629,17 @@ export type ExtraProps = Rules extends [] * @property ...extra - Extra data that might be useful for Arcjet. For example, requested tokens are specified as the `requested` property. */ export type ArcjetRequest = Simplify< - Partial & Props + { + [key: string]: unknown; + ip?: string; + method?: string; + protocol?: string; + host?: string; + path?: string; + headers?: Headers | Record; + cookies?: string; + query?: string; + } & Props >; function isLocalRule( @@ -1052,9 +1066,19 @@ export default function arcjet< request = {} as typeof request; } - const details = Object.freeze({ - ...request, + const details: Partial = Object.freeze({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, headers: new ArcjetHeaders(request.headers), + cookies: request.cookies, + query: request.query, + // TODO(#208): Re-add body + // body: request.body, + extra: extraProps(request), + email: typeof request.email === "string" ? request.email : undefined, }); log.time("local"); diff --git a/arcjet/test/index.node.test.ts b/arcjet/test/index.node.test.ts index e3ae8794d..8a8ba2de5 100644 --- a/arcjet/test/index.node.test.ts +++ b/arcjet/test/index.node.test.ts @@ -83,17 +83,15 @@ import arcjet, { // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -type IsEqual = (() => G extends A ? 1 : 2) extends () => G extends B - ? 1 - : 2 - ? true - : false; +type IsEqual = + (() => G extends A ? 1 : 2) extends () => G extends B ? 1 : 2 + ? true + : false; // Type testing utilities type Assert = T; -type Props

= P extends Primitive - ? Props - : never; +type Props

= + P extends Primitive ? Props : never; type RequiredProps

= IsEqual, E>; // Instances of Headers contain symbols that may be different depending @@ -254,7 +252,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -279,12 +279,7 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, fingerprint, @@ -312,7 +307,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -338,12 +335,7 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, fingerprint, @@ -371,7 +363,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -395,12 +389,7 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, fingerprint, @@ -428,7 +417,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, email: "abc@example.com", }; @@ -453,14 +444,8 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, - email: details.email, }, fingerprint, rules: [], @@ -487,7 +472,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, email: "abc@example.com", }; @@ -517,14 +504,8 @@ describe("createRemoteClient", () => { expect(router.decide).toHaveBeenCalledWith( new DecideRequest({ details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, - email: details.email, }, fingerprint, rules: [new Rule()], @@ -551,7 +532,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -591,7 +574,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -630,7 +615,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -669,7 +656,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -711,7 +700,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -759,7 +750,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const router = { @@ -800,7 +793,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, email: "test@example.com", }; @@ -834,14 +829,8 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, - email: details.email, }, decision: { id: decision.id, @@ -872,7 +861,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -905,12 +896,7 @@ describe("createRemoteClient", () => { sdkStack: SDKStack.SDK_STACK_NODEJS, sdkVersion: "__ARCJET_SDK_VERSION__", details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, decision: { @@ -942,7 +928,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -975,12 +963,7 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, decision: { @@ -1019,7 +1002,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -1052,12 +1037,7 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, decision: { @@ -1089,7 +1069,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -1118,12 +1100,7 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, }, decision: { @@ -1155,7 +1132,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, email: "abc@example.com", }; @@ -1202,14 +1181,8 @@ describe("createRemoteClient", () => { sdkVersion: "__ARCJET_SDK_VERSION__", fingerprint, details: { - ip: details.ip, - method: details.method, - protocol: details.protocol, - host: details.host, - path: details.path, - extra: { "extra-test": details["extra-test"] }, + ...details, headers: { "user-agent": "curl/8.1.2" }, - email: details.email, }, decision: { id: decision.id, @@ -1247,7 +1220,9 @@ describe("createRemoteClient", () => { host: "example.com", path: "/", headers: new Headers([["User-Agent", "curl/8.1.2"]]), - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [promise, resolve] = deferred(); @@ -1662,7 +1637,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1714,7 +1691,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1766,7 +1745,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1805,7 +1786,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot({ @@ -1861,7 +1844,9 @@ describe("Primitive > detectBot", () => { headers: new Headers([["User-Agent", "curl/8.1.2"]]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1912,7 +1897,9 @@ describe("Primitive > detectBot", () => { ]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -1952,7 +1939,9 @@ describe("Primitive > detectBot", () => { headers: new Headers([["User-Agent", "curl/8.1.2"]]), cookies: "", query: "", - "extra-test": "extra-test-value", + extra: { + "extra-test": "extra-test-value", + }, }; const [rule] = detectBot(options); @@ -2802,6 +2791,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@example.com", + extra: {}, }; const [rule] = validateEmail(); @@ -2833,6 +2823,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz", + extra: {}, }; const [rule] = validateEmail(); @@ -2864,6 +2855,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@localhost", + extra: {}, }; const [rule] = validateEmail(); @@ -2895,6 +2887,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@localhost", + extra: {}, }; const [rule] = validateEmail({ @@ -2928,6 +2921,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "@example.com", + extra: {}, }; const [rule] = validateEmail(); @@ -2959,6 +2953,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@[127.0.0.1]", + extra: {}, }; const [rule] = validateEmail(); @@ -2990,6 +2985,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@localhost", + extra: {}, }; const [rule] = validateEmail({ @@ -3023,6 +3019,7 @@ describe("Primitive > validateEmail", () => { cookies: "", query: "", email: "foobarbaz@[127.0.0.1]", + extra: {}, }; const [rule] = validateEmail({ @@ -3343,7 +3340,7 @@ describe("SDK", () => { report: jest.fn(), }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3361,7 +3358,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.conclusion).toEqual("DENY"); expect(allowed.validate).toHaveBeenCalledTimes(1); @@ -3370,7 +3367,7 @@ describe("SDK", () => { expect(denied.protect).toHaveBeenCalledTimes(1); }); - test("works with an empty details object", async () => { + test("works with an empty request object", async () => { const client = { decide: jest.fn(async () => { return new ArcjetAllowDecision({ @@ -3382,7 +3379,7 @@ describe("SDK", () => { report: jest.fn(), }; - const details = {}; + const request = {}; const aj = arcjet({ key: "test-key", @@ -3390,11 +3387,11 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.conclusion).toEqual("ALLOW"); }); - test("does not crash with no details object", async () => { + test("does not crash with no request object", async () => { const client = { decide: jest.fn(async () => { return new ArcjetAllowDecision({ @@ -3429,7 +3426,7 @@ describe("SDK", () => { report: jest.fn(), }; - const details = {}; + const request = {}; const rules: ArcjetRule[][] = []; // We only iterate 4 times because `testRuleMultiple` generates 3 rules @@ -3443,7 +3440,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.conclusion).toEqual("ERROR"); }); @@ -3459,7 +3456,7 @@ describe("SDK", () => { report: jest.fn(), }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3477,7 +3474,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.conclusion).toEqual("DENY"); expect(denied.validate).toHaveBeenCalledTimes(1); @@ -3486,12 +3483,12 @@ describe("SDK", () => { expect(allowed.protect).toHaveBeenCalledTimes(0); }); - test("does not call `client.report()` if the local decision is ALLOW", async () => { + test("accepts plain object of headers", async () => { const client = { decide: jest.fn(async () => { - return new ArcjetErrorDecision({ + return new ArcjetAllowDecision({ ttl: 0, - reason: new ArcjetErrorReason("This decision not under test"), + reason: new ArcjetTestReason(), results: [], }); }), @@ -3504,7 +3501,169 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: { "User-Agent": "curl/8.1.2" }, + "extra-test": "extra-test-value", + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + + const decision = await aj.protect(request); + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.objectContaining(context), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: new Headers(Object.entries(request.headers)), + extra: { + "extra-test": "extra-test-value", + }, + }), + [], + ); + }); + + test("accepts plain object of `raw` headers", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const key = "test-key"; + const context = { + key, + fingerprint: + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + }; + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: { "User-Agent": ["curl/8.1.2", "something"] }, + "extra-test": "extra-test-value", + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + + const decision = await aj.protect(request); + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.objectContaining(context), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: new Headers([ + ["user-agent", "curl/8.1.2"], + ["user-agent", "something"], + ]), + extra: { + "extra-test": "extra-test-value", + }, + }), + [], + ); + }); + + test("converts extra keys with non-string values to string values", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetAllowDecision({ + ttl: 0, + reason: new ArcjetTestReason(), + results: [], + }); + }), + report: jest.fn(), + }; + + const key = "test-key"; + const context = { + key, + fingerprint: + "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", + }; + const request = { + ip: "172.100.1.1", + method: "GET", + protocol: "http", + host: "example.com", + path: "/", + headers: { "User-Agent": "curl/8.1.2" }, + "extra-number": 123, + "extra-false": false, + "extra-true": true, + "extra-unsupported": new Date(), + }; + + const aj = arcjet({ + key: "test-key", + rules: [], + client, + }); + + const decision = await aj.protect(request); + expect(client.decide).toHaveBeenCalledTimes(1); + expect(client.decide).toHaveBeenCalledWith( + expect.objectContaining(context), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: new Headers(Object.entries(request.headers)), + extra: { + "extra-number": "123", + "extra-false": "false", + "extra-true": "true", + "extra-unsupported": "", + }, + }), + [], + ); + }); + + test("does not call `client.report()` if the local decision is ALLOW", async () => { + const client = { + decide: jest.fn(async () => { + return new ArcjetErrorDecision({ + ttl: 0, + reason: new ArcjetErrorReason("This decision not under test"), + results: [], + }); + }), + report: jest.fn(), + }; + + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3516,12 +3675,12 @@ describe("SDK", () => { const allowed = testRuleLocalAllowed(); const aj = arcjet({ - key, + key: "test-key", rules: [[allowed]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.report).toHaveBeenCalledTimes(0); expect(client.decide).toHaveBeenCalledTimes(1); // TODO: Validate correct `ruleResults` are sent with `decide` when available @@ -3545,7 +3704,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3562,11 +3721,21 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), [rule], ); }); @@ -3589,7 +3758,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3606,11 +3775,21 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.report).toHaveBeenCalledTimes(1); expect(client.report).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), expect.objectContaining({ conclusion: "DENY", }), @@ -3630,8 +3809,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3643,12 +3821,12 @@ describe("SDK", () => { const denied = testRuleLocalDenied(); const aj = arcjet({ - key, + key: "test-key", rules: [[denied]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.decide).toHaveBeenCalledTimes(0); }); @@ -3670,7 +3848,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3686,13 +3864,23 @@ describe("SDK", () => { client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.report).toHaveBeenCalledTimes(0); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), [], ); }); @@ -3709,8 +3897,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3721,12 +3908,12 @@ describe("SDK", () => { }; const aj = arcjet({ - key, + key: "test-key", rules: [], client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.isErrored()).toBe(false); @@ -3735,7 +3922,7 @@ describe("SDK", () => { expect(decision.conclusion).toEqual("DENY"); - const decision2 = await aj.protect(details); + const decision2 = await aj.protect(request); expect(decision2.isErrored()).toBe(false); expect(client.decide).toHaveBeenCalledTimes(1); @@ -3777,13 +3964,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const context = { - key, - fingerprint: - "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", - }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3794,12 +3975,12 @@ describe("SDK", () => { }; const aj = arcjet({ - key, + key: "test-key", rules: [[testRuleLocalThrow()]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(client.report).toHaveBeenCalledTimes(0); expect(client.decide).toHaveBeenCalledTimes(1); @@ -3818,8 +3999,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3837,7 +4017,7 @@ describe("SDK", () => { type: "TEST_RULE_LOCAL_THROW_STRING", priority: 1, validate: jest.fn(), - async protect(context, req) { + async protect(context, details) { errorLogSpy = jest.spyOn(context.log, "error"); throw "Local rule protect failed"; }, @@ -3845,12 +4025,12 @@ describe("SDK", () => { } const aj = arcjet({ - key, + key: "test-key", rules: [[testRuleLocalThrowString()]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(errorLogSpy).toHaveBeenCalledTimes(1); expect(errorLogSpy).toHaveBeenCalledWith( @@ -3872,8 +4052,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3891,7 +4070,7 @@ describe("SDK", () => { type: "TEST_RULE_LOCAL_THROW_NULL", priority: 1, validate: jest.fn(), - async protect(context, req) { + async protect(context, details) { errorLogSpy = jest.spyOn(context.log, "error"); throw null; }, @@ -3899,12 +4078,12 @@ describe("SDK", () => { } const aj = arcjet({ - key, + key: "test-key", rules: [[testRuleLocalThrowNull()]], client, }); - const _ = await aj.protect(details); + const _ = await aj.protect(request); expect(errorLogSpy).toHaveBeenCalledTimes(1); expect(errorLogSpy).toHaveBeenCalledWith( @@ -3926,8 +4105,7 @@ describe("SDK", () => { report: jest.fn(), }; - const key = "test-key"; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3938,19 +4116,19 @@ describe("SDK", () => { }; const aj = arcjet({ - key, + key: "test-key", rules: [[testRuleLocalDryRun()]], client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.isDenied()).toBe(false); expect(client.decide).toBeCalledTimes(1); expect(client.report).toBeCalledTimes(1); - const decision2 = await aj.protect(details); + const decision2 = await aj.protect(request); expect(decision2.isDenied()).toBe(false); @@ -3976,7 +4154,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -3994,14 +4172,24 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.isErrored()).toBe(false); expect(client.decide).toHaveBeenCalledTimes(1); expect(client.decide).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), [rule], ); }); @@ -4020,7 +4208,7 @@ describe("SDK", () => { fingerprint: "fp_1_ac8547705f1f45c5050f1424700dfa3f6f2f681b550ca4f3c19571585aea7a2c", }; - const details = { + const request = { ip: "172.100.1.1", method: "GET", protocol: "http", @@ -4036,7 +4224,7 @@ describe("SDK", () => { client, }); - const decision = await aj.protect(details); + const decision = await aj.protect(request); expect(decision.isErrored()).toBe(true); @@ -4044,7 +4232,17 @@ describe("SDK", () => { expect(client.report).toHaveBeenCalledTimes(1); expect(client.report).toHaveBeenCalledWith( expect.objectContaining(context), - expect.objectContaining(details), + expect.objectContaining({ + ip: request.ip, + method: request.method, + protocol: request.protocol, + host: request.host, + path: request.path, + headers: request.headers, + extra: { + "extra-test": "extra-test-value", + }, + }), expect.objectContaining({ conclusion: "ERROR", }), diff --git a/protocol/index.ts b/protocol/index.ts index cfdbc79dc..7c799935a 100644 --- a/protocol/index.ts +++ b/protocol/index.ts @@ -371,16 +371,17 @@ export class ArcjetErrorDecision extends ArcjetDecision { } export interface ArcjetRequestDetails { - [key: string]: unknown; ip: string; method: string; protocol: string; host: string; path: string; - // TODO(#215): Allow `Record` and `Record`? headers: Headers; cookies: string; query: string; + extra: { [key: string]: string }; + // TODO: Consider moving email to `extra` map + email?: string; } export type ArcjetRule = {