Skip to content

Commit 97c683a

Browse files
committed
feat: Allow user-defined characteristics on rate limit options
1 parent 6701b02 commit 97c683a

File tree

2 files changed

+130
-40
lines changed

2 files changed

+130
-40
lines changed

arcjet/index.ts

+106-35
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ function errorMessage(err: unknown): string {
110110
// https://github.com/sindresorhus/type-fest/blob/964466c9d59c711da57a5297ad954c13132a0001/source/simplify.d.ts
111111
// UnionToIntersection:
112112
// https://github.com/sindresorhus/type-fest/blob/017bf38ebb52df37c297324d97bcc693ec22e920/source/union-to-intersection.d.ts
113+
// IsNever:
114+
// https://github.com/sindresorhus/type-fest/blob/e02f228f6391bb2b26c32a55dfe1e3aa2386d515/source/primitive.d.ts
115+
// LiteralCheck & IsStringLiteral:
116+
// https://github.com/sindresorhus/type-fest/blob/e02f228f6391bb2b26c32a55dfe1e3aa2386d515/source/is-literal.d.ts
113117
//
114118
// Licensed: MIT License Copyright (c) Sindre Sorhus <[email protected]>
115119
// (https://sindresorhus.com)
@@ -148,6 +152,25 @@ type UnionToIntersection<Union> =
148152
? // The `& Union` is to allow indexing by the resulting type
149153
Intersection & Union
150154
: never;
155+
type IsNever<T> = [T] extends [never] ? true : false;
156+
type LiteralCheck<
157+
T,
158+
LiteralType extends
159+
| null
160+
| undefined
161+
| string
162+
| number
163+
| boolean
164+
| symbol
165+
| bigint,
166+
> = IsNever<T> extends false // Must be wider than `never`
167+
? [T] extends [LiteralType] // Must be narrower than `LiteralType`
168+
? [LiteralType] extends [T] // Cannot be wider than `LiteralType`
169+
? false
170+
: true
171+
: false
172+
: false;
173+
type IsStringLiteral<T> = LiteralCheck<T, string>;
151174

152175
export interface RemoteClient {
153176
decide(
@@ -416,30 +439,31 @@ function runtime(): Runtime {
416439
}
417440
}
418441

419-
type TokenBucketRateLimitOptions = {
442+
type TokenBucketRateLimitOptions<Characteristics extends readonly string[]> = {
420443
mode?: ArcjetMode;
421444
match?: string;
422-
characteristics?: string[];
445+
characteristics?: Characteristics;
423446
refillRate: number;
424447
interval: number;
425448
capacity: number;
426449
};
427450

428-
type FixedWindowRateLimitOptions = {
451+
type FixedWindowRateLimitOptions<Characteristics extends readonly string[]> = {
429452
mode?: ArcjetMode;
430453
match?: string;
431-
characteristics?: string[];
454+
characteristics?: Characteristics;
432455
window: string;
433456
max: number;
434457
};
435458

436-
type SlidingWindowRateLimitOptions = {
437-
mode?: ArcjetMode;
438-
match?: string;
439-
characteristics?: string[];
440-
interval: number;
441-
max: number;
442-
};
459+
type SlidingWindowRateLimitOptions<Characteristics extends readonly string[]> =
460+
{
461+
mode?: ArcjetMode;
462+
match?: string;
463+
characteristics?: Characteristics;
464+
interval: number;
465+
max: number;
466+
};
443467

444468
/**
445469
* Bot detection is disabled by default. The `bots` configuration block allows
@@ -549,6 +573,25 @@ type PlainObject = { [key: string]: unknown };
549573
export type Primitive<Props extends PlainObject = {}> = ArcjetRule<Props>[];
550574
export type Product<Props extends PlainObject = {}> = ArcjetRule<Props>[];
551575

576+
// User-defined characteristics alter the required props of an ArcjetRequest
577+
// Note: If a user doesn't provide the object literal to our primitives
578+
// directly, we fallback to no required props. They can opt-in by adding the
579+
// `as const` suffix to the characteristics array.
580+
type PropsForCharacteristic<T> = IsStringLiteral<T> extends true
581+
? T extends
582+
| "ip.src"
583+
| "http.host"
584+
| "http.method"
585+
| "http.request.uri.path"
586+
| `http.request.headers["${string}"]`
587+
| `http.request.cookie["${string}"]`
588+
| `http.request.uri.args["${string}"]`
589+
? {}
590+
: T extends string
591+
? Record<T, string | number | boolean>
592+
: never
593+
: {};
594+
// Rules can specify they require specific props on an ArcjetRequest
552595
type PropsForRule<R> = R extends ArcjetRule<infer Props> ? Props : {};
553596
// We theoretically support an arbitrary amount of rule flattening,
554597
// but one level seems to be easiest; however, this puts a constraint of
@@ -589,10 +632,16 @@ function isLocalRule<Props extends PlainObject>(
589632
);
590633
}
591634

592-
export function tokenBucket(
593-
options?: TokenBucketRateLimitOptions,
594-
...additionalOptions: TokenBucketRateLimitOptions[]
595-
): Primitive<{ requested: number }> {
635+
export function tokenBucket<
636+
const Characteristics extends readonly string[] = [],
637+
>(
638+
options?: TokenBucketRateLimitOptions<Characteristics>,
639+
...additionalOptions: TokenBucketRateLimitOptions<Characteristics>[]
640+
): Primitive<
641+
UnionToIntersection<
642+
{ requested: number } | PropsForCharacteristic<Characteristics[number]>
643+
>
644+
> {
596645
const rules: ArcjetTokenBucketRateLimitRule<{ requested: number }>[] = [];
597646

598647
if (typeof options === "undefined") {
@@ -602,7 +651,7 @@ export function tokenBucket(
602651
for (const opt of [options, ...additionalOptions]) {
603652
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
604653
const match = opt.match;
605-
const characteristics = opt.characteristics;
654+
const characteristics = Array.isArray(opt.characteristics) ? opt.characteristics : undefined;
606655

607656
const refillRate = opt.refillRate;
608657
const interval = opt.interval;
@@ -624,10 +673,14 @@ export function tokenBucket(
624673
return rules;
625674
}
626675

627-
export function fixedWindow(
628-
options?: FixedWindowRateLimitOptions,
629-
...additionalOptions: FixedWindowRateLimitOptions[]
630-
): Primitive {
676+
export function fixedWindow<
677+
const Characteristics extends readonly string[] = [],
678+
>(
679+
options?: FixedWindowRateLimitOptions<Characteristics>,
680+
...additionalOptions: FixedWindowRateLimitOptions<Characteristics>[]
681+
): Primitive<
682+
UnionToIntersection<PropsForCharacteristic<Characteristics[number]>>
683+
> {
631684
const rules: ArcjetFixedWindowRateLimitRule<{}>[] = [];
632685

633686
if (typeof options === "undefined") {
@@ -637,7 +690,9 @@ export function fixedWindow(
637690
for (const opt of [options, ...additionalOptions]) {
638691
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
639692
const match = opt.match;
640-
const characteristics = opt.characteristics;
693+
const characteristics = Array.isArray(opt.characteristics)
694+
? opt.characteristics
695+
: undefined;
641696

642697
const max = opt.max;
643698
const window = opt.window;
@@ -659,19 +714,25 @@ export function fixedWindow(
659714

660715
// This is currently kept for backwards compatibility but should be removed in
661716
// favor of the fixedWindow primitive.
662-
export function rateLimit(
663-
options?: FixedWindowRateLimitOptions,
664-
...additionalOptions: FixedWindowRateLimitOptions[]
665-
): Primitive {
717+
export function rateLimit<const Characteristics extends readonly string[] = []>(
718+
options?: FixedWindowRateLimitOptions<Characteristics>,
719+
...additionalOptions: FixedWindowRateLimitOptions<Characteristics>[]
720+
): Primitive<
721+
UnionToIntersection<PropsForCharacteristic<Characteristics[number]>>
722+
> {
666723
// TODO(#195): We should also have a local rate limit using an in-memory data
667724
// structure if the environment supports it
668725
return fixedWindow(options, ...additionalOptions);
669726
}
670727

671-
export function slidingWindow(
672-
options?: SlidingWindowRateLimitOptions,
673-
...additionalOptions: SlidingWindowRateLimitOptions[]
674-
): Primitive {
728+
export function slidingWindow<
729+
const Characteristics extends readonly string[] = [],
730+
>(
731+
options?: SlidingWindowRateLimitOptions<Characteristics>,
732+
...additionalOptions: SlidingWindowRateLimitOptions<Characteristics>[]
733+
): Primitive<
734+
UnionToIntersection<PropsForCharacteristic<Characteristics[number]>>
735+
> {
675736
const rules: ArcjetSlidingWindowRateLimitRule<{}>[] = [];
676737

677738
if (typeof options === "undefined") {
@@ -681,7 +742,9 @@ export function slidingWindow(
681742
for (const opt of [options, ...additionalOptions]) {
682743
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
683744
const match = opt.match;
684-
const characteristics = opt.characteristics;
745+
const characteristics = Array.isArray(opt.characteristics)
746+
? opt.characteristics
747+
: undefined;
685748

686749
const max = opt.max;
687750
const interval = opt.interval;
@@ -866,15 +929,23 @@ export function detectBot(
866929
return rules;
867930
}
868931

869-
export type ProtectSignupOptions = {
870-
rateLimit?: SlidingWindowRateLimitOptions | SlidingWindowRateLimitOptions[];
932+
export type ProtectSignupOptions<Characteristics extends string[]> = {
933+
rateLimit?:
934+
| SlidingWindowRateLimitOptions<Characteristics>
935+
| SlidingWindowRateLimitOptions<Characteristics>[];
871936
bots?: BotOptions | BotOptions[];
872937
email?: EmailOptions | EmailOptions[];
873938
};
874939

875-
export function protectSignup(
876-
options?: ProtectSignupOptions,
877-
): Product<{ email: string }> {
940+
export function protectSignup<const Characteristics extends string[] = []>(
941+
options?: ProtectSignupOptions<Characteristics>,
942+
): Product<
943+
Simplify<
944+
UnionToIntersection<
945+
{ email: string } | PropsForCharacteristic<Characteristics[number]>
946+
>
947+
>
948+
> {
878949
let rateLimitRules: Primitive<{}> = [];
879950
if (Array.isArray(options?.rateLimit)) {
880951
rateLimitRules = slidingWindow(...options.rateLimit);

arcjet/test/index.edge.test.ts

+24-5
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,28 @@ describe("Arcjet: Env = Edge runtime", () => {
3636
rules: [
3737
// Test rules
3838
foobarbaz(),
39-
tokenBucket({
40-
refillRate: 1,
41-
interval: 1,
42-
capacity: 1,
43-
}),
39+
tokenBucket(
40+
{
41+
characteristics: [
42+
"ip.src",
43+
"http.host",
44+
"http.method",
45+
"http.request.uri.path",
46+
`http.request.headers["abc"]`,
47+
`http.request.cookie["xyz"]`,
48+
`http.request.uri.args["foobar"]`,
49+
],
50+
refillRate: 1,
51+
interval: 1,
52+
capacity: 1,
53+
},
54+
{
55+
characteristics: ["userId"],
56+
refillRate: 1,
57+
interval: 1,
58+
capacity: 1,
59+
},
60+
),
4461
rateLimit({
4562
max: 1,
4663
window: "60s",
@@ -61,6 +78,8 @@ describe("Arcjet: Env = Edge runtime", () => {
6178
path: "",
6279
headers: new Headers(),
6380
extra: {},
81+
userId: "user123",
82+
foobar: 123,
6483
});
6584

6685
expect(decision.isErrored()).toBe(false);

0 commit comments

Comments
 (0)