diff --git a/mm2src/mm2_net/Cargo.toml b/mm2src/mm2_net/Cargo.toml index fd42e034ac..5627f003ff 100644 --- a/mm2src/mm2_net/Cargo.toml +++ b/mm2src/mm2_net/Cargo.toml @@ -49,7 +49,7 @@ wasm-bindgen-futures = "0.4.21" web-sys = { version = "0.3.55", features = ["console", "CloseEvent", "DomException", "ErrorEvent", "IdbDatabase", "IdbCursor", "IdbCursorWithValue", "IdbFactory", "IdbIndex", "IdbIndexParameters", "IdbObjectStore", "IdbObjectStoreParameters", "IdbOpenDbRequest", "IdbKeyRange", "IdbTransaction", "IdbTransactionMode", - "IdbVersionChangeEvent", "MessageEvent", "MessagePort", "ReadableStreamDefaultReader", "ReadableStream", "SharedWorker", "WebSocket"] } + "IdbVersionChangeEvent", "MessageEvent", "MessagePort", "ReadableStreamDefaultReader", "ReadableStream", "SharedWorker", "Url", "WebSocket"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] diff --git a/mm2src/mm2_net/src/wasm/wasm_ws.rs b/mm2src/mm2_net/src/wasm/wasm_ws.rs index 661babef3c..a0a7bbb067 100644 --- a/mm2src/mm2_net/src/wasm/wasm_ws.rs +++ b/mm2src/mm2_net/src/wasm/wasm_ws.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use wasm_bindgen::prelude::*; use wasm_bindgen::JsCast; -use web_sys::{CloseEvent, DomException, MessageEvent, WebSocket}; +use web_sys::{CloseEvent, DomException, MessageEvent, Url, WebSocket}; const NORMAL_CLOSURE_CODE: u16 = 1000; @@ -50,8 +50,17 @@ pub enum InitWsError { impl InitWsError { fn from_ws_new_err(e: JsValue, url: &str) -> InitWsError { let reason = stringify_js_error(&e); - match e.dyn_ref::().map(DomException::code) { - Some(DomException::SYNTAX_ERR) => InitWsError::InvalidUrl { + + // Check for TypeError + if reason.contains("URL constructor") { + return InitWsError::InvalidUrl { + url: url.to_owned(), + reason, + }; + }; + + match e.dyn_ref::().map(DomException::name) { + Some(ref name) if name == "SyntaxError" => InitWsError::InvalidUrl { url: url.to_owned(), reason, }, @@ -312,6 +321,7 @@ struct WebSocketImpl { impl WebSocketImpl { fn init(url: &str) -> InitWsResult<(WebSocketImpl, WsTransportReceiver)> { + Self::validate_websocket_url(url)?; let ws = WebSocket::new(url).map_to_mm(|e| InitWsError::from_ws_new_err(e, url))?; let (tx, rx) = mpsc::channel(1024); @@ -355,6 +365,20 @@ impl WebSocketImpl { }, } } + + fn validate_websocket_url(url: &str) -> Result<(), MmError> { + let parsed_url = Url::new(url).map_to_mm(|e| InitWsError::from_ws_new_err(e, url))?; + + let scheme = parsed_url.protocol(); + if scheme != "ws:" && scheme != "wss:" { + return MmError::err(InitWsError::InvalidUrl { + url: url.to_string(), + reason: "URL must use 'ws' or 'wss' scheme".to_string(), + }); + } + + Ok(()) + } } impl Drop for WebSocketImpl {