Skip to content

Commit 42c6c8e

Browse files
authored
server: improve parsing of websocket connection URL, and add tests (#1288)
* add unit tests for parsing websocket connection url * add e2e test for websocket reconnects * account for base_url and make parsing more robust
1 parent b30709d commit 42c6c8e

File tree

3 files changed

+73
-4
lines changed

3 files changed

+73
-4
lines changed

server/clientmanager.ts

+19-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import {
1818
import { ClientNotFoundInRoomException, MissingToken } from "./exceptions";
1919
import { MySession, OttWebsocketError, AuthToken, ClientId } from "../common/models/types";
2020
import roommanager from "./roommanager";
21-
import { ANNOUNCEMENT_CHANNEL } from "../common/constants";
21+
import { ANNOUNCEMENT_CHANNEL, ROOM_NAME_REGEX } from "../common/constants";
2222
import tokens, { SessionInfo } from "./auth/tokens";
2323
import { RoomStateSyncable } from "./room";
2424
import { Gauge } from "prom-client";
@@ -77,8 +77,12 @@ export function shutdown() {
7777
*/
7878
async function onDirectConnect(socket: WebSocket, req: express.Request) {
7979
try {
80-
const connectUrl = new URL(`ws://${req.headers.host}${req.url}`);
81-
const roomName = connectUrl.pathname.split("/").slice(-1)[0];
80+
const roomName = parseWebsocketConnectionUrl(req);
81+
if (!ROOM_NAME_REGEX.test(roomName)) {
82+
log.warn("Rejecting connection because the room name was invalid");
83+
socket.close(OttWebsocketError.INVALID_CONNECTION_URL, "Invalid room name");
84+
return;
85+
}
8286
log.debug(`connection received: ${roomName}, waiting for auth token...`);
8387
const client = new DirectClient(roomName, socket);
8488
addClient(client);
@@ -88,6 +92,18 @@ async function onDirectConnect(socket: WebSocket, req: express.Request) {
8892
}
8993
}
9094

95+
/**
96+
* Extract the room name from the websocket connection url.
97+
* @returns Room name
98+
*/
99+
export function parseWebsocketConnectionUrl(req: express.Request): string {
100+
const connectUrl = new URL(req.url, `ws://${req.headers.host ?? "localhost"}`);
101+
const base_url = conf.get("base_url");
102+
const adjustedPath = base_url ? connectUrl.pathname.replace(base_url, "") : connectUrl.pathname;
103+
const roomName = adjustedPath.split("/").slice(3)[0];
104+
return roomName;
105+
}
106+
91107
export function addClient(client: Client) {
92108
connections.push(client);
93109
client.on("auth", onClientAuth);

server/tests/unit/clientmanager.spec.ts

+30-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import clientmanager from "../../clientmanager";
1+
import clientmanager, { parseWebsocketConnectionUrl } from "../../clientmanager";
22
import {
33
BalancerConnection,
44
BalancerConnectionEventHandlers,
@@ -13,6 +13,8 @@ import { buildClients } from "../../redisclient";
1313
import { Result, ok } from "../../../common/result";
1414
import roommanager from "../../roommanager";
1515
import { loadModels } from "../../models";
16+
import type { Request } from "express";
17+
import { loadConfigFile, conf } from "../../ott-config";
1618

1719
class TestClient extends Client {
1820
sendRawMock = jest.fn();
@@ -50,6 +52,7 @@ class BalancerConnectionMock extends BalancerConnection {
5052

5153
describe("ClientManager", () => {
5254
beforeAll(async () => {
55+
loadConfigFile();
5356
loadModels();
5457
await buildClients();
5558
await clientmanager.setup();
@@ -66,6 +69,32 @@ describe("ClientManager", () => {
6669
await roommanager.unloadRoom("foo");
6770
});
6871

72+
it.each([
73+
["/api/room/foo", "foo"],
74+
["/api/room/foo/", "foo"],
75+
["/api/room/foo/bar", "foo"],
76+
["/api/room/foo?reconnect=true", "foo"],
77+
])(`should parse room name from %s`, (path, roomName) => {
78+
const got = parseWebsocketConnectionUrl({ url: path, headers: {} } as Request);
79+
expect(got).toEqual(roomName);
80+
});
81+
82+
it.each([
83+
["/base", "/api/room/foo", "foo"],
84+
["/base", "/api/room/foo/", "foo"],
85+
["/base", "/api/room/foo/bar", "foo"],
86+
["/base", "/api/room/foo?reconnect=true", "foo"],
87+
["/base/base2", "/api/room/foo", "foo"],
88+
["/base/base2", "/api/room/foo/", "foo"],
89+
["/base/base2", "/api/room/foo/bar", "foo"],
90+
["/base/base2", "/api/room/foo?reconnect=true", "foo"],
91+
])(`should parse room name when base url is %s from %s`, (baseurl, path, roomName) => {
92+
conf.set("base_url", baseurl);
93+
const got = parseWebsocketConnectionUrl({ url: baseurl + path, headers: {} } as Request);
94+
expect(got).toEqual(roomName);
95+
conf.set("base_url", "/");
96+
});
97+
6998
it("should add clients", () => {
7099
const client = new TestClient("foo");
71100
clientmanager.addClient(client);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
describe("Websocket connection", () => {
2+
beforeEach(() => {
3+
cy.ottEnsureToken();
4+
cy.ottResetRateLimit();
5+
cy.ottRequest({ method: "POST", url: "/api/room/generate" }).then(resp => {
6+
// @ts-expect-error Cypress doesn't know how to respect this return type
7+
cy.visit(`/room/${resp.body.room}`);
8+
});
9+
});
10+
11+
it("should connect to the websocket", () => {
12+
cy.get("#connectStatus").should("contain", "Connected");
13+
});
14+
15+
it("should connect to the websocket on reconnect", () => {
16+
cy.get("#connectStatus").should("contain", "Connected");
17+
cy.get("button").eq(0).focus(); // focus something so keyboard shortcuts work
18+
cy.realPress(["Control", "Shift", "F12"]);
19+
cy.get("button").contains("Disconnect Me").click();
20+
cy.scrollTo("top");
21+
cy.get("#connectStatus").should("contain", "Connecting");
22+
cy.get("#connectStatus").should("contain", "Connected");
23+
});
24+
})

0 commit comments

Comments
 (0)