From a6d4790e3ff2e31e1237df07bddc8071056c56b8 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Wed, 17 Dec 2025 17:24:21 +0000 Subject: [PATCH] bugfix(client): Ensure that the braid stream can be pooled in Chateau's connection pools. --- src/stream/core.rs | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/src/stream/core.rs b/src/stream/core.rs index 0b2bfb19..13194f01 100644 --- a/src/stream/core.rs +++ b/src/stream/core.rs @@ -1,5 +1,6 @@ //! Core stream type for braid providing [AsyncRead] and [AsyncWrite]. +use chateau::client::pool::PoolableStream; use pin_project::pin_project; use tokio::io::{AsyncRead, AsyncWrite}; @@ -50,14 +51,29 @@ impl HasConnectionInfo for Braid { } macro_rules! dispatch_core { - ($driver:ident.$method:ident($($args:expr),+)) => { + (pin $driver:ident.$method:ident($($args:expr),*)) => { match $driver.project().inner.project() { - BraidCoreProjection::Tcp(stream) => stream.$method($($args),+), - BraidCoreProjection::Duplex(stream) => stream.$method($($args),+), - BraidCoreProjection::Unix(stream) => stream.$method($($args),+), + BraidCoreProjection::Tcp(stream) => stream.$method($($args),*), + BraidCoreProjection::Duplex(stream) => stream.$method($($args),*), + BraidCoreProjection::Unix(stream) => stream.$method($($args),*), } }; + + ($driver:ident.$method:ident($($args:expr),*)) => { + + match &$driver.inner { + BraidCore::Tcp(stream) => stream.$method($($args),*), + BraidCore::Duplex(stream) => stream.$method($($args),*), + BraidCore::Unix(stream) => stream.$method($($args),*), + } + }; +} + +impl PoolableStream for Braid { + fn can_share(&self) -> bool { + dispatch_core!(self.can_share()) + } } impl AsyncRead for Braid { @@ -66,7 +82,7 @@ impl AsyncRead for Braid { cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { - dispatch_core!(self.poll_read(cx, buf)) + dispatch_core!(pin self.poll_read(cx, buf)) } } @@ -76,21 +92,21 @@ impl AsyncWrite for Braid { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { - dispatch_core!(self.poll_write(cx, buf)) + dispatch_core!(pin self.poll_write(cx, buf)) } fn poll_flush( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - dispatch_core!(self.poll_flush(cx)) + dispatch_core!(pin self.poll_flush(cx)) } fn poll_shutdown( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - dispatch_core!(self.poll_shutdown(cx)) + dispatch_core!(pin self.poll_shutdown(cx)) } }