Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect WebSocket endpoints from untrusted origin requests #134

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Changed

- Deny WebSocket connections from unknown origins
- Let firmware update command look for bootloader if no Senso in regular mode is found
- Update build system and development environment

Expand Down
1 change: 1 addition & 0 deletions src/dividat-driver/flex/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ var webSocketUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// Check is performed by top-level HTTP middleware, and not repeated here.
return true
},
}
1 change: 1 addition & 0 deletions src/dividat-driver/senso/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ var webSocketUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// Check is performed by top-level HTTP middleware, and not repeated here.
return true
},
}
39 changes: 29 additions & 10 deletions src/dividat-driver/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ func Start(logger *logrus.Logger, origins []string) context.CancelFunc {
// Log Server
logServer := logging.NewLogServer()
logger.AddHook(logServer)
http.Handle("/log", corsHeaders(origins, logServer))

baseLog := logger.WithFields(logrus.Fields{
"version": version,
Expand All @@ -46,22 +45,25 @@ func Start(logger *logrus.Logger, origins []string) context.CancelFunc {

baseLog.Info("Dividat Driver starting")

// Setup log endpoint
http.Handle("/log", originMiddleware(origins, baseLog, logServer))

// Setup a context
ctx, cancel := context.WithCancel(context.Background())

// Setup Senso
sensoHandle := senso.New(ctx, baseLog.WithField("package", "senso"))
http.Handle("/senso", corsHeaders(origins, sensoHandle))
http.Handle("/senso", originMiddleware(origins, baseLog, sensoHandle))

// Setup SensingTex reader
flexHandle := flex.New(ctx, baseLog.WithField("package", "flex"))
http.Handle("/flex", corsHeaders(origins, flexHandle))
http.Handle("/flex", originMiddleware(origins, baseLog, flexHandle))

// Setup RFID scanner
rfidHandle := rfid.NewHandle(ctx, baseLog.WithField("package", "rfid"))
// net/http performs a redirect from `/rfid` if only `/rfid/` is mounted
http.Handle("/rfid", corsHeaders(origins, rfidHandle))
http.Handle("/rfid/", corsHeaders(origins, rfidHandle))
http.Handle("/rfid", originMiddleware(origins, baseLog, rfidHandle))
http.Handle("/rfid/", originMiddleware(origins, baseLog, rfidHandle))

// Create a logger for server
log := baseLog.WithField("package", "server")
Expand All @@ -80,7 +82,7 @@ func Start(logger *logrus.Logger, origins []string) context.CancelFunc {
"os": systemInfo.Os,
"arch": systemInfo.Arch,
})
http.Handle("/", corsHeaders(origins, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Handle("/", originMiddleware(origins, baseLog, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write(rootMsg)
})))
Expand All @@ -107,17 +109,34 @@ func Start(logger *logrus.Logger, origins []string) context.CancelFunc {
return cancel
}

// Middleware for CORS headers, to be applied to any route that should be accessible from browser apps.
func corsHeaders(origins []string, next http.Handler) http.Handler {
// Middleware to ensure browser requests come from permissible origins.
//
// This protects anyone running the driver from malicious websites connecting
// to the loopback address. In order to protect WebSocket endpoints, for which
// CORS pre-flight requests are not performed, we fully deny requests from
// unknown origins instead of just withholding CORS headers.
func originMiddleware(origins []string, log *logrus.Entry, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(r.Header["Origin"]) == 1 && contains(origins, r.Header["Origin"][0]) {
w.Header().Set("Access-Control-Allow-Origin", r.Header["Origin"][0])
origin := r.Header.Get("Origin")

// Check whether a request was made from a permissible origin.
// An absent Origin header indicates a non-browser request and is permissible.
if origin != "" && !contains(origins, origin) {
log.WithField("origin", r.Header.Get("Origin")).Info("Denying request from untrusted origin.")
w.WriteHeader(403)
return
}

// Set CORS/Private Network Access headers
if origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Private-Network", "true")
}

// Announce that `Origin` header value may affect response
w.Header().Set("Vary", "Origin")

// Greenlight pre-flight requests, forward all other requests
if r.Method == "OPTIONS" {
w.WriteHeader(200)
return
Expand Down
89 changes: 89 additions & 0 deletions test/cors.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* eslint-env mocha */

const { wait, getJSON, startDriver, connectWS } = require('./utils')
const expect = require('chai').expect

const httpEndpoints = [
'http://127.0.0.1:8382',
'http://127.0.0.1:8382/log',
'http://127.0.0.1:8382/rfid/readers'
]

const wsEndpoints = [
'ws://127.0.0.1:8382/senso',
'ws://127.0.0.1:8382/flex'
]

const permissibleOrigins = [ 'https://test-origin.xyz', 'http://127.0.0.1:8000' ]
const untrustedOrigin = 'https://foreign-origin.xyz'

let driver

beforeEach(async () => {
let code = 0
driver = startDriver(...permissibleOrigins.flatMap(origin => [ '--permissible-origin', origin ])).on('exit', (c) => {
code = c
})
await wait(500)
expect(code).to.be.equal(0)
driver.removeAllListeners()
})

afterEach(() => {
driver.kill()
})

// No `Origin` header

it('can make HTTP requests with no origin set', async () => {
return Promise.all(httpEndpoints.map(url =>
fetch(url).then((response) => { expect(response.status, `GET ${url}`).to.equal(200) })
))
})

it('can connect to WebSocket endpoints with no origin set', async () => {
return Promise.all(wsEndpoints.map(async (url) => {
let connected = await connectWS(url).then(_ => true).catch(_ => false)
expect(connected, `WS ${url}`).to.be.true
}))
})

// Known `Origin` header

it('can make HTTP requests with known origin set', async () => {
return Promise.all(httpEndpoints.flatMap(url =>
permissibleOrigins.map(origin =>
fetch(url, { headers: { Origin: origin } }).then((response) => {
expect(response.status, `GET ${url} (Origin: ${origin})`).to.equal(200)
expect(response.headers.get('Access-Control-Allow-Origin'), `Access-Control-Allow-Origin ${url}`).to.equal(origin)
expect(response.headers.get('Access-Control-Allow-Private-Network'), `Access-Control-Allow-Private-Network ${url}`).to.equal('true')
})
)
))
})

it('can connect to WebSocket endpoints with known origin set', async () => {
return Promise.all(wsEndpoints.flatMap(url =>
permissibleOrigins.map(async (origin) => {
let connected = await connectWS(url, { headers: { Origin: origin } }).then(_ => true).catch(_ => false)
expect(connected, `WS ${url} (Origin: ${origin})`).to.be.true
})
))
})

// Unknown `Origin` header

it('can not make HTTP requests with unknown origin set', async () => {
return Promise.all(httpEndpoints.map(url =>
fetch(url, { headers: { Origin: untrustedOrigin } }).then((response) => {
expect(response.status, `GET ${url}`).to.equal(403)
})
))
})

it('can not connect to WebSocket endpoints with unknown origin set', async () => {
return Promise.all(wsEndpoints.map(async (url) => {
let connected = await connectWS(url, { headers: { Origin: untrustedOrigin } }).then(_ => true).catch(_ => false)
expect(connected, `WS ${url}`).to.be.false
}))
})
5 changes: 4 additions & 1 deletion test/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ describe('General functionality', () => {
require('./general')
})

describe('CORS and PNA protection', () => {
require('./cors')
})

describe('Senso', () => {
require('./senso')
})


describe('RFID', () => {
require('./rfid')
})
8 changes: 4 additions & 4 deletions test/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ module.exports = {
})
},

startDriver: function () {
return spawn('bin/dividat-driver')
startDriver: function (...args) {
return spawn('bin/dividat-driver', args)
// useful for debugging:
// return spawn('bin/dividat-driver', [], {stdio: 'inherit'})
},

connectWS: function (url) {
connectWS: function (url, opts) {
return new Promise((resolve, reject) => {
const ws = new WebSocket(url)
const ws = new WebSocket(url, opts)
ws.on('open', () => {
ws.removeAllListeners()
resolve(ws)
Expand Down