diff --git a/sqlx-core/src/net/socket/buffered.rs b/sqlx-core/src/net/socket/buffered.rs index 25e1276432..6785e70879 100644 --- a/sqlx-core/src/net/socket/buffered.rs +++ b/sqlx-core/src/net/socket/buffered.rs @@ -1,9 +1,9 @@ +use crate::error::Error; use crate::net::Socket; use bytes::BytesMut; +use std::ops::ControlFlow; use std::{cmp, io}; -use crate::error::Error; - use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode}; // Tokio, async-std, and std all use this as the default capacity for their buffered I/O. @@ -45,8 +45,39 @@ impl BufferedSocket { } } - pub async fn read_buffered(&mut self, len: usize) -> io::Result { - self.read_buf.read(len, &mut self.socket).await + pub async fn read_buffered(&mut self, len: usize) -> Result { + self.try_read(|buf| { + Ok(if buf.len() < len { + ControlFlow::Continue(len) + } else { + ControlFlow::Break(buf.split_to(len)) + }) + }) + .await + } + + /// Retryable read operation. + /// + /// The callback should check the contents of the buffer passed to it and either: + /// + /// * Remove a full message from the buffer and return [`ControlFlow::Break`], or: + /// * Return [`ControlFlow::Continue`] with the expected _total_ length of the buffer, + /// _without_ modifying it. + /// + /// Cancel-safe as long as the callback does not modify the passed `BytesMut` + /// before returning [`ControlFlow::Continue`]. + pub async fn try_read(&mut self, mut try_read: F) -> Result + where + F: FnMut(&mut BytesMut) -> Result, Error>, + { + loop { + let read_len = match try_read(&mut self.read_buf.read)? { + ControlFlow::Continue(read_len) => read_len, + ControlFlow::Break(ret) => return Ok(ret), + }; + + self.read_buf.read(read_len, &mut self.socket).await?; + } } pub fn write_buffer(&self) -> &WriteBuffer { @@ -244,7 +275,7 @@ impl WriteBuffer { } impl ReadBuffer { - async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result { + async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<()> { // Because of how `BytesMut` works, we should only be shifting capacity back and forth // between `read` and `available` unless we have to read an oversize message. while self.read.len() < len { @@ -266,7 +297,7 @@ impl ReadBuffer { self.advance(read); } - Ok(self.drain(len)) + Ok(()) } fn reserve(&mut self, amt: usize) { @@ -279,10 +310,6 @@ impl ReadBuffer { self.read.unsplit(self.available.split_to(amt)); } - fn drain(&mut self, amt: usize) -> BytesMut { - self.read.split_to(amt) - } - fn shrink(&mut self) { if self.available.capacity() > DEFAULT_BUF_SIZE { // `BytesMut` doesn't have a way to shrink its capacity, diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index c1e163c704..bf3a6d4b1c 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -13,11 +13,14 @@ use super::inner::{is_beyond_max_lifetime, DecrementSizeGuard, PoolInner}; use crate::pool::options::PoolConnectionMetadata; use std::future::Future; +const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); + /// A connection managed by a [`Pool`][crate::pool::Pool]. /// /// Will be returned to the pool on-drop. pub struct PoolConnection { live: Option>, + close_on_drop: bool, pub(crate) pool: Arc>, } @@ -85,6 +88,16 @@ impl PoolConnection { floating.inner.raw.close().await } + /// Close this connection on-drop, instead of returning it to the pool. + /// + /// May be used in cases where waiting for the [`.close()`][Self::close] call + /// to complete is unacceptable, but you still want the connection to be closed gracefully + /// so that the server can clean up resources. + #[inline(always)] + pub fn close_on_drop(&mut self) { + self.close_on_drop = true; + } + /// Detach this connection from the pool, allowing it to open a replacement. /// /// Note that if your application uses a single shared pool, this @@ -140,6 +153,27 @@ impl PoolConnection { } } } + + fn take_and_close(&mut self) -> impl Future + Send + 'static { + // float the connection in the pool before we move into the task + // in case the returned `Future` isn't executed, like if it's spawned into a dying runtime + // https://github.com/launchbadge/sqlx/issues/1396 + // Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22). + let floating = self.live.take().map(|live| live.float(self.pool.clone())); + + let pool = self.pool.clone(); + + async move { + if let Some(floating) = floating { + // Don't hold the connection forever if it hangs while trying to close + crate::rt::timeout(CLOSE_ON_DROP_TIMEOUT, floating.close()) + .await + .ok(); + } + + pool.min_connections_maintenance(None).await; + } + } } impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection { @@ -164,6 +198,11 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection Drop for PoolConnection { fn drop(&mut self) { + if self.close_on_drop { + crate::rt::spawn(self.take_and_close()); + return; + } + // We still need to spawn a task to maintain `min_connections`. if self.live.is_some() || self.pool.options.min_connections > 0 { crate::rt::spawn(self.return_to_pool()); @@ -221,6 +260,7 @@ impl Floating> { guard.cancel(); PoolConnection { live: Some(inner), + close_on_drop: false, pool, } } diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index 7817399925..f165899248 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -1,11 +1,11 @@ use std::collections::BTreeMap; -use std::ops::{Deref, DerefMut}; +use std::ops::{ControlFlow, Deref, DerefMut}; use std::str::FromStr; use futures_channel::mpsc::UnboundedSender; use futures_util::SinkExt; use log::Level; -use sqlx_core::bytes::{Buf, Bytes}; +use sqlx_core::bytes::Buf; use crate::connection::tls::MaybeUpgradeTls; use crate::error::Error; @@ -77,16 +77,45 @@ impl PgStream { } pub(crate) async fn recv_unchecked(&mut self) -> Result { - // all packets in postgres start with a 5-byte header - // this header contains the message type and the total length of the message - let mut header: Bytes = self.inner.read(5).await?; + // NOTE: to not break everything, this should be cancel-safe; + // DO NOT modify `buf` unless a full message has been read + self.inner + .try_read(|buf| { + // all packets in postgres start with a 5-byte header + // this header contains the message type and the total length of the message + let Some(mut header) = buf.get(..5) else { + return Ok(ControlFlow::Continue(5)); + }; + + let format = BackendMessageFormat::try_from_u8(header.get_u8())?; + + let message_len = header.get_u32() as usize; + + let expected_len = message_len + .checked_add(1) + // this shouldn't really happen but is mostly a sanity check + .ok_or_else(|| { + err_protocol!("message_len + 1 overflows usize: {message_len}") + })?; + + if buf.len() < expected_len { + return Ok(ControlFlow::Continue(expected_len)); + } + + // `buf` SHOULD NOT be modified ABOVE this line + + // pop off the format code since it's not counted in `message_len` + buf.advance(1); - let format = BackendMessageFormat::try_from_u8(header.get_u8())?; - let size = (header.get_u32() - 4) as usize; + // consume the message, including the length prefix + let mut contents = buf.split_to(message_len).freeze(); - let contents = self.inner.read(size).await?; + // cut off the length prefix + contents.advance(4); - Ok(ReceivedMessage { format, contents }) + Ok(ControlFlow::Break(ReceivedMessage { format, contents })) + }) + .await } // Get the next message from the server diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index 43bd3c8ff5..3e647d6340 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -262,8 +262,11 @@ impl PgListener { if (err.kind() == io::ErrorKind::ConnectionAborted || err.kind() == io::ErrorKind::UnexpectedEof) => { - self.buffer_tx = self.connection().await?.stream.notifications.take(); - self.connection = None; + if let Some(mut conn) = self.connection.take() { + self.buffer_tx = conn.stream.notifications.take(); + // Close the connection in a background task, so we can continue. + conn.close_on_drop(); + } // lost connection return Ok(None);