Skip to content

Commit

Permalink
fix: Ensure withRule always contains previous rules in the same chain (
Browse files Browse the repository at this point in the history
…#981)

Fixes #972 

When I originally wrote `withRule` in the arcjet core package, I always applied `rootRules` as the base rules. This meant that there was a bug where multiple calls to `withRule` would drop all rules except the rules on the root client and the most recent rule added with `withRule`.

I've changed some internal logic to ensure we always use the previous rule set as the base rules before adding the new rule via `withRule`.

Additionally, I've added functional tests to what were previously only type assertion tests to verify that the correct rules are being set inside the SDK.
  • Loading branch information
blaine-arcjet authored Jun 14, 2024
1 parent cd4621e commit 2ee6581
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 33 deletions.
11 changes: 7 additions & 4 deletions arcjet/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1132,14 +1132,17 @@ export default function arcjet<
}

// This is a separate function so it can be called recursively
function withRule<Rule extends Primitive | Product>(rule: Rule) {
const rules = [...rootRules, ...rule].sort(
function withRule<Rule extends Primitive | Product>(
baseRules: ArcjetRule[],
rule: Rule,
) {
const rules = [...baseRules, ...rule].sort(
(a, b) => a.priority - b.priority,
);

return Object.freeze({
withRule(rule: Primitive | Product) {
return withRule(rule);
return withRule(rules, rule);
},
async protect(
ctx: ArcjetAdapterContext,
Expand All @@ -1152,7 +1155,7 @@ export default function arcjet<

return Object.freeze({
withRule(rule: Primitive | Product) {
return withRule(rule);
return withRule(rootRules, rule);
},
async protect(
ctx: ArcjetAdapterContext,
Expand Down
127 changes: 98 additions & 29 deletions arcjet/test/index.node.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1973,28 +1973,49 @@ describe("SDK", () => {
report: jest.fn(),
};

const key = "test-key";
const request = {
ip: "172.100.1.1",
method: "GET",
protocol: "http",
host: "example.com",
path: "/",
headers: { "User-Agent": "curl/8.1.2" },
"extra-test": "extra-test-value",
userId: "abc123",
requested: 1,
};

const aj = arcjet({
key: "test-key",
key,
rules: [],
client,
log,
});
type WithoutRuleTest = Assert<SDKProps<typeof aj, {}>>;

const aj2 = aj.withRule(
tokenBucket({
characteristics: ["userId"],
refillRate: 60,
interval: 60,
capacity: 120,
}),
);
const tokenBucketRule = tokenBucket({
characteristics: ["userId"],
refillRate: 60,
interval: 60,
capacity: 120,
});

const aj2 = aj.withRule(tokenBucketRule);
type WithRuleTest = Assert<
SDKProps<
typeof aj2,
{ requested: number; userId: string | number | boolean }
>
>;

const _ = await aj2.protect({}, request);
expect(client.decide).toHaveBeenCalledTimes(1);
expect(client.decide).toHaveBeenCalledWith(
expect.anything(),
expect.anything(),
[...tokenBucketRule],
);
});

test("can chain new rules via multiple `withRule` calls", async () => {
Expand All @@ -2009,36 +2030,60 @@ describe("SDK", () => {
report: jest.fn(),
};

const key = "test-key";
const request = {
ip: "172.100.1.1",
method: "GET",
protocol: "http",
host: "example.com",
path: "/",
headers: { "User-Agent": "curl/8.1.2" },
"extra-test": "extra-test-value",
userId: "abc123",
requested: 1,
abc: 123,
};

const aj = arcjet({
key: "test-key",
key,
rules: [],
client,
log,
});
type WithoutRuleTest = Assert<SDKProps<typeof aj, {}>>;

const aj2 = aj.withRule(
tokenBucket({
characteristics: ["userId"],
refillRate: 60,
interval: 60,
capacity: 120,
}),
);
const tokenBucketRule = tokenBucket({
characteristics: ["userId"],
refillRate: 60,
interval: 60,
capacity: 120,
});

const aj2 = aj.withRule(tokenBucketRule);
type WithRuleTestOne = Assert<
SDKProps<
typeof aj2,
{ requested: number; userId: string | number | boolean }
>
>;

const aj3 = aj2.withRule(testRuleProps());
const testRule = testRuleProps();

const aj3 = aj2.withRule(testRule);
type WithRuleTestTwo = Assert<
SDKProps<
typeof aj3,
{ requested: number; userId: string | number | boolean; abc: number }
>
>;

const _ = await aj3.protect({}, request);
expect(client.decide).toHaveBeenCalledTimes(1);
expect(client.decide).toHaveBeenCalledWith(
expect.anything(),
expect.anything(),
[...tokenBucketRule, ...testRule],
);
});

test("creates different augmented clients when `withRule` not chained", async () => {
Expand All @@ -2053,31 +2098,55 @@ describe("SDK", () => {
report: jest.fn(),
};

const key = "test-key";
const request = {
ip: "172.100.1.1",
method: "GET",
protocol: "http",
host: "example.com",
path: "/",
headers: { "User-Agent": "curl/8.1.2" },
"extra-test": "extra-test-value",
userId: "abc123",
requested: 1,
abc: 123,
};

const aj = arcjet({
key: "test-key",
key,
rules: [],
client,
log,
});
type WithoutRuleTest = Assert<SDKProps<typeof aj, {}>>;

const aj2 = aj.withRule(
tokenBucket({
characteristics: ["userId"],
refillRate: 60,
interval: 60,
capacity: 120,
}),
);
const tokenBucketRule = tokenBucket({
characteristics: ["userId"],
refillRate: 60,
interval: 60,
capacity: 120,
});

const aj2 = aj.withRule(tokenBucketRule);
type WithRuleTestOne = Assert<
SDKProps<
typeof aj2,
{ requested: number; userId: string | number | boolean }
>
>;

const aj3 = aj.withRule(testRuleProps());
const testRule = testRuleProps();

const aj3 = aj.withRule(testRule);
type WithRuleTestTwo = Assert<SDKProps<typeof aj3, { abc: number }>>;

const _ = await aj3.protect({}, request);
expect(client.decide).toHaveBeenCalledTimes(1);
expect(client.decide).toHaveBeenCalledWith(
expect.anything(),
expect.anything(),
[...testRule],
);
});

test("creates a new Arcjet SDK with only local rules", () => {
Expand Down

0 comments on commit 2ee6581

Please sign in to comment.