diff --git a/src/protocol/libp2p/ping/config.rs b/src/protocol/libp2p/ping/config.rs index 085f25425..8b44b14ad 100644 --- a/src/protocol/libp2p/ping/config.rs +++ b/src/protocol/libp2p/ping/config.rs @@ -22,6 +22,7 @@ use crate::{ codec::ProtocolCodec, protocol::libp2p::ping::PingEvent, types::protocol::ProtocolName, DEFAULT_CHANNEL_SIZE, }; +use std::time::Duration; use futures::Stream; use tokio::sync::mpsc::{channel, Sender}; @@ -36,6 +37,8 @@ const PING_PAYLOAD_SIZE: usize = 32; /// Maximum PING failures. const MAX_FAILURES: usize = 3; +pub const PING_INTERVAL: Duration = Duration::from_secs(15); + /// Ping configuration. pub struct Config { /// Protocol name. @@ -49,6 +52,8 @@ pub struct Config { /// TX channel for sending events to the user protocol. pub(crate) tx_event: Sender, + + pub(crate) ping_interval: Duration, } impl Config { @@ -61,6 +66,7 @@ impl Config { ( Self { tx_event, + ping_interval: PING_INTERVAL, max_failures: MAX_FAILURES, protocol: ProtocolName::from(PROTOCOL_NAME), codec: ProtocolCodec::Identity(PING_PAYLOAD_SIZE), @@ -80,6 +86,7 @@ pub struct ConfigBuilder { /// Maximum failures before the peer is considered unreachable. max_failures: usize, + ping_interval: Duration, } impl Default for ConfigBuilder { @@ -92,6 +99,7 @@ impl ConfigBuilder { /// Create new default [`Config`] which can be modified by the user. pub fn new() -> Self { Self { + ping_interval: PING_INTERVAL, max_failures: MAX_FAILURES, protocol: ProtocolName::from(PROTOCOL_NAME), codec: ProtocolCodec::Identity(PING_PAYLOAD_SIZE), @@ -104,6 +112,11 @@ impl ConfigBuilder { self } + pub fn with_ping_interval(mut self, ping_interval: Duration) -> Self { + self.ping_interval = ping_interval; + self + } + /// Build [`Config`]. pub fn build(self) -> (Config, Box + Send + Unpin>) { let (tx_event, rx_event) = channel(DEFAULT_CHANNEL_SIZE); @@ -111,6 +124,7 @@ impl ConfigBuilder { ( Config { tx_event, + ping_interval: self.ping_interval, max_failures: self.max_failures, protocol: self.protocol, codec: self.codec, diff --git a/src/protocol/libp2p/ping/mod.rs b/src/protocol/libp2p/ping/mod.rs index fa16069fd..7e895490b 100644 --- a/src/protocol/libp2p/ping/mod.rs +++ b/src/protocol/libp2p/ping/mod.rs @@ -24,26 +24,26 @@ use crate::{ error::{Error, SubstreamError}, protocol::{Direction, TransportEvent, TransportService}, substream::Substream, - types::SubstreamId, PeerId, }; -use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; -use tokio::sync::mpsc::Sender; - +use bytes::Bytes; +use futures::{stream::SplitSink, SinkExt, StreamExt}; use std::{ - collections::HashSet, + collections::HashMap, time::{Duration, Instant}, }; +use tokio::sync::mpsc; +use tokio_stream::StreamMap; pub use config::{Config, ConfigBuilder}; - mod config; // TODO: https://github.com/paritytech/litep2p/issues/132 let the user handle max failures /// Log target for the file. const LOG_TARGET: &str = "litep2p::ipfs::ping"; +const PING_TIMEOUT: Duration = Duration::from_secs(10); /// Events emitted by the ping protocol. #[derive(Debug)] @@ -60,23 +60,32 @@ pub enum PingEvent { /// Ping protocol. pub(crate) struct Ping { - /// Maximum failures before the peer is considered unreachable. - _max_failures: usize, - // Connection service. service: TransportService, /// TX channel for sending events to the user protocol. - tx: Sender, + tx: mpsc::Sender, + + /// Streams we read Pongs from. + outbound_streams: StreamMap>, - /// Connected peers. - peers: HashSet, + /// Sinks we write Pings to. + outbound_sinks: HashMap>, - /// Pending outbound substreams. - pending_outbound: FuturesUnordered>>, + /// Streams we read Pings from. + /// Keyed by a local counter to handle multiple streams per peer if necessary. + inbound_streams: StreamMap>, - /// Pending inbound substreams. - pending_inbound: FuturesUnordered>>, + /// Sinks we write Pongs to. + inbound_sinks: HashMap>, + + /// Counter for generating unique keys for inbound streams. + inbound_id_counter: usize, + + /// We need to track when we sent the ping to calculate the duration. + ping_times: HashMap, + + ping_interval: Duration, } impl Ping { @@ -85,126 +94,130 @@ impl Ping { Self { service, tx: config.tx_event, - peers: HashSet::new(), - pending_outbound: FuturesUnordered::new(), - pending_inbound: FuturesUnordered::new(), - _max_failures: config.max_failures, + ping_interval: config.ping_interval, + outbound_streams: StreamMap::new(), + outbound_sinks: HashMap::new(), + ping_times: HashMap::new(), + inbound_streams: StreamMap::new(), + inbound_sinks: HashMap::new(), + inbound_id_counter: 0, } } /// Connection established to remote peer. - fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, "connection established"); - - self.service.open_substream(peer)?; - self.peers.insert(peer); + fn on_connection_established(&mut self, peer: PeerId) { + tracing::debug!(target: LOG_TARGET, ?peer, "connection established, opening ping substream"); - Ok(()) + if let Err(error) = self.service.open_substream(peer) { + tracing::debug!(target: LOG_TARGET, ?peer, ?error, "failed to open substream"); + } } /// Connection closed to remote peer. fn on_connection_closed(&mut self, peer: PeerId) { - tracing::trace!(target: LOG_TARGET, ?peer, "connection closed"); - - self.peers.remove(&peer); + tracing::debug!(target: LOG_TARGET, ?peer, "connection closed"); + self.outbound_streams.remove(&peer); + self.outbound_sinks.remove(&peer); + self.ping_times.remove(&peer); } - /// Handle outbound substream. - fn on_outbound_substream( - &mut self, - peer: PeerId, - substream_id: SubstreamId, - mut substream: Substream, - ) { - tracing::trace!(target: LOG_TARGET, ?peer, "handle outbound substream"); - - self.pending_outbound.push(Box::pin(async move { - let future = async move { - // TODO: https://github.com/paritytech/litep2p/issues/134 generate random payload and verify it - substream.send_framed(vec![0u8; 32].into()).await?; - let now = Instant::now(); - let _ = substream.next().await.ok_or(Error::SubstreamError( - SubstreamError::ReadFailure(Some(substream_id)), - ))?; - let _ = substream.close().await; - - Ok(now.elapsed()) - }; - - match tokio::time::timeout(Duration::from_secs(10), future).await { - Err(_) => Err(Error::Timeout), - Ok(Err(error)) => Err(error), - Ok(Ok(elapsed)) => Ok((peer, elapsed)), - } - })); + /// Handle outbound substream (We initiated) + /// Registers it into the Outbound pipeline. + fn on_outbound_substream(&mut self, peer: PeerId, substream: Substream) { + tracing::trace!(target: LOG_TARGET, ?peer, "outbound ping substream registered"); + let (sink, stream) = substream.split(); + self.outbound_streams.insert(peer, stream); + self.outbound_sinks.insert(peer, sink); } - /// Substream opened to remote peer. - fn on_inbound_substream(&mut self, peer: PeerId, mut substream: Substream) { - tracing::trace!(target: LOG_TARGET, ?peer, "handle inbound substream"); - - self.pending_inbound.push(Box::pin(async move { - let future = async move { - let payload = substream - .next() - .await - .ok_or(Error::SubstreamError(SubstreamError::ReadFailure(None)))??; - substream.send_framed(payload.freeze()).await?; - let _ = substream.next().await.map(|_| ()); - - Ok(()) - }; - - match tokio::time::timeout(Duration::from_secs(10), future).await { - Err(_) => Err(Error::Timeout), - Ok(Err(error)) => Err(error), - Ok(Ok(())) => Ok(()), - } - })); + /// Handle inbound substream (They initiated). + /// Registers it into the Inbound pipeline. + fn on_inbound_substream(&mut self, peer: PeerId, substream: Substream) { + tracing::trace!(target: LOG_TARGET, ?peer, "inbound ping substream registered"); + let (sink, stream) = substream.split(); + + let id = self.inbound_id_counter; + self.inbound_id_counter += 1; + + self.inbound_streams.insert(id, stream); + self.inbound_sinks.insert(id, sink); } /// Start [`Ping`] event loop. pub async fn run(mut self) { tracing::debug!(target: LOG_TARGET, "starting ping event loop"); + let mut interval = tokio::time::interval(self.ping_interval); loop { tokio::select! { event = self.service.next() => match event { Some(TransportEvent::ConnectionEstablished { peer, .. }) => { - let _ = self.on_connection_established(peer); + self.on_connection_established(peer); } Some(TransportEvent::ConnectionClosed { peer }) => { self.on_connection_closed(peer); } - Some(TransportEvent::SubstreamOpened { - peer, - substream, - direction, - .. - }) => match direction { + Some(TransportEvent::SubstreamOpened { peer, substream, direction,.. }) => match direction { Direction::Inbound => { self.on_inbound_substream(peer, substream); } - Direction::Outbound(substream_id) => { - self.on_outbound_substream(peer, substream_id, substream); + Direction::Outbound(_) => { + self.on_outbound_substream(peer, substream); } - }, + } Some(_) => {} None => return, }, - _event = self.pending_inbound.next(), if !self.pending_inbound.is_empty() => {} - event = self.pending_outbound.next(), if !self.pending_outbound.is_empty() => { + + _ = interval.tick() => { + for (peer, sink) in self.outbound_sinks.iter_mut() { + let payload = vec![0u8; 32]; + + self.ping_times.insert(*peer, Instant::now()); + tracing::trace!(target: LOG_TARGET, ?peer, "sending ping"); + + if let Err(error) = sink.send(Bytes::from(payload)).await { + tracing::debug!(target: LOG_TARGET, ?peer, ?error, "failed to send ping"); + + } + } + } + + // Handle Outbound Responses (Pong is expected here) + Some((peer, event)) = self.outbound_streams.next() => { match event { - Some(Ok((peer, elapsed))) => { - let _ = self - .tx - .send(PingEvent::Ping { - peer, - ping: elapsed, - }) - .await; + Ok(payload) => { + if let Some(started) = self.ping_times.remove(&peer) { + + let elapsed = started.elapsed(); + tracing::trace!(target: LOG_TARGET, ?peer, ?elapsed, "pong received"); + let _ = self.tx.send(PingEvent::Ping { peer, ping: elapsed }).await; + } + } + Err(error) => { + tracing::debug!(target: LOG_TARGET, ?peer, ?error, "ping substream closed/error"); + self.outbound_streams.remove(&peer); + self.outbound_sinks.remove(&peer); + self.ping_times.remove(&peer); + } + } + } + + // Handle Outbound Responses (Ping is expected here) + Some((id, event)) = self.inbound_streams.next() => { + match event { + Ok(payload) => { + if let Some(sink) = self.inbound_sinks.get_mut(&id) { + tracing::trace!(target: LOG_TARGET, ?id, "sending pong"); + if let Err(error) = sink.send(payload.freeze()).await { + tracing::debug!(target: LOG_TARGET, ?id, ?error, "failed to send pong"); + } + } + } + Err(_) => { + self.inbound_streams.remove(&id); + self.inbound_sinks.remove(&id); } - event => tracing::debug!(target: LOG_TARGET, "failed to handle ping for an outbound peer: {event:?}"), } } } diff --git a/src/protocol/transport_service.rs b/src/protocol/transport_service.rs index b729e9312..16fa2e0fe 100644 --- a/src/protocol/transport_service.rs +++ b/src/protocol/transport_service.rs @@ -287,6 +287,8 @@ pub struct TransportService { /// Close the connection if no substreams are open within this time frame. keep_alive_tracker: KeepAliveTracker, + + counts_towards_keep_alive: bool, } impl TransportService { @@ -298,6 +300,7 @@ impl TransportService { next_substream_id: Arc, transport_handle: TransportManagerHandle, keep_alive_timeout: Duration, + counts_towards_keep_alive: bool, ) -> (Self, Sender) { let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE); @@ -313,6 +316,7 @@ impl TransportService { next_substream_id, connections: HashMap::new(), keep_alive_tracker, + counts_towards_keep_alive, }, tx, ) @@ -507,8 +511,10 @@ impl TransportService { "open substream", ); - self.keep_alive_tracker.substream_activity(peer, connection_id); - connection.try_upgrade(); + if self.counts_towards_keep_alive { + self.keep_alive_tracker.substream_activity(peer, connection_id); + connection.try_upgrade(); + } connection .open_substream( @@ -592,7 +598,7 @@ impl Stream for TransportService { substream, connection_id, }) => { - if protocol == self.protocol { + if protocol == self.protocol && self.counts_towards_keep_alive { self.keep_alive_tracker.substream_activity(peer, connection_id); if let Some(context) = self.connections.get_mut(&peer) { context.try_upgrade(&connection_id); diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index f44a07a6b..a8237005c 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -348,6 +348,7 @@ impl TransportManager { self.next_substream_id.clone(), self.transport_manager_handle.clone(), keep_alive_timeout, + true, ); self.protocols.insert(