diff --git a/lib/api/api-stream.js b/lib/api/api-stream.js index f33f459f9d4..7560a2e6505 100644 --- a/lib/api/api-stream.js +++ b/lib/api/api-stream.js @@ -1,10 +1,11 @@ 'use strict' -const { finished } = require('stream') +const { finished, PassThrough } = require('stream') const { InvalidArgumentError, InvalidReturnValueError, - RequestAbortedError + RequestAbortedError, + ResponseStatusCodeError } = require('../core/errors') const util = require('../core/util') const { AsyncResource } = require('async_hooks') @@ -16,7 +17,7 @@ class StreamHandler extends AsyncResource { throw new InvalidArgumentError('invalid opts') } - const { signal, method, opaque, body, onInfo, responseHeaders } = opts + const { signal, method, opaque, body, onInfo, responseHeaders, throwOnError } = opts try { if (typeof callback !== 'function') { @@ -57,6 +58,7 @@ class StreamHandler extends AsyncResource { this.trailers = null this.body = body this.onInfo = onInfo || null + this.throwOnError = throwOnError || false if (util.isStream(body)) { body.on('error', (err) => { @@ -76,8 +78,8 @@ class StreamHandler extends AsyncResource { this.context = context } - onHeaders (statusCode, rawHeaders, resume) { - const { factory, opaque, context } = this + onHeaders (statusCode, rawHeaders, resume, statusMessage) { + const { factory, opaque, context, callback } = this if (statusCode < 200) { if (this.onInfo) { @@ -96,6 +98,32 @@ class StreamHandler extends AsyncResource { context }) + if (this.throwOnError && statusCode >= 400) { + const headers = this.responseHeaders === 'raw' ? util.parseRawHeaders(rawHeaders) : util.parseHeaders(rawHeaders) + const chunks = [] + const pt = new PassThrough() + pt + .on('data', (chunk) => chunks.push(chunk)) + .on('end', () => { + const payload = Buffer.concat(chunks).toString('utf8') + this.runInAsyncScope( + callback, + null, + new ResponseStatusCodeError( + `Response status code ${statusCode}${statusMessage ? `: ${statusMessage}` : ''}`, + statusCode, + headers, + payload + ) + ) + }) + .on('error', (err) => { + this.onError(err) + }) + this.res = pt + return + } + if ( !res || typeof res.write !== 'function' || diff --git a/test/async_hooks.js b/test/async_hooks.js index 8a77af35d72..2e8533d2d9b 100644 --- a/test/async_hooks.js +++ b/test/async_hooks.js @@ -157,7 +157,7 @@ test('async hooks client is destroyed', (t) => { const client = new Client(`http://localhost:${server.address().port}`) t.teardown(client.destroy.bind(client)) - client.request({ path: '/', method: 'GET' }, (err, { body }) => { + client.request({ path: '/', method: 'GET', throwOnError: true }, (err, { body }) => { t.error(err) body.resume() body.on('error', (err) => { diff --git a/test/client-stream.js b/test/client-stream.js index 2ff5fa53563..e67727b74c7 100644 --- a/test/client-stream.js +++ b/test/client-stream.js @@ -785,4 +785,63 @@ test('stream legacy needDrain', (t) => { t.pass() }) }) + + test('steam throwOnError', (t) => { + t.plan(2) + + const errStatusCode = 500 + const errMessage = 'Internal Server Error' + + const server = createServer((req, res) => { + res.writeHead(errStatusCode, { 'Content-Type': 'text/plain' }) + res.end(errMessage) + }) + t.teardown(server.close.bind(server)) + + server.listen(0, async () => { + const client = new Client(`http://localhost:${server.address().port}`) + t.teardown(client.close.bind(client)) + + client.stream({ + path: '/', + method: 'GET', + throwOnError: true, + opaque: new PassThrough() + }, ({ opaque: pt }) => { + pt.on('data', () => { + t.fail() + }) + return pt + }, (e) => { + t.equal(e.status, errStatusCode) + t.equal(e.body, errMessage) + t.end() + }) + }) + }) + + test('steam throwOnError=true, error on stream', (t) => { + t.plan(1) + + const server = createServer((req, res) => { + res.end('asd') + }) + t.teardown(server.close.bind(server)) + + server.listen(0, async () => { + const client = new Client(`http://localhost:${server.address().port}`) + t.teardown(client.close.bind(client)) + + client.stream({ + path: '/', + method: 'GET', + throwOnError: true, + opaque: new PassThrough() + }, () => { + throw new Error('asd') + }, (e) => { + t.equal(e.message, 'asd') + }) + }) + }) })