Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add withRule API for adding adhoc rules #245

Merged
merged 8 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 129 additions & 101 deletions arcjet-next/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import arcjet, {
RemoteClientOptions,
defaultBaseUrl,
createRemoteClient,
Arcjet,
} from "arcjet";
import findIP from "@arcjet/ip";

Expand Down Expand Up @@ -160,6 +161,10 @@ function cookiesToString(cookies?: ArcjetNextRequest["cookies"]): string {
.join("; ");
}

/**
* The ArcjetNext client provides a public `protect()` method to
* make a decision about how a Next.js request should be handled.
*/
export interface ArcjetNext<Props extends PlainObject> {
get runtime(): Runtime;
/**
Expand All @@ -178,124 +183,147 @@ export interface ArcjetNext<Props extends PlainObject> {
// that is required if the ExtraProps aren't strictly an empty object
...props: Props extends WithoutCustomProps ? [] : [Props]
): Promise<ArcjetDecision>;
}

/**
* This is the main class for Arcjet when using Next.js. It provides several
* methods for protecting Next.js routes depending on whether they are using the
* Edge or Serverless Functions runtime.
*/
/**
* Create a new Arcjet Next client. If possible, call this outside of the
* request context so it persists across requests.
*
* @param key - The key to identify the site in Arcjet.
* @param options - Arcjet configuration options to apply to all requests.
* These can be overriden on a per-request basis by providing them to the
* `protect()` or `protectApi` methods.
*/
export default function arcjetNext<const Rules extends (Primitive | Product)[]>(
options: ArcjetOptions<Rules>,
): ArcjetNext<Simplify<ExtraProps<Rules>>> {
const client = options.client ?? createNextRemoteClient();
/**
* Augments the client with another rule. Useful for varying rules based on
* criteria in your handler—e.g. different rate limit for logged in users.
*
* @param rule The rule to add to this execution.
* @returns An augmented {@link ArcjetNext} client.
*/
withRule<Rule extends Primitive | Product>(
rule: Rule,
): ArcjetNext<Simplify<Props & ExtraProps<Rule>>>;
}

const aj = arcjet({ ...options, client });
function toArcjetRequest<Props extends PlainObject>(
request: ArcjetNextRequest,
props: Props,
): ArcjetRequest<Props> {
// We construct an ArcjetHeaders to normalize over Headers
const headers = new ArcjetHeaders(request.headers);

const ip = findIP(request, headers);
const method = request.method ?? "";
const host = headers.get("host") ?? "";
let path = "";
let query = "";
let protocol = "";
// TODO(#36): nextUrl has formatting logic when you `toString` but
// we don't account for that here
if (typeof request.nextUrl !== "undefined") {
path = request.nextUrl.pathname ?? "";
if (typeof request.nextUrl.search !== "undefined") {
query = request.nextUrl.search;
}
if (typeof request.nextUrl.protocol !== "undefined") {
protocol = request.nextUrl.protocol;
}
} else {
if (typeof request.socket?.encrypted !== "undefined") {
protocol = request.socket.encrypted ? "https:" : "http:";
} else {
protocol = "http:";
}
// Do some very simple validation, but also try/catch around URL parsing
if (
typeof request.url !== "undefined" &&
request.url !== "" &&
host !== ""
) {
try {
const url = new URL(request.url, `${protocol}//${host}`);
path = url.pathname;
query = url.search;
protocol = url.protocol;
} catch {
// If the parsing above fails, just set the path as whatever url we
// received.
// TODO(#216): Add logging to arcjet-next
path = request.url ?? "";
}
} else {
path = request.url ?? "";
}
}
const cookies = cookiesToString(request.cookies);

const extra: { [key: string]: string } = {};

// If we're running on Vercel, we can add some extra information
if (process.env["VERCEL"]) {
// Vercel ID https://vercel.com/docs/concepts/edge-network/headers
extra["vercel-id"] = headers.get("x-vercel-id") ?? "";
// Vercel deployment URL
// https://vercel.com/docs/concepts/edge-network/headers
extra["vercel-deployment-url"] =
headers.get("x-vercel-deployment-url") ?? "";
// Vercel git commit SHA
// https://vercel.com/docs/concepts/projects/environment-variables/system-environment-variables
extra["vercel-git-commit-sha"] = process.env["VERCEL_GIT_COMMIT_SHA"] ?? "";
extra["vercel-git-commit-sha"] = process.env["VERCEL_GIT_COMMIT_SHA"] ?? "";
}
return {
...props,
...extra,
ip,
method,
protocol,
host,
path,
headers,
cookies,
query,
};
}

function withClient<const Rules extends (Primitive | Product)[]>(
aj: Arcjet<ExtraProps<Rules>>,
): ArcjetNext<ExtraProps<Rules>> {
return Object.freeze({
get runtime() {
return aj.runtime;
},
withRule(rule: Primitive | Product) {
const client = aj.withRule(rule);
return withClient(client);
},
async protect(
request: ArcjetNextRequest,
...[props]: ExtraProps<Rules> extends WithoutCustomProps
? []
: [ExtraProps<Rules>]
): Promise<ArcjetDecision> {
// We construct an ArcjetHeaders to normalize over Headers
const headers = new ArcjetHeaders(request.headers);

const ip = findIP(request, headers);
const method = request.method ?? "";
const host = headers.get("host") ?? "";
let path = "";
let query = "";
let protocol = "";
// TODO(#36): nextUrl has formatting logic when you `toString` but
// we don't account for that here
if (typeof request.nextUrl !== "undefined") {
path = request.nextUrl.pathname ?? "";
if (typeof request.nextUrl.search !== "undefined") {
query = request.nextUrl.search;
}
if (typeof request.nextUrl.protocol !== "undefined") {
protocol = request.nextUrl.protocol;
}
} else {
if (typeof request.socket?.encrypted !== "undefined") {
protocol = request.socket.encrypted ? "https:" : "http:";
} else {
protocol = "http:";
}
// Do some very simple validation, but also try/catch around URL parsing
if (
typeof request.url !== "undefined" &&
request.url !== "" &&
host !== ""
) {
try {
const url = new URL(request.url, `${protocol}//${host}`);
path = url.pathname;
query = url.search;
protocol = url.protocol;
} catch {
// If the parsing above fails, just set the path as whatever url we
// received.
// TODO(#216): Add logging to arcjet-next
path = request.url ?? "";
}
} else {
path = request.url ?? "";
}
}
const cookies = cookiesToString(request.cookies);

const extra: { [key: string]: string } = {};

// If we're running on Vercel, we can add some extra information
if (process.env["VERCEL"]) {
// Vercel ID https://vercel.com/docs/concepts/edge-network/headers
extra["vercel-id"] = headers.get("x-vercel-id") ?? "";
// Vercel deployment URL
// https://vercel.com/docs/concepts/edge-network/headers
extra["vercel-deployment-url"] =
headers.get("x-vercel-deployment-url") ?? "";
// Vercel git commit SHA
// https://vercel.com/docs/concepts/projects/environment-variables/system-environment-variables
extra["vercel-git-commit-sha"] =
process.env["VERCEL_GIT_COMMIT_SHA"] ?? "";
extra["vercel-git-commit-sha"] =
process.env["VERCEL_GIT_COMMIT_SHA"] ?? "";
}

const decision = await aj.protect({
...props,
ip,
method,
protocol,
host,
path,
headers,
cookies,
query,
...extra,
// TODO(#220): The generic manipulations get really mad here, so we just cast it
} as ArcjetRequest<ExtraProps<Rules>>);

return decision;
// TODO(#220): The generic manipulations get really mad here, so we cast
// Further investigation makes it seem like it has something to do with
// the definition of `props` in the signature but it's hard to track down
const req = toArcjetRequest(request, props ?? {}) as ArcjetRequest<
ExtraProps<Rules>
>;

return aj.protect(req);
},
});
}

/**
* Create a new {@link ArcjetNext} client. Always build your initial client
* outside of a request handler so it persists across requests. If you need to
* augment a client inside a handler, call the `withRule()` function on the base
* client.
*
* @param options - Arcjet configuration options to apply to all requests.
*/
export default function arcjetNext<const Rules extends (Primitive | Product)[]>(
options: ArcjetOptions<Rules>,
): ArcjetNext<Simplify<ExtraProps<Rules>>> {
const client = options.client ?? createNextRemoteClient();

const aj = arcjet({ ...options, client });

return withClient(aj);
}

/**
* Protects your Next.js application using Arcjet middleware.
*
Expand Down
Loading
Loading