diff --git a/Cargo.lock b/Cargo.lock index c214f28e1..f62fac11f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -840,6 +840,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -2512,7 +2522,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdf9d64cfcf380606e64f9a0bcf493616b65331199f984151a6fa11a7b3cde38" dependencies = [ "async-io", - "core-foundation", + "core-foundation 0.9.4", "fnv", "futures", "if-addrs", @@ -3497,7 +3507,7 @@ dependencies = [ "rcgen", "ring 0.16.20", "rustls 0.21.12", - "rustls-webpki", + "rustls-webpki 0.101.7", "thiserror 1.0.69", "x509-parser 0.15.1", "yasna", @@ -3733,7 +3743,7 @@ dependencies = [ "thiserror 2.0.11", "tokio", "tokio-stream", - "tokio-tungstenite", + "tokio-tungstenite 0.26.1", "tokio-util", "tracing", "tracing-subscriber", @@ -3788,7 +3798,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-stream", - "tokio-tungstenite", + "tokio-tungstenite 0.20.1", "tokio-util", "tracing", "uint 0.9.5", @@ -5557,10 +5567,23 @@ checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" dependencies = [ "log", "ring 0.17.8", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.23.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki 0.102.8", + "subtle", + "zeroize", +] + [[package]] name = "rustls-native-certs" version = "0.6.3" @@ -5570,7 +5593,19 @@ dependencies = [ "openssl-probe", "rustls-pemfile", "schannel", - "security-framework", + "security-framework 2.11.1", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework 3.2.0", ] [[package]] @@ -5582,6 +5617,12 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-pki-types" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -5592,6 +5633,17 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "ring 0.17.8", + "rustls-pki-types", + "untrusted 0.9.0", +] + [[package]] name = "rustversion" version = "1.0.19" @@ -6051,7 +6103,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.8.0", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags 2.8.0", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -6971,9 +7036,9 @@ dependencies = [ [[package]] name = "subtle" -version = "2.4.1" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" @@ -7027,7 +7092,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.8.0", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -7226,6 +7291,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" +dependencies = [ + "rustls 0.23.23", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.17" @@ -7246,10 +7321,26 @@ dependencies = [ "futures-util", "log", "rustls 0.21.12", - "rustls-native-certs", + "rustls-native-certs 0.6.3", "tokio", - "tokio-rustls", - "tungstenite", + "tokio-rustls 0.24.1", + "tungstenite 0.20.1", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4bf6fecd69fcdede0ec680aaf474cdab988f9de6bc73d3758f0160e3b7025a" +dependencies = [ + "futures-util", + "log", + "rustls 0.23.23", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.1", + "tungstenite 0.26.1", ] [[package]] @@ -7517,6 +7608,27 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413083a99c579593656008130e29255e54dcaae495be556cc26888f211648c24" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.2.0", + "httparse", + "log", + "rand", + "rustls 0.23.23", + "rustls-pki-types", + "sha1", + "thiserror 2.0.11", + "url", + "utf-8", +] + [[package]] name = "tuplex" version = "0.1.2" diff --git a/Cargo.toml b/Cargo.toml index 51e86dcd1..296d5e9c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,14 +39,14 @@ socket2 = { version = "0.5.8", features = ["all"] } str0m = { version = "0.6.2", optional = true } thiserror = "2.0.11" tokio-stream = "0.1.12" -tokio-tungstenite = { version = "0.20.0", features = ["rustls-tls-native-roots"], optional = true } +tokio-tungstenite = { version = "0.26.1", features = ["rustls-tls-native-roots", "url"], optional = true } tokio-util = { version = "0.7.11", features = ["compat", "io", "codec"] } tokio = { version = "1.26.0", features = ["rt", "net", "io-util", "time", "macros", "sync", "parking_lot"] } tracing = { version = "0.1.40", features = ["log"] } hickory-resolver = "0.24.2" uint = "0.10.0" unsigned-varint = { version = "0.8.0", features = ["codec"] } -url = "2.4.0" +url = "2.5.4" webpki = { version = "0.22.4", optional = true } x25519-dalek = "2.0.1" x509-parser = "0.17.0" diff --git a/src/transport/websocket/stream.rs b/src/transport/websocket/stream.rs index 2dd20091c..2941a8726 100644 --- a/src/transport/websocket/stream.rs +++ b/src/transport/websocket/stream.rs @@ -21,7 +21,7 @@ //! Stream implementation for `tokio_tungstenite::WebSocketStream` that implements //! `AsyncRead + AsyncWrite` -use bytes::{Buf, Bytes}; +use bytes::{Buf, Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; @@ -43,9 +43,6 @@ enum State { /// Sink is accepting input. ReadyToSend, - /// Sink is ready to send. - ReadyPending { to_write: Vec }, - /// Flush is pending for the sink. FlushPending, } @@ -53,13 +50,12 @@ enum State { /// Buffered stream which implements `AsyncRead + AsyncWrite` pub(super) struct BufferedStream { /// Write buffer. - write_buffer: Vec, - - /// Write pointer. - write_ptr: usize, + write_buffer: BytesMut, - // Read buffer. - read_buffer: Option, + /// Read buffer. + /// + /// The buffer is taken directly from the WebSocket stream. + read_buffer: Bytes, /// Underlying WebSocket stream. stream: WebSocketStream, @@ -72,9 +68,8 @@ impl BufferedStream { /// Create new [`BufferedStream`]. pub(super) fn new(stream: WebSocketStream) -> Self { Self { - write_buffer: Vec::with_capacity(DEFAULT_BUF_SIZE), - read_buffer: None, - write_ptr: 0usize, + write_buffer: BytesMut::with_capacity(DEFAULT_BUF_SIZE), + read_buffer: Bytes::new(), stream, state: State::ReadyToSend, } @@ -88,7 +83,6 @@ impl futures::AsyncWrite for BufferedStream Poll> { self.write_buffer.extend_from_slice(buf); - self.write_ptr += buf.len(); Poll::Ready(Ok(buf.len())) } @@ -104,38 +98,43 @@ impl futures::AsyncWrite for BufferedStream { - let message = self.write_buffer[..self.write_ptr].to_vec(); - self.state = State::ReadyPending { to_write: message }; - - match futures::ready!(self.stream.poll_ready_unpin(cx)) { - Ok(()) => continue, - Err(_error) => { - return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())); + match self.stream.poll_ready_unpin(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(_error)) => + return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), + Poll::Pending => { + self.state = State::ReadyToSend; + return Poll::Pending; } } - } - State::ReadyPending { to_write } => { - match self.stream.start_send_unpin(Message::Binary(to_write.clone())) { - Ok(_) => { - self.state = State::FlushPending; - continue; - } + + let message = std::mem::take(&mut self.write_buffer); + match self.stream.start_send_unpin(Message::Binary(message.freeze())) { + Ok(()) => {} Err(_error) => return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), } + + // Transition to flush pending state. + self.state = State::FlushPending; + continue; } - State::FlushPending => match futures::ready!(self.stream.poll_flush_unpin(cx)) { - Ok(_res) => { - self.state = State::ReadyToSend; - self.write_ptr = 0; - self.write_buffer.clear(); - // In the unlikely event that the buffer is too large, we need to bound the - // capacity to avoid unbounded memory usage. - self.write_buffer.shrink_to(DEFAULT_BUF_SIZE); - return Poll::Ready(Ok(())); + + State::FlushPending => { + match self.stream.poll_flush_unpin(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(_error)) => + return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), + Poll::Pending => { + self.state = State::ReadyToSend; + return Poll::Pending; + } } - Err(_) => return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), - }, + + self.state = State::ReadyToSend; + self.write_buffer = BytesMut::with_capacity(DEFAULT_BUF_SIZE); + return Poll::Ready(Ok(())); + } State::Poisoned => return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), } @@ -157,10 +156,10 @@ impl futures::AsyncRead for BufferedStream buf: &mut [u8], ) -> Poll> { loop { - if self.read_buffer.is_none() { - match self.stream.poll_next_unpin(cx) { + if self.read_buffer.is_empty() { + let next_chunk = match self.stream.poll_next_unpin(cx) { Poll::Ready(Some(Ok(chunk))) => match chunk { - Message::Binary(chunk) => self.read_buffer.replace(chunk.into()), + Message::Binary(chunk) => chunk, _event => return Poll::Ready(Err(std::io::ErrorKind::Unsupported.into())), }, Poll::Ready(Some(Err(_error))) => @@ -168,21 +167,15 @@ impl futures::AsyncRead for BufferedStream Poll::Ready(None) => return Poll::Ready(Ok(0)), Poll::Pending => return Poll::Pending, }; - } - - let buffer = self.read_buffer.as_mut().expect("buffer to exist"); - let bytes_read = buf.len().min(buffer.len()); - let _orig_size = buffer.len(); - buf[..bytes_read].copy_from_slice(&buffer[..bytes_read]); - buffer.advance(bytes_read); - - // TODO: this can't be correct - if !buffer.is_empty() || bytes_read != 0 { - return Poll::Ready(Ok(bytes_read)); - } else { - self.read_buffer.take(); + self.read_buffer = next_chunk; + continue; } + + let len = std::cmp::min(self.read_buffer.len(), buf.len()); + buf[..len].copy_from_slice(&self.read_buffer[..len]); + self.read_buffer.advance(len); + return Poll::Ready(Ok(len)); } } }