Skip to content

Commit

Permalink
server: unify accept error handling (#1882)
Browse files Browse the repository at this point in the history
  • Loading branch information
djc authored Aug 23, 2024
1 parent f321d6a commit c3be20c
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use super::service::ServerIo;
#[cfg(feature = "tls")]
use super::service::TlsAcceptor;
#[cfg(not(feature = "tls"))]
use std::io;
use std::{
io,
net::{SocketAddr, TcpListener as StdTcpListener},
ops::ControlFlow,
pin::{pin, Pin},
task::{ready, Context, Poll},
time::Duration,
};

use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
Expand All @@ -17,6 +15,10 @@ use tokio_stream::wrappers::TcpListenerStream;
use tokio_stream::{Stream, StreamExt};
use tracing::warn;

use super::service::ServerIo;
#[cfg(feature = "tls")]
use super::service::TlsAcceptor;

#[cfg(not(feature = "tls"))]
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
Expand All @@ -31,15 +33,9 @@ where
while let Some(item) = incoming.next().await {
yield match item {
Ok(_) => item.map(ServerIo::new_io)?,
Err(e) => {
let e = e.into();
tracing::debug!(error = %e, "accept loop error");
if let Some(e) = e.downcast_ref::<io::Error>() {
if e.kind() == io::ErrorKind::ConnectionAborted {
continue;
}
}
Err(e)?
Err(e) => match handle_accept_error(e) {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(e) => Err(e)?,
}
}
}
Expand Down Expand Up @@ -78,8 +74,9 @@ where
yield io;
}

SelectOutput::Err(e) => {
tracing::debug!(error = %e, "accept loop error");
SelectOutput::Err(e) => match handle_accept_error(e) {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(e) => Err(e)?,
}

SelectOutput::Done => {
Expand All @@ -90,6 +87,18 @@ where
}
}

fn handle_accept_error(e: impl Into<crate::Error>) -> ControlFlow<crate::Error> {
let e = e.into();
tracing::debug!(error = %e, "accept loop error");
if let Some(e) = e.downcast_ref::<io::Error>() {
if e.kind() == io::ErrorKind::ConnectionAborted {
return ControlFlow::Continue(());
}
}

ControlFlow::Break(e)
}

#[cfg(feature = "tls")]
async fn select<IO: 'static, IE>(
incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
Expand Down

0 comments on commit c3be20c

Please sign in to comment.