Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 118 additions & 54 deletions crates/optimism/flashblocks/src/ws/stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::FlashBlock;
use futures_util::{stream::SplitStream, FutureExt, Stream, StreamExt};
use futures_util::{
stream::{SplitSink, SplitStream},
FutureExt, Sink, Stream, StreamExt,
};
use std::{
fmt::{Debug, Formatter},
future::Future,
Expand All @@ -9,7 +12,7 @@ use std::{
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async,
tungstenite::{Error, Message},
tungstenite::{Bytes, Error, Message},
MaybeTlsStream, WebSocketStream,
};
use tracing::debug;
Expand All @@ -21,15 +24,16 @@ use url::Url;
///
/// If the connection fails, the error is returned and connection retried. The number of retries is
/// unbounded.
pub struct WsFlashBlockStream<Stream, Connector> {
pub struct WsFlashBlockStream<Stream, Sink, Connector> {
ws_url: Url,
state: State,
connector: Connector,
connect: ConnectFuture<Stream>,
connect: ConnectFuture<Sink, Stream>,
stream: Option<Stream>,
sink: Option<Sink>,
}

impl WsFlashBlockStream<WssStream, WsConnector> {
impl WsFlashBlockStream<WsStream, WsSink, WsConnector> {
/// Creates a new websocket stream over `ws_url`.
pub fn new(ws_url: Url) -> Self {
Self {
Expand All @@ -38,11 +42,12 @@ impl WsFlashBlockStream<WssStream, WsConnector> {
connector: WsConnector,
connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
stream: None,
sink: None,
}
}
}

impl<S, C> WsFlashBlockStream<S, C> {
impl<Stream, S, C> WsFlashBlockStream<Stream, S, C> {
/// Creates a new websocket stream over `ws_url`.
pub fn with_connector(ws_url: Url, connector: C) -> Self {
Self {
Expand All @@ -51,60 +56,73 @@ impl<S, C> WsFlashBlockStream<S, C> {
connector,
connect: Box::pin(async move { Err(Error::ConnectionClosed)? }),
stream: None,
sink: None,
}
}
}

impl<S, C> Stream for WsFlashBlockStream<S, C>
impl<Str, S, C> Stream for WsFlashBlockStream<Str, S, C>
where
S: Stream<Item = Result<Message, Error>> + Unpin,
C: WsConnect<Stream = S> + Clone + Send + Sync + 'static + Unpin,
Str: Stream<Item = Result<Message, Error>> + Unpin,
S: Sink<Message> + Send + Sync + Unpin,
C: WsConnect<Stream = Str, Sink = S> + Clone + Send + Sync + 'static + Unpin,
{
type Item = eyre::Result<FlashBlock>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.state == State::Initial {
self.connect();
}
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

loop {
if this.state == State::Initial {
this.connect();
}

if self.state == State::Connect {
match ready!(self.connect.poll_unpin(cx)) {
Ok(stream) => self.stream(stream),
Err(err) => {
self.state = State::Initial;
if this.state == State::Connect {
match ready!(this.connect.poll_unpin(cx)) {
Ok((sink, stream)) => this.stream(sink, stream),
Err(err) => {
this.state = State::Initial;

return Poll::Ready(Some(Err(err)));
return Poll::Ready(Some(Err(err)));
}
}
}
}

loop {
let Some(msg) = ready!(self
.stream
.as_mut()
.expect("Stream state should be unreachable without stream")
.poll_next_unpin(cx))
else {
return Poll::Ready(None);
};

match msg {
Ok(Message::Binary(bytes)) => return Poll::Ready(Some(FlashBlock::decode(bytes))),
Ok(Message::Ping(_) | Message::Pong(_)) => {
// can ginore for now
while let State::Stream(pong) = &mut this.state {
if pong.is_some() {
let mut sink = Pin::new(this.sink.as_mut().unwrap());
let _ = ready!(sink.as_mut().poll_ready(cx));
if let Some(pong) = pong.take() {
let _ = sink.as_mut().start_send(pong);
}
let _ = ready!(sink.as_mut().poll_flush(cx));
}
Ok(msg) => {
debug!("Received unexpected message: {:?}", msg);

let Some(msg) = ready!(this
.stream
.as_mut()
.expect("Stream state should be unreachable without stream")
.poll_next_unpin(cx))
else {
return Poll::Ready(None);
};

match msg {
Ok(Message::Binary(bytes)) => {
return Poll::Ready(Some(FlashBlock::decode(bytes)))
}
Ok(Message::Ping(bytes)) => this.ping(bytes),
Ok(msg) => debug!("Received unexpected message: {:?}", msg),
Err(err) => return Poll::Ready(Some(Err(err.into()))),
}
Err(err) => return Poll::Ready(Some(Err(err.into()))),
}
}
}
}

impl<S, C> WsFlashBlockStream<S, C>
impl<Stream, S, C> WsFlashBlockStream<Stream, S, C>
where
C: WsConnect<Stream = S> + Clone + Send + Sync + 'static,
C: WsConnect<Stream = Stream, Sink = S> + Clone + Send + Sync + 'static,
{
fn connect(&mut self) {
let ws_url = self.ws_url.clone();
Expand All @@ -115,14 +133,21 @@ where
self.state = State::Connect;
}

fn stream(&mut self, stream: S) {
fn stream(&mut self, sink: S, stream: Stream) {
self.sink.replace(sink);
self.stream.replace(stream);

self.state = State::Stream;
self.state = State::Stream(None);
}

fn ping(&mut self, pong: Bytes) {
if let State::Stream(current) = &mut self.state {
current.replace(Message::Pong(pong));
}
}
}

impl<S: Debug, C: Debug> Debug for WsFlashBlockStream<S, C> {
impl<Stream: Debug, S: Debug, C: Debug> Debug for WsFlashBlockStream<Stream, S, C> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlashBlockStream")
.field("ws_url", &self.ws_url)
Expand All @@ -139,13 +164,14 @@ enum State {
#[default]
Initial,
Connect,
Stream,
Stream(Option<Message>),
}

type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WssStream = SplitStream<WsStream>;
type ConnectFuture<Stream> =
Pin<Box<dyn Future<Output = eyre::Result<Stream>> + Send + Sync + 'static>>;
type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WsStream = SplitStream<Ws>;
type WsSink = SplitSink<Ws, Message>;
type ConnectFuture<Sink, Stream> =
Pin<Box<dyn Future<Output = eyre::Result<(Sink, Stream)>> + Send + Sync + 'static>>;

/// The `WsConnect` trait allows for connecting to a websocket.
///
Expand All @@ -160,13 +186,16 @@ pub trait WsConnect {
/// An associated `Stream` of [`Message`]s wrapped in a [`Result`] that this connection returns.
type Stream;

/// An associated `Sink` of [`Message`]s that this connection sends.
type Sink;

/// Asynchronously connects to a websocket hosted on `ws_url`.
///
/// See the [`WsConnect`] documentation for details.
fn connect(
&mut self,
ws_url: Url,
) -> impl Future<Output = eyre::Result<Self::Stream>> + Send + Sync;
) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync;
}

/// Establishes a secure websocket subscription.
Expand All @@ -176,12 +205,13 @@ pub trait WsConnect {
pub struct WsConnector;

impl WsConnect for WsConnector {
type Stream = WssStream;
type Stream = WsStream;
type Sink = WsSink;

async fn connect(&mut self, ws_url: Url) -> eyre::Result<WssStream> {
async fn connect(&mut self, ws_url: Url) -> eyre::Result<(WsSink, WsStream)> {
let (stream, _response) = connect_async(ws_url.as_str()).await?;

Ok(stream.split().1)
Ok(stream.split())
}
}

Expand Down Expand Up @@ -231,14 +261,47 @@ mod tests {
}
}

#[derive(Clone)]
struct NoopSink;

impl<T> Sink<T> for NoopSink {
type Error = ();

fn poll_ready(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
unimplemented!()
}

fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
unimplemented!()
}

fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
unimplemented!()
}

fn poll_close(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
unimplemented!()
}
}

impl WsConnect for FakeConnector {
type Stream = FakeStream;
type Sink = NoopSink;

fn connect(
&mut self,
_ws_url: Url,
) -> impl Future<Output = eyre::Result<Self::Stream>> + Send + Sync {
future::ready(Ok(self.0.clone()))
) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync {
future::ready(Ok((NoopSink, self.0.clone())))
}
}

Expand All @@ -254,11 +317,12 @@ mod tests {

impl WsConnect for FailingConnector {
type Stream = FakeStream;
type Sink = NoopSink;

fn connect(
&mut self,
_ws_url: Url,
) -> impl Future<Output = eyre::Result<Self::Stream>> + Send + Sync {
) -> impl Future<Output = eyre::Result<(Self::Sink, Self::Stream)>> + Send + Sync {
future::ready(Err(eyre::eyre!("{}", &self.0)))
}
}
Expand Down