-
Notifications
You must be signed in to change notification settings - Fork 824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added a fix so that closed sockets do not cause errors #4548
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,13 +15,21 @@ use std::{net::SocketAddr, task::Poll}; | |
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader}; | ||
use virtual_mio::{ArcInterestHandler, InterestType}; | ||
|
||
#[derive(Debug, PartialEq, Eq, Clone, Copy)] | ||
enum State { | ||
Alive, | ||
Dead, | ||
Closed, | ||
Shutdown, | ||
} | ||
|
||
#[derive(Debug)] | ||
struct SocketBufferState { | ||
buffer: RingBuffer<'static, u8>, | ||
push_handler: Option<ArcInterestHandler>, | ||
pull_handler: Option<ArcInterestHandler>, | ||
wakers: Vec<Waker>, | ||
dead: bool, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix this |
||
state: State, | ||
// This flag prevents a poll write ready storm | ||
halt_immediate_poll_write: bool, | ||
} | ||
|
@@ -42,8 +50,12 @@ pub(crate) struct SocketBuffer { | |
|
||
impl Drop for SocketBuffer { | ||
fn drop(&mut self) { | ||
if self.dead_on_drop { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix this |
||
self.set_dead(); | ||
if self.state() == State::Alive { | ||
if self.dead_on_drop { | ||
self.set_state(State::Dead); | ||
} else { | ||
self.set_state(State::Closed); | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -56,7 +68,7 @@ impl SocketBuffer { | |
push_handler: None, | ||
pull_handler: None, | ||
wakers: Vec::new(), | ||
dead: false, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix this |
||
state: State::Alive, | ||
halt_immediate_poll_write: false, | ||
})), | ||
dead_on_drop: false, | ||
|
@@ -65,7 +77,7 @@ impl SocketBuffer { | |
|
||
pub fn set_push_handler(&self, mut handler: ArcInterestHandler) { | ||
let mut state = self.state.lock().unwrap(); | ||
if state.dead { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix this |
||
if state.state != State::Alive { | ||
handler.push_interest(InterestType::Closed); | ||
} | ||
if !state.buffer.is_empty() { | ||
|
@@ -76,7 +88,7 @@ impl SocketBuffer { | |
|
||
pub fn set_pull_handler(&self, mut handler: ArcInterestHandler) { | ||
let mut state = self.state.lock().unwrap(); | ||
if state.dead { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix this |
||
if state.state != State::Alive { | ||
handler.push_interest(InterestType::Closed); | ||
} | ||
if !state.buffer.is_full() && state.pull_handler.is_none() { | ||
|
@@ -100,33 +112,51 @@ impl SocketBuffer { | |
if !state.buffer.is_empty() { | ||
return Poll::Ready(Ok(state.buffer.len())); | ||
} | ||
if state.dead { | ||
return Poll::Ready(Ok(0)); | ||
} | ||
if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) { | ||
state.wakers.push(cx.waker().clone()); | ||
match state.state { | ||
State::Alive => { | ||
if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) { | ||
state.wakers.push(cx.waker().clone()); | ||
} | ||
Poll::Pending | ||
} | ||
State::Dead => { | ||
tracing::trace!("poll_read_ready: socket is dead"); | ||
Poll::Ready(Err(NetworkError::ConnectionReset)) | ||
} | ||
State::Closed | State::Shutdown => { | ||
tracing::trace!("poll_read_ready: socket is closed or shutdown"); | ||
Poll::Ready(Ok(0)) | ||
} | ||
} | ||
Poll::Pending | ||
} | ||
|
||
pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> { | ||
let mut state = self.state.lock().unwrap(); | ||
if state.dead { | ||
return Poll::Ready(Ok(0)); | ||
} | ||
if !state.buffer.is_full() && !state.halt_immediate_poll_write { | ||
state.halt_immediate_poll_write = true; | ||
return Poll::Ready(Ok(state.buffer.window())); | ||
} | ||
if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) { | ||
state.wakers.push(cx.waker().clone()); | ||
match state.state { | ||
State::Alive => { | ||
if !state.buffer.is_full() && !state.halt_immediate_poll_write { | ||
state.halt_immediate_poll_write = true; | ||
return Poll::Ready(Ok(state.buffer.window())); | ||
} | ||
if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) { | ||
state.wakers.push(cx.waker().clone()); | ||
} | ||
Poll::Pending | ||
} | ||
State::Dead => { | ||
tracing::trace!("poll_write_ready: socket is dead"); | ||
Poll::Ready(Err(NetworkError::ConnectionReset)) | ||
} | ||
State::Closed | State::Shutdown => { | ||
tracing::trace!("poll_write_ready: socket is closed or shutdown"); | ||
Poll::Ready(Ok(0)) | ||
} | ||
} | ||
Poll::Pending | ||
} | ||
|
||
pub fn set_dead(&self) { | ||
fn set_state(&self, new_state: State) { | ||
let mut state = self.state.lock().unwrap(); | ||
state.dead = true; | ||
state.state = new_state; | ||
if let Some(handler) = state.pull_handler.as_mut() { | ||
handler.push_interest(InterestType::Closed); | ||
} | ||
|
@@ -136,9 +166,9 @@ impl SocketBuffer { | |
state.wakers.drain(..).for_each(|w| w.wake()); | ||
} | ||
|
||
pub fn is_dead(&self) -> bool { | ||
fn state(&self) -> State { | ||
let state = self.state.lock().unwrap(); | ||
state.dead | ||
state.state | ||
} | ||
|
||
pub fn try_send( | ||
|
@@ -148,7 +178,7 @@ impl SocketBuffer { | |
waker: Option<&Waker>, | ||
) -> crate::Result<usize> { | ||
let mut state = self.state.lock().unwrap(); | ||
if state.dead { | ||
if state.state != State::Alive { | ||
return Err(NetworkError::ConnectionReset); | ||
} | ||
state.halt_immediate_poll_write = false; | ||
|
@@ -211,14 +241,22 @@ impl SocketBuffer { | |
) -> crate::Result<usize> { | ||
let mut state = self.state.lock().unwrap(); | ||
if state.buffer.is_empty() { | ||
if state.dead { | ||
return Err(NetworkError::ConnectionReset); | ||
} | ||
|
||
if let Some(waker) = waker { | ||
state.add_waker(waker) | ||
} | ||
return Err(NetworkError::WouldBlock); | ||
return match state.state { | ||
State::Alive => { | ||
if let Some(waker) = waker { | ||
state.add_waker(waker) | ||
} | ||
Err(NetworkError::WouldBlock) | ||
} | ||
State::Dead => { | ||
tracing::trace!("try_read: socket is dead"); | ||
return Err(NetworkError::ConnectionReset); | ||
} | ||
State::Closed | State::Shutdown => { | ||
tracing::trace!("try_read: socket is closed or shutdown"); | ||
return Ok(0); | ||
} | ||
}; | ||
} | ||
|
||
let buf: &mut [u8] = unsafe { std::mem::transmute(buf) }; | ||
|
@@ -272,7 +310,7 @@ impl AsyncWrite for SocketBuffer { | |
} | ||
|
||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { | ||
self.set_dead(); | ||
self.set_state(State::Shutdown); | ||
Poll::Ready(Ok(())) | ||
} | ||
} | ||
|
@@ -334,12 +372,12 @@ impl TcpSocketHalf { | |
} | ||
|
||
pub fn is_active(&self) -> bool { | ||
!self.tx.is_dead() | ||
self.tx.state() == State::Alive | ||
} | ||
|
||
pub fn close(&self) -> crate::Result<()> { | ||
self.tx.set_dead(); | ||
self.rx.set_dead(); | ||
self.tx.set_state(State::Closed); | ||
self.rx.set_state(State::Closed); | ||
Ok(()) | ||
} | ||
} | ||
|
@@ -374,10 +412,11 @@ impl VirtualSocket for TcpSocketHalf { | |
} | ||
|
||
fn status(&self) -> crate::Result<SocketStatus> { | ||
Ok(if self.tx.is_dead() { | ||
SocketStatus::Closed | ||
} else { | ||
SocketStatus::Opened | ||
Ok(match self.tx.state() { | ||
State::Alive => SocketStatus::Opened, | ||
State::Dead => SocketStatus::Failed, | ||
State::Closed => SocketStatus::Closed, | ||
State::Shutdown => SocketStatus::Closed, | ||
}) | ||
} | ||
|
||
|
@@ -410,8 +449,8 @@ impl VirtualConnectedSocket for TcpSocketHalf { | |
} | ||
|
||
fn close(&mut self) -> crate::Result<()> { | ||
self.tx.set_dead(); | ||
self.rx.set_dead(); | ||
self.tx.set_state(State::Closed); | ||
self.rx.set_state(State::Closed); | ||
Ok(()) | ||
} | ||
|
||
|
@@ -470,21 +509,21 @@ impl VirtualTcpSocket for TcpSocketHalf { | |
fn shutdown(&mut self, how: std::net::Shutdown) -> crate::Result<()> { | ||
match how { | ||
std::net::Shutdown::Both => { | ||
self.tx.set_dead(); | ||
self.rx.set_dead(); | ||
self.tx.set_state(State::Shutdown); | ||
self.rx.set_state(State::Shutdown); | ||
} | ||
std::net::Shutdown::Read => { | ||
self.rx.set_dead(); | ||
self.rx.set_state(State::Shutdown); | ||
} | ||
std::net::Shutdown::Write => { | ||
self.tx.set_dead(); | ||
self.tx.set_state(State::Shutdown); | ||
} | ||
} | ||
Ok(()) | ||
} | ||
|
||
fn is_closed(&self) -> bool { | ||
self.tx.is_dead() | ||
self.tx.state() != State::Alive | ||
} | ||
} | ||
|
||
|
@@ -528,7 +567,7 @@ impl TcpSocketHalfTx { | |
} | ||
|
||
pub fn close(&self) -> crate::Result<()> { | ||
self.tx.set_dead(); | ||
self.tx.set_state(State::Closed); | ||
Ok(()) | ||
} | ||
} | ||
|
@@ -569,7 +608,7 @@ impl TcpSocketHalfRx { | |
} | ||
|
||
pub fn close(&mut self) -> crate::Result<()> { | ||
self.rx.get_mut().set_dead(); | ||
self.rx.get_mut().set_state(State::Closed); | ||
Ok(()) | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix this