diff --git a/muxers/mplex/src/lib.rs b/muxers/mplex/src/lib.rs index 9ce50f74770..709ed007be2 100644 --- a/muxers/mplex/src/lib.rs +++ b/muxers/mplex/src/lib.rs @@ -32,15 +32,15 @@ extern crate unsigned_varint; mod codec; -use std::{cmp, iter}; +use std::{cmp, iter, mem}; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; -use std::sync::{atomic::AtomicUsize, atomic::Ordering}; +use std::sync::{atomic::AtomicUsize, atomic::Ordering, Arc}; use bytes::Bytes; use core::{ConnectionUpgrade, Endpoint, StreamMuxer}; use parking_lot::Mutex; use fnv::{FnvHashMap, FnvHashSet}; use futures::prelude::*; -use futures::{future, stream::Fuse, task}; +use futures::{executor, future, stream::Fuse, task}; use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; @@ -131,12 +131,17 @@ where let out = Multiplex { inner: Mutex::new(MultiplexInner { error: Ok(()), - inner: Framed::new(i, codec::Codec::new()).fuse(), + inner: executor::spawn(Framed::new(i, codec::Codec::new()).fuse()), config: self, buffer: Vec::with_capacity(cmp::min(max_buffer_len, 512)), opened_substreams: Default::default(), next_outbound_stream_id: if endpoint == Endpoint::Dialer { 0 } else { 1 }, - to_notify: Default::default(), + notifier_read: Arc::new(Notifier { + to_notify: Mutex::new(Default::default()), + }), + notifier_write: Arc::new(Notifier { + to_notify: Mutex::new(Default::default()), + }), }) }; @@ -159,7 +164,7 @@ struct MultiplexInner { // Errored that happend earlier. Should poison any attempt to use this `MultiplexError`. error: Result<(), IoError>, // Underlying stream. - inner: Fuse>, + inner: executor::Spawn>>, /// The original configuration. config: MplexConfig, // Buffer of elements pulled from the stream but not processed yet. @@ -169,9 +174,30 @@ struct MultiplexInner { opened_substreams: FnvHashSet, // Id of the next outgoing substream. Should always increase by two. next_outbound_stream_id: u32, - /// List of tasks to notify when a new element is inserted in `buffer` or an error happens or - /// when the buffer was full and no longer is. - to_notify: FnvHashMap, + /// List of tasks to notify when a read event happens on the underlying stream. + notifier_read: Arc, + /// List of tasks to notify when a write event happens on the underlying stream. + notifier_write: Arc, +} + +struct Notifier { + /// List of tasks to notify. + to_notify: Mutex>, +} + +impl executor::Notify for Notifier { + fn notify(&self, _: usize) { + let tasks = mem::replace(&mut *self.to_notify.lock(), Default::default()); + for (_, task) in tasks { + task.notify(); + } + } +} + +// TODO: replace with another system +static NEXT_TASK_ID: AtomicUsize = AtomicUsize::new(0); +task_local!{ + static TASK_ID: usize = NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed) } /// Processes elements in `inner` until one matching `filter` is found. @@ -190,21 +216,13 @@ where C: AsyncRead + AsyncWrite, if let Some((offset, out)) = inner.buffer.iter().enumerate().filter_map(|(n, v)| filter(v).map(|v| (n, v))).next() { // The buffer was full and no longer is, so let's notify everything. if inner.buffer.len() == inner.config.max_buffer_len { - for task in inner.to_notify.drain() { - task.1.notify(); - } + executor::Notify::notify(&*inner.notifier_read, 0); } inner.buffer.remove(offset); return Ok(Async::Ready(Some(out))); } - // TODO: replace with another system - static NEXT_TASK_ID: AtomicUsize = AtomicUsize::new(0); - task_local!{ - static TASK_ID: usize = NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed) - } - loop { // Check if we reached max buffer length first. debug_assert!(inner.buffer.len() <= inner.config.max_buffer_len); @@ -213,72 +231,57 @@ where C: AsyncRead + AsyncWrite, match inner.config.max_buffer_behaviour { MaxBufferBehaviour::CloseAll => { inner.error = Err(IoError::new(IoErrorKind::Other, "reached maximum buffer length")); - for task in inner.to_notify.drain() { - task.1.notify(); - } return Err(IoError::new(IoErrorKind::Other, "reached maximum buffer length")); }, MaxBufferBehaviour::Block => { - inner.to_notify.insert(TASK_ID.with(|&t| t), task::current()); - return Ok(Async::Ready(None)); + inner.notifier_read.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); + return Ok(Async::NotReady); }, } } - let elem = match inner.inner.poll() { - Ok(Async::Ready(item)) => item, + let elem = match inner.inner.poll_stream_notify(&inner.notifier_read, 0) { + Ok(Async::Ready(Some(item))) => item, + Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), Ok(Async::NotReady) => { - inner.to_notify.insert(TASK_ID.with(|&t| t), task::current()); + inner.notifier_read.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); return Ok(Async::NotReady); }, Err(err) => { let err2 = IoError::new(err.kind(), err.to_string()); inner.error = Err(err); - for task in inner.to_notify.drain() { - task.1.notify(); - } return Err(err2); }, }; - if let Some(elem) = elem { - trace!("Received message: {:?}", elem); + trace!("Received message: {:?}", elem); - // Handle substreams opening/closing. - match elem { - codec::Elem::Open { substream_id } => { - if (substream_id % 2) == (inner.next_outbound_stream_id % 2) { - inner.error = Err(IoError::new(IoErrorKind::Other, "invalid substream id opened")); - for task in inner.to_notify.drain() { - task.1.notify(); - } - return Err(IoError::new(IoErrorKind::Other, "invalid substream id opened")); - } - - if !inner.opened_substreams.insert(substream_id) { - debug!("Received open message for substream {} which was already open", substream_id) - } - }, - codec::Elem::Close { substream_id, .. } | codec::Elem::Reset { substream_id, .. } => { - inner.opened_substreams.remove(&substream_id); - }, - _ => () - } + // Handle substreams opening/closing. + match elem { + codec::Elem::Open { substream_id } => { + if (substream_id % 2) == (inner.next_outbound_stream_id % 2) { + inner.error = Err(IoError::new(IoErrorKind::Other, "invalid substream id opened")); + return Err(IoError::new(IoErrorKind::Other, "invalid substream id opened")); + } - if let Some(out) = filter(&elem) { - return Ok(Async::Ready(Some(out))); - } else { - if inner.opened_substreams.contains(&elem.substream_id()) || elem.is_open_msg() { - inner.buffer.push(elem); - for task in inner.to_notify.drain() { - task.1.notify(); - } - } else if !elem.is_close_or_reset_msg() { - debug!("Ignored message {:?} because the substream wasn't open", elem); + if !inner.opened_substreams.insert(substream_id) { + debug!("Received open message for substream {} which was already open", substream_id) } - } + }, + codec::Elem::Close { substream_id, .. } | codec::Elem::Reset { substream_id, .. } => { + inner.opened_substreams.remove(&substream_id); + }, + _ => () + } + + if let Some(out) = filter(&elem) { + return Ok(Async::Ready(Some(out))); } else { - return Ok(Async::Ready(None)); + if inner.opened_substreams.contains(&elem.substream_id()) || elem.is_open_msg() { + inner.buffer.push(elem); + } else if !elem.is_close_or_reset_msg() { + debug!("Ignored message {:?} because the substream wasn't open", elem); + } } } } @@ -287,11 +290,12 @@ where C: AsyncRead + AsyncWrite, fn poll_send(inner: &mut MultiplexInner, elem: codec::Elem) -> Poll<(), IoError> where C: AsyncRead + AsyncWrite { - match inner.inner.start_send(elem) { + match inner.inner.start_send_notify(elem, &inner.notifier_write, 0) { Ok(AsyncSink::Ready) => { Ok(Async::Ready(())) }, Ok(AsyncSink::NotReady(_)) => { + inner.notifier_write.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); Ok(Async::NotReady) }, Err(err) => Err(err) @@ -352,14 +356,15 @@ where C: AsyncRead + AsyncWrite fn poll_outbound(&self, substream: &mut Self::OutboundSubstream) -> Poll, IoError> { loop { + let mut inner = self.inner.lock(); + let polling = match substream.state { OutboundSubstreamState::SendElem(ref elem) => { - let mut inner = self.inner.lock(); poll_send(&mut inner, elem.clone()) }, OutboundSubstreamState::Flush => { - let mut inner = self.inner.lock(); - inner.inner.poll_complete() + let inner = &mut *inner; // Avoids borrow errors + inner.inner.poll_flush_notify(&inner.notifier_write, 0) }, OutboundSubstreamState::Done => { panic!("Polling outbound substream after it's been succesfully open"); @@ -368,14 +373,19 @@ where C: AsyncRead + AsyncWrite match polling { Ok(Async::Ready(())) => (), - Ok(Async::NotReady) => return Ok(Async::NotReady), + Ok(Async::NotReady) => { + inner.notifier_write.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); + return Ok(Async::NotReady) + }, Err(err) => { debug!("Failed to open outbound substream {}", substream.num); - self.inner.lock().buffer.retain(|elem| elem.substream_id() != substream.num); + inner.buffer.retain(|elem| elem.substream_id() != substream.num); return Err(err) }, }; + drop(inner); + // Going to next step. match substream.state { OutboundSubstreamState::SendElem(_) => { @@ -456,9 +466,14 @@ where C: AsyncRead + AsyncWrite fn flush_substream(&self, _substream: &mut Self::Substream) -> Result<(), IoError> { let mut inner = self.inner.lock(); - match inner.inner.poll_complete() { + let inner = &mut *inner; // Avoids borrow errors + + match inner.inner.poll_flush_notify(&inner.notifier_write, 0) { Ok(Async::Ready(())) => Ok(()), - Ok(Async::NotReady) => Err(IoErrorKind::WouldBlock.into()), + Ok(Async::NotReady) => { + inner.notifier_write.to_notify.lock().insert(TASK_ID.with(|&t| t), task::current()); + Err(IoErrorKind::WouldBlock.into()) + }, Err(err) => Err(err), } }