From 93957828be1252c83275b56f0c7c0bd145a0ceb9 Mon Sep 17 00:00:00 2001 From: Ciel <9755720+cieldeville@users.noreply.github.com> Date: Tue, 2 May 2023 00:00:47 +0200 Subject: [PATCH] fix: include error handling for Express middlewares (#674) Following https://github.com/socketio/engine.io/commit/24786e77c5403b1c4b5a2bc84e2af06f9187f74a. Reference: https://expressjs.com/en/guide/error-handling.html --- lib/server.ts | 67 ++++++++++------- lib/userver.ts | 170 ++++++++++++++++++++++++-------------------- test/middlewares.js | 44 ++++++++++++ 3 files changed, 177 insertions(+), 104 deletions(-) diff --git a/lib/server.ts b/lib/server.ts index 8a21f3f5..1651a2ed 100644 --- a/lib/server.ts +++ b/lib/server.ts @@ -137,7 +137,7 @@ export interface ServerOptions { type Middleware = ( req: IncomingMessage, res: ServerResponse, - next: () => void + next: (err?: any) => void ) => void; export abstract class BaseServer extends EventEmitter { @@ -335,7 +335,7 @@ export abstract class BaseServer extends EventEmitter { protected _applyMiddlewares( req: IncomingMessage, res: ServerResponse, - callback: () => void + callback: (err?: any) => void ) { if (this.middlewares.length === 0) { debug("no middleware to apply, skipping"); @@ -344,7 +344,11 @@ export abstract class BaseServer extends EventEmitter { const apply = (i) => { debug("applying middleware n°%d", i + 1); - this.middlewares[i](req, res, () => { + this.middlewares[i](req, res, (err?: any) => { + if (err) { + return callback(err); + } + if (i + 1 < this.middlewares.length) { apply(i + 1); } else { @@ -655,8 +659,12 @@ export class Server extends BaseServer { } }; - this._applyMiddlewares(req, res, () => { - this.verify(req, false, callback); + this._applyMiddlewares(req, res, (err) => { + if (err) { + callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" }); + } else { + this.verify(req, false, callback); + } }); } @@ -673,32 +681,37 @@ export class Server extends BaseServer { this.prepare(req); const res = new WebSocketResponse(req, socket); + const callback = (errorCode, errorContext) => { + if (errorCode) { + this.emit("connection_error", { + req, + code: errorCode, + message: Server.errorMessages[errorCode], + context: errorContext, + }); + abortUpgrade(socket, errorCode, errorContext); + return; + } - this._applyMiddlewares(req, res as unknown as ServerResponse, () => { - this.verify(req, true, (errorCode, errorContext) => { - if (errorCode) { - this.emit("connection_error", { - req, - code: errorCode, - message: Server.errorMessages[errorCode], - context: errorContext, - }); - abortUpgrade(socket, errorCode, errorContext); - return; - } - - const head = Buffer.from(upgradeHead); - upgradeHead = null; + const head = Buffer.from(upgradeHead); + upgradeHead = null; - // some middlewares (like express-session) wait for the writeHead() call to flush their headers - // see https://github.com/expressjs/session/blob/1010fadc2f071ddf2add94235d72224cf65159c6/index.js#L220-L244 - res.writeHead(); + // some middlewares (like express-session) wait for the writeHead() call to flush their headers + // see https://github.com/expressjs/session/blob/1010fadc2f071ddf2add94235d72224cf65159c6/index.js#L220-L244 + res.writeHead(); - // delegate to ws - this.ws.handleUpgrade(req, socket, head, (websocket) => { - this.onWebSocket(req, socket, websocket); - }); + // delegate to ws + this.ws.handleUpgrade(req, socket, head, (websocket) => { + this.onWebSocket(req, socket, websocket); }); + }; + + this._applyMiddlewares(req, res as unknown as ServerResponse, (err) => { + if (err) { + callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" }); + } else { + this.verify(req, true, callback); + } }); } diff --git a/lib/userver.ts b/lib/userver.ts index 29729bd1..6f4872f2 100644 --- a/lib/userver.ts +++ b/lib/userver.ts @@ -92,7 +92,11 @@ export class uServer extends BaseServer { }); } - override _applyMiddlewares(req: any, res: any, callback: () => void): void { + override _applyMiddlewares( + req: any, + res: any, + callback: (err?: any) => void + ): void { if (this.middlewares.length === 0) { return callback(); } @@ -100,12 +104,12 @@ export class uServer extends BaseServer { // needed to buffer headers until the status is computed req.res = new ResponseWrapper(res); - super._applyMiddlewares(req, req.res, () => { + super._applyMiddlewares(req, req.res, (err) => { // some middlewares (like express-session) wait for the writeHead() call to flush their headers // see https://github.com/expressjs/session/blob/1010fadc2f071ddf2add94235d72224cf65159c6/index.js#L220-L244 req.res.writeHead(); - callback(); + callback(err); }); } @@ -118,28 +122,34 @@ export class uServer extends BaseServer { req.res = res; - this._applyMiddlewares(req, res, () => { - this.verify(req, false, (errorCode, errorContext) => { - if (errorCode !== undefined) { - this.emit("connection_error", { - req, - code: errorCode, - message: Server.errorMessages[errorCode], - context: errorContext, - }); - this.abortRequest(req.res, errorCode, errorContext); - return; - } + const callback = (errorCode, errorContext) => { + if (errorCode !== undefined) { + this.emit("connection_error", { + req, + code: errorCode, + message: Server.errorMessages[errorCode], + context: errorContext, + }); + this.abortRequest(req.res, errorCode, errorContext); + return; + } + + if (req._query.sid) { + debug("setting new request for existing client"); + this.clients[req._query.sid].transport.onRequest(req); + } else { + const closeConnection = (errorCode, errorContext) => + this.abortRequest(res, errorCode, errorContext); + this.handshake(req._query.transport, req, closeConnection); + } + }; - if (req._query.sid) { - debug("setting new request for existing client"); - this.clients[req._query.sid].transport.onRequest(req); - } else { - const closeConnection = (errorCode, errorContext) => - this.abortRequest(res, errorCode, errorContext); - this.handshake(req._query.transport, req, closeConnection); - } - }); + this._applyMiddlewares(req, res, (err) => { + if (err) { + callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" }); + } else { + this.verify(req, false, callback); + } }); } @@ -154,63 +164,69 @@ export class uServer extends BaseServer { req.res = res; - this._applyMiddlewares(req, res, () => { - this.verify(req, true, async (errorCode, errorContext) => { - if (errorCode) { - this.emit("connection_error", { - req, - code: errorCode, - message: Server.errorMessages[errorCode], - context: errorContext, - }); - this.abortRequest(res, errorCode, errorContext); + const callback = async (errorCode, errorContext) => { + if (errorCode) { + this.emit("connection_error", { + req, + code: errorCode, + message: Server.errorMessages[errorCode], + context: errorContext, + }); + this.abortRequest(res, errorCode, errorContext); + return; + } + + const id = req._query.sid; + let transport; + + if (id) { + const client = this.clients[id]; + if (!client) { + debug("upgrade attempt for closed client"); + res.close(); + } else if (client.upgrading) { + debug("transport has already been trying to upgrade"); + res.close(); + } else if (client.upgraded) { + debug("transport had already been upgraded"); + res.close(); + } else { + debug("upgrading existing transport"); + transport = this.createTransport(req._query.transport, req); + client.maybeUpgrade(transport); + } + } else { + transport = await this.handshake( + req._query.transport, + req, + (errorCode, errorContext) => + this.abortRequest(res, errorCode, errorContext) + ); + if (!transport) { return; } + } - const id = req._query.sid; - let transport; - - if (id) { - const client = this.clients[id]; - if (!client) { - debug("upgrade attempt for closed client"); - res.close(); - } else if (client.upgrading) { - debug("transport has already been trying to upgrade"); - res.close(); - } else if (client.upgraded) { - debug("transport had already been upgraded"); - res.close(); - } else { - debug("upgrading existing transport"); - transport = this.createTransport(req._query.transport, req); - client.maybeUpgrade(transport); - } - } else { - transport = await this.handshake( - req._query.transport, - req, - (errorCode, errorContext) => - this.abortRequest(res, errorCode, errorContext) - ); - if (!transport) { - return; - } - } + // calling writeStatus() triggers the flushing of any header added in a middleware + req.res.writeStatus("101 Switching Protocols"); - // calling writeStatus() triggers the flushing of any header added in a middleware - req.res.writeStatus("101 Switching Protocols"); - - res.upgrade( - { - transport, - }, - req.getHeader("sec-websocket-key"), - req.getHeader("sec-websocket-protocol"), - req.getHeader("sec-websocket-extensions"), - context - ); - }); + res.upgrade( + { + transport, + }, + req.getHeader("sec-websocket-key"), + req.getHeader("sec-websocket-protocol"), + req.getHeader("sec-websocket-extensions"), + context + ); + }; + + this._applyMiddlewares(req, res, (err) => { + if (err) { + callback(Server.errors.BAD_REQUEST, { name: "MIDDLEWARE_FAILURE" }); + } else { + this.verify(req, true, callback); + } }); } diff --git a/test/middlewares.js b/test/middlewares.js index a81f1354..4045d5d0 100644 --- a/test/middlewares.js +++ b/test/middlewares.js @@ -247,4 +247,48 @@ describe("middlewares", () => { }); }); }); + + it("should fail on errors (polling)", (done) => { + const engine = listen((port) => { + engine.use((req, res, next) => { + next(new Error("will always fail")); + }); + + request + .get(`http://localhost:${port}/engine.io/`) + .query({ EIO: 4, transport: "polling" }) + .end((err, res) => { + expect(err).to.be.an(Error); + expect(res.status).to.eql(400); + + if (engine.httpServer) { + engine.httpServer.close(); + } + done(); + }); + }); + + it("should fail on errors (websocket)", (done) => { + const engine = listen((port) => { + engine.use((req, res, next) => { + next(new Error("will always fail")); + }); + + engine.on("connection", () => { + done(new Error("should not connect")); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket` + ); + + socket.addEventListener("error", () => { + if (engine.httpServer) { + engine.httpServer.close(); + } + done(); + }); + }); + }); + }); });