diff --git a/benchmark/express-ws.js b/benchmark/express-ws.js index 97208bb..56fe2b6 100644 --- a/benchmark/express-ws.js +++ b/benchmark/express-ws.js @@ -1,5 +1,5 @@ const express = require('express') -const Protocol = require('fast-ws-server/js/ws-protocol/fast-ws') +const Protocol = require('fast-ws-server/ws/fast-ws') const { Readable } = require('stream') const app = express() @@ -35,6 +35,10 @@ app.get('/stream', (req, res) => { stream.pipe(res) }) +app.post('/stream', (req, res) => { + req.pipe(res) +}) + app.use('/', express.static('static')) console.time('STARTUP') diff --git a/benchmark/fast-ws.js b/benchmark/fast-ws.js index 4fd4e75..7a5c47b 100644 --- a/benchmark/fast-ws.js +++ b/benchmark/fast-ws.js @@ -27,6 +27,14 @@ app.get('/stream', (req, res) => { stream.pipe(res) }) +app.post('/stream', (req, res) => { + req.bodyStream.pipe(res) +}) + +app.post('/stream/send', async (req, res) => { + res.send(await req.body) +}) + app.serve('/') console.time('STARTUP') diff --git a/benchmark/nanoexpress.js b/benchmark/nanoexpress.js index 7a52620..a84d3b7 100644 --- a/benchmark/nanoexpress.js +++ b/benchmark/nanoexpress.js @@ -1,6 +1,6 @@ const nanoexpress = require('nanoexpress') const staticServe = require('@nanoexpress/middleware-static-serve/cjs') -const Protocol = require('fast-ws-server/js/ws-protocol/fast-ws') +const Protocol = require('fast-ws-server/ws/fast-ws') const app = nanoexpress() @@ -31,6 +31,10 @@ app.get('/hello/:name', async (req, res) => { res.end(`Hello ${name}`) }) +app.post('/stream', (req, res) => { + res.send(req.body) +}) + app.use('/', staticServe('./static')) app.listen(3000) diff --git a/benchmark/post.lua b/benchmark/post.lua new file mode 100644 index 0000000..21e783e --- /dev/null +++ b/benchmark/post.lua @@ -0,0 +1,3 @@ +wrk.method = "POST" +wrk.body = string.rep("-TEST_STRING-", 8192) +wrk.headers["Content-Type"] = "text/plain" diff --git a/packages/server/js/connection.js b/packages/server/js/connection.js index ed86739..94d4d87 100644 --- a/packages/server/js/connection.js +++ b/packages/server/js/connection.js @@ -9,8 +9,6 @@ const { cache, templateEngine, maxBodySize } = require('./constants') const methodsWithBody = ['POST', 'PUT', 'PATCH', 'OPTIONS'] -const emptyBuffer = Buffer.alloc(0) - class Connection { constructor (app, request, response, wsContext) { this.app = app @@ -24,11 +22,13 @@ class Connection { this.headers[k] = v } }) + this.url = this.request.getUrl() + this.method = this.request.getMethod().toUpperCase() this.rawQuery = this.request.getQuery() this._req_info = {} this._method = null this._body = null - this._body_stream = null + this._bodyStream = null this._reject_data = null this._on_aborted = [] this._on_writable = null @@ -42,25 +42,16 @@ class Connection { this._on_writable ? this._on_writable(offset) : true) this.wsContext = wsContext this._remote_address = null + this.processBodyData() } static create (app, request, response, wsContext = null) { return new Connection(app, request, response, wsContext) } - bodyDataStream () { - if (!this.response) { - throw new ServerError({ code: 'SERVER_INVALID_CONNECTION' }) - } - if (!methodsWithBody.includes(this.method)) { - throw new ServerError({ - code: 'SERVER_INVALID_OPERATE', - message: `The method "${this.method}" should not have body.` - }) - } - if (this._body_stream !== null) { - return this._body_stream - } + processBodyData () { + if (!this.response) return + if (!methodsWithBody.includes(this.method)) return const contentLength = this.headers['content-length'] // Verify Content-Length if (!contentLength) { @@ -70,68 +61,83 @@ class Connection { } else if (this.bodyLimit && Number(contentLength) > this.bodyLimit) { throw new ServerError({ code: 'CLIENT_LENGTH_TOO_LARGE', message: '', httpCode: 413 }) } - const length = Number(contentLength || 0) - const chunks = [] - let received = 0 - let isEnd = false - let error = null - let callback = null - this._body_stream = new Readable({ - read (size) { - if (error) { - this.destroy(error) - } else { - if (!chunks.length && !isEnd) { - callback = () => { - if (error) { - this.destroy(error) - } else { - const chunk = chunks.shift() - this.push(chunk) - } - } - } else { - const chunk = chunks.shift() - if (!chunk && isEnd) this.push(null) - else this.push(chunk || emptyBuffer) - } - } - } - }) - this._body_stream.bodyLength = length + this.bodyLength = Number(contentLength || 0) + this._buffer = [] + this._dataEnd = false + this._received = 0 + this._onData = null + this._dataError = null this.onAborted(() => { - isEnd = true - error = new ServerError({ code: 'CONNECTION_ABORTED' }) - if (callback) { - callback() + this._dataEnd = true + this._dataError = new ServerError({ code: 'CONNECTION_ABORTED' }) + if (this._onData) { + this._onData() } }) this.response.onData((chunk, isLast) => { - if (isEnd) return - received += chunk.byteLength - if (length && received > length) { - isEnd = true - error = new ServerError({ code: 'CLIENT_BAD_REQUEST', httpCode: 400 }) - } else if (length && isLast && received < length) { - isEnd = true - error = new ServerError({ code: 'CLIENT_BAD_REQUEST', httpCode: 400 }) + if (this._dataEnd) return + this._received += chunk.byteLength + if (this.bodyLength && this._received > this.bodyLength) { + this._dataEnd = true + this._dataError = new ServerError({ code: 'CLIENT_BAD_REQUEST', httpCode: 400 }) + } else if (this.bodyLength && isLast && this._received < this.bodyLength) { + this._dataEnd = true + this._dataError = new ServerError({ code: 'CLIENT_BAD_REQUEST', httpCode: 400 }) } else { - chunks.push(Buffer.from(chunk)) - isEnd = isLast + // Copy buffer to avoid memory release + this._buffer.push(Buffer.from(Buffer.from(chunk))) + this._dataEnd = isLast } - if (callback) { - callback() + if (this._onData) { + this._onData() } }) - return this._body_stream + } + + bodyDataStream () { + if (!this.response) { + throw new ServerError({ code: 'SERVER_INVALID_CONNECTION' }) + } + if (!methodsWithBody.includes(this.method)) { + throw new ServerError({ + code: 'SERVER_INVALID_OPERATE', + message: `The method "${this.method}" should not have body.` + }) + } + if (this._bodyStream !== null) { + return this._bodyStream + } + const readData = (callback) => { + if (this._dataError) { + callback(this._dataError) + } else { + if (!this._buffer.length && !this._dataEnd) { + this._onData = () => readData(callback) + } else { + this._onData = null + const chunk = this._buffer.shift() + if (!chunk && this._dataEnd) callback(null, null) + else callback(null, Buffer.from(chunk)) + } + } + } + this._bodyStream = new Readable({ + read () { + readData((err, chunk) => { + if (err) this.destroy(err) + else this.push(chunk) + }) + } + }) + this._bodyStream.bodyLength = this.bodyLength + return this._bodyStream } bodyData () { if (this._body !== null) { return this._body } - const type = this.headers['content-type'] - if (!type) return null + const type = this.headers['content-type'] || 'application/octet-stream' const stream = this.bodyDataStream() this._body = new Promise((resolve, reject) => { let data = null @@ -185,14 +191,6 @@ class Connection { ) } - get url () { - return this.getInfo('url', () => this.request.getUrl()) - } - - get method () { - return this.getInfo('method', () => this.request.getMethod().toUpperCase()) - } - getInfo (name, valueFn) { if (!this._req_info[name]) this._req_info[name] = valueFn() return this._req_info[name] diff --git a/packages/server/js/response.js b/packages/server/js/response.js index 6fee479..6d77d71 100644 --- a/packages/server/js/response.js +++ b/packages/server/js/response.js @@ -264,7 +264,7 @@ class Response extends Writable { } send (data, contentType = null) { - if (!contentType && !this._headers['content-type']) { + if (!contentType && !this._headers['content-type'] && typeof data === 'string') { this._headers['content-type'] = data.includes('') ? 'text/html' : 'text/plain' } else if (contentType) { this._headers['content-type'] = contentType diff --git a/packages/server/package.json b/packages/server/package.json index ed88108..335e63b 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -12,7 +12,8 @@ "./request": "./js/request.js", "./response": "./js/response.js", "./error": "./js/errors.js", - "./ws-base": "./js/ws-protocol/basic.js" + "./ws-base": "./js/ws-protocol/basic.js", + "./ws/": "./js/ws-protocol/" }, "repository": { "type": "git", diff --git a/test/cases/http-pipe-body-stream-later-1.js b/test/cases/http-pipe-body-stream-later-1.js new file mode 100644 index 0000000..0213a59 --- /dev/null +++ b/test/cases/http-pipe-body-stream-later-1.js @@ -0,0 +1,16 @@ +const axios = require('axios') + +module.exports = async function ({ HTTP_PORT }) { + const body = '__TEST__STRING__'.repeat(4096) + + const res = await axios.post(`http://localhost:${HTTP_PORT}/stream/body-later-1`, body) + if (res.status !== 200) { + throw new Error(`Response ${res.status}`) + } + if (res.headers['content-length'] !== body.length.toString()) { + throw new Error('Content-Length mismatch') + } + if (res.data !== body) { + throw new Error('Response data mismatch') + } +} diff --git a/test/cases/http-pipe-body-stream-later-2.js b/test/cases/http-pipe-body-stream-later-2.js new file mode 100644 index 0000000..7d10502 --- /dev/null +++ b/test/cases/http-pipe-body-stream-later-2.js @@ -0,0 +1,16 @@ +const axios = require('axios') + +module.exports = async function ({ HTTP_PORT }) { + const body = '__TEST__STRING__'.repeat(4096) + + const res = await axios.post(`http://localhost:${HTTP_PORT}/stream/body-later-2`, body) + if (res.status !== 200) { + throw new Error(`Response ${res.status}`) + } + if (res.headers['content-length'] !== body.length.toString()) { + throw new Error('Content-Length mismatch') + } + if (res.data !== body) { + throw new Error('Response data mismatch') + } +} diff --git a/test/prepare/app.js b/test/prepare/app.js index 9d9ede8..af8ea0a 100644 --- a/test/prepare/app.js +++ b/test/prepare/app.js @@ -114,6 +114,15 @@ module.exports = function (app) { req.bodyStream.pipe(res) }) + app.post('/stream/body-later-1', (req, res) => { + setTimeout(() => req.bodyStream.pipe(res), 100) + }) + + app.post('/stream/body-later-2', (req, res) => { + const stream = req.bodyStream + setTimeout(() => stream.pipe(res), 100) + }) + app.get('/stream/error', (req, res) => { const stream = new Stream.Readable({ read: () => '',