From 6e031d6d1ee6ab6a6163b6a8385be5ded625037e Mon Sep 17 00:00:00 2001 From: Aryan Tikarya Date: Sun, 26 Jan 2025 21:26:54 +0530 Subject: [PATCH 1/4] refactor: Shared to use internal mutability --- yamux/src/connection.rs | 152 ++++++++++++------------ yamux/src/connection/stream.rs | 207 +++++++++++++++++++++++---------- 2 files changed, 223 insertions(+), 136 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 24c53c7..32ba968 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -282,7 +282,7 @@ struct Active { socket: Fuse>, next_id: u32, - streams: IntMap>>, + streams: IntMap, stream_receivers: SelectAll>>, no_streams_waker: Option, @@ -507,9 +507,8 @@ impl Active { let s = self.streams.remove(&stream_id).expect("stream not found"); log::trace!("{}: removing dropped stream {}", self.id, stream_id); - let frame = { - let mut shared = s.lock(); - let frame = match shared.update_state(self.id, stream_id, State::Closed) { + let frame = s.with_mut(|inner| { + let frame = match inner.update_state(self.id, stream_id, State::Closed) { // The stream was dropped without calling `poll_close`. // We reset the stream to inform the remote of the closure. State::Open { .. } => { @@ -541,14 +540,15 @@ impl Active { // remote end has already done so in the past. State::Closed => None, }; - if let Some(w) = shared.reader.take() { + if let Some(w) = inner.reader.take() { w.wake() } - if let Some(w) = shared.writer.take() { + if let Some(w) = inner.writer.take() { w.wake() } + frame - }; + }); frame.map(Into::into) } @@ -565,10 +565,8 @@ impl Active { && matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate) { let id = frame.header().stream_id(); - if let Some(stream) = self.streams.get(&id) { - stream - .lock() - .update_state(self.id, id, State::Open { acknowledged: true }); + if let Some(shared) = self.streams.get(&id) { + shared.update_state(self.id, id, State::Open { acknowledged: true }); } if let Some(waker) = self.new_outbound_stream_waker.take() { waker.wake(); @@ -590,14 +588,15 @@ impl Active { if frame.header().flags().contains(header::RST) { // stream reset if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.update_state(self.id, stream_id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } + s.with_mut(|inner| { + inner.update_state(self.id, stream_id, State::Closed); + if let Some(w) = inner.reader.take() { + w.wake() + } + if let Some(w) = inner.writer.take() { + w.wake() + } + }); } return Action::None; } @@ -628,35 +627,40 @@ impl Active { } let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT); { - let mut shared = stream.shared(); - if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); - } - shared.consume_receive_window(frame.body_len()); - shared.buffer.push(frame.into_body()); + stream.shared().with_mut(|inner| { + if is_finish { + inner.update_state(self.id, stream_id, State::RecvClosed); + } + inner.consume_receive_window(frame.body_len()); + inner.buffer.push(frame.into_body()); + }) } self.streams.insert(stream_id, stream.clone_shared()); return Action::New(stream); } - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - if frame.body_len() > shared.receive_window() { - log::error!( - "{}/{}: frame body larger than window of stream", - self.id, - stream_id - ); - return Action::Terminate(Frame::protocol_error()); - } - if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); - } - shared.consume_receive_window(frame.body_len()); - shared.buffer.push(frame.into_body()); - if let Some(w) = shared.reader.take() { - w.wake() - } + if let Some(shared) = self.streams.get_mut(&stream_id) { + let action = shared.with_mut(|inner| { + if frame.body_len() > inner.receive_window() { + log::error!( + "{}/{}: frame body larger than window of stream", + self.id, + stream_id + ); + Action::Terminate(Frame::protocol_error()) + } else { + if is_finish { + inner.update_state(self.id, stream_id, State::RecvClosed); + } + inner.consume_receive_window(frame.body_len()); + inner.buffer.push(frame.into_body()); + if let Some(w) = inner.reader.take() { + w.wake() + } + Action::None + } + }); + return action; } else { log::trace!( "{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}", @@ -681,15 +685,16 @@ impl Active { if frame.header().flags().contains(header::RST) { // stream reset - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.update_state(self.id, stream_id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } + if let Some(shared) = self.streams.get_mut(&stream_id) { + shared.with_mut(|inner| { + inner.update_state(self.id, stream_id, State::Closed); + if let Some(w) = inner.reader.take() { + w.wake() + } + if let Some(w) = inner.writer.take() { + w.wake() + } + }); } return Action::None; } @@ -723,19 +728,21 @@ impl Active { return Action::New(stream); } - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.increase_send_window_by(frame.header().credit()); - if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); + if let Some(shared) = self.streams.get_mut(&stream_id) { + shared.with_mut(|inner| { + inner.increase_send_window_by(frame.header().credit()); + if is_finish { + inner.update_state(self.id, stream_id, State::RecvClosed); + + if let Some(w) = inner.reader.take() { + w.wake() + } + } - if let Some(w) = shared.reader.take() { + if let Some(w) = inner.writer.take() { w.wake() } - } - if let Some(w) = shared.writer.take() { - w.wake() - } + }); } else { log::trace!( "{}/{}: window update for unknown stream, possibly dropped earlier: {:?}", @@ -848,7 +855,7 @@ impl Active { Mode::Client => id.is_client(), Mode::Server => id.is_server(), }) - .filter(|(_, s)| s.lock().is_pending_ack()) + .filter(|(_, s)| s.is_pending_ack()) .count() } @@ -867,15 +874,16 @@ impl Active { impl Active { /// Close and drop all `Stream`s and wake any pending `Waker`s. fn drop_all_streams(&mut self) { - for (id, s) in self.streams.drain() { - let mut shared = s.lock(); - shared.update_state(self.id, id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } + for (id, shared) in self.streams.drain() { + shared.with_mut(|inner| { + inner.update_state(self.id, id, State::Closed); + if let Some(w) = inner.reader.take() { + w.wake() + } + if let Some(w) = inner.writer.take() { + w.wake() + } + }); } } } diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 1f48e1b..a01f689 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -8,6 +8,7 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. +use crate::chunks::Chunk; use crate::connection::rtt::Rtt; use crate::frame::header::ACK; use crate::{ @@ -26,7 +27,8 @@ use futures::{ io::{AsyncRead, AsyncWrite}, ready, SinkExt, }; -use parking_lot::{Mutex, MutexGuard}; + +use parking_lot::Mutex; use std::{ fmt, io, pin::Pin, @@ -96,7 +98,7 @@ pub struct Stream { config: Arc, sender: mpsc::Sender, flag: Flag, - shared: Arc>, + shared: Shared, } impl fmt::Debug for Stream { @@ -130,13 +132,13 @@ impl Stream { config: config.clone(), sender, flag: Flag::Ack, - shared: Arc::new(Mutex::new(Shared::new( + shared: Shared::new( DEFAULT_CREDIT, send_window, accumulated_max_stream_windows, rtt, config, - ))), + ), } } @@ -154,13 +156,13 @@ impl Stream { config: config.clone(), sender, flag: Flag::Syn, - shared: Arc::new(Mutex::new(Shared::new( + shared: Shared::new( DEFAULT_CREDIT, DEFAULT_CREDIT, accumulated_max_stream_windows, rtt, config, - ))), + ), } } @@ -170,23 +172,24 @@ impl Stream { } pub fn is_write_closed(&self) -> bool { - matches!(self.shared().state(), State::SendClosed) + matches!(self.shared.state(), State::SendClosed) } pub fn is_closed(&self) -> bool { - matches!(self.shared().state(), State::Closed) + matches!(self.shared.state(), State::Closed) } /// Whether we are still waiting for the remote to acknowledge this stream. pub fn is_pending_ack(&self) -> bool { - self.shared().is_pending_ack() + self.shared.is_pending_ack() } - pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> { - self.shared.lock() + /// Returns a reference to the `Shared` concurrency wrapper. + pub(crate) fn shared(&self) -> &Shared { + &self.shared } - pub(crate) fn clone_shared(&self) -> Arc> { + pub(crate) fn clone_shared(&self) -> Shared { self.shared.clone() } @@ -213,7 +216,7 @@ impl Stream { /// Send new credit to the sending side via a window update message if /// permitted. fn send_window_update(&mut self, cx: &mut Context) -> Poll> { - if !self.shared.lock().state.can_read() { + if !self.shared.state().can_read() { return Poll::Ready(Ok(())); } @@ -222,7 +225,7 @@ impl Stream { .poll_ready(cx) .map_err(|_| self.write_zero_err())?); - let Some(credit) = self.shared.lock().next_window_update() else { + let Some(credit) = self.shared.next_window_update() else { return Poll::Ready(Ok(())); }; @@ -262,9 +265,7 @@ impl futures::stream::Stream for Stream { Poll::Pending => {} } - let mut shared = self.shared(); - - if let Some(bytes) = shared.buffer.pop() { + if let Some(bytes) = self.shared.pop_buffer() { let off = bytes.offset(); let mut vec = bytes.into_vec(); if off != 0 { @@ -283,14 +284,14 @@ impl futures::stream::Stream for Stream { } // Buffer is empty, let's check if we can expect to read more data. - if !shared.state().can_read() { + if !self.shared.state().can_read() { log::debug!("{}/{}: eof", self.conn, self.id); return Poll::Ready(None); // stream has been reset } // Since we have no more data at this point, we want to be woken up // by the connection when more becomes available for us. - shared.reader = Some(cx.waker().clone()); + self.shared.set_reader_waker(Some(cx.waker().clone())); Poll::Pending } @@ -316,37 +317,47 @@ impl AsyncRead for Stream { } // Copy data from stream buffer. - let mut shared = self.shared(); let mut n = 0; - while let Some(chunk) = shared.buffer.front_mut() { - if chunk.is_empty() { - shared.buffer.pop(); - continue; + let can_read = self.shared.with_mut(|inner| { + while let Some(chunk) = inner.buffer.front_mut() { + if chunk.is_empty() { + inner.buffer.pop(); + continue; + } + let k = std::cmp::min(chunk.len(), buf.len() - n); + buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); + n += k; + chunk.advance(k); + if n == buf.len() { + break; + } } - let k = std::cmp::min(chunk.len(), buf.len() - n); - buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); - n += k; - chunk.advance(k); - if n == buf.len() { - break; + + if n > 0 { + return true; } - } + + // Buffer is empty, let's check if we can expect to read more data. + if !inner.state.can_read() { + return false; // No more data available + } + + // Since we have no more data at this point, we want to be woken up + // by the connection when more becomes available for us. + inner.reader = Some(cx.waker().clone()); + true + }); if n > 0 { log::trace!("{}/{}: read {} bytes", self.conn, self.id, n); return Poll::Ready(Ok(n)); } - // Buffer is empty, let's check if we can expect to read more data. - if !shared.state().can_read() { + if !can_read { log::debug!("{}/{}: eof", self.conn, self.id); return Poll::Ready(Ok(0)); // stream has been reset } - // Since we have no more data at this point, we want to be woken up - // by the connection when more becomes available for us. - shared.reader = Some(cx.waker().clone()); - Poll::Pending } } @@ -361,28 +372,39 @@ impl AsyncWrite for Stream { .sender .poll_ready(cx) .map_err(|_| self.write_zero_err())?); - let body = { - let mut shared = self.shared(); - if !shared.state().can_write() { + + let result = self.shared.with_mut(|inner| { + if !inner.state.can_write() { log::debug!("{}/{}: can no longer write", self.conn, self.id); - return Poll::Ready(Err(self.write_zero_err())); + // Return an error + return Err(self.write_zero_err()); } - if shared.send_window() == 0 { + + let window = inner.send_window(); + if window == 0 { log::trace!("{}/{}: no more credit left", self.conn, self.id); - shared.writer = Some(cx.waker().clone()); - return Poll::Pending; + inner.writer = Some(cx.waker().clone()); + return Ok(None); // means we are Pending } - let k = std::cmp::min( - shared.send_window(), - buf.len().try_into().unwrap_or(u32::MAX), - ); + + let k = std::cmp::min(window, buf.len().try_into().unwrap_or(u32::MAX)); + let k = std::cmp::min( k, self.config.split_send_size.try_into().unwrap_or(u32::MAX), ); - shared.consume_send_window(k); - Vec::from(&buf[..k as usize]) + + inner.consume_send_window(k); + let body = Some(Vec::from(&buf[..k as usize])); + Ok(body) + }); + + let body = match result { + Err(e) => return Poll::Ready(Err(e)), // can't write + Ok(None) => return Poll::Pending, // no credit => Pending + Ok(Some(b)) => b, // we have a body }; + let n = body.len(); let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); self.add_flag(frame.header_mut()); @@ -393,8 +415,9 @@ impl AsyncWrite for Stream { // a) to be consistent with outbound streams // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test. if frame.header().flags().contains(ACK) { - self.shared() - .update_state(self.conn, self.id, State::Open { acknowledged: true }); + self.shared.with_mut(|inner| { + inner.update_state(self.conn, self.id, State::Open { acknowledged: true }); + }); } let cmd = StreamCommand::SendFrame(frame); @@ -429,14 +452,74 @@ impl AsyncWrite for Stream { self.sender .start_send(cmd) .map_err(|_| self.write_zero_err())?; - self.shared() - .update_state(self.conn, self.id, State::SendClosed); + self.shared.with_mut(|inner| { + inner.update_state(self.conn, self.id, State::SendClosed); + }); Poll::Ready(Ok(())) } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) struct Shared { + inner: Arc>, +} + +impl Shared { + fn new( + receive_window: u32, + send_window: u32, + accumulated_max_stream_windows: Arc>, + rtt: Rtt, + config: Arc, + ) -> Self { + Self { + inner: Arc::new(Mutex::new(SharedInner::new( + receive_window, + send_window, + accumulated_max_stream_windows, + rtt, + config, + ))), + } + } + + pub fn state(&self) -> State { + self.inner.lock().state + } + + pub fn pop_buffer(&self) -> Option { + self.with_mut(|inner| inner.buffer.pop()) + } + + pub fn is_pending_ack(&self) -> bool { + self.inner.lock().is_pending_ack() + } + + pub fn next_window_update(&self) -> Option { + self.inner.lock().next_window_update() + } + + pub fn set_reader_waker(&self, waker: Option) { + self.with_mut(|inner| { + inner.reader = waker; + }); + } + + pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State { + self.with_mut(|inner| inner.update_state(cid, sid, next)) + } + + pub fn with_mut(&self, f: F) -> R + where + F: FnOnce(&mut SharedInner) -> R, + { + let mut guard = self.inner.lock(); + f(&mut guard) + } +} + +#[derive(Debug)] +pub(crate) struct SharedInner { state: State, flow_controller: FlowController, pub(crate) buffer: Chunks, @@ -444,7 +527,7 @@ pub(crate) struct Shared { pub(crate) writer: Option, } -impl Shared { +impl SharedInner { fn new( receive_window: u32, send_window: u32, @@ -452,7 +535,7 @@ impl Shared { rtt: Rtt, config: Arc, ) -> Self { - Shared { + Self { state: State::Open { acknowledged: false, }, @@ -469,10 +552,6 @@ impl Shared { } } - pub(crate) fn state(&self) -> State { - self.state - } - /// Update the stream state and return the state before it was updated. pub(crate) fn update_state( &mut self, @@ -509,14 +588,14 @@ impl Shared { current // Return the previous stream state for informational purposes. } - pub(crate) fn next_window_update(&mut self) -> Option { + fn next_window_update(&mut self) -> Option { self.flow_controller.next_window_update(self.buffer.len()) } /// Whether we are still waiting for the remote to acknowledge this stream. - pub fn is_pending_ack(&self) -> bool { + fn is_pending_ack(&self) -> bool { matches!( - self.state(), + self.state, State::Open { acknowledged: false } From f7ad5e9bc1e2d8cfbbb893e327ccfdca06e6de8c Mon Sep 17 00:00:00 2001 From: Aryan Tikarya Date: Fri, 31 Jan 2025 01:09:43 +0530 Subject: [PATCH 2/4] address comments --- yamux/src/connection.rs | 33 +++++----- yamux/src/connection/stream.rs | 116 ++++++++++++--------------------- 2 files changed, 58 insertions(+), 91 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 32ba968..2110051 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -504,7 +504,7 @@ impl Active { } fn on_drop_stream(&mut self, stream_id: StreamId) -> Option> { - let s = self.streams.remove(&stream_id).expect("stream not found"); + let mut s = self.streams.remove(&stream_id).expect("stream not found"); log::trace!("{}: removing dropped stream {}", self.id, stream_id); let frame = s.with_mut(|inner| { @@ -565,7 +565,7 @@ impl Active { && matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate) { let id = frame.header().stream_id(); - if let Some(shared) = self.streams.get(&id) { + if let Some(shared) = self.streams.get_mut(&id) { shared.update_state(self.id, id, State::Open { acknowledged: true }); } if let Some(waker) = self.new_outbound_stream_waker.take() { @@ -625,16 +625,14 @@ impl Active { log::error!("{}: maximum number of streams reached", self.id); return Action::Terminate(Frame::internal_error()); } - let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT); - { - stream.shared().with_mut(|inner| { - if is_finish { - inner.update_state(self.id, stream_id, State::RecvClosed); - } - inner.consume_receive_window(frame.body_len()); - inner.buffer.push(frame.into_body()); - }) - } + let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT); + stream.shared_mut().with_mut(|inner| { + if is_finish { + inner.update_state(self.id, stream_id, State::RecvClosed); + } + inner.consume_receive_window(frame.body_len()); + inner.buffer.push(frame.into_body()); + }); self.streams.insert(stream_id, stream.clone_shared()); return Action::New(stream); } @@ -660,7 +658,7 @@ impl Active { Action::None } }); - return action; + action } else { log::trace!( "{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}", @@ -675,9 +673,8 @@ impl Active { // termination for the remote. // // See https://github.com/paritytech/yamux/issues/110 for details. + Action::None } - - Action::None } fn on_window_update(&mut self, frame: &Frame) -> Action { @@ -717,11 +714,11 @@ impl Active { } let credit = frame.header().credit() + DEFAULT_CREDIT; - let stream = self.make_new_inbound_stream(stream_id, credit); + let mut stream = self.make_new_inbound_stream(stream_id, credit); if is_finish { stream - .shared() + .shared_mut() .update_state(self.id, stream_id, State::RecvClosed); } self.streams.insert(stream_id, stream.clone_shared()); @@ -874,7 +871,7 @@ impl Active { impl Active { /// Close and drop all `Stream`s and wake any pending `Waker`s. fn drop_all_streams(&mut self) { - for (id, shared) in self.streams.drain() { + for (id, mut shared) in self.streams.drain() { shared.with_mut(|inner| { inner.update_state(self.id, id, State::Closed); if let Some(w) = inner.reader.take() { diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index a01f689..0916235 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -7,8 +7,6 @@ // as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0 // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. - -use crate::chunks::Chunk; use crate::connection::rtt::Rtt; use crate::frame::header::ACK; use crate::{ @@ -28,7 +26,7 @@ use futures::{ ready, SinkExt, }; -use parking_lot::Mutex; +use parking_lot::{Mutex, MutexGuard}; use std::{ fmt, io, pin::Pin, @@ -179,14 +177,12 @@ impl Stream { matches!(self.shared.state(), State::Closed) } - /// Whether we are still waiting for the remote to acknowledge this stream. pub fn is_pending_ack(&self) -> bool { self.shared.is_pending_ack() } - /// Returns a reference to the `Shared` concurrency wrapper. - pub(crate) fn shared(&self) -> &Shared { - &self.shared + pub(crate) fn shared_mut(&mut self) -> &mut Shared { + &mut self.shared } pub(crate) fn clone_shared(&self) -> Shared { @@ -265,7 +261,8 @@ impl futures::stream::Stream for Stream { Poll::Pending => {} } - if let Some(bytes) = self.shared.pop_buffer() { + let mut shared = self.shared.lock(); + if let Some(bytes) = shared.buffer.pop() { let off = bytes.offset(); let mut vec = bytes.into_vec(); if off != 0 { @@ -276,23 +273,21 @@ impl futures::stream::Stream for Stream { log::debug!( "{}/{}: chunk has been partially consumed", self.conn, - self.id + self.id, ); vec = vec.split_off(off) } return Poll::Ready(Some(Ok(Packet(vec)))); } - // Buffer is empty, let's check if we can expect to read more data. - if !self.shared.state().can_read() { + if !shared.state.can_read() { log::debug!("{}/{}: eof", self.conn, self.id); return Poll::Ready(None); // stream has been reset } // Since we have no more data at this point, we want to be woken up // by the connection when more becomes available for us. - self.shared.set_reader_waker(Some(cx.waker().clone())); - + shared.reader = Some(cx.waker().clone()); Poll::Pending } } @@ -317,47 +312,36 @@ impl AsyncRead for Stream { } // Copy data from stream buffer. + let mut shared = self.shared.lock(); let mut n = 0; - let can_read = self.shared.with_mut(|inner| { - while let Some(chunk) = inner.buffer.front_mut() { - if chunk.is_empty() { - inner.buffer.pop(); - continue; - } - let k = std::cmp::min(chunk.len(), buf.len() - n); - buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); - n += k; - chunk.advance(k); - if n == buf.len() { - break; - } + while let Some(chunk) = shared.buffer.front_mut() { + if chunk.is_empty() { + shared.buffer.pop(); + continue; } - - if n > 0 { - return true; - } - - // Buffer is empty, let's check if we can expect to read more data. - if !inner.state.can_read() { - return false; // No more data available + let k = std::cmp::min(chunk.len(), buf.len() - n); + buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); + n += k; + chunk.advance(k); + if n == buf.len() { + break; } - - // Since we have no more data at this point, we want to be woken up - // by the connection when more becomes available for us. - inner.reader = Some(cx.waker().clone()); - true - }); + } if n > 0 { log::trace!("{}/{}: read {} bytes", self.conn, self.id, n); return Poll::Ready(Ok(n)); } - if !can_read { + // Buffer is empty, let's check if we can expect to read more data. + if !shared.state.can_read() { log::debug!("{}/{}: eof", self.conn, self.id); return Poll::Ready(Ok(0)); // stream has been reset } + // Since we have no more data at this point, we want to be woken up + // by the connection when more becomes available for us. + shared.reader = Some(cx.waker().clone()); Poll::Pending } } @@ -373,18 +357,19 @@ impl AsyncWrite for Stream { .poll_ready(cx) .map_err(|_| self.write_zero_err())?); - let result = self.shared.with_mut(|inner| { - if !inner.state.can_write() { + let body = { + let mut shared = self.shared.lock(); + if !shared.state.can_write() { log::debug!("{}/{}: can no longer write", self.conn, self.id); // Return an error - return Err(self.write_zero_err()); + return Poll::Ready(Err(self.write_zero_err())); } - let window = inner.send_window(); + let window = shared.send_window(); if window == 0 { log::trace!("{}/{}: no more credit left", self.conn, self.id); - inner.writer = Some(cx.waker().clone()); - return Ok(None); // means we are Pending + shared.writer = Some(cx.waker().clone()); + return Poll::Pending; } let k = std::cmp::min(window, buf.len().try_into().unwrap_or(u32::MAX)); @@ -394,15 +379,8 @@ impl AsyncWrite for Stream { self.config.split_send_size.try_into().unwrap_or(u32::MAX), ); - inner.consume_send_window(k); - let body = Some(Vec::from(&buf[..k as usize])); - Ok(body) - }); - - let body = match result { - Err(e) => return Poll::Ready(Err(e)), // can't write - Ok(None) => return Poll::Pending, // no credit => Pending - Ok(Some(b)) => b, // we have a body + shared.consume_send_window(k); + Vec::from(&buf[..k as usize]) }; let n = body.len(); @@ -415,9 +393,8 @@ impl AsyncWrite for Stream { // a) to be consistent with outbound streams // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test. if frame.header().flags().contains(ACK) { - self.shared.with_mut(|inner| { - inner.update_state(self.conn, self.id, State::Open { acknowledged: true }); - }); + self.shared + .update_state(self.conn, self.id, State::Open { acknowledged: true }); } let cmd = StreamCommand::SendFrame(frame); @@ -452,9 +429,8 @@ impl AsyncWrite for Stream { self.sender .start_send(cmd) .map_err(|_| self.write_zero_err())?; - self.shared.with_mut(|inner| { - inner.update_state(self.conn, self.id, State::SendClosed); - }); + self.shared + .update_state(self.conn, self.id, State::SendClosed); Poll::Ready(Ok(())) } } @@ -487,10 +463,6 @@ impl Shared { self.inner.lock().state } - pub fn pop_buffer(&self) -> Option { - self.with_mut(|inner| inner.buffer.pop()) - } - pub fn is_pending_ack(&self) -> bool { self.inner.lock().is_pending_ack() } @@ -499,17 +471,15 @@ impl Shared { self.inner.lock().next_window_update() } - pub fn set_reader_waker(&self, waker: Option) { - self.with_mut(|inner| { - inner.reader = waker; - }); + pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State { + self.inner.lock().update_state(cid, sid, next) } - pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State { - self.with_mut(|inner| inner.update_state(cid, sid, next)) + pub fn lock(&self) -> MutexGuard<'_, SharedInner> { + self.inner.lock() } - pub fn with_mut(&self, f: F) -> R + pub fn with_mut(&mut self, f: F) -> R where F: FnOnce(&mut SharedInner) -> R, { From 142719cbb180b70c8a18acc614f09dce42067654 Mon Sep 17 00:00:00 2001 From: Aryan Tikarya Date: Sat, 8 Feb 2025 21:04:21 +0530 Subject: [PATCH 3/4] refactor: remove lock and use with_mut --- yamux/src/connection/stream.rs | 207 +++++++++++++++++---------------- 1 file changed, 109 insertions(+), 98 deletions(-) diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 0916235..6701347 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -26,7 +26,7 @@ use futures::{ ready, SinkExt, }; -use parking_lot::{Mutex, MutexGuard}; +use parking_lot::Mutex; use std::{ fmt, io, pin::Pin, @@ -189,8 +189,8 @@ impl Stream { self.shared.clone() } - fn write_zero_err(&self) -> io::Error { - let msg = format!("{}/{}: connection is closed", self.conn, self.id); + fn write_zero_err(conn: connection::Id, id: StreamId) -> io::Error { + let msg = format!("{}/{}: connection is closed", conn, id); io::Error::new(io::ErrorKind::WriteZero, msg) } @@ -219,9 +219,9 @@ impl Stream { ready!(self .sender .poll_ready(cx) - .map_err(|_| self.write_zero_err())?); + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?); - let Some(credit) = self.shared.next_window_update() else { + let Some(credit) = self.shared_mut().next_window_update() else { return Poll::Ready(Ok(())); }; @@ -230,7 +230,7 @@ impl Stream { let cmd = StreamCommand::SendFrame(frame); self.sender .start_send(cmd) - .map_err(|_| self.write_zero_err())?; + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?; Poll::Ready(Ok(())) } @@ -261,34 +261,35 @@ impl futures::stream::Stream for Stream { Poll::Pending => {} } - let mut shared = self.shared.lock(); - if let Some(bytes) = shared.buffer.pop() { - let off = bytes.offset(); - let mut vec = bytes.into_vec(); - if off != 0 { - // This should generally not happen when the stream is used only as - // a `futures::stream::Stream` since the whole point of this impl is - // to consume chunks atomically. It may perhaps happen when mixing - // this impl and the `AsyncRead` one. - log::debug!( - "{}/{}: chunk has been partially consumed", - self.conn, - self.id, - ); - vec = vec.split_off(off) + let Self { + id, conn, shared, .. + } = self.get_mut(); + let polling_state = shared.with_mut(|inner| { + if let Some(bytes) = inner.buffer.pop() { + let off = bytes.offset(); + let mut vec = bytes.into_vec(); + if off != 0 { + // This should generally not happen when the stream is used only as + // a `futures::stream::Stream` since the whole point of this impl is + // to consume chunks atomically. It may perhaps happen when mixing + // this impl and the `AsyncRead` one. + log::debug!("{}/{}: chunk has been partially consumed", conn, id,); + vec = vec.split_off(off) + } + return Poll::Ready(Some(Ok(Packet(vec)))); } - return Poll::Ready(Some(Ok(Packet(vec)))); - } - // Buffer is empty, let's check if we can expect to read more data. - if !shared.state.can_read() { - log::debug!("{}/{}: eof", self.conn, self.id); - return Poll::Ready(None); // stream has been reset - } + // Buffer is empty, let's check if we can expect to read more data. + if !inner.state.can_read() { + log::debug!("{}/{}: eof", conn, id); + return Poll::Ready(None); // stream has been reset + } + // Since we have no more data at this point, we want to be woken up + // by the connection when more becomes available for us. + inner.reader = Some(cx.waker().clone()); + Poll::Pending + }); - // Since we have no more data at this point, we want to be woken up - // by the connection when more becomes available for us. - shared.reader = Some(cx.waker().clone()); - Poll::Pending + polling_state } } @@ -312,37 +313,43 @@ impl AsyncRead for Stream { } // Copy data from stream buffer. - let mut shared = self.shared.lock(); - let mut n = 0; - while let Some(chunk) = shared.buffer.front_mut() { - if chunk.is_empty() { - shared.buffer.pop(); - continue; + let Self { + id, conn, shared, .. + } = self.get_mut(); + let poll_state = shared.with_mut(|inner| { + let mut n = 0; + while let Some(chunk) = inner.buffer.front_mut() { + if chunk.is_empty() { + inner.buffer.pop(); + continue; + } + let k = std::cmp::min(chunk.len(), buf.len() - n); + buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); + n += k; + chunk.advance(k); + if n == buf.len() { + break; + } } - let k = std::cmp::min(chunk.len(), buf.len() - n); - buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); - n += k; - chunk.advance(k); - if n == buf.len() { - break; + + if n > 0 { + log::trace!("{}/{}: read {} bytes", conn, id, n); + return Poll::Ready(Ok(n)); } - } - if n > 0 { - log::trace!("{}/{}: read {} bytes", self.conn, self.id, n); - return Poll::Ready(Ok(n)); - } + // Buffer is empty, let's check if we can expect to read more data. + if !inner.state.can_read() { + log::debug!("{}/{}: eof", conn, id); + return Poll::Ready(Ok(0)); // stream has been reset + } - // Buffer is empty, let's check if we can expect to read more data. - if !shared.state.can_read() { - log::debug!("{}/{}: eof", self.conn, self.id); - return Poll::Ready(Ok(0)); // stream has been reset - } + // Since we have no more data at this point, we want to be woken up + // by the connection when more becomes available for us. + inner.reader = Some(cx.waker().clone()); + Poll::Pending + }); - // Since we have no more data at this point, we want to be woken up - // by the connection when more becomes available for us. - shared.reader = Some(cx.waker().clone()); - Poll::Pending + poll_state } } @@ -355,59 +362,65 @@ impl AsyncWrite for Stream { ready!(self .sender .poll_ready(cx) - .map_err(|_| self.write_zero_err())?); - - let body = { - let mut shared = self.shared.lock(); - if !shared.state.can_write() { - log::debug!("{}/{}: can no longer write", self.conn, self.id); - // Return an error - return Poll::Ready(Err(self.write_zero_err())); + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?); + + let stream = self.as_mut().get_mut(); + let result = stream.shared.with_mut(|inner| { + if !inner.state.can_write() { + log::debug!("{}/{}: can no longer write", stream.conn, stream.id); + return Err(Stream::write_zero_err(stream.conn, stream.id)); } - let window = shared.send_window(); + let window = inner.send_window(); if window == 0 { - log::trace!("{}/{}: no more credit left", self.conn, self.id); - shared.writer = Some(cx.waker().clone()); - return Poll::Pending; + log::trace!("{}/{}: no more credit left", stream.conn, stream.id); + inner.writer = Some(cx.waker().clone()); + return Ok(None); } let k = std::cmp::min(window, buf.len().try_into().unwrap_or(u32::MAX)); - let k = std::cmp::min( k, - self.config.split_send_size.try_into().unwrap_or(u32::MAX), + stream.config.split_send_size.try_into().unwrap_or(u32::MAX), ); - shared.consume_send_window(k); - Vec::from(&buf[..k as usize]) + inner.consume_send_window(k); + Ok(Some(Vec::from(&buf[..k as usize]))) + }); + + let body = match result { + Err(e) => return Poll::Ready(Err(e)), + Ok(None) => return Poll::Pending, + Ok(Some(b)) => b, }; let n = body.len(); - let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); - self.add_flag(frame.header_mut()); - log::trace!("{}/{}: write {} bytes", self.conn, self.id, n); + let mut frame = Frame::data(stream.id, body) + .expect("body <= u32::MAX") + .left(); + + stream.add_flag(frame.header_mut()); + + log::trace!("{}/{}: write {} bytes", stream.conn, stream.id, n); - // technically, the frame hasn't been sent yet on the wire but from the perspective of this data structure, we've queued the frame for sending - // We are tracking this information: - // a) to be consistent with outbound streams - // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test. if frame.header().flags().contains(ACK) { - self.shared - .update_state(self.conn, self.id, State::Open { acknowledged: true }); + stream + .shared + .update_state(stream.conn, stream.id, State::Open { acknowledged: true }); } let cmd = StreamCommand::SendFrame(frame); - self.sender + stream + .sender .start_send(cmd) - .map_err(|_| self.write_zero_err())?; + .map_err(|_| Stream::write_zero_err(stream.conn, stream.id))?; Poll::Ready(Ok(n)) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.sender .poll_flush_unpin(cx) - .map_err(|_| self.write_zero_err()) + .map_err(|_| Stream::write_zero_err(self.conn, self.id)) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { @@ -417,7 +430,7 @@ impl AsyncWrite for Stream { ready!(self .sender .poll_ready(cx) - .map_err(|_| self.write_zero_err())?); + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?); let ack = if self.flag == Flag::Ack { self.flag = Flag::None; true @@ -428,9 +441,11 @@ impl AsyncWrite for Stream { let cmd = StreamCommand::CloseStream { ack }; self.sender .start_send(cmd) - .map_err(|_| self.write_zero_err())?; - self.shared - .update_state(self.conn, self.id, State::SendClosed); + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?; + let Self { + id, conn, shared, .. + } = self.get_mut(); + shared.update_state(*conn, *id, State::SendClosed); Poll::Ready(Ok(())) } } @@ -467,16 +482,12 @@ impl Shared { self.inner.lock().is_pending_ack() } - pub fn next_window_update(&self) -> Option { - self.inner.lock().next_window_update() - } - - pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State { - self.inner.lock().update_state(cid, sid, next) + pub fn next_window_update(&mut self) -> Option { + self.with_mut(|inner| inner.next_window_update()) } - pub fn lock(&self) -> MutexGuard<'_, SharedInner> { - self.inner.lock() + pub fn update_state(&mut self, cid: connection::Id, sid: StreamId, next: State) -> State { + self.with_mut(|inner| inner.update_state(cid, sid, next)) } pub fn with_mut(&mut self, f: F) -> R From 46604fabd4480a4f190005fd50ad93b8161df556 Mon Sep 17 00:00:00 2001 From: Aryan Tikarya Date: Tue, 11 Feb 2025 15:22:38 +0530 Subject: [PATCH 4/4] address comments --- yamux/src/connection/stream.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 6701347..5eead28 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -403,6 +403,10 @@ impl AsyncWrite for Stream { log::trace!("{}/{}: write {} bytes", stream.conn, stream.id, n); + // technically, the frame hasn't been sent yet on the wire but from the perspective of this data structure, we've queued the frame for sending + // We are tracking this information: + // a) to be consistent with outbound streams + // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test. if frame.header().flags().contains(ACK) { stream .shared