diff --git a/docs/docs/api/Agent.md b/docs/docs/api/Agent.md index 2a8e30bac14..0131b103975 100644 --- a/docs/docs/api/Agent.md +++ b/docs/docs/api/Agent.md @@ -19,6 +19,7 @@ Returns: `Agent` Extends: [`PoolOptions`](/docs/docs/api/Pool.md#parameter-pooloptions) * **factory** `(origin: URL, opts: Object) => Dispatcher` - Default: `(origin, opts) => new Pool(origin, opts)` +* **maxOrigins** `number` (optional) - Default: `Infinity` - Limits the total number of origins that can receive requests at a time, throwing an `MaxOriginsReachedError` error when attempting to dispatch when the max is reached. If `Infinity`, no limit is enforced. ## Instance Properties diff --git a/lib/core/errors.js b/lib/core/errors.js index 0ac3e190eb6..f8e2d9cb07b 100644 --- a/lib/core/errors.js +++ b/lib/core/errors.js @@ -359,6 +359,22 @@ class SecureProxyConnectionError extends UndiciError { [kSecureProxyConnectionError] = true } +const kMaxOriginsReachedError = Symbol.for('undici.error.UND_ERR_MAX_ORIGINS_REACHED') +class MaxOriginsReachedError extends UndiciError { + constructor (message) { + super(message) + this.name = 'MaxOriginsReachedError' + this.message = message || 'Maximum allowed origins reached' + this.code = 'UND_ERR_MAX_ORIGINS_REACHED' + } + + static [Symbol.hasInstance] (instance) { + return instance && instance[kMaxOriginsReachedError] === true + } + + [kMaxOriginsReachedError] = true +} + module.exports = { AbortError, HTTPParserError, @@ -381,5 +397,6 @@ module.exports = { ResponseExceededMaxSizeError, RequestRetryError, ResponseError, - SecureProxyConnectionError + SecureProxyConnectionError, + MaxOriginsReachedError } diff --git a/lib/dispatcher/agent.js b/lib/dispatcher/agent.js index 6d6cdecef65..781bc8cf0d1 100644 --- a/lib/dispatcher/agent.js +++ b/lib/dispatcher/agent.js @@ -1,6 +1,6 @@ 'use strict' -const { InvalidArgumentError } = require('../core/errors') +const { InvalidArgumentError, MaxOriginsReachedError } = require('../core/errors') const { kClients, kRunning, kClose, kDestroy, kDispatch, kUrl } = require('../core/symbols') const DispatcherBase = require('./dispatcher-base') const Pool = require('./pool') @@ -13,6 +13,7 @@ const kOnConnectionError = Symbol('onConnectionError') const kOnDrain = Symbol('onDrain') const kFactory = Symbol('factory') const kOptions = Symbol('options') +const kOrigins = Symbol('origins') function defaultFactory (origin, opts) { return opts && opts.connections === 1 @@ -21,7 +22,7 @@ function defaultFactory (origin, opts) { } class Agent extends DispatcherBase { - constructor ({ factory = defaultFactory, connect, ...options } = {}) { + constructor ({ factory = defaultFactory, maxOrigins = Infinity, connect, ...options } = {}) { if (typeof factory !== 'function') { throw new InvalidArgumentError('factory must be a function.') } @@ -30,15 +31,20 @@ class Agent extends DispatcherBase { throw new InvalidArgumentError('connect must be a function or an object') } + if (typeof maxOrigins !== 'number' || Number.isNaN(maxOrigins) || maxOrigins <= 0) { + throw new InvalidArgumentError('maxOrigins must be a number greater than 0') + } + super() if (connect && typeof connect !== 'function') { connect = { ...connect } } - this[kOptions] = { ...util.deepClone(options), connect } + this[kOptions] = { ...util.deepClone(options), maxOrigins, connect } this[kFactory] = factory this[kClients] = new Map() + this[kOrigins] = new Set() this[kOnDrain] = (origin, targets) => { this.emit('drain', origin, [this, ...targets]) @@ -73,6 +79,10 @@ class Agent extends DispatcherBase { throw new InvalidArgumentError('opts.origin must be a non-empty string or URL.') } + if (this[kOrigins].size >= this[kOptions].maxOrigins && !this[kOrigins].has(key)) { + throw new MaxOriginsReachedError() + } + const result = this[kClients].get(key) let dispatcher = result && result.dispatcher if (!dispatcher) { @@ -84,6 +94,7 @@ class Agent extends DispatcherBase { this[kClients].delete(key) result.dispatcher.close() } + this[kOrigins].delete(key) } } dispatcher = this[kFactory](opts.origin, this[kOptions]) @@ -105,6 +116,7 @@ class Agent extends DispatcherBase { }) this[kClients].set(key, { count: 0, dispatcher }) + this[kOrigins].add(key) } return dispatcher.dispatch(opts, handler) diff --git a/test/node-test/agent.js b/test/node-test/agent.js index 32977df5989..d9aee46b597 100644 --- a/test/node-test/agent.js +++ b/test/node-test/agent.js @@ -2,6 +2,7 @@ const { describe, test, after } = require('node:test') const assert = require('node:assert/strict') +const { once } = require('node:events') const http = require('node:http') const { PassThrough } = require('node:stream') const { kRunning } = require('../../lib/core/symbols') @@ -40,6 +41,42 @@ test('Agent', t => { p.doesNotThrow(() => new Agent()) }) +test('Agent enforces maxOrigins', async (t) => { + const p = tspl(t, { plan: 1 }) + + const dispatcher = new Agent({ + maxOrigins: 1, + keepAliveMaxTimeout: 100, + keepAliveTimeout: 100 + }) + t.after(() => dispatcher.close()) + + const handler = (_req, res) => { + setTimeout(() => res.end('ok'), 50) + } + + const server1 = http.createServer({ joinDuplicateHeaders: true }, handler) + server1.listen(0) + await once(server1, 'listening') + t.after(closeServerAsPromise(server1)) + + const server2 = http.createServer({ joinDuplicateHeaders: true }, handler) + server2.listen(0) + await once(server2, 'listening') + t.after(closeServerAsPromise(server2)) + + try { + await Promise.all([ + request(`http://localhost:${server1.address().port}`, { dispatcher }), + request(`http://localhost:${server2.address().port}`, { dispatcher }) + ]) + } catch (err) { + p.ok(err instanceof errors.MaxOriginsReachedError) + } + + await p.completed +}) + test('agent should call callback after closing internal pools', async (t) => { const p = tspl(t, { plan: 2 }) @@ -662,8 +699,10 @@ test('stream: fails with invalid onInfo', async (t) => { }) test('constructor validations', t => { - const p = tspl(t, { plan: 1 }) + const p = tspl(t, { plan: 3 }) p.throws(() => new Agent({ factory: 'ASD' }), errors.InvalidArgumentError, 'throws on invalid opts argument') + p.throws(() => new Agent({ maxOrigins: -1 }), errors.InvalidArgumentError, 'maxOrigins must be a number greater than 0') + p.throws(() => new Agent({ maxOrigins: 'foo' }), errors.InvalidArgumentError, 'maxOrigins must be a number greater than 0') }) test('dispatch validations', async t => { diff --git a/test/types/errors.test-d.ts b/test/types/errors.test-d.ts index bdaa4c69204..794d0a5e559 100644 --- a/test/types/errors.test-d.ts +++ b/test/types/errors.test-d.ts @@ -110,6 +110,11 @@ expectAssignable(new errors.SecureProxyConnec expectAssignable<'SecureProxyConnectionError'>(new errors.SecureProxyConnectionError().name) expectAssignable<'UND_ERR_PRX_TLS'>(new errors.SecureProxyConnectionError().code) +expectAssignable(new errors.MaxOriginsReachedError()) +expectAssignable(new errors.MaxOriginsReachedError()) +expectAssignable<'MaxOriginsReachedError'>(new errors.MaxOriginsReachedError().name) +expectAssignable<'UND_ERR_MAX_ORIGINS_REACHED'>(new errors.MaxOriginsReachedError().code) + { // @ts-ignore function f (): errors.HeadersTimeoutError | errors.ConnectTimeoutError { } diff --git a/types/agent.d.ts b/types/agent.d.ts index 8c881481a46..4bb3512c77b 100644 --- a/types/agent.d.ts +++ b/types/agent.d.ts @@ -24,6 +24,7 @@ declare namespace Agent { factory?(origin: string | URL, opts: Object): Dispatcher; interceptors?: { Agent?: readonly Dispatcher.DispatchInterceptor[] } & Pool.Options['interceptors'] + maxOrigins?: number } export interface DispatchOptions extends Dispatcher.DispatchOptions { diff --git a/types/errors.d.ts b/types/errors.d.ts index 0de05787a37..fbf31955611 100644 --- a/types/errors.d.ts +++ b/types/errors.d.ts @@ -153,4 +153,9 @@ declare namespace Errors { name: 'SecureProxyConnectionError' code: 'UND_ERR_PRX_TLS' } + + class MaxOriginsReachedError extends UndiciError { + name: 'MaxOriginsReachedError' + code: 'UND_ERR_MAX_ORIGINS_REACHED' + } }