diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 24c53c7..2110051 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -282,7 +282,7 @@ struct Active { socket: Fuse>, next_id: u32, - streams: IntMap>>, + streams: IntMap, stream_receivers: SelectAll>>, no_streams_waker: Option, @@ -504,12 +504,11 @@ impl Active { } fn on_drop_stream(&mut self, stream_id: StreamId) -> Option> { - let s = self.streams.remove(&stream_id).expect("stream not found"); + let mut s = self.streams.remove(&stream_id).expect("stream not found"); log::trace!("{}: removing dropped stream {}", self.id, stream_id); - let frame = { - let mut shared = s.lock(); - let frame = match shared.update_state(self.id, stream_id, State::Closed) { + let frame = s.with_mut(|inner| { + let frame = match inner.update_state(self.id, stream_id, State::Closed) { // The stream was dropped without calling `poll_close`. // We reset the stream to inform the remote of the closure. State::Open { .. } => { @@ -541,14 +540,15 @@ impl Active { // remote end has already done so in the past. State::Closed => None, }; - if let Some(w) = shared.reader.take() { + if let Some(w) = inner.reader.take() { w.wake() } - if let Some(w) = shared.writer.take() { + if let Some(w) = inner.writer.take() { w.wake() } + frame - }; + }); frame.map(Into::into) } @@ -565,10 +565,8 @@ impl Active { && matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate) { let id = frame.header().stream_id(); - if let Some(stream) = self.streams.get(&id) { - stream - .lock() - .update_state(self.id, id, State::Open { acknowledged: true }); + if let Some(shared) = self.streams.get_mut(&id) { + shared.update_state(self.id, id, State::Open { acknowledged: true }); } if let Some(waker) = self.new_outbound_stream_waker.take() { waker.wake(); @@ -590,14 +588,15 @@ impl Active { if frame.header().flags().contains(header::RST) { // stream reset if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.update_state(self.id, stream_id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } + s.with_mut(|inner| { + inner.update_state(self.id, stream_id, State::Closed); + if let Some(w) = inner.reader.take() { + w.wake() + } + if let Some(w) = inner.writer.take() { + w.wake() + } + }); } return Action::None; } @@ -626,37 +625,40 @@ impl Active { log::error!("{}: maximum number of streams reached", self.id); return Action::Terminate(Frame::internal_error()); } - let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT); - { - let mut shared = stream.shared(); + let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT); + stream.shared_mut().with_mut(|inner| { if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); + inner.update_state(self.id, stream_id, State::RecvClosed); } - shared.consume_receive_window(frame.body_len()); - shared.buffer.push(frame.into_body()); - } + inner.consume_receive_window(frame.body_len()); + inner.buffer.push(frame.into_body()); + }); self.streams.insert(stream_id, stream.clone_shared()); return Action::New(stream); } - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - if frame.body_len() > shared.receive_window() { - log::error!( - "{}/{}: frame body larger than window of stream", - self.id, - stream_id - ); - return Action::Terminate(Frame::protocol_error()); - } - if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); - } - shared.consume_receive_window(frame.body_len()); - shared.buffer.push(frame.into_body()); - if let Some(w) = shared.reader.take() { - w.wake() - } + if let Some(shared) = self.streams.get_mut(&stream_id) { + let action = shared.with_mut(|inner| { + if frame.body_len() > inner.receive_window() { + log::error!( + "{}/{}: frame body larger than window of stream", + self.id, + stream_id + ); + Action::Terminate(Frame::protocol_error()) + } else { + if is_finish { + inner.update_state(self.id, stream_id, State::RecvClosed); + } + inner.consume_receive_window(frame.body_len()); + inner.buffer.push(frame.into_body()); + if let Some(w) = inner.reader.take() { + w.wake() + } + Action::None + } + }); + action } else { log::trace!( "{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}", @@ -671,9 +673,8 @@ impl Active { // termination for the remote. // // See https://github.com/paritytech/yamux/issues/110 for details. + Action::None } - - Action::None } fn on_window_update(&mut self, frame: &Frame) -> Action { @@ -681,15 +682,16 @@ impl Active { if frame.header().flags().contains(header::RST) { // stream reset - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.update_state(self.id, stream_id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } + if let Some(shared) = self.streams.get_mut(&stream_id) { + shared.with_mut(|inner| { + inner.update_state(self.id, stream_id, State::Closed); + if let Some(w) = inner.reader.take() { + w.wake() + } + if let Some(w) = inner.writer.take() { + w.wake() + } + }); } return Action::None; } @@ -712,30 +714,32 @@ impl Active { } let credit = frame.header().credit() + DEFAULT_CREDIT; - let stream = self.make_new_inbound_stream(stream_id, credit); + let mut stream = self.make_new_inbound_stream(stream_id, credit); if is_finish { stream - .shared() + .shared_mut() .update_state(self.id, stream_id, State::RecvClosed); } self.streams.insert(stream_id, stream.clone_shared()); return Action::New(stream); } - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.increase_send_window_by(frame.header().credit()); - if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); + if let Some(shared) = self.streams.get_mut(&stream_id) { + shared.with_mut(|inner| { + inner.increase_send_window_by(frame.header().credit()); + if is_finish { + inner.update_state(self.id, stream_id, State::RecvClosed); + + if let Some(w) = inner.reader.take() { + w.wake() + } + } - if let Some(w) = shared.reader.take() { + if let Some(w) = inner.writer.take() { w.wake() } - } - if let Some(w) = shared.writer.take() { - w.wake() - } + }); } else { log::trace!( "{}/{}: window update for unknown stream, possibly dropped earlier: {:?}", @@ -848,7 +852,7 @@ impl Active { Mode::Client => id.is_client(), Mode::Server => id.is_server(), }) - .filter(|(_, s)| s.lock().is_pending_ack()) + .filter(|(_, s)| s.is_pending_ack()) .count() } @@ -867,15 +871,16 @@ impl Active { impl Active { /// Close and drop all `Stream`s and wake any pending `Waker`s. fn drop_all_streams(&mut self) { - for (id, s) in self.streams.drain() { - let mut shared = s.lock(); - shared.update_state(self.id, id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } + for (id, mut shared) in self.streams.drain() { + shared.with_mut(|inner| { + inner.update_state(self.id, id, State::Closed); + if let Some(w) = inner.reader.take() { + w.wake() + } + if let Some(w) = inner.writer.take() { + w.wake() + } + }); } } } diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 1f48e1b..5eead28 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -7,7 +7,6 @@ // as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0 // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. - use crate::connection::rtt::Rtt; use crate::frame::header::ACK; use crate::{ @@ -26,7 +25,8 @@ use futures::{ io::{AsyncRead, AsyncWrite}, ready, SinkExt, }; -use parking_lot::{Mutex, MutexGuard}; + +use parking_lot::Mutex; use std::{ fmt, io, pin::Pin, @@ -96,7 +96,7 @@ pub struct Stream { config: Arc, sender: mpsc::Sender, flag: Flag, - shared: Arc>, + shared: Shared, } impl fmt::Debug for Stream { @@ -130,13 +130,13 @@ impl Stream { config: config.clone(), sender, flag: Flag::Ack, - shared: Arc::new(Mutex::new(Shared::new( + shared: Shared::new( DEFAULT_CREDIT, send_window, accumulated_max_stream_windows, rtt, config, - ))), + ), } } @@ -154,13 +154,13 @@ impl Stream { config: config.clone(), sender, flag: Flag::Syn, - shared: Arc::new(Mutex::new(Shared::new( + shared: Shared::new( DEFAULT_CREDIT, DEFAULT_CREDIT, accumulated_max_stream_windows, rtt, config, - ))), + ), } } @@ -170,28 +170,27 @@ impl Stream { } pub fn is_write_closed(&self) -> bool { - matches!(self.shared().state(), State::SendClosed) + matches!(self.shared.state(), State::SendClosed) } pub fn is_closed(&self) -> bool { - matches!(self.shared().state(), State::Closed) + matches!(self.shared.state(), State::Closed) } - /// Whether we are still waiting for the remote to acknowledge this stream. pub fn is_pending_ack(&self) -> bool { - self.shared().is_pending_ack() + self.shared.is_pending_ack() } - pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> { - self.shared.lock() + pub(crate) fn shared_mut(&mut self) -> &mut Shared { + &mut self.shared } - pub(crate) fn clone_shared(&self) -> Arc> { + pub(crate) fn clone_shared(&self) -> Shared { self.shared.clone() } - fn write_zero_err(&self) -> io::Error { - let msg = format!("{}/{}: connection is closed", self.conn, self.id); + fn write_zero_err(conn: connection::Id, id: StreamId) -> io::Error { + let msg = format!("{}/{}: connection is closed", conn, id); io::Error::new(io::ErrorKind::WriteZero, msg) } @@ -213,16 +212,16 @@ impl Stream { /// Send new credit to the sending side via a window update message if /// permitted. fn send_window_update(&mut self, cx: &mut Context) -> Poll> { - if !self.shared.lock().state.can_read() { + if !self.shared.state().can_read() { return Poll::Ready(Ok(())); } ready!(self .sender .poll_ready(cx) - .map_err(|_| self.write_zero_err())?); + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?); - let Some(credit) = self.shared.lock().next_window_update() else { + let Some(credit) = self.shared_mut().next_window_update() else { return Poll::Ready(Ok(())); }; @@ -231,7 +230,7 @@ impl Stream { let cmd = StreamCommand::SendFrame(frame); self.sender .start_send(cmd) - .map_err(|_| self.write_zero_err())?; + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?; Poll::Ready(Ok(())) } @@ -262,37 +261,35 @@ impl futures::stream::Stream for Stream { Poll::Pending => {} } - let mut shared = self.shared(); - - if let Some(bytes) = shared.buffer.pop() { - let off = bytes.offset(); - let mut vec = bytes.into_vec(); - if off != 0 { - // This should generally not happen when the stream is used only as - // a `futures::stream::Stream` since the whole point of this impl is - // to consume chunks atomically. It may perhaps happen when mixing - // this impl and the `AsyncRead` one. - log::debug!( - "{}/{}: chunk has been partially consumed", - self.conn, - self.id - ); - vec = vec.split_off(off) + let Self { + id, conn, shared, .. + } = self.get_mut(); + let polling_state = shared.with_mut(|inner| { + if let Some(bytes) = inner.buffer.pop() { + let off = bytes.offset(); + let mut vec = bytes.into_vec(); + if off != 0 { + // This should generally not happen when the stream is used only as + // a `futures::stream::Stream` since the whole point of this impl is + // to consume chunks atomically. It may perhaps happen when mixing + // this impl and the `AsyncRead` one. + log::debug!("{}/{}: chunk has been partially consumed", conn, id,); + vec = vec.split_off(off) + } + return Poll::Ready(Some(Ok(Packet(vec)))); } - return Poll::Ready(Some(Ok(Packet(vec)))); - } - - // Buffer is empty, let's check if we can expect to read more data. - if !shared.state().can_read() { - log::debug!("{}/{}: eof", self.conn, self.id); - return Poll::Ready(None); // stream has been reset - } - - // Since we have no more data at this point, we want to be woken up - // by the connection when more becomes available for us. - shared.reader = Some(cx.waker().clone()); + // Buffer is empty, let's check if we can expect to read more data. + if !inner.state.can_read() { + log::debug!("{}/{}: eof", conn, id); + return Poll::Ready(None); // stream has been reset + } + // Since we have no more data at this point, we want to be woken up + // by the connection when more becomes available for us. + inner.reader = Some(cx.waker().clone()); + Poll::Pending + }); - Poll::Pending + polling_state } } @@ -316,38 +313,43 @@ impl AsyncRead for Stream { } // Copy data from stream buffer. - let mut shared = self.shared(); - let mut n = 0; - while let Some(chunk) = shared.buffer.front_mut() { - if chunk.is_empty() { - shared.buffer.pop(); - continue; + let Self { + id, conn, shared, .. + } = self.get_mut(); + let poll_state = shared.with_mut(|inner| { + let mut n = 0; + while let Some(chunk) = inner.buffer.front_mut() { + if chunk.is_empty() { + inner.buffer.pop(); + continue; + } + let k = std::cmp::min(chunk.len(), buf.len() - n); + buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); + n += k; + chunk.advance(k); + if n == buf.len() { + break; + } } - let k = std::cmp::min(chunk.len(), buf.len() - n); - buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); - n += k; - chunk.advance(k); - if n == buf.len() { - break; - } - } - if n > 0 { - log::trace!("{}/{}: read {} bytes", self.conn, self.id, n); - return Poll::Ready(Ok(n)); - } + if n > 0 { + log::trace!("{}/{}: read {} bytes", conn, id, n); + return Poll::Ready(Ok(n)); + } - // Buffer is empty, let's check if we can expect to read more data. - if !shared.state().can_read() { - log::debug!("{}/{}: eof", self.conn, self.id); - return Poll::Ready(Ok(0)); // stream has been reset - } + // Buffer is empty, let's check if we can expect to read more data. + if !inner.state.can_read() { + log::debug!("{}/{}: eof", conn, id); + return Poll::Ready(Ok(0)); // stream has been reset + } - // Since we have no more data at this point, we want to be woken up - // by the connection when more becomes available for us. - shared.reader = Some(cx.waker().clone()); + // Since we have no more data at this point, we want to be woken up + // by the connection when more becomes available for us. + inner.reader = Some(cx.waker().clone()); + Poll::Pending + }); - Poll::Pending + poll_state } } @@ -360,54 +362,69 @@ impl AsyncWrite for Stream { ready!(self .sender .poll_ready(cx) - .map_err(|_| self.write_zero_err())?); - let body = { - let mut shared = self.shared(); - if !shared.state().can_write() { - log::debug!("{}/{}: can no longer write", self.conn, self.id); - return Poll::Ready(Err(self.write_zero_err())); + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?); + + let stream = self.as_mut().get_mut(); + let result = stream.shared.with_mut(|inner| { + if !inner.state.can_write() { + log::debug!("{}/{}: can no longer write", stream.conn, stream.id); + return Err(Stream::write_zero_err(stream.conn, stream.id)); } - if shared.send_window() == 0 { - log::trace!("{}/{}: no more credit left", self.conn, self.id); - shared.writer = Some(cx.waker().clone()); - return Poll::Pending; + + let window = inner.send_window(); + if window == 0 { + log::trace!("{}/{}: no more credit left", stream.conn, stream.id); + inner.writer = Some(cx.waker().clone()); + return Ok(None); } - let k = std::cmp::min( - shared.send_window(), - buf.len().try_into().unwrap_or(u32::MAX), - ); + + let k = std::cmp::min(window, buf.len().try_into().unwrap_or(u32::MAX)); let k = std::cmp::min( k, - self.config.split_send_size.try_into().unwrap_or(u32::MAX), + stream.config.split_send_size.try_into().unwrap_or(u32::MAX), ); - shared.consume_send_window(k); - Vec::from(&buf[..k as usize]) + + inner.consume_send_window(k); + Ok(Some(Vec::from(&buf[..k as usize]))) + }); + + let body = match result { + Err(e) => return Poll::Ready(Err(e)), + Ok(None) => return Poll::Pending, + Ok(Some(b)) => b, }; + let n = body.len(); - let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); - self.add_flag(frame.header_mut()); - log::trace!("{}/{}: write {} bytes", self.conn, self.id, n); + let mut frame = Frame::data(stream.id, body) + .expect("body <= u32::MAX") + .left(); + + stream.add_flag(frame.header_mut()); + + log::trace!("{}/{}: write {} bytes", stream.conn, stream.id, n); // technically, the frame hasn't been sent yet on the wire but from the perspective of this data structure, we've queued the frame for sending // We are tracking this information: // a) to be consistent with outbound streams // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test. if frame.header().flags().contains(ACK) { - self.shared() - .update_state(self.conn, self.id, State::Open { acknowledged: true }); + stream + .shared + .update_state(stream.conn, stream.id, State::Open { acknowledged: true }); } let cmd = StreamCommand::SendFrame(frame); - self.sender + stream + .sender .start_send(cmd) - .map_err(|_| self.write_zero_err())?; + .map_err(|_| Stream::write_zero_err(stream.conn, stream.id))?; Poll::Ready(Ok(n)) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.sender .poll_flush_unpin(cx) - .map_err(|_| self.write_zero_err()) + .map_err(|_| Stream::write_zero_err(self.conn, self.id)) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { @@ -417,7 +434,7 @@ impl AsyncWrite for Stream { ready!(self .sender .poll_ready(cx) - .map_err(|_| self.write_zero_err())?); + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?); let ack = if self.flag == Flag::Ack { self.flag = Flag::None; true @@ -428,15 +445,66 @@ impl AsyncWrite for Stream { let cmd = StreamCommand::CloseStream { ack }; self.sender .start_send(cmd) - .map_err(|_| self.write_zero_err())?; - self.shared() - .update_state(self.conn, self.id, State::SendClosed); + .map_err(|_| Stream::write_zero_err(self.conn, self.id))?; + let Self { + id, conn, shared, .. + } = self.get_mut(); + shared.update_state(*conn, *id, State::SendClosed); Poll::Ready(Ok(())) } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) struct Shared { + inner: Arc>, +} + +impl Shared { + fn new( + receive_window: u32, + send_window: u32, + accumulated_max_stream_windows: Arc>, + rtt: Rtt, + config: Arc, + ) -> Self { + Self { + inner: Arc::new(Mutex::new(SharedInner::new( + receive_window, + send_window, + accumulated_max_stream_windows, + rtt, + config, + ))), + } + } + + pub fn state(&self) -> State { + self.inner.lock().state + } + + pub fn is_pending_ack(&self) -> bool { + self.inner.lock().is_pending_ack() + } + + pub fn next_window_update(&mut self) -> Option { + self.with_mut(|inner| inner.next_window_update()) + } + + pub fn update_state(&mut self, cid: connection::Id, sid: StreamId, next: State) -> State { + self.with_mut(|inner| inner.update_state(cid, sid, next)) + } + + pub fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut SharedInner) -> R, + { + let mut guard = self.inner.lock(); + f(&mut guard) + } +} + +#[derive(Debug)] +pub(crate) struct SharedInner { state: State, flow_controller: FlowController, pub(crate) buffer: Chunks, @@ -444,7 +512,7 @@ pub(crate) struct Shared { pub(crate) writer: Option, } -impl Shared { +impl SharedInner { fn new( receive_window: u32, send_window: u32, @@ -452,7 +520,7 @@ impl Shared { rtt: Rtt, config: Arc, ) -> Self { - Shared { + Self { state: State::Open { acknowledged: false, }, @@ -469,10 +537,6 @@ impl Shared { } } - pub(crate) fn state(&self) -> State { - self.state - } - /// Update the stream state and return the state before it was updated. pub(crate) fn update_state( &mut self, @@ -509,14 +573,14 @@ impl Shared { current // Return the previous stream state for informational purposes. } - pub(crate) fn next_window_update(&mut self) -> Option { + fn next_window_update(&mut self) -> Option { self.flow_controller.next_window_update(self.buffer.len()) } /// Whether we are still waiting for the remote to acknowledge this stream. - pub fn is_pending_ack(&self) -> bool { + fn is_pending_ack(&self) -> bool { matches!( - self.state(), + self.state, State::Open { acknowledged: false }