diff --git a/.changeset/spotty-masks-beg.md b/.changeset/spotty-masks-beg.md new file mode 100644 index 0000000000..6e436efef5 --- /dev/null +++ b/.changeset/spotty-masks-beg.md @@ -0,0 +1,6 @@ +--- +"@react-router/dev": minor +"react-router": minor +--- + +Add additional layer of CSRF protection by rejecting submissions to UI routes from external origins. If you need to permit access to specific external origins, you can specify them in the `react-router.config.ts` config `allowedActionOrigins` field. diff --git a/integration/vite-presets-test.ts b/integration/vite-presets-test.ts index 6bf1d96221..502902d703 100644 --- a/integration/vite-presets-test.ts +++ b/integration/vite-presets-test.ts @@ -238,6 +238,7 @@ test.describe("Vite / presets", async () => { "serverBundles", "serverModuleFormat", "ssr", + "allowedActionOrigins", "unstable_routeConfig", ]); diff --git a/packages/react-router-dev/config/config.ts b/packages/react-router-dev/config/config.ts index 1a0892ec88..604985deaa 100644 --- a/packages/react-router-dev/config/config.ts +++ b/packages/react-router-dev/config/config.ts @@ -211,6 +211,12 @@ export type ReactRouterConfig = { * SPA without server-rendering. Default's to `true`. */ ssr?: boolean; + + /** + * The allowed origins for actions / mutations. Does not apply to routes + * without a component. micromatch glob patterns are supported. + */ + allowedActionOrigins?: string[]; }; export type ResolvedReactRouterConfig = Readonly<{ @@ -277,6 +283,11 @@ export type ResolvedReactRouterConfig = Readonly<{ * SPA without server-rendering. Default's to `true`. */ ssr: boolean; + /** + * The allowed origins for actions / mutations. Does not apply to routes + * without a component. micromatch glob patterns are supported. + */ + allowedActionOrigins: string[] | false; /** * The resolved array of route config entries exported from `routes.ts` */ @@ -645,6 +656,8 @@ async function resolveConfig({ userAndPresetConfigs.future?.v8_viteEnvironmentApi ?? false, }; + let allowedActionOrigins = userAndPresetConfigs.allowedActionOrigins ?? false; + let reactRouterConfig: ResolvedReactRouterConfig = deepFreeze({ appDirectory, basename, @@ -658,6 +671,7 @@ async function resolveConfig({ serverBundles, serverModuleFormat, ssr, + allowedActionOrigins, unstable_routeConfig: routeConfig, } satisfies ResolvedReactRouterConfig); diff --git a/packages/react-router-dev/typegen/generate.ts b/packages/react-router-dev/typegen/generate.ts index 72b5b46fb1..a18cd5f198 100644 --- a/packages/react-router-dev/typegen/generate.ts +++ b/packages/react-router-dev/typegen/generate.ts @@ -48,6 +48,7 @@ export function generateServerBuild(ctx: Context): VirtualFile { export const routeDiscovery: ServerBuild["routeDiscovery"]; export const routes: ServerBuild["routes"]; export const ssr: ServerBuild["ssr"]; + export const allowedActionOrigins: ServerBuild["allowedActionOrigins"]; export const unstable_getCriticalCss: ServerBuild["unstable_getCriticalCss"]; } `; diff --git a/packages/react-router-dev/vite/plugin.ts b/packages/react-router-dev/vite/plugin.ts index 6ed5990bbe..62935c15bc 100644 --- a/packages/react-router-dev/vite/plugin.ts +++ b/packages/react-router-dev/vite/plugin.ts @@ -871,7 +871,9 @@ export const reactRouterVitePlugin: ReactRouterVitePlugin = () => { } ` : "" - }`; + } + export const allowedActionOrigins = ${JSON.stringify(ctx.reactRouterConfig.allowedActionOrigins)}; + `; }; let loadViteManifest = async (directory: string) => { diff --git a/packages/react-router/lib/actions.ts b/packages/react-router/lib/actions.ts new file mode 100644 index 0000000000..0118cbc2fb --- /dev/null +++ b/packages/react-router/lib/actions.ts @@ -0,0 +1,122 @@ +export function throwIfPotentialCSRFAttack( + headers: Headers, + allowedActionOrigins: string[] | undefined, +) { + let originHeader = headers.get("origin"); + let originDomain = + typeof originHeader === "string" && originHeader !== "null" + ? new URL(originHeader).host + : originHeader; + let host = parseHostHeader(headers); + + if (originDomain && (!host || originDomain !== host.value)) { + if (!isAllowedOrigin(originDomain, allowedActionOrigins)) { + if (host) { + // This seems to be an CSRF attack. We should not proceed with the action. + throw new Error( + `${host.type} header does not match \`origin\` header from a forwarded ` + + `action request. Aborting the action.`, + ); + } else { + // This is an attack. We should not proceed with the action. + throw new Error( + "`x-forwarded-host` or `host` headers are not provided. One of these " + + "is needed to compare the `origin` header from a forwarded action " + + "request. Aborting the action.", + ); + } + } + } +} + +// Implementation of micromatch by Next.js https://github.com/vercel/next.js/blob/ea927b583d24f42e538001bf13370e38c91d17bf/packages/next/src/server/app-render/csrf-protection.ts#L6 +function matchWildcardDomain(domain: string, pattern: string) { + const domainParts = domain.split("."); + const patternParts = pattern.split("."); + + if (patternParts.length < 1) { + // pattern is empty and therefore invalid to match against + return false; + } + + if (domainParts.length < patternParts.length) { + // domain has too few segments and thus cannot match + return false; + } + + // Prevent wildcards from matching entire domains (e.g. '**' or '*.com') + // This ensures wildcards can only match subdomains, not the main domain + if ( + patternParts.length === 1 && + (patternParts[0] === "*" || patternParts[0] === "**") + ) { + return false; + } + + while (patternParts.length) { + const patternPart = patternParts.pop(); + const domainPart = domainParts.pop(); + + switch (patternPart) { + case "": { + // invalid pattern. pattern segments must be non empty + return false; + } + case "*": { + // wildcard matches anything so we continue if the domain part is non-empty + if (domainPart) { + continue; + } else { + return false; + } + } + case "**": { + // if this is not the last item in the pattern the pattern is invalid + if (patternParts.length > 0) { + return false; + } + // recursive wildcard matches anything so we terminate here if the domain part is non empty + return domainPart !== undefined; + } + case undefined: + default: { + if (domainPart !== patternPart) { + return false; + } + } + } + } + + // We exhausted the pattern. If we also exhausted the domain we have a match + return domainParts.length === 0; +} + +function isAllowedOrigin( + originDomain: string, + allowedActionOrigins: string[] | undefined = [], +) { + return allowedActionOrigins.some( + (allowedOrigin) => + allowedOrigin && + (allowedOrigin === originDomain || + matchWildcardDomain(originDomain, allowedOrigin)), + ); +} + +function parseHostHeader(headers: Headers) { + let forwardedHostHeader = headers.get("x-forwarded-host"); + let forwardedHostValue = forwardedHostHeader?.split(",")[0]?.trim(); + let hostHeader = headers.get("host"); + + return forwardedHostValue + ? { + type: "x-forwarded-host", + value: forwardedHostValue, + } + : hostHeader + ? { + type: "host", + value: hostHeader, + } + : undefined; +} diff --git a/packages/react-router/lib/rsc/server.rsc.ts b/packages/react-router/lib/rsc/server.rsc.ts index 87d8d1b06c..21a385311c 100644 --- a/packages/react-router/lib/rsc/server.rsc.ts +++ b/packages/react-router/lib/rsc/server.rsc.ts @@ -38,6 +38,7 @@ import { } from "../router/utils"; import { getDocumentHeadersImpl } from "../server-runtime/headers"; import { SINGLE_FETCH_REDIRECT_STATUS } from "../dom/ssr/single-fetch"; +import { throwIfPotentialCSRFAttack } from "../actions"; import type { RouteMatch, RouteObject } from "../context"; import invariant from "../server-runtime/invariant"; @@ -331,6 +332,7 @@ export type LoadServerActionFunction = (id: string) => Promise; * @category RSC * @mode data * @param opts Options + * @param opts.allowedActionOrigins Origin patterns that are allowed to execute actions. * @param opts.basename The basename to use when matching the request. * @param opts.createTemporaryReferenceSet A function that returns a temporary * reference set for the request, used to track temporary references in the [RSC](https://react.dev/reference/rsc/server-components) @@ -361,6 +363,7 @@ export type LoadServerActionFunction = (id: string) => Promise; * data for hydration. */ export async function matchRSCServerRequest({ + allowedActionOrigins, createTemporaryReferenceSet, basename, decodeReply, @@ -373,6 +376,7 @@ export async function matchRSCServerRequest({ routes, generateResponse, }: { + allowedActionOrigins?: string[]; createTemporaryReferenceSet: () => unknown; basename?: string; decodeReply?: DecodeReplyFunction; @@ -477,6 +481,7 @@ export async function matchRSCServerRequest({ onError, generateResponse, temporaryReferences, + allowedActionOrigins, ); // The front end uses this to know whether a 4xx/5xx status came from app code // or never reached the origin server @@ -754,6 +759,7 @@ async function generateRenderResponse( }, ) => Response, temporaryReferences: unknown, + allowedActionOrigins: string[] | undefined, ): Promise { // If this is a RR submission, we just want the `actionData` but don't want // to call any loaders or render any components back in the response - that @@ -799,6 +805,8 @@ async function generateRenderResponse( let formState: unknown; let skipRevalidation = false; if (request.method === "POST") { + throwIfPotentialCSRFAttack(request.headers, allowedActionOrigins); + ctx.runningAction = true; let result = await processServerAction( request, diff --git a/packages/react-router/lib/server-runtime/build.ts b/packages/react-router/lib/server-runtime/build.ts index c18e7e3d85..4a1a6200e3 100644 --- a/packages/react-router/lib/server-runtime/build.ts +++ b/packages/react-router/lib/server-runtime/build.ts @@ -12,11 +12,7 @@ import type { import type { ServerRouteManifest } from "./routes"; import type { AppLoadContext } from "./data"; import type { MiddlewareEnabled } from "../types/future"; -import type { - unstable_InstrumentRequestHandlerFunction, - unstable_InstrumentRouteFunction, - unstable_ServerInstrumentation, -} from "../router/instrumentation"; +import type { unstable_ServerInstrumentation } from "../router/instrumentation"; type OptionalCriticalCss = CriticalCss | undefined; @@ -46,6 +42,7 @@ export interface ServerBuild { mode: "lazy" | "initial"; manifestPath: string; }; + allowedActionOrigins?: string[] | false; } export interface HandleDocumentRequestFunction { diff --git a/packages/react-router/lib/server-runtime/server.ts b/packages/react-router/lib/server-runtime/server.ts index 8ce88bf45f..14d1c8b4e5 100644 --- a/packages/react-router/lib/server-runtime/server.ts +++ b/packages/react-router/lib/server-runtime/server.ts @@ -38,6 +38,7 @@ import type { MiddlewareEnabled } from "../types/future"; import { getManifestPath } from "../dom/ssr/fog-of-war"; import type { unstable_InstrumentRequestHandlerFunction } from "../router/instrumentation"; import { instrumentHandler } from "../router/instrumentation"; +import { throwIfPotentialCSRFAttack } from "../actions"; export type RequestHandler = ( request: Request, @@ -481,6 +482,14 @@ async function handleDocumentRequest( criticalCss?: CriticalCss, ) { try { + if (request.method === "POST") { + throwIfPotentialCSRFAttack( + request.headers, + Array.isArray(build.allowedActionOrigins) + ? build.allowedActionOrigins + : [], + ); + } let result = await staticHandler.query(request, { requestContext: loadContext, generateMiddlewareResponse: build.future.v8_middleware diff --git a/packages/react-router/lib/server-runtime/single-fetch.ts b/packages/react-router/lib/server-runtime/single-fetch.ts index 52e76df5ff..944afa948b 100644 --- a/packages/react-router/lib/server-runtime/single-fetch.ts +++ b/packages/react-router/lib/server-runtime/single-fetch.ts @@ -23,6 +23,7 @@ import { sanitizeError, sanitizeErrors } from "./errors"; import { ServerMode } from "./mode"; import { getDocumentHeaders } from "./headers"; import type { ServerBuild } from "./build"; +import { throwIfPotentialCSRFAttack } from "../actions"; // Add 304 for server side - that is not included in the client side logic // because the browser should fill those responses with the cached data @@ -42,6 +43,13 @@ export async function singleFetchAction( handleError: (err: unknown) => void, ): Promise { try { + throwIfPotentialCSRFAttack( + request.headers, + Array.isArray(build.allowedActionOrigins) + ? build.allowedActionOrigins + : [], + ); + let handlerRequest = new Request(handlerUrl, { method: request.method, body: request.body,