Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 118 additions & 39 deletions app/javascript/packages/request/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,135 @@ import { request } from '.';
describe('request', () => {
const sandbox = useSandbox();

it('includes the CSRF token by default', async () => {
const csrf = 'TYsqyyQ66Y';
const mockGetCSRF = () => csrf;
describe('csrf token header', () => {
it('does not include the CSRF token', async () => {
const csrf = 'TYsqyyQ66Y';
const mockGetCSRF = () => csrf;

sandbox.stub(window, 'fetch').callsFake((url, init = {}) => {
const headers = init.headers as Headers;
expect(headers.get('X-CSRF-Token')).to.equal(csrf);
sandbox.stub(window, 'fetch').callsFake((url, init = {}) => {
const headers = init.headers as Headers;
expect(headers.has('X-CSRF-Token')).to.be.false();

return Promise.resolve(
new Response(JSON.stringify({}), {
status: 200,
}),
);
});
return Promise.resolve(
new Response(JSON.stringify({}), {
status: 200,
}),
);
});

await request('https://example.com', {
csrf: mockGetCSRF,
await request('https://example.com', {
csrf: mockGetCSRF,
});

expect(window.fetch).to.have.been.calledOnce();
});

expect(window.fetch).to.have.been.calledOnce();
});
context('with a GET request', () => {
it('does not include the CSRF token', async () => {
const csrf = 'TYsqyyQ66Y';
const mockGetCSRF = () => csrf;

it('works even if the CSRF token is not found on the page', async () => {
sandbox.stub(window, 'fetch').callsFake(() =>
Promise.resolve(
new Response(JSON.stringify({}), {
status: 200,
}),
),
);
sandbox.stub(window, 'fetch').callsFake((url, init = {}) => {
const headers = init.headers as Headers;
expect(headers.has('X-CSRF-Token')).to.be.false();

await request('https://example.com', {
csrf: () => undefined,
return Promise.resolve(
new Response(JSON.stringify({}), {
status: 200,
}),
);
});

await request('https://example.com', {
csrf: mockGetCSRF,
method: 'GET',
});

expect(window.fetch).to.have.been.calledOnce();
});
});
});

it('does not try to send a csrf when csrf is false', async () => {
sandbox.stub(window, 'fetch').callsFake((url, init = {}) => {
const headers = init.headers as Headers;
expect(headers.get('X-CSRF-Token')).to.be.null();
context('with a HEAD request', () => {
it('does not include the CSRF token', async () => {
const csrf = 'TYsqyyQ66Y';
const mockGetCSRF = () => csrf;

return Promise.resolve(
new Response(JSON.stringify({}), {
status: 200,
}),
);
sandbox.stub(window, 'fetch').callsFake((url, init = {}) => {
const headers = init.headers as Headers;
expect(headers.has('X-CSRF-Token')).to.be.false();

return Promise.resolve(
new Response(JSON.stringify({}), {
status: 200,
}),
);
});

await request('https://example.com', {
csrf: mockGetCSRF,
method: 'HEAD',
});

expect(window.fetch).to.have.been.calledOnce();
});
});

await request('https://example.com', {
csrf: false,
context('with a request method other than exempt methods', () => {
it('includes the CSRF token', async () => {
const csrf = 'TYsqyyQ66Y';
const mockGetCSRF = () => csrf;

sandbox.stub(window, 'fetch').callsFake((url, init = {}) => {
const headers = init.headers as Headers;
expect(headers.get('X-CSRF-Token')).to.equal(csrf);

return Promise.resolve(
new Response(JSON.stringify({}), {
status: 200,
}),
);
});

await request('https://example.com', {
csrf: mockGetCSRF,
method: 'PUT',
});

expect(window.fetch).to.have.been.calledOnce();
});

it('works even if the CSRF token is not found on the page', async () => {
sandbox.stub(window, 'fetch').callsFake(() =>
Promise.resolve(
new Response(JSON.stringify({}), {
status: 200,
}),
),
);

await request('https://example.com', {
csrf: () => undefined,
method: 'PUT',
});
});

it('does not try to send a csrf when csrf is false', async () => {
sandbox.stub(window, 'fetch').callsFake((url, init = {}) => {
const headers = init.headers as Headers;
expect(headers.get('X-CSRF-Token')).to.be.null();

return Promise.resolve(
new Response(JSON.stringify({}), {
status: 200,
}),
);
});

await request('https://example.com', {
csrf: false,
method: 'PUT',
});
});
});
});

Expand Down Expand Up @@ -213,7 +292,7 @@ describe('request', () => {
it('uses response token for next request', async () => {
await request('https://example.com', {});
(window.fetch as SinonStub).resetHistory();
await request('https://example.com', {});
await request('https://example.com', { method: 'PUT' });
expect(window.fetch).to.have.been.calledWith(
sinon.match.string,
sinon.match((init) => init!.headers!.get('x-csrf-token') === 'new-token'),
Expand Down
15 changes: 14 additions & 1 deletion app/javascript/packages/request/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@ class CSRF {
}
}

/**
* Returns true if the request associated with the given options would require a valid CSRF token,
* or false otherwise.
*
* @see https://github.com/rails/rails/blob/v7.0.5/actionpack/lib/action_controller/metal/request_forgery_protection.rb#L335-L343
*
* @param options Request options
*
* @return Whether the request would require a CSRF token
*/
const isCSRFValidatedRequest = (options: RequestOptions) =>
!!options.method && !['GET', 'HEAD'].includes(options.method.toUpperCase());

export async function request<Response = any>(
url,
options?: Partial<RequestOptions> & { read?: true },
Expand All @@ -67,7 +80,7 @@ export async function request(url: string, options: Partial<RequestOptions> = {}
let { body, headers } = fetchOptions;
headers = new Headers(headers);

if (csrf) {
if (csrf && isCSRFValidatedRequest(fetchOptions)) {
const csrfToken = typeof csrf === 'boolean' ? CSRF.token : csrf();

if (csrfToken) {
Expand Down