From 265ad67c86379841a5aa821543a01648ccc8c26c Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 5 Feb 2018 09:56:29 -0800 Subject: [PATCH] fix(client): more reliably detect closed pooled connections (#1434) --- src/client/cancel.rs | 148 +++++++++++++++++++++++++++++++++++++++ src/client/dispatch.rs | 73 +++++++++++++++++++ src/client/mod.rs | 78 +++++++++------------ src/common/mod.rs | 2 + src/proto/h1/dispatch.rs | 23 +++--- 5 files changed, 264 insertions(+), 60 deletions(-) create mode 100644 src/client/cancel.rs create mode 100644 src/client/dispatch.rs diff --git a/src/client/cancel.rs b/src/client/cancel.rs new file mode 100644 index 0000000000..6ebc01a902 --- /dev/null +++ b/src/client/cancel.rs @@ -0,0 +1,148 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; + +use futures::{Async, Future, Poll}; +use futures::task::{self, Task}; + +use common::Never; + +use self::lock::Lock; + +#[derive(Clone)] +pub struct Cancel { + inner: Arc, +} + +pub struct Canceled { + inner: Arc, +} + +struct Inner { + is_canceled: AtomicBool, + task: Lock>, +} + +impl Cancel { + pub fn new() -> (Cancel, Canceled) { + let inner = Arc::new(Inner { + is_canceled: AtomicBool::new(false), + task: Lock::new(None), + }); + let inner2 = inner.clone(); + ( + Cancel { + inner: inner, + }, + Canceled { + inner: inner2, + }, + ) + } + + pub fn cancel(&self) { + if !self.inner.is_canceled.swap(true, Ordering::SeqCst) { + if let Some(mut locked) = self.inner.task.try_lock() { + if let Some(task) = locked.take() { + task.notify(); + } + } + // if we couldn't take the lock, Canceled was trying to park. + // After parking, it will check is_canceled one last time, + // so we can just stop here. + } + } + + pub fn is_canceled(&self) -> bool { + self.inner.is_canceled.load(Ordering::SeqCst) + } +} + +impl Future for Canceled { + type Item = (); + type Error = Never; + + fn poll(&mut self) -> Poll { + if self.inner.is_canceled.load(Ordering::SeqCst) { + Ok(Async::Ready(())) + } else { + if let Some(mut locked) = self.inner.task.try_lock() { + if locked.is_none() { + // it's possible a Cancel just tried to cancel on another thread, + // and we just missed it. Once we have the lock, we should check + // one more time before parking this task and going away. + if self.inner.is_canceled.load(Ordering::SeqCst) { + return Ok(Async::Ready(())); + } + *locked = Some(task::current()); + } + Ok(Async::NotReady) + } else { + // if we couldn't take the lock, then a Cancel taken has it. + // The *ONLY* reason is because it is in the process of canceling. + Ok(Async::Ready(())) + } + } + } +} + +impl Drop for Canceled { + fn drop(&mut self) { + self.inner.is_canceled.store(true, Ordering::SeqCst); + } +} + + +// a sub module just to protect unsafety +mod lock { + use std::cell::UnsafeCell; + use std::ops::{Deref, DerefMut}; + use std::sync::atomic::{AtomicBool, Ordering}; + + pub struct Lock { + is_locked: AtomicBool, + value: UnsafeCell, + } + + impl Lock { + pub fn new(val: T) -> Lock { + Lock { + is_locked: AtomicBool::new(false), + value: UnsafeCell::new(val), + } + } + + pub fn try_lock(&self) -> Option> { + if !self.is_locked.swap(true, Ordering::SeqCst) { + Some(Locked { lock: self }) + } else { + None + } + } + } + + unsafe impl Send for Lock {} + unsafe impl Sync for Lock {} + + pub struct Locked<'a, T: 'a> { + lock: &'a Lock, + } + + impl<'a, T> Deref for Locked<'a, T> { + type Target = T; + fn deref(&self) -> &T { + unsafe { &*self.lock.value.get() } + } + } + + impl<'a, T> DerefMut for Locked<'a, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.lock.value.get() } + } + } + + impl<'a, T> Drop for Locked<'a, T> { + fn drop(&mut self) { + self.lock.is_locked.store(false, Ordering::SeqCst); + } + } +} diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs new file mode 100644 index 0000000000..8e74872cce --- /dev/null +++ b/src/client/dispatch.rs @@ -0,0 +1,73 @@ +use futures::{Async, Future, Poll, Stream}; +use futures::sync::{mpsc, oneshot}; + +use common::Never; +use super::cancel::{Cancel, Canceled}; + +pub type Callback = oneshot::Sender<::Result>; +pub type Promise = oneshot::Receiver<::Result>; + +pub fn channel() -> (Sender, Receiver) { + let (tx, rx) = mpsc::unbounded(); + let (cancel, canceled) = Cancel::new(); + let tx = Sender { + cancel: cancel, + inner: tx, + }; + let rx = Receiver { + canceled: canceled, + inner: rx, + }; + (tx, rx) +} + +pub struct Sender { + cancel: Cancel, + inner: mpsc::UnboundedSender<(T, Callback)>, +} + +impl Sender { + pub fn is_closed(&self) -> bool { + self.cancel.is_canceled() + } + + pub fn cancel(&self) { + self.cancel.cancel(); + } + + pub fn send(&self, val: T) -> Result, T> { + let (tx, rx) = oneshot::channel(); + self.inner.unbounded_send((val, tx)) + .map(move |_| rx) + .map_err(|e| e.into_inner().0) + } +} + +impl Clone for Sender { + fn clone(&self) -> Sender { + Sender { + cancel: self.cancel.clone(), + inner: self.inner.clone(), + } + } +} + +pub struct Receiver { + canceled: Canceled, + inner: mpsc::UnboundedReceiver<(T, Callback)>, +} + +impl Stream for Receiver { + type Item = (T, Callback); + type Error = Never; + + fn poll(&mut self) -> Poll, Self::Error> { + if let Async::Ready(()) = self.canceled.poll()? { + return Ok(Async::Ready(None)); + } + self.inner.poll() + .map_err(|()| unreachable!("mpsc never errors")) + } +} + +//TODO: Drop for Receiver should consume inner diff --git a/src/client/mod.rs b/src/client/mod.rs index 538ec53837..b7c665d258 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,14 +1,14 @@ //! HTTP Client -use std::cell::{Cell, RefCell}; +use std::cell::Cell; use std::fmt; use std::io; use std::marker::PhantomData; use std::rc::Rc; use std::time::Duration; -use futures::{Future, Poll, Stream}; -use futures::future::{self, Executor}; +use futures::{Async, Future, Poll, Stream}; +use futures::future::{self, Either, Executor}; #[cfg(feature = "compat")] use http; use tokio::reactor::Handle; @@ -28,7 +28,10 @@ pub use self::connect::{HttpConnector, Connect}; use self::background::{bg, Background}; +mod cancel; mod connect; +//TODO(easy): move cancel and dispatch into common instead +pub(crate) mod dispatch; mod dns; mod pool; #[cfg(feature = "compat")] @@ -189,9 +192,6 @@ where C: Connect, head.headers.set_pos(0, host); } - use futures::Sink; - use futures::sync::{mpsc, oneshot}; - let checkout = self.pool.checkout(domain.as_ref()); let connect = { let executor = self.executor.clone(); @@ -199,10 +199,9 @@ where C: Connect, let pool_key = Rc::new(domain.to_string()); self.connector.connect(url) .and_then(move |io| { - // 1 extra slot for possible Close message - let (tx, rx) = mpsc::channel(1); + let (tx, rx) = dispatch::channel(); let tx = HyperClient { - tx: RefCell::new(tx), + tx: tx, should_close: Cell::new(true), }; let pooled = pool.pooled(pool_key, tx); @@ -225,33 +224,26 @@ where C: Connect, }); let resp = race.and_then(move |client| { - use proto::dispatch::ClientMsg; - - let (callback, rx) = oneshot::channel(); - client.should_close.set(false); - - match client.tx.borrow_mut().start_send(ClientMsg::Request(head, body, callback)) { - Ok(_) => (), - Err(e) => match e.into_inner() { - ClientMsg::Request(_, _, callback) => { - error!("pooled connection was not ready, this is a hyper bug"); - let err = io::Error::new( - io::ErrorKind::BrokenPipe, - "pool selected dead connection", - ); - let _ = callback.send(Err(::Error::Io(err))); - }, - _ => unreachable!("ClientMsg::Request was just sent"), + match client.tx.send((head, body)) { + Ok(rx) => { + client.should_close.set(false); + Either::A(rx.then(|res| { + match res { + Ok(Ok(res)) => Ok(res), + Ok(Err(err)) => Err(err), + Err(_) => panic!("dispatch dropped without returning error"), + } + })) + }, + Err(_) => { + error!("pooled connection was not ready, this is a hyper bug"); + let err = io::Error::new( + io::ErrorKind::BrokenPipe, + "pool selected dead connection", + ); + Either::B(future::err(::Error::Io(err))) } } - - rx.then(|res| { - match res { - Ok(Ok(res)) => Ok(res), - Ok(Err(err)) => Err(err), - Err(_) => panic!("dispatch dropped without returning error"), - } - }) }); FutureResponse(Box::new(resp)) @@ -276,13 +268,8 @@ impl fmt::Debug for Client { } struct HyperClient { - // A sentinel that is usually always true. If this is dropped - // while true, this will try to shutdown the dispatcher task. - // - // This should be set to false whenever it is checked out of the - // pool and successfully used to send a request. should_close: Cell, - tx: RefCell<::futures::sync::mpsc::Sender>>, + tx: dispatch::Sender, ::Response>, } impl Clone for HyperClient { @@ -296,10 +283,11 @@ impl Clone for HyperClient { impl self::pool::Ready for HyperClient { fn poll_ready(&mut self) -> Poll<(), ()> { - self.tx - .borrow_mut() - .poll_ready() - .map_err(|_| ()) + if self.tx.is_closed() { + Err(()) + } else { + Ok(Async::Ready(())) + } } } @@ -307,7 +295,7 @@ impl Drop for HyperClient { fn drop(&mut self) { if self.should_close.get() { self.should_close.set(false); - let _ = self.tx.borrow_mut().try_send(proto::dispatch::ClientMsg::Close); + self.tx.cancel(); } } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 795061776f..93363c3ff6 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,3 +1,5 @@ pub use self::str::ByteStr; mod str; + +pub enum Never {} diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 27e5b2376c..4b9c5aa1e3 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -1,7 +1,7 @@ use std::io; use futures::{Async, AsyncSink, Future, Poll, Stream}; -use futures::sync::{mpsc, oneshot}; +use futures::sync::oneshot; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_service::Service; @@ -36,12 +36,9 @@ pub struct Client { rx: ClientRx, } -pub enum ClientMsg { - Request(RequestHead, Option, oneshot::Sender<::Result<::Response>>), - Close, -} +pub type ClientMsg = (RequestHead, Option); -type ClientRx = mpsc::Receiver>; +type ClientRx = ::client::dispatch::Receiver, ::Response>; impl Dispatcher where @@ -365,7 +362,7 @@ where fn poll_msg(&mut self) -> Poll)>, ::Error> { match self.rx.poll() { - Ok(Async::Ready(Some(ClientMsg::Request(head, body, mut cb)))) => { + Ok(Async::Ready(Some(((head, body), mut cb)))) => { // check that future hasn't been canceled already match cb.poll_cancel().expect("poll_cancel cannot error") { Async::Ready(()) => { @@ -378,14 +375,13 @@ where } } }, - Ok(Async::Ready(Some(ClientMsg::Close))) | Ok(Async::Ready(None)) => { trace!("client tx closed"); // user has dropped sender handle Ok(Async::Ready(None)) }, Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(()) => unreachable!("mpsc receiver cannot error"), + Err(_) => unreachable!("receiver cannot error"), } } @@ -404,7 +400,7 @@ where if let Some(cb) = self.callback.take() { let _ = cb.send(Err(err)); Ok(()) - } else if let Ok(Async::Ready(Some(ClientMsg::Request(_, _, cb)))) = self.rx.poll() { + } else if let Ok(Async::Ready(Some((_, cb)))) = self.rx.poll() { let _ = cb.send(Err(err)); Ok(()) } else { @@ -435,8 +431,6 @@ where #[cfg(test)] mod tests { - use futures::Sink; - use super::*; use mock::AsyncIo; use proto::ClientTransaction; @@ -447,7 +441,7 @@ mod tests { let _ = pretty_env_logger::try_init(); ::futures::lazy(|| { let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\n\r\n".to_vec(), 100); - let (mut tx, rx) = mpsc::channel(0); + let (tx, rx) = ::client::dispatch::channel(); let conn = Conn::<_, ::Chunk, ClientTransaction>::new(io, Default::default()); let mut dispatcher = Dispatcher::new(Client::new(rx), conn); @@ -456,8 +450,7 @@ mod tests { subject: ::proto::RequestLine::default(), headers: Default::default(), }; - let (res_tx, res_rx) = oneshot::channel(); - tx.start_send(ClientMsg::Request(req, None::<::Body>, res_tx)).unwrap(); + let res_rx = tx.send((req, None::<::Body>)).unwrap(); dispatcher.poll().expect("dispatcher poll 1"); dispatcher.poll().expect("dispatcher poll 2");