diff --git a/crates/optimism/flashblocks/src/ws/stream.rs b/crates/optimism/flashblocks/src/ws/stream.rs index 8a8438b0878..26626102d31 100644 --- a/crates/optimism/flashblocks/src/ws/stream.rs +++ b/crates/optimism/flashblocks/src/ws/stream.rs @@ -1,5 +1,8 @@ use crate::FlashBlock; -use futures_util::{stream::SplitStream, FutureExt, Stream, StreamExt}; +use futures_util::{ + stream::{SplitSink, SplitStream}, + FutureExt, Sink, Stream, StreamExt, +}; use std::{ fmt::{Debug, Formatter}, future::Future, @@ -9,7 +12,7 @@ use std::{ use tokio::net::TcpStream; use tokio_tungstenite::{ connect_async, - tungstenite::{Error, Message}, + tungstenite::{Bytes, Error, Message}, MaybeTlsStream, WebSocketStream, }; use tracing::debug; @@ -21,15 +24,16 @@ use url::Url; /// /// If the connection fails, the error is returned and connection retried. The number of retries is /// unbounded. -pub struct WsFlashBlockStream { +pub struct WsFlashBlockStream { ws_url: Url, state: State, connector: Connector, - connect: ConnectFuture, + connect: ConnectFuture, stream: Option, + sink: Option, } -impl WsFlashBlockStream { +impl WsFlashBlockStream { /// Creates a new websocket stream over `ws_url`. pub fn new(ws_url: Url) -> Self { Self { @@ -38,11 +42,12 @@ impl WsFlashBlockStream { connector: WsConnector, connect: Box::pin(async move { Err(Error::ConnectionClosed)? }), stream: None, + sink: None, } } } -impl WsFlashBlockStream { +impl WsFlashBlockStream { /// Creates a new websocket stream over `ws_url`. pub fn with_connector(ws_url: Url, connector: C) -> Self { Self { @@ -51,60 +56,73 @@ impl WsFlashBlockStream { connector, connect: Box::pin(async move { Err(Error::ConnectionClosed)? }), stream: None, + sink: None, } } } -impl Stream for WsFlashBlockStream +impl Stream for WsFlashBlockStream where - S: Stream> + Unpin, - C: WsConnect + Clone + Send + Sync + 'static + Unpin, + Str: Stream> + Unpin, + S: Sink + Send + Sync + Unpin, + C: WsConnect + Clone + Send + Sync + 'static + Unpin, { type Item = eyre::Result; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.state == State::Initial { - self.connect(); - } + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + loop { + if this.state == State::Initial { + this.connect(); + } - if self.state == State::Connect { - match ready!(self.connect.poll_unpin(cx)) { - Ok(stream) => self.stream(stream), - Err(err) => { - self.state = State::Initial; + if this.state == State::Connect { + match ready!(this.connect.poll_unpin(cx)) { + Ok((sink, stream)) => this.stream(sink, stream), + Err(err) => { + this.state = State::Initial; - return Poll::Ready(Some(Err(err))); + return Poll::Ready(Some(Err(err))); + } } } - } - loop { - let Some(msg) = ready!(self - .stream - .as_mut() - .expect("Stream state should be unreachable without stream") - .poll_next_unpin(cx)) - else { - return Poll::Ready(None); - }; - - match msg { - Ok(Message::Binary(bytes)) => return Poll::Ready(Some(FlashBlock::decode(bytes))), - Ok(Message::Ping(_) | Message::Pong(_)) => { - // can ginore for now + while let State::Stream(pong) = &mut this.state { + if pong.is_some() { + let mut sink = Pin::new(this.sink.as_mut().unwrap()); + let _ = ready!(sink.as_mut().poll_ready(cx)); + if let Some(pong) = pong.take() { + let _ = sink.as_mut().start_send(pong); + } + let _ = ready!(sink.as_mut().poll_flush(cx)); } - Ok(msg) => { - debug!("Received unexpected message: {:?}", msg); + + let Some(msg) = ready!(this + .stream + .as_mut() + .expect("Stream state should be unreachable without stream") + .poll_next_unpin(cx)) + else { + return Poll::Ready(None); + }; + + match msg { + Ok(Message::Binary(bytes)) => { + return Poll::Ready(Some(FlashBlock::decode(bytes))) + } + Ok(Message::Ping(bytes)) => this.ping(bytes), + Ok(msg) => debug!("Received unexpected message: {:?}", msg), + Err(err) => return Poll::Ready(Some(Err(err.into()))), } - Err(err) => return Poll::Ready(Some(Err(err.into()))), } } } } -impl WsFlashBlockStream +impl WsFlashBlockStream where - C: WsConnect + Clone + Send + Sync + 'static, + C: WsConnect + Clone + Send + Sync + 'static, { fn connect(&mut self) { let ws_url = self.ws_url.clone(); @@ -115,14 +133,21 @@ where self.state = State::Connect; } - fn stream(&mut self, stream: S) { + fn stream(&mut self, sink: S, stream: Stream) { + self.sink.replace(sink); self.stream.replace(stream); - self.state = State::Stream; + self.state = State::Stream(None); + } + + fn ping(&mut self, pong: Bytes) { + if let State::Stream(current) = &mut self.state { + current.replace(Message::Pong(pong)); + } } } -impl Debug for WsFlashBlockStream { +impl Debug for WsFlashBlockStream { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("FlashBlockStream") .field("ws_url", &self.ws_url) @@ -139,13 +164,14 @@ enum State { #[default] Initial, Connect, - Stream, + Stream(Option), } -type WsStream = WebSocketStream>; -type WssStream = SplitStream; -type ConnectFuture = - Pin> + Send + Sync + 'static>>; +type Ws = WebSocketStream>; +type WsStream = SplitStream; +type WsSink = SplitSink; +type ConnectFuture = + Pin> + Send + Sync + 'static>>; /// The `WsConnect` trait allows for connecting to a websocket. /// @@ -160,13 +186,16 @@ pub trait WsConnect { /// An associated `Stream` of [`Message`]s wrapped in a [`Result`] that this connection returns. type Stream; + /// An associated `Sink` of [`Message`]s that this connection sends. + type Sink; + /// Asynchronously connects to a websocket hosted on `ws_url`. /// /// See the [`WsConnect`] documentation for details. fn connect( &mut self, ws_url: Url, - ) -> impl Future> + Send + Sync; + ) -> impl Future> + Send + Sync; } /// Establishes a secure websocket subscription. @@ -176,12 +205,13 @@ pub trait WsConnect { pub struct WsConnector; impl WsConnect for WsConnector { - type Stream = WssStream; + type Stream = WsStream; + type Sink = WsSink; - async fn connect(&mut self, ws_url: Url) -> eyre::Result { + async fn connect(&mut self, ws_url: Url) -> eyre::Result<(WsSink, WsStream)> { let (stream, _response) = connect_async(ws_url.as_str()).await?; - Ok(stream.split().1) + Ok(stream.split()) } } @@ -231,14 +261,47 @@ mod tests { } } + #[derive(Clone)] + struct NoopSink; + + impl Sink for NoopSink { + type Error = (); + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + unimplemented!() + } + + fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> { + unimplemented!() + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + unimplemented!() + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + unimplemented!() + } + } + impl WsConnect for FakeConnector { type Stream = FakeStream; + type Sink = NoopSink; fn connect( &mut self, _ws_url: Url, - ) -> impl Future> + Send + Sync { - future::ready(Ok(self.0.clone())) + ) -> impl Future> + Send + Sync { + future::ready(Ok((NoopSink, self.0.clone()))) } } @@ -254,11 +317,12 @@ mod tests { impl WsConnect for FailingConnector { type Stream = FakeStream; + type Sink = NoopSink; fn connect( &mut self, _ws_url: Url, - ) -> impl Future> + Send + Sync { + ) -> impl Future> + Send + Sync { future::ready(Err(eyre::eyre!("{}", &self.0))) } }