Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a fix so that closed sockets do not cause errors #4548

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 91 additions & 52 deletions lib/virtual-net/src/tcp_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@ use std::{net::SocketAddr, task::Poll};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader};
use virtual_mio::{ArcInterestHandler, InterestType};

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum State {
Alive,
Dead,
Closed,
Shutdown,
}

#[derive(Debug)]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this

struct SocketBufferState {
buffer: RingBuffer<'static, u8>,
push_handler: Option<ArcInterestHandler>,
pull_handler: Option<ArcInterestHandler>,
wakers: Vec<Waker>,
dead: bool,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this

state: State,
// This flag prevents a poll write ready storm
halt_immediate_poll_write: bool,
}
Expand All @@ -42,8 +50,12 @@ pub(crate) struct SocketBuffer {

impl Drop for SocketBuffer {
fn drop(&mut self) {
if self.dead_on_drop {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this

self.set_dead();
if self.state() == State::Alive {
if self.dead_on_drop {
self.set_state(State::Dead);
} else {
self.set_state(State::Closed);
}
}
}
}
Expand All @@ -56,7 +68,7 @@ impl SocketBuffer {
push_handler: None,
pull_handler: None,
wakers: Vec::new(),
dead: false,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this

state: State::Alive,
halt_immediate_poll_write: false,
})),
dead_on_drop: false,
Expand All @@ -65,7 +77,7 @@ impl SocketBuffer {

pub fn set_push_handler(&self, mut handler: ArcInterestHandler) {
let mut state = self.state.lock().unwrap();
if state.dead {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this

if state.state != State::Alive {
handler.push_interest(InterestType::Closed);
}
if !state.buffer.is_empty() {
Expand All @@ -76,7 +88,7 @@ impl SocketBuffer {

pub fn set_pull_handler(&self, mut handler: ArcInterestHandler) {
let mut state = self.state.lock().unwrap();
if state.dead {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this

if state.state != State::Alive {
handler.push_interest(InterestType::Closed);
}
if !state.buffer.is_full() && state.pull_handler.is_none() {
Expand All @@ -100,33 +112,51 @@ impl SocketBuffer {
if !state.buffer.is_empty() {
return Poll::Ready(Ok(state.buffer.len()));
}
if state.dead {
return Poll::Ready(Ok(0));
}
if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
state.wakers.push(cx.waker().clone());
match state.state {
State::Alive => {
if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
state.wakers.push(cx.waker().clone());
}
Poll::Pending
}
State::Dead => {
tracing::trace!("poll_read_ready: socket is dead");
Poll::Ready(Err(NetworkError::ConnectionReset))
}
State::Closed | State::Shutdown => {
tracing::trace!("poll_read_ready: socket is closed or shutdown");
Poll::Ready(Ok(0))
}
}
Poll::Pending
}

pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<crate::Result<usize>> {
let mut state = self.state.lock().unwrap();
if state.dead {
return Poll::Ready(Ok(0));
}
if !state.buffer.is_full() && !state.halt_immediate_poll_write {
state.halt_immediate_poll_write = true;
return Poll::Ready(Ok(state.buffer.window()));
}
if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
state.wakers.push(cx.waker().clone());
match state.state {
State::Alive => {
if !state.buffer.is_full() && !state.halt_immediate_poll_write {
state.halt_immediate_poll_write = true;
return Poll::Ready(Ok(state.buffer.window()));
}
if !state.wakers.iter().any(|w| w.will_wake(cx.waker())) {
state.wakers.push(cx.waker().clone());
}
Poll::Pending
}
State::Dead => {
tracing::trace!("poll_write_ready: socket is dead");
Poll::Ready(Err(NetworkError::ConnectionReset))
}
State::Closed | State::Shutdown => {
tracing::trace!("poll_write_ready: socket is closed or shutdown");
Poll::Ready(Ok(0))
}
}
Poll::Pending
}

pub fn set_dead(&self) {
fn set_state(&self, new_state: State) {
let mut state = self.state.lock().unwrap();
state.dead = true;
state.state = new_state;
if let Some(handler) = state.pull_handler.as_mut() {
handler.push_interest(InterestType::Closed);
}
Expand All @@ -136,9 +166,9 @@ impl SocketBuffer {
state.wakers.drain(..).for_each(|w| w.wake());
}

pub fn is_dead(&self) -> bool {
fn state(&self) -> State {
let state = self.state.lock().unwrap();
state.dead
state.state
}

pub fn try_send(
Expand All @@ -148,7 +178,7 @@ impl SocketBuffer {
waker: Option<&Waker>,
) -> crate::Result<usize> {
let mut state = self.state.lock().unwrap();
if state.dead {
if state.state != State::Alive {
return Err(NetworkError::ConnectionReset);
}
state.halt_immediate_poll_write = false;
Expand Down Expand Up @@ -211,14 +241,22 @@ impl SocketBuffer {
) -> crate::Result<usize> {
let mut state = self.state.lock().unwrap();
if state.buffer.is_empty() {
if state.dead {
return Err(NetworkError::ConnectionReset);
}

if let Some(waker) = waker {
state.add_waker(waker)
}
return Err(NetworkError::WouldBlock);
return match state.state {
State::Alive => {
if let Some(waker) = waker {
state.add_waker(waker)
}
Err(NetworkError::WouldBlock)
}
State::Dead => {
tracing::trace!("try_read: socket is dead");
return Err(NetworkError::ConnectionReset);
}
State::Closed | State::Shutdown => {
tracing::trace!("try_read: socket is closed or shutdown");
return Ok(0);
}
};
}

let buf: &mut [u8] = unsafe { std::mem::transmute(buf) };
Expand Down Expand Up @@ -272,7 +310,7 @@ impl AsyncWrite for SocketBuffer {
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.set_dead();
self.set_state(State::Shutdown);
Poll::Ready(Ok(()))
}
}
Expand Down Expand Up @@ -334,12 +372,12 @@ impl TcpSocketHalf {
}

pub fn is_active(&self) -> bool {
!self.tx.is_dead()
self.tx.state() == State::Alive
}

pub fn close(&self) -> crate::Result<()> {
self.tx.set_dead();
self.rx.set_dead();
self.tx.set_state(State::Closed);
self.rx.set_state(State::Closed);
Ok(())
}
}
Expand Down Expand Up @@ -374,10 +412,11 @@ impl VirtualSocket for TcpSocketHalf {
}

fn status(&self) -> crate::Result<SocketStatus> {
Ok(if self.tx.is_dead() {
SocketStatus::Closed
} else {
SocketStatus::Opened
Ok(match self.tx.state() {
State::Alive => SocketStatus::Opened,
State::Dead => SocketStatus::Failed,
State::Closed => SocketStatus::Closed,
State::Shutdown => SocketStatus::Closed,
})
}

Expand Down Expand Up @@ -410,8 +449,8 @@ impl VirtualConnectedSocket for TcpSocketHalf {
}

fn close(&mut self) -> crate::Result<()> {
self.tx.set_dead();
self.rx.set_dead();
self.tx.set_state(State::Closed);
self.rx.set_state(State::Closed);
Ok(())
}

Expand Down Expand Up @@ -470,21 +509,21 @@ impl VirtualTcpSocket for TcpSocketHalf {
fn shutdown(&mut self, how: std::net::Shutdown) -> crate::Result<()> {
match how {
std::net::Shutdown::Both => {
self.tx.set_dead();
self.rx.set_dead();
self.tx.set_state(State::Shutdown);
self.rx.set_state(State::Shutdown);
}
std::net::Shutdown::Read => {
self.rx.set_dead();
self.rx.set_state(State::Shutdown);
}
std::net::Shutdown::Write => {
self.tx.set_dead();
self.tx.set_state(State::Shutdown);
}
}
Ok(())
}

fn is_closed(&self) -> bool {
self.tx.is_dead()
self.tx.state() != State::Alive
}
}

Expand Down Expand Up @@ -528,7 +567,7 @@ impl TcpSocketHalfTx {
}

pub fn close(&self) -> crate::Result<()> {
self.tx.set_dead();
self.tx.set_state(State::Closed);
Ok(())
}
}
Expand Down Expand Up @@ -569,7 +608,7 @@ impl TcpSocketHalfRx {
}

pub fn close(&mut self) -> crate::Result<()> {
self.rx.get_mut().set_dead();
self.rx.get_mut().set_state(State::Closed);
Ok(())
}

Expand Down
Loading