diff --git a/Changelog.md b/Changelog.md index a345cc5..44b2654 100644 --- a/Changelog.md +++ b/Changelog.md @@ -4,6 +4,7 @@ ### Changed +- Deny WebSocket connections from unknown origins - Let firmware update command look for bootloader if no Senso in regular mode is found - Update build system and development environment diff --git a/src/dividat-driver/flex/websocket.go b/src/dividat-driver/flex/websocket.go index 3427073..bbec7c8 100644 --- a/src/dividat-driver/flex/websocket.go +++ b/src/dividat-driver/flex/websocket.go @@ -124,6 +124,7 @@ var webSocketUpgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { + // Check is performed by top-level HTTP middleware, and not repeated here. return true }, } diff --git a/src/dividat-driver/senso/websocket.go b/src/dividat-driver/senso/websocket.go index 4d0bb72..833c9e0 100644 --- a/src/dividat-driver/senso/websocket.go +++ b/src/dividat-driver/senso/websocket.go @@ -323,6 +323,7 @@ var webSocketUpgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { + // Check is performed by top-level HTTP middleware, and not repeated here. return true }, } diff --git a/src/dividat-driver/server/main.go b/src/dividat-driver/server/main.go index 33a4425..94f5bfe 100644 --- a/src/dividat-driver/server/main.go +++ b/src/dividat-driver/server/main.go @@ -26,7 +26,6 @@ func Start(logger *logrus.Logger, origins []string) context.CancelFunc { // Log Server logServer := logging.NewLogServer() logger.AddHook(logServer) - http.Handle("/log", corsHeaders(origins, logServer)) baseLog := logger.WithFields(logrus.Fields{ "version": version, @@ -46,22 +45,25 @@ func Start(logger *logrus.Logger, origins []string) context.CancelFunc { baseLog.Info("Dividat Driver starting") + // Setup log endpoint + http.Handle("/log", originMiddleware(origins, baseLog, logServer)) + // Setup a context ctx, cancel := context.WithCancel(context.Background()) // Setup Senso sensoHandle := senso.New(ctx, baseLog.WithField("package", "senso")) - http.Handle("/senso", corsHeaders(origins, sensoHandle)) + http.Handle("/senso", originMiddleware(origins, baseLog, sensoHandle)) // Setup SensingTex reader flexHandle := flex.New(ctx, baseLog.WithField("package", "flex")) - http.Handle("/flex", corsHeaders(origins, flexHandle)) + http.Handle("/flex", originMiddleware(origins, baseLog, flexHandle)) // Setup RFID scanner rfidHandle := rfid.NewHandle(ctx, baseLog.WithField("package", "rfid")) // net/http performs a redirect from `/rfid` if only `/rfid/` is mounted - http.Handle("/rfid", corsHeaders(origins, rfidHandle)) - http.Handle("/rfid/", corsHeaders(origins, rfidHandle)) + http.Handle("/rfid", originMiddleware(origins, baseLog, rfidHandle)) + http.Handle("/rfid/", originMiddleware(origins, baseLog, rfidHandle)) // Create a logger for server log := baseLog.WithField("package", "server") @@ -80,7 +82,7 @@ func Start(logger *logrus.Logger, origins []string) context.CancelFunc { "os": systemInfo.Os, "arch": systemInfo.Arch, }) - http.Handle("/", corsHeaders(origins, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Handle("/", originMiddleware(origins, baseLog, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write(rootMsg) }))) @@ -107,17 +109,34 @@ func Start(logger *logrus.Logger, origins []string) context.CancelFunc { return cancel } -// Middleware for CORS headers, to be applied to any route that should be accessible from browser apps. -func corsHeaders(origins []string, next http.Handler) http.Handler { +// Middleware to ensure browser requests come from permissible origins. +// +// This protects anyone running the driver from malicious websites connecting +// to the loopback address. In order to protect WebSocket endpoints, for which +// CORS pre-flight requests are not performed, we fully deny requests from +// unknown origins instead of just withholding CORS headers. +func originMiddleware(origins []string, log *logrus.Entry, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if len(r.Header["Origin"]) == 1 && contains(origins, r.Header["Origin"][0]) { - w.Header().Set("Access-Control-Allow-Origin", r.Header["Origin"][0]) + origin := r.Header.Get("Origin") + + // Check whether a request was made from a permissible origin. + // An absent Origin header indicates a non-browser request and is permissible. + if origin != "" && !contains(origins, origin) { + log.WithField("origin", r.Header.Get("Origin")).Info("Denying request from untrusted origin.") + w.WriteHeader(403) + return + } + + // Set CORS/Private Network Access headers + if origin != "" { + w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Private-Network", "true") } // Announce that `Origin` header value may affect response w.Header().Set("Vary", "Origin") + // Greenlight pre-flight requests, forward all other requests if r.Method == "OPTIONS" { w.WriteHeader(200) return diff --git a/test/cors.js b/test/cors.js new file mode 100644 index 0000000..7e835a2 --- /dev/null +++ b/test/cors.js @@ -0,0 +1,89 @@ +/* eslint-env mocha */ + +const { wait, getJSON, startDriver, connectWS } = require('./utils') +const expect = require('chai').expect + +const httpEndpoints = [ + 'http://127.0.0.1:8382', + 'http://127.0.0.1:8382/log', + 'http://127.0.0.1:8382/rfid/readers' +] + +const wsEndpoints = [ + 'ws://127.0.0.1:8382/senso', + 'ws://127.0.0.1:8382/flex' +] + +const permissibleOrigins = [ 'https://test-origin.xyz', 'http://127.0.0.1:8000' ] +const untrustedOrigin = 'https://foreign-origin.xyz' + +let driver + +beforeEach(async () => { + let code = 0 + driver = startDriver(...permissibleOrigins.flatMap(origin => [ '--permissible-origin', origin ])).on('exit', (c) => { + code = c + }) + await wait(500) + expect(code).to.be.equal(0) + driver.removeAllListeners() +}) + +afterEach(() => { + driver.kill() +}) + +// No `Origin` header + +it('can make HTTP requests with no origin set', async () => { + return Promise.all(httpEndpoints.map(url => + fetch(url).then((response) => { expect(response.status, `GET ${url}`).to.equal(200) }) + )) +}) + +it('can connect to WebSocket endpoints with no origin set', async () => { + return Promise.all(wsEndpoints.map(async (url) => { + let connected = await connectWS(url).then(_ => true).catch(_ => false) + expect(connected, `WS ${url}`).to.be.true + })) +}) + +// Known `Origin` header + +it('can make HTTP requests with known origin set', async () => { + return Promise.all(httpEndpoints.flatMap(url => + permissibleOrigins.map(origin => + fetch(url, { headers: { Origin: origin } }).then((response) => { + expect(response.status, `GET ${url} (Origin: ${origin})`).to.equal(200) + expect(response.headers.get('Access-Control-Allow-Origin'), `Access-Control-Allow-Origin ${url}`).to.equal(origin) + expect(response.headers.get('Access-Control-Allow-Private-Network'), `Access-Control-Allow-Private-Network ${url}`).to.equal('true') + }) + ) + )) +}) + +it('can connect to WebSocket endpoints with known origin set', async () => { + return Promise.all(wsEndpoints.flatMap(url => + permissibleOrigins.map(async (origin) => { + let connected = await connectWS(url, { headers: { Origin: origin } }).then(_ => true).catch(_ => false) + expect(connected, `WS ${url} (Origin: ${origin})`).to.be.true + }) + )) +}) + +// Unknown `Origin` header + +it('can not make HTTP requests with unknown origin set', async () => { + return Promise.all(httpEndpoints.map(url => + fetch(url, { headers: { Origin: untrustedOrigin } }).then((response) => { + expect(response.status, `GET ${url}`).to.equal(403) + }) + )) +}) + +it('can not connect to WebSocket endpoints with unknown origin set', async () => { + return Promise.all(wsEndpoints.map(async (url) => { + let connected = await connectWS(url, { headers: { Origin: untrustedOrigin } }).then(_ => true).catch(_ => false) + expect(connected, `WS ${url}`).to.be.false + })) +}) diff --git a/test/index.js b/test/index.js index d41eac9..9302ff0 100644 --- a/test/index.js +++ b/test/index.js @@ -3,11 +3,14 @@ describe('General functionality', () => { require('./general') }) +describe('CORS and PNA protection', () => { + require('./cors') +}) + describe('Senso', () => { require('./senso') }) - describe('RFID', () => { require('./rfid') }) diff --git a/test/utils.js b/test/utils.js index 5613efd..2a92ce2 100644 --- a/test/utils.js +++ b/test/utils.js @@ -10,15 +10,15 @@ module.exports = { }) }, - startDriver: function () { - return spawn('bin/dividat-driver') + startDriver: function (...args) { + return spawn('bin/dividat-driver', args) // useful for debugging: // return spawn('bin/dividat-driver', [], {stdio: 'inherit'}) }, - connectWS: function (url) { + connectWS: function (url, opts) { return new Promise((resolve, reject) => { - const ws = new WebSocket(url) + const ws = new WebSocket(url, opts) ws.on('open', () => { ws.removeAllListeners() resolve(ws)