diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 57673b507cc..f8e64e23dd9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,12 +67,14 @@ jobs: - name: cargo hack run: | cargo hack check --workspace --target wasm32-unknown-unknown \ + --exclude alloy-contract \ + --exclude alloy-node-bindings \ + --exclude alloy-providers \ --exclude alloy-signer \ --exclude alloy-signer-aws \ --exclude alloy-signer-gcp \ --exclude alloy-signer-ledger \ --exclude alloy-signer-trezor \ - --exclude alloy-node-bindings \ --exclude alloy-transport-ipc feature-checks: diff --git a/Cargo.toml b/Cargo.toml index 514a6494887..389e9f600ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,7 @@ futures-executor = "0.3.29" hyper = "0.14.27" tokio = "1.33" tokio-util = "0.7" +tokio-stream = "0.1.14" tower = { version = "0.4.13", features = ["util"] } tracing = "0.1.40" diff --git a/crates/contract/src/call.rs b/crates/contract/src/call.rs index 085c49d7fd8..ef5e1b6ac0b 100644 --- a/crates/contract/src/call.rs +++ b/crates/contract/src/call.rs @@ -2,7 +2,7 @@ use crate::{Error, Result}; use alloy_dyn_abi::{DynSolValue, FunctionExt, JsonAbiExt}; use alloy_json_abi::Function; use alloy_primitives::{Address, Bytes, U256, U64}; -use alloy_providers::provider::TempProvider; +use alloy_providers::tmp::TempProvider; use alloy_rpc_types::{ request::{TransactionInput, TransactionRequest}, state::StateOverride, @@ -456,7 +456,7 @@ mod tests { use super::*; use alloy_node_bindings::{Anvil, AnvilInstance}; use alloy_primitives::{address, b256, bytes, hex}; - use alloy_providers::provider::{HttpProvider, Provider}; + use alloy_providers::tmp::{HttpProvider, Provider}; use alloy_sol_types::sol; #[test] diff --git a/crates/contract/src/instance.rs b/crates/contract/src/instance.rs index b0a05bb9c32..e94ee9d9cae 100644 --- a/crates/contract/src/instance.rs +++ b/crates/contract/src/instance.rs @@ -2,7 +2,7 @@ use crate::{CallBuilder, Interface, Result}; use alloy_dyn_abi::DynSolValue; use alloy_json_abi::{Function, JsonAbi}; use alloy_primitives::{Address, Selector}; -use alloy_providers::provider::TempProvider; +use alloy_providers::tmp::TempProvider; /// A handle to an Ethereum contract at a specific address. /// diff --git a/crates/contract/src/lib.rs b/crates/contract/src/lib.rs index 5d12208e72d..64561fdce9b 100644 --- a/crates/contract/src/lib.rs +++ b/crates/contract/src/lib.rs @@ -34,5 +34,5 @@ pub use call::*; // NOTE: please avoid changing the API of this module due to its use in the `sol!` macro. #[doc(hidden)] pub mod private { - pub use alloy_providers::provider::TempProvider as Provider; + pub use alloy_providers::tmp::TempProvider as Provider; } diff --git a/crates/json-rpc/src/error.rs b/crates/json-rpc/src/error.rs index 277f87fa5c2..0aaa136f558 100644 --- a/crates/json-rpc/src/error.rs +++ b/crates/json-rpc/src/error.rs @@ -2,7 +2,7 @@ use crate::{ErrorPayload, RpcReturn}; use serde_json::value::RawValue; /// An RPC error. -#[derive(thiserror::Error, Debug)] +#[derive(Debug, thiserror::Error)] pub enum RpcError> { /// Server returned an error response. #[error("Server returned an error response: {0}")] diff --git a/crates/json-rpc/src/request.rs b/crates/json-rpc/src/request.rs index 5ff5c175724..b5bc2021185 100644 --- a/crates/json-rpc/src/request.rs +++ b/crates/json-rpc/src/request.rs @@ -150,7 +150,6 @@ where // Params may be omitted if it is 0-sized if sized_params { - // TODO: remove unwrap map.serialize_entry("params", &self.params)?; } diff --git a/crates/providers/Cargo.toml b/crates/providers/Cargo.toml index 3009f71c98f..7cc9ff377c9 100644 --- a/crates/providers/Cargo.toml +++ b/crates/providers/Cargo.toml @@ -13,21 +13,31 @@ exclude.workspace = true [dependencies] alloy-network.workspace = true -alloy-primitives.workspace = true alloy-rpc-client = { workspace = true, features = ["reqwest"] } -alloy-rpc-types.workspace = true alloy-rpc-trace-types.workspace = true +alloy-rpc-types.workspace = true alloy-transport-http = { workspace = true, features = ["reqwest"] } alloy-transport.workspace = true + +alloy-primitives.workspace = true + +async-stream = "0.3.5" async-trait.workspace = true +auto_impl = "1.1.0" +futures.workspace = true +lru = "0.12.2" +reqwest.workspace = true serde.workspace = true thiserror.workspace = true -reqwest.workspace = true -auto_impl = "1.1.0" +tokio = { workspace = true, features = ["sync", "macros"] } +tracing.workspace = true [dev-dependencies] +alloy-consensus.workspace = true alloy-node-bindings.workspace = true -tokio = { version = "1.33.0", features = ["macros"] } +alloy-rlp.workspace = true +tokio = { workspace = true, features = ["macros"] } +tracing-subscriber = { workspace = true, features = ["fmt"] } [features] anvil = [] diff --git a/crates/providers/src/builder.rs b/crates/providers/src/builder.rs index d69a7ad359a..6765cf2efd8 100644 --- a/crates/providers/src/builder.rs +++ b/crates/providers/src/builder.rs @@ -1,4 +1,4 @@ -use crate::{NetworkRpcClient, Provider}; +use crate::new::{Provider, RootProvider}; use alloy_network::Network; use alloy_rpc_client::RpcClient; use alloy_transport::Transport; @@ -7,28 +7,27 @@ use std::marker::PhantomData; /// A layering abstraction in the vein of [`tower::Layer`] /// /// [`tower::Layer`]: https://docs.rs/tower/latest/tower/trait.Layer.html -pub trait ProviderLayer, N: Network, T: Transport> { +pub trait ProviderLayer, N: Network, T: Transport + Clone> { type Provider: Provider; fn layer(&self, inner: P) -> Self::Provider; } -pub struct Stack { +pub struct Stack { inner: Inner, outer: Outer, - _pd: std::marker::PhantomData T>, } -impl Stack { +impl Stack { /// Create a new `Stack`. pub fn new(inner: Inner, outer: Outer) -> Self { - Stack { inner, outer, _pd: std::marker::PhantomData } + Stack { inner, outer } } } -impl ProviderLayer for Stack +impl ProviderLayer for Stack where - T: Transport, + T: Transport + Clone, N: Network, P: Provider, Inner: ProviderLayer, @@ -49,14 +48,13 @@ where /// around maintaining the network and transport types. /// /// [`tower::ServiceBuilder`]: https://docs.rs/tower/latest/tower/struct.ServiceBuilder.html -pub struct ProviderBuilder { +pub struct ProviderBuilder { layer: L, - transport: PhantomData, network: PhantomData, } -impl ProviderBuilder { +impl ProviderBuilder { /// Add a layer to the stack being built. This is similar to /// [`tower::ServiceBuilder::layer`]. /// @@ -70,12 +68,8 @@ impl ProviderBuilder { /// [`tower::ServiceBuilder::layer`]: https://docs.rs/tower/latest/tower/struct.ServiceBuilder.html#method.layer /// [`tower::ServiceBuilder`]: https://docs.rs/tower/latest/tower/struct.ServiceBuilder.html - pub fn layer(self, layer: Inner) -> ProviderBuilder> { - ProviderBuilder { - layer: Stack::new(layer, self.layer), - transport: PhantomData, - network: PhantomData, - } + pub fn layer(self, layer: Inner) -> ProviderBuilder> { + ProviderBuilder { layer: Stack::new(layer, self.layer), network: PhantomData } } /// Change the network. @@ -87,34 +81,34 @@ impl ProviderBuilder { /// ```rust,ignore /// builder.network::() /// ``` - pub fn network(self) -> ProviderBuilder { - ProviderBuilder { layer: self.layer, transport: self.transport, network: PhantomData } + pub fn network(self) -> ProviderBuilder { + ProviderBuilder { layer: self.layer, network: PhantomData } } - /// Finish the layer stack by providing a root [`RpcClient`], outputting + /// Finish the layer stack by providing a root [`Provider`], outputting /// the final [`Provider`] type with all stack components. - /// - /// This is a convenience function for - /// `ProviderBuilder::provider`. - pub fn client(self, client: RpcClient) -> L::Provider + pub fn provider(self, provider: P) -> L::Provider where - L: ProviderLayer, N, T>, + L: ProviderLayer, + P: Provider, T: Transport + Clone, N: Network, { - self.provider(NetworkRpcClient::from(client)) + self.layer.layer(provider) } - /// Finish the layer stack by providing a root [`Provider`], outputting + /// Finish the layer stack by providing a root [`RpcClient`], outputting /// the final [`Provider`] type with all stack components. - pub fn provider

(self, provider: P) -> L::Provider + /// + /// This is a convenience function for + /// `ProviderBuilder::provider`. + pub fn on_client(self, client: RpcClient) -> L::Provider where - L: ProviderLayer, - P: Provider, - T: Transport, + L: ProviderLayer, N, T>, + T: Transport + Clone, N: Network, { - self.layer.layer(provider) + self.provider(RootProvider::new(client)) } } diff --git a/crates/providers/src/chain.rs b/crates/providers/src/chain.rs new file mode 100644 index 00000000000..a233b2a9894 --- /dev/null +++ b/crates/providers/src/chain.rs @@ -0,0 +1,112 @@ +use crate::{new::RootProviderInner, Provider, RootProvider, WeakProvider}; +use alloy_network::Network; +use alloy_primitives::{BlockNumber, U64}; +use alloy_rpc_client::{PollTask, WeakClient}; +use alloy_rpc_types::Block; +use alloy_transport::{RpcError, Transport}; +use async_stream::stream; +use futures::{Stream, StreamExt}; +use lru::LruCache; +use std::{num::NonZeroUsize, sync::Arc, time::Duration}; + +/// The size of the block cache. +const BLOCK_CACHE_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(10) }; + +/// Maximum number of retries for fetching a block. +const MAX_RETRIES: usize = 3; + +/// Default block number for when we don't have a block yet. +const NO_BLOCK_NUMBER: BlockNumber = BlockNumber::MAX; + +pub(crate) struct ChainStreamPoller { + provider: WeakProvider

, + poll_task: PollTask, + next_yield: BlockNumber, + known_blocks: LruCache, +} + +impl ChainStreamPoller, T> { + pub(crate) fn from_root(p: &RootProvider) -> Self { + let mut this = Self::new(Arc::downgrade(&p.inner), p.inner.weak_client()); + if p.client().is_local() { + this.poll_task.set_poll_interval(Duration::from_secs(1)); + } + this + } +} + +impl ChainStreamPoller { + pub(crate) fn new(provider: WeakProvider

, client: WeakClient) -> Self { + Self { + provider, + poll_task: PollTask::new(client, "eth_blockNumber", ()), + next_yield: NO_BLOCK_NUMBER, + known_blocks: LruCache::new(BLOCK_CACHE_SIZE), + } + } + + pub(crate) fn into_stream(mut self) -> impl Stream + where + P: Provider, + { + stream! { + let mut poll_task = self.poll_task.spawn().into_stream(); + 'task: loop { + // Clear any buffered blocks. + while let Some(known_block) = self.known_blocks.pop(&self.next_yield) { + debug!(number=self.next_yield, "yielding block"); + self.next_yield += 1; + yield known_block; + } + + // Get the tip. + let block_number = match poll_task.next().await { + Some(Ok(block_number)) => block_number, + Some(Err(err)) => { + // This is fine. + debug!(%err, "polling stream lagged"); + continue 'task; + } + None => { + debug!("polling stream ended"); + break 'task; + } + }; + let block_number = block_number.to::(); + if self.next_yield == NO_BLOCK_NUMBER { + assert!(block_number < NO_BLOCK_NUMBER, "too many blocks"); + self.next_yield = block_number; + } else if block_number < self.next_yield { + debug!(block_number, self.next_yield, "not advanced yet"); + continue 'task; + } + + // Upgrade the provider. + let Some(provider) = self.provider.upgrade() else { + debug!("provider dropped"); + break 'task; + }; + + // Then try to fill as many blocks as possible. + // TODO: Maybe use `join_all` + let mut retries = MAX_RETRIES; + for number in self.next_yield..=block_number { + debug!(number, "fetching block"); + let block = match provider.get_block_by_number(number, false).await { + Ok(block) => block, + Err(RpcError::Transport(err)) if retries > 0 && err.recoverable() => { + debug!(number, %err, "failed to fetch block, retrying"); + retries -= 1; + continue; + } + Err(err) => { + error!(number, %err, "failed to fetch block"); + break 'task; + } + }; + self.known_blocks.put(number, block); + } + } + } + } +} diff --git a/crates/providers/src/heart.rs b/crates/providers/src/heart.rs new file mode 100644 index 00000000000..a34d9c2a968 --- /dev/null +++ b/crates/providers/src/heart.rs @@ -0,0 +1,286 @@ +//! Block Hearbeat and Transaction Watcher + +use alloy_primitives::{B256, U256}; +use alloy_rpc_types::Block; +use alloy_transport::{utils::Spawnable, TransportErrorKind, TransportResult}; +use futures::{stream::StreamExt, FutureExt, Stream}; +use std::{ + collections::{BTreeMap, HashMap}, + fmt, + future::Future, + time::{Duration, Instant}, +}; +use tokio::{ + select, + sync::{mpsc, oneshot, watch}, +}; + +/// A configuration object for watching for transaction confirmation. +#[derive(Debug)] +pub struct WatchConfig { + /// The transaction hash to watch for. + tx_hash: B256, + + /// Require a number of confirmations. + confirmations: u64, + + /// Optional timeout for the transaction. + timeout: Option, +} + +impl WatchConfig { + /// Create a new watch for a transaction. + pub fn new(tx_hash: B256) -> Self { + Self { tx_hash, confirmations: 0, timeout: None } + } + + /// Set the number of confirmations to wait for. + pub fn set_confirmations(&mut self, confirmations: u64) { + self.confirmations = confirmations; + } + + /// Set the number of confirmations to wait for. + pub fn with_confirmations(mut self, confirmations: u64) -> Self { + self.confirmations = confirmations; + self + } + + /// Set the timeout for the transaction. + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = Some(timeout); + } + + /// Set the timeout for the transaction. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } +} + +struct TxWatcher { + config: WatchConfig, + tx: oneshot::Sender<()>, +} + +impl TxWatcher { + /// Notify the waiter. + fn notify(self) { + debug!(tx=%self.config.tx_hash, "notifying"); + let _ = self.tx.send(()); + } +} + +/// A pending transaction that can be awaited. +pub struct PendingTransaction { + /// The transaction hash. + pub(crate) tx_hash: B256, + /// The receiver for the notification. + // TODO: send a receipt? + pub(crate) rx: oneshot::Receiver<()>, +} + +impl fmt::Debug for PendingTransaction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PendingTransaction").field("tx_hash", &self.tx_hash).finish() + } +} + +impl PendingTransaction { + /// Returns this transaction's hash. + pub const fn tx_hash(&self) -> &B256 { + &self.tx_hash + } +} + +impl Future for PendingTransaction { + type Output = TransportResult; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.rx + .poll_unpin(cx) + .map(|res| res.map(|()| self.tx_hash).map_err(|_| TransportErrorKind::backend_gone())) + } +} + +/// A handle to the heartbeat task. +#[derive(Debug, Clone)] +pub(crate) struct HeartbeatHandle { + tx: mpsc::Sender, + #[allow(dead_code)] + latest: watch::Receiver>, +} + +impl HeartbeatHandle { + /// Watch for a transaction to be confirmed with the given config. + pub(crate) async fn watch_tx( + &self, + config: WatchConfig, + ) -> Result { + let (tx, rx) = oneshot::channel(); + let tx_hash = config.tx_hash; + match self.tx.send(TxWatcher { config, tx }).await { + Ok(()) => Ok(PendingTransaction { tx_hash, rx }), + Err(e) => Err(e.0.config), + } + } + + /// Returns a watcher that always sees the latest block. + #[allow(dead_code)] + pub(crate) fn latest(&self) -> &watch::Receiver> { + &self.latest + } +} + +// TODO: Parameterize with `Network` +/// A heartbeat task that receives blocks and watches for transactions. +pub(crate) struct Heartbeat { + /// The stream of incoming blocks to watch. + stream: futures::stream::Fuse, + + /// Transactions to watch for. + unconfirmed: HashMap, + + /// Ordered map of transactions waiting for confirmations. + waiting_confs: BTreeMap>, + + /// Ordered map of transactions to reap at a certain time. + reap_at: BTreeMap, +} + +impl> Heartbeat { + /// Create a new heartbeat task. + pub(crate) fn new(stream: S) -> Self { + Self { + stream: stream.fuse(), + unconfirmed: Default::default(), + waiting_confs: Default::default(), + reap_at: Default::default(), + } + } +} + +impl Heartbeat { + /// Check if any transactions have enough confirmations to notify. + fn check_confirmations(&mut self, current_height: &U256) { + let to_keep = self.waiting_confs.split_off(current_height); + let to_notify = std::mem::replace(&mut self.waiting_confs, to_keep); + for watcher in to_notify.into_values().flatten() { + watcher.notify(); + } + } + + /// Get the next time to reap a transaction. If no reaps, this is a very + /// long time from now (i.e. will not be woken). + fn next_reap(&self) -> Instant { + self.reap_at + .first_key_value() + .map(|(k, _)| *k) + .unwrap_or_else(|| Instant::now() + Duration::from_secs(60_000)) + } + + /// Reap any timeout + fn reap_timeouts(&mut self) { + let now = Instant::now(); + let to_keep = self.reap_at.split_off(&now); + let to_reap = std::mem::replace(&mut self.reap_at, to_keep); + + for tx_hash in to_reap.values() { + if self.unconfirmed.remove(tx_hash).is_some() { + debug!(tx=%tx_hash, "reaped"); + } + } + } + + /// Handle a watch instruction by adding it to the watch list, and + /// potentially adding it to our `reap_at` list. + fn handle_watch_ix(&mut self, to_watch: TxWatcher) { + // Start watching for the transaction. + debug!(tx=%to_watch.config.tx_hash, "watching"); + trace!(?to_watch.config); + if let Some(timeout) = to_watch.config.timeout { + self.reap_at.insert(Instant::now() + timeout, to_watch.config.tx_hash); + } + self.unconfirmed.insert(to_watch.config.tx_hash, to_watch); + } + + /// Handle a new block by checking if any of the transactions we're + /// watching are in it, and if so, notifying the watcher. Also updates + /// the latest block. + fn handle_new_block(&mut self, block: Block, latest: &watch::Sender>) { + // Blocks without numbers are ignored, as they're not part of the chain. + let Some(block_height) = &block.header.number else { return }; + + // Check if we are watching for any of the transactions in this block. + let to_check = + block.transactions.hashes().filter_map(|tx_hash| self.unconfirmed.remove(tx_hash)); + for watcher in to_check { + // If `confirmations` is 0 we can notify the watcher immediately. + let confirmations = watcher.config.confirmations; + if confirmations == 0 { + watcher.notify(); + continue; + } + // Otherwise add it to the waiting list. + debug!(tx=%watcher.config.tx_hash, %block_height, confirmations, "adding to waiting list"); + self.waiting_confs + .entry(*block_height + U256::from(confirmations)) + .or_default() + .push(watcher); + } + + self.check_confirmations(block_height); + + // Update the latest block. We use `send_replace` here to ensure the + // latest block is always up to date, even if no receivers exist. + // C.f. https://docs.rs/tokio/latest/tokio/sync/watch/struct.Sender.html#method.send + debug!(%block_height, "updating latest block"); + let _ = latest.send_replace(Some(block)); + } +} + +impl + Unpin + Send + 'static> Heartbeat { + /// Spawn the heartbeat task, returning a [`HeartbeatHandle`] + pub(crate) fn spawn(mut self) -> HeartbeatHandle { + let (latest, latest_rx) = watch::channel(None::); + let (ix_tx, mut ixns) = mpsc::channel(16); + + let fut = async move { + 'shutdown: loop { + { + let next_reap = self.next_reap(); + let sleep = std::pin::pin!(tokio::time::sleep_until(next_reap.into())); + + // We bias the select so that we always handle new messages + // before checking blocks, and reap timeouts are last. + select! { + biased; + + // Watch for new transactions. + ix_opt = ixns.recv() => match ix_opt { + Some(to_watch) => self.handle_watch_ix(to_watch), + None => break 'shutdown, // ix channel is closed + }, + + // Wake up to handle new blocks. + block = self.stream.select_next_some() => { + self.handle_new_block(block, &latest); + }, + + // This arm ensures we always wake up to reap timeouts, + // even if there are no other events. + _ = sleep => {}, + } + } + + // Always reap timeouts + self.reap_timeouts(); + } + }; + fut.spawn_task(); + + HeartbeatHandle { tx: ix_tx, latest: latest_rx } + } +} diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 036cdb1c5a5..dab33965640 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -16,157 +16,21 @@ #![deny(unused_must_use, rust_2018_idioms)] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] -use alloy_network::{Network, Transaction}; -use alloy_primitives::Address; -use alloy_rpc_client::RpcClient; -use alloy_transport::{BoxTransport, Transport, TransportResult}; -use std::{borrow::Cow, marker::PhantomData}; +#[macro_use] +extern crate tracing; mod builder; pub use builder::{ProviderBuilder, ProviderLayer, Stack}; -pub mod provider; -pub mod utils; - -/// A network-wrapped RPC client. -/// -/// This type allows you to specify (at the type-level) that the RPC client is -/// for a specific network. This helps avoid accidentally using the wrong -/// connection to access a network. -#[derive(Debug)] -pub struct NetworkRpcClient { - pub network: PhantomData N>, - pub client: RpcClient, -} - -impl std::ops::Deref for NetworkRpcClient -where - N: Network, - T: Transport, -{ - type Target = RpcClient; - - fn deref(&self) -> &Self::Target { - &self.client - } -} - -impl From> for NetworkRpcClient -where - N: Network, - T: Transport, -{ - fn from(client: RpcClient) -> Self { - Self { network: PhantomData, client } - } -} - -impl From> for RpcClient -where - N: Network, - T: Transport, -{ - fn from(client: NetworkRpcClient) -> Self { - client.client - } -} - -/// Provider is parameterized with a network and a transport. The default -/// transport is type-erased, but you can do `Provider`. -#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] -pub trait Provider: Send + Sync { - fn raw_client(&self) -> &RpcClient { - &self.client().client - } - - /// Return a reference to the inner RpcClient. - fn client(&self) -> &NetworkRpcClient; - - /// Return a reference to the inner Provider. - /// - /// Providers are object safe now :) - fn inner(&self) -> &dyn Provider; +mod chain; - async fn estimate_gas( - &self, - tx: &N::TransactionRequest, - ) -> TransportResult { - self.inner().estimate_gas(tx).await - } +mod heart; +pub use heart::{PendingTransaction, WatchConfig}; - /// Get the transaction count for an address. Used for finding the - /// appropriate nonce. - /// - /// TODO: block number/hash/tag - async fn get_transaction_count( - &self, - address: Address, - ) -> TransportResult { - self.inner().get_transaction_count(address).await - } +pub mod new; +pub use new::{Provider, ProviderRef, RootProvider, WeakProvider}; - /// Send a transaction to the network. - /// - /// The transaction type is defined by the network. - async fn send_transaction( - &self, - tx: &N::TransactionRequest, - ) -> TransportResult { - self.inner().send_transaction(tx).await - } - - async fn populate_gas(&self, tx: &mut N::TransactionRequest) -> TransportResult<()> { - let gas = self.estimate_gas(&*tx).await; - - gas.map(|gas| tx.set_gas_limit(gas.try_into().unwrap())) - } -} - -#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] -impl Provider for NetworkRpcClient { - fn client(&self) -> &NetworkRpcClient { - self - } - - fn inner(&self) -> &dyn Provider { - panic!("called inner on ") - } - - async fn estimate_gas( - &self, - tx: &::TransactionRequest, - ) -> TransportResult { - self.prepare("eth_estimateGas", Cow::Borrowed(tx)).await - } - - async fn get_transaction_count( - &self, - address: Address, - ) -> TransportResult { - self.prepare( - "eth_getTransactionCount", - Cow::<(Address, String)>::Owned((address, "latest".to_string())), - ) - .await - } - - async fn send_transaction( - &self, - tx: &N::TransactionRequest, - ) -> TransportResult { - self.prepare("eth_sendTransaction", Cow::Borrowed(tx)).await - } -} - -#[cfg(test)] -mod test { - use crate::Provider; - use alloy_network::Network; +pub mod utils; - // checks that `Provider` is object-safe - fn __compile_check() -> Box> { - unimplemented!() - } -} +// TODO: remove +pub mod tmp; diff --git a/crates/providers/src/new.rs b/crates/providers/src/new.rs new file mode 100644 index 00000000000..71b37233e6d --- /dev/null +++ b/crates/providers/src/new.rs @@ -0,0 +1,363 @@ +use crate::{ + chain::ChainStreamPoller, + heart::{Heartbeat, HeartbeatHandle, PendingTransaction, WatchConfig}, +}; +use alloy_network::{Network, Transaction}; +use alloy_primitives::{hex, Address, BlockNumber, B256, U256, U64}; +use alloy_rpc_client::{ClientRef, RpcClient, WeakClient}; +use alloy_rpc_types::Block; +use alloy_transport::{BoxTransport, Transport, TransportErrorKind, TransportResult}; +use std::{ + marker::PhantomData, + sync::{Arc, OnceLock, Weak}, +}; + +/// A [`Provider`] in a [`Weak`] reference. +pub type WeakProvider

= Weak

; + +/// A borrowed [`Provider`]. +pub type ProviderRef<'a, P> = &'a P; + +/// The root provider manages the RPC client and the heartbeat. It is at the +/// base of every provider stack. +pub struct RootProvider { + /// The inner state of the root provider. + pub(crate) inner: Arc>, +} + +impl RootProvider { + pub(crate) fn new(client: RpcClient) -> Self { + Self { inner: Arc::new(RootProviderInner::new(client)) } + } +} + +impl RootProvider { + async fn new_pending_transaction(&self, tx_hash: B256) -> TransportResult { + // TODO: Make this configurable. + let cfg = WatchConfig::new(tx_hash); + self.get_heart().watch_tx(cfg).await.map_err(|_| TransportErrorKind::backend_gone()) + } + + #[inline] + fn get_heart(&self) -> &HeartbeatHandle { + self.inner.heart.get_or_init(|| { + let poller = ChainStreamPoller::from_root(self); + // TODO: Can we avoid `Box::pin` here? + Heartbeat::new(Box::pin(poller.into_stream())).spawn() + }) + } +} + +/// The root provider manages the RPC client and the heartbeat. It is at the +/// base of every provider stack. +pub(crate) struct RootProviderInner { + client: RpcClient, + heart: OnceLock, + _network: PhantomData, +} + +impl RootProviderInner { + pub(crate) fn new(client: RpcClient) -> Self { + Self { client, heart: OnceLock::new(), _network: PhantomData } + } + + fn weak_client(&self) -> WeakClient { + self.client.get_weak() + } + + fn client_ref(&self) -> ClientRef<'_, T> { + self.client.get_ref() + } +} + +/// Provider is parameterized with a network and a transport. The default +/// transport is type-erased, but you can do `Provider`. +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[auto_impl::auto_impl(&, &mut, Rc, Arc, Box)] +pub trait Provider: Send + Sync { + /// Returns the RPC client used to send requests. + fn client(&self) -> ClientRef<'_, T>; + + /// Returns a [`Weak`] RPC client used to send requests. + fn weak_client(&self) -> WeakClient; + + async fn new_pending_transaction(&self, tx_hash: B256) -> TransportResult; + + async fn estimate_gas(&self, tx: &N::TransactionRequest) -> TransportResult { + self.client().prepare("eth_estimateGas", (tx,)).await + } + + /// Get the last block number available. + async fn get_block_number(&self) -> TransportResult { + self.client().prepare("eth_blockNumber", ()).await.map(|num: U64| num.to::()) + } + + /// Get the transaction count for an address. Used for finding the + /// appropriate nonce. + /// + /// TODO: block number/hash/tag + async fn get_transaction_count(&self, address: Address) -> TransportResult { + self.client().prepare("eth_getTransactionCount", (address, "latest")).await + } + + /// Get a block by its number. + /// + /// TODO: Network associate + async fn get_block_by_number( + &self, + number: BlockNumber, + hydrate: bool, + ) -> TransportResult { + self.client().prepare("eth_getBlockByNumber", (number, hydrate)).await + } + + /// Populate the gas limit for a transaction. + async fn populate_gas(&self, tx: &mut N::TransactionRequest) -> TransportResult<()> { + let gas = self.estimate_gas(&*tx).await?; + if let Ok(gas) = gas.try_into() { + tx.set_gas_limit(gas); + } + Ok(()) + } + + /// Broadcasts a transaction, returning a [`PendingTransaction`] that resolves once the + /// transaction has been confirmed. + async fn send_transaction( + &self, + tx: &N::TransactionRequest, + ) -> TransportResult { + let tx_hash = self.client().prepare("eth_sendTransaction", (tx,)).await?; + self.new_pending_transaction(tx_hash).await + } + + /// Broadcasts a transaction's raw RLP bytes, returning a [`PendingTransaction`] that resolves + /// once the transaction has been confirmed. + async fn send_raw_transaction(&self, rlp_bytes: &[u8]) -> TransportResult { + let rlp_hex = hex::encode(rlp_bytes); + let tx_hash = self.client().prepare("eth_sendRawTransaction", (rlp_hex,)).await?; + self.new_pending_transaction(tx_hash).await + } +} + +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl Provider for RootProvider { + #[inline] + fn client(&self) -> ClientRef<'_, T> { + self.inner.client_ref() + } + + #[inline] + fn weak_client(&self) -> WeakClient { + self.inner.weak_client() + } + + #[inline] + async fn new_pending_transaction(&self, tx_hash: B256) -> TransportResult { + RootProvider::new_pending_transaction(self, tx_hash).await + } +} + +// Internal implementation for [`ChainStreamPoller`]. +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl Provider for RootProviderInner { + #[inline] + fn client(&self) -> ClientRef<'_, T> { + self.client_ref() + } + + #[inline] + fn weak_client(&self) -> WeakClient { + self.weak_client() + } + + #[inline] + async fn new_pending_transaction(&self, _tx_hash: B256) -> TransportResult { + unreachable!() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use alloy_primitives::address; + use alloy_rpc_types::request::{TransactionInput, TransactionRequest}; + use alloy_transport_http::Http; + use reqwest::Client; + + struct _ObjectSafe(dyn Provider); + + #[derive(Clone)] + struct TxLegacy(alloy_consensus::TxLegacy); + impl serde::Serialize for TxLegacy { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let tx = &self.0; + TransactionRequest { + from: None, + to: tx.to().to(), + gas_price: tx.gas_price(), + max_fee_per_gas: None, + max_priority_fee_per_gas: None, + max_fee_per_blob_gas: None, + gas: Some(U256::from(tx.gas_limit())), + value: Some(tx.value()), + input: TransactionInput::new(tx.input().to_vec().into()), + nonce: Some(U64::from(tx.nonce())), + chain_id: tx.chain_id().map(U64::from), + access_list: None, + transaction_type: None, + blob_versioned_hashes: None, + sidecar: None, + other: Default::default(), + } + .serialize(serializer) + } + } + impl<'de> serde::Deserialize<'de> for TxLegacy { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + unimplemented!() + } + } + #[allow(unused)] + impl alloy_network::Transaction for TxLegacy { + type Signature = (); + + fn encode_for_signing(&self, out: &mut dyn alloy_rlp::BufMut) { + todo!() + } + + fn payload_len_for_signature(&self) -> usize { + todo!() + } + + fn into_signed( + self, + signature: alloy_primitives::Signature, + ) -> alloy_network::Signed + where + Self: Sized, + { + todo!() + } + + fn encode_signed( + &self, + signature: &alloy_primitives::Signature, + out: &mut dyn alloy_primitives::bytes::BufMut, + ) { + todo!() + } + + fn decode_signed(buf: &mut &[u8]) -> alloy_rlp::Result> + where + Self: Sized, + { + todo!() + } + + fn input(&self) -> &[u8] { + todo!() + } + + fn input_mut(&mut self) -> &mut alloy_primitives::Bytes { + todo!() + } + + fn set_input(&mut self, data: alloy_primitives::Bytes) { + todo!() + } + + fn to(&self) -> alloy_network::TxKind { + todo!() + } + + fn set_to(&mut self, to: alloy_network::TxKind) { + todo!() + } + + fn value(&self) -> U256 { + todo!() + } + + fn set_value(&mut self, value: U256) { + todo!() + } + + fn chain_id(&self) -> Option { + todo!() + } + + fn set_chain_id(&mut self, chain_id: alloy_primitives::ChainId) { + todo!() + } + + fn nonce(&self) -> u64 { + todo!() + } + + fn set_nonce(&mut self, nonce: u64) { + todo!() + } + + fn gas_limit(&self) -> u64 { + todo!() + } + + fn set_gas_limit(&mut self, limit: u64) { + todo!() + } + + fn gas_price(&self) -> Option { + todo!() + } + + fn set_gas_price(&mut self, price: U256) { + todo!() + } + } + + struct TmpNetwork; + impl Network for TmpNetwork { + type TxEnvelope = alloy_consensus::TxEnvelope; + type ReceiptEnvelope = alloy_consensus::ReceiptEnvelope; + type Header = (); + type TransactionRequest = TxLegacy; + type TransactionResponse = (); + type ReceiptResponse = (); + type HeaderResponse = (); + } + + fn init_tracing() { + let _ = tracing_subscriber::fmt::try_init(); + } + + #[tokio::test] + async fn test_send_tx() { + init_tracing(); + + let anvil = alloy_node_bindings::Anvil::new().spawn(); + let url = anvil.endpoint().parse().unwrap(); + let http = Http::::new(url); + let provider = RootProvider::::new(RpcClient::new(http, true)); + + let tx = alloy_consensus::TxLegacy { + value: U256::from(100), + to: address!("d8dA6BF26964aF9D7eEd9e03E53415D37aA96045").into(), + gas_price: 20e9 as u128, + gas_limit: 21000, + ..Default::default() + }; + let pending_tx = provider.send_transaction(&TxLegacy(tx)).await.expect("failed to send tx"); + let hash1 = pending_tx.tx_hash; + let hash2 = pending_tx.await.expect("failed to await pending tx"); + assert_eq!(hash1, hash2); + } +} diff --git a/crates/providers/src/provider.rs b/crates/providers/src/tmp.rs similarity index 99% rename from crates/providers/src/provider.rs rename to crates/providers/src/tmp.rs index 9c032adf214..25c054e52dd 100644 --- a/crates/providers/src/provider.rs +++ b/crates/providers/src/tmp.rs @@ -33,7 +33,7 @@ pub type HttpProvider = Provider>; /// An abstract provider for interacting with the [Ethereum JSON RPC /// API](https://github.com/ethereum/wiki/wiki/JSON-RPC). Must be instantiated /// with a transport which implements the [Transport] trait. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Provider { inner: RpcClient, from: Option

, @@ -552,7 +552,7 @@ impl<'a> TryFrom<&'a String> for Provider> { #[cfg(test)] mod tests { use crate::{ - provider::{Provider, TempProvider}, + tmp::{Provider, TempProvider}, utils, }; use alloy_node_bindings::Anvil; diff --git a/crates/pubsub/Cargo.toml b/crates/pubsub/Cargo.toml index 273943147a5..ec1459dc15d 100644 --- a/crates/pubsub/Cargo.toml +++ b/crates/pubsub/Cargo.toml @@ -21,5 +21,6 @@ futures.workspace = true serde.workspace = true serde_json.workspace = true tokio = { workspace = true, features = ["macros", "sync"] } +tokio-stream = { workspace = true, features = ["sync"] } tower.workspace = true tracing.workspace = true diff --git a/crates/pubsub/src/service.rs b/crates/pubsub/src/service.rs index a57cbd67791..84ae35b966b 100644 --- a/crates/pubsub/src/service.rs +++ b/crates/pubsub/src/service.rs @@ -123,18 +123,10 @@ impl PubSubService { /// the subscription does not exist, the waiter is sent nothing, and the /// `tx` is dropped. This notifies the waiter that the subscription does /// not exist. - fn service_get_sub( - &mut self, - local_id: U256, - tx: oneshot::Sender, - ) -> TransportResult<()> { - let local_id = local_id.into(); - - if let Some(rx) = self.subs.get_subscription(local_id) { + fn service_get_sub(&mut self, local_id: U256, tx: oneshot::Sender) { + if let Some(rx) = self.subs.get_subscription(local_id.into()) { let _ = tx.send(rx); } - - Ok(()) } /// Service an unsubscribe instruction. @@ -153,7 +145,10 @@ impl PubSubService { trace!(?ix, "servicing instruction"); match ix { PubSubInstruction::Request(in_flight) => self.service_request(in_flight), - PubSubInstruction::GetSub(alias, tx) => self.service_get_sub(alias, tx), + PubSubInstruction::GetSub(alias, tx) => { + self.service_get_sub(alias, tx); + Ok(()) + } PubSubInstruction::Unsubscribe(alias) => self.service_unsubscribe(alias), } } diff --git a/crates/pubsub/src/sub.rs b/crates/pubsub/src/sub.rs index 119080882ce..b37cdea884f 100644 --- a/crates/pubsub/src/sub.rs +++ b/crates/pubsub/src/sub.rs @@ -1,7 +1,10 @@ use alloy_primitives::B256; +use futures::{ready, Stream, StreamExt}; use serde::de::DeserializeOwned; use serde_json::value::RawValue; +use std::{pin::Pin, task}; use tokio::sync::broadcast; +use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; /// A Subscription is a feed of notifications from the server, identified by a /// local ID. @@ -18,8 +21,8 @@ pub struct RawSubscription { impl RawSubscription { /// Get the local ID of the subscription. - pub const fn local_id(&self) -> B256 { - self.local_id + pub const fn local_id(&self) -> &B256 { + &self.local_id } /// Wrapper for [`blocking_recv`]. Block the current thread until a message @@ -72,6 +75,11 @@ impl RawSubscription { pub fn try_recv(&mut self) -> Result, broadcast::error::TryRecvError> { self.rx.try_recv() } + + /// Convert the subscription into a stream. + pub fn into_stream(self) -> BroadcastStream> { + self.rx.into() + } } /// An item in a typed [`Subscription`]. This is either the expected type, or @@ -119,7 +127,7 @@ impl From for Subscription { impl Subscription { /// Get the local ID of the subscription. - pub const fn local_id(&self) -> B256 { + pub const fn local_id(&self) -> &B256 { self.inner.local_id() } @@ -206,6 +214,15 @@ impl Subscription { self.inner.try_recv().map(Into::into) } + /// Convert the subscription into a stream that may yield unexpected types. + pub fn into_any_stream(self) -> SubAnyStream { + SubAnyStream { + id: self.inner.local_id, + inner: self.inner.into_stream(), + _pd: std::marker::PhantomData, + } + } + /// Wrapper for [`blocking_recv`]. Block the current thread until a message /// of the expected type is available. /// @@ -245,6 +262,15 @@ impl Subscription { } } + /// Convert the subscription into a stream. + pub fn into_stream(self) -> SubscriptionStream { + SubscriptionStream { + id: self.inner.local_id, + inner: self.inner.into_stream(), + _pd: std::marker::PhantomData, + } + } + /// Wrapper for [`blocking_recv`]. Block the current thread until a message /// is available, deserializing the message and returning the result. /// @@ -274,4 +300,117 @@ impl Subscription { ) -> Result, broadcast::error::TryRecvError> { self.inner.try_recv().map(|value| serde_json::from_str(value.get())) } + + /// Convert the subscription into a stream that returns deserialization + /// results. + pub fn into_result_stream(self) -> SubResultStream { + SubResultStream { + id: self.inner.local_id, + inner: self.inner.into_stream(), + _pd: std::marker::PhantomData, + } + } +} + +/// A stream of notifications from the server, identified by a local ID. This +/// stream may yield unexpected types. +#[derive(Debug)] +pub struct SubAnyStream { + id: B256, + inner: BroadcastStream>, + _pd: std::marker::PhantomData T>, +} + +impl SubAnyStream { + /// Get the local ID of the subscription. + pub const fn id(&self) -> &B256 { + &self.id + } +} + +impl Stream for SubAnyStream { + type Item = Result, BroadcastStreamRecvError>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + match ready!(self.inner.poll_next_unpin(cx)) { + Some(value) => task::Poll::Ready(Some(value.map(Into::into))), + None => task::Poll::Ready(None), + } + } +} + +/// A stream of notifications from the server, identified by a local ID. This/ +/// stream will yield only the expected type, discarding any notifications of +/// unexpected types. +#[derive(Debug)] +pub struct SubscriptionStream { + id: B256, + inner: BroadcastStream>, + _pd: std::marker::PhantomData T>, +} + +impl SubscriptionStream { + /// Get the local ID of the subscription. + pub const fn id(&self) -> &B256 { + &self.id + } +} + +impl Stream for SubscriptionStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + loop { + match ready!(self.inner.poll_next_unpin(cx)) { + Some(Ok(value)) => match serde_json::from_str(value.get()) { + Ok(item) => return task::Poll::Ready(Some(Ok(item))), + Err(e) => { + trace!(value = value.get(), error = ?e, "Received unexpected value in subscription."); + continue; + } + }, + Some(Err(e)) => return task::Poll::Ready(Some(Err(e))), + None => return task::Poll::Ready(None), + } + } + } +} + +/// A stream of notifications from the server, identified by a local ID. +/// +/// This stream will attempt to deserialize the notifications and yield the [`serde_json::Result`] +/// of the deserialization. +#[derive(Debug)] +pub struct SubResultStream { + id: B256, + inner: BroadcastStream>, + _pd: std::marker::PhantomData T>, +} + +impl SubResultStream { + /// Get the local ID of the subscription. + pub const fn id(&self) -> &B256 { + &self.id + } +} + +impl Stream for SubResultStream { + type Item = Result, BroadcastStreamRecvError>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + match ready!(self.inner.poll_next_unpin(cx)) { + Some(Ok(value)) => task::Poll::Ready(Some(Ok(serde_json::from_str(value.get())))), + Some(Err(e)) => task::Poll::Ready(Some(Err(e))), + None => task::Poll::Ready(None), + } + } } diff --git a/crates/rpc-client/Cargo.toml b/crates/rpc-client/Cargo.toml index d66928548c7..00488f3993a 100644 --- a/crates/rpc-client/Cargo.toml +++ b/crates/rpc-client/Cargo.toml @@ -19,6 +19,9 @@ alloy-transport.workspace = true futures.workspace = true pin-project.workspace = true serde_json.workspace = true +serde.workspace = true +tokio = { workspace = true, features = ["sync"] } +tokio-stream = { workspace = true, features = ["sync"] } tower.workspace = true tracing.workspace = true @@ -29,8 +32,6 @@ hyper = { workspace = true, optional = true } reqwest = { workspace = true, optional = true } url = { workspace = true, optional = true } -serde = { workspace = true, optional = true } - [target.'cfg(not(target_arch = "wasm32"))'.dependencies] alloy-transport-ipc = { workspace = true, optional = true } @@ -50,6 +51,6 @@ tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } default = ["reqwest"] reqwest = ["dep:url", "dep:reqwest", "alloy-transport-http/reqwest"] hyper = ["dep:url", "dep:hyper", "alloy-transport-http/hyper"] -pubsub = ["dep:alloy-pubsub", "dep:alloy-primitives", "dep:serde"] +pubsub = ["dep:alloy-pubsub", "dep:alloy-primitives"] ws = ["pubsub", "dep:alloy-transport-ws"] ipc = ["pubsub", "dep:alloy-transport-ipc"] diff --git a/crates/rpc-client/src/batch.rs b/crates/rpc-client/src/batch.rs index 0e715b5070d..1f208266a9d 100644 --- a/crates/rpc-client/src/batch.rs +++ b/crates/rpc-client/src/batch.rs @@ -1,4 +1,4 @@ -use crate::RpcClient; +use crate::{client::RpcClientInner, ClientRef}; use alloy_json_rpc::{ transform_response, try_deserialize_ok, Id, Request, RequestPacket, ResponsePacket, RpcParam, RpcReturn, SerializedRequest, @@ -24,7 +24,7 @@ pub(crate) type ChannelMap = HashMap; #[must_use = "A BatchRequest does nothing unless sent via `send_batch` and `.await`"] pub struct BatchRequest<'a, T> { /// The transport via which the batch will be sent. - transport: &'a RpcClient, + transport: ClientRef<'a, T>, /// The requests to be sent. requests: RequestPacket, @@ -66,10 +66,7 @@ where #[pin_project::pin_project(project = CallStateProj)] #[derive(Debug)] -pub enum BatchFuture -where - Conn: Transport, -{ +pub enum BatchFuture { Prepared { transport: Conn, requests: RequestPacket, @@ -86,7 +83,7 @@ where impl<'a, T> BatchRequest<'a, T> { /// Create a new batch request. - pub fn new(transport: &'a RpcClient) -> Self { + pub fn new(transport: &'a RpcClientInner) -> Self { Self { transport, requests: RequestPacket::Batch(Vec::with_capacity(10)), diff --git a/crates/rpc-client/src/call.rs b/crates/rpc-client/src/call.rs index aaf1839dc87..c2f1560abb8 100644 --- a/crates/rpc-client/src/call.rs +++ b/crates/rpc-client/src/call.rs @@ -33,6 +33,21 @@ where Complete, } +impl Clone for CallState +where + Params: RpcParam, + Conn: Transport + Clone, +{ + fn clone(&self) -> Self { + match self { + Self::Prepared { request, connection } => { + Self::Prepared { request: request.clone(), connection: connection.clone() } + } + _ => panic!("cloned after dispatch"), + } + } +} + impl fmt::Debug for CallState where Params: RpcParam, @@ -152,6 +167,17 @@ where _pd: PhantomData Resp>, } +impl Clone for RpcCall +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, +{ + fn clone(&self) -> Self { + Self { state: self.state.clone(), _pd: PhantomData } + } +} + impl RpcCall where Conn: Transport + Clone, diff --git a/crates/rpc-client/src/client.rs b/crates/rpc-client/src/client.rs index b1126235569..be17d740a3b 100644 --- a/crates/rpc-client/src/client.rs +++ b/crates/rpc-client/src/client.rs @@ -1,10 +1,99 @@ -use crate::{BatchRequest, ClientBuilder, RpcCall}; +use crate::{poller::PollTask, BatchRequest, ClientBuilder, RpcCall}; use alloy_json_rpc::{Id, Request, RpcParam, RpcReturn}; -use alloy_transport::{BoxTransport, Transport, TransportConnect, TransportResult}; +use alloy_transport::{BoxTransport, Transport, TransportConnect, TransportError}; use alloy_transport_http::Http; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::{ + ops::Deref, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Weak, + }, +}; use tower::{layer::util::Identity, ServiceBuilder}; +/// An [`RpcClient`] in a [`Weak`] reference. +pub type WeakClient = Weak>; + +/// A borrowed [`RpcClient`]. +pub type ClientRef<'a, T> = &'a RpcClientInner; + +/// A JSON-RPC client. +#[derive(Debug)] +pub struct RpcClient(Arc>); + +impl Clone for RpcClient { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } +} + +impl RpcClient { + /// Create a new [`ClientBuilder`]. + pub fn builder() -> ClientBuilder { + ClientBuilder { builder: ServiceBuilder::new() } + } +} + +impl RpcClient { + /// Create a new [`RpcClient`] with the given transport. + pub fn new(t: T, is_local: bool) -> Self { + Self(Arc::new(RpcClientInner::new(t, is_local))) + } + + /// Get a [`Weak`] reference to the client. + pub fn get_weak(&self) -> WeakClient { + Arc::downgrade(&self.0) + } + + /// Borrow the client. + pub fn get_ref(&self) -> ClientRef<'_, T> { + &self.0 + } +} + +impl RpcClient { + /// Connect to a transport via a [`TransportConnect`] implementor. + pub async fn connect(connect: C) -> Result + where + C: TransportConnect, + { + ClientBuilder::default().connect(connect).await + } + + /// Poll a method with the given parameters. + /// + /// A [`PollTask`] + pub fn prepare_static_poller( + &self, + method: &'static str, + params: Params, + ) -> PollTask + where + T: Clone, + Params: RpcParam + 'static, + Resp: RpcReturn + Clone, + { + let request: Request = self.make_request(method, params); + PollTask::new(self.get_weak(), method, request.params) + } +} + +impl RpcClient> { + /// Create a new [`BatchRequest`] builder. + #[inline] + pub fn new_batch(&self) -> BatchRequest<'_, Http> { + BatchRequest::new(&self.0) + } +} + +impl Deref for RpcClient { + type Target = RpcClientInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + /// A JSON-RPC client. /// /// This struct manages a [`Transport`] and a request ID counter. It is used to @@ -14,11 +103,11 @@ use tower::{layer::util::Identity, ServiceBuilder}; /// ### Note /// /// IDs are allocated sequentially, starting at 0. IDs are reserved via -/// [`RpcClient::next_id`]. Note that allocated IDs may not be used. There is -/// no guarantee that a prepared [`RpcCall`] will be sent, or that a sent call -/// will receive a response. +/// [`RpcClientInner::next_id`]. Note that allocated IDs may not be used. There +/// is no guarantee that a prepared [`RpcCall`] will be sent, or that a sent +/// call will receive a response. #[derive(Debug)] -pub struct RpcClient { +pub struct RpcClientInner { /// The underlying transport. pub(crate) transport: T, /// `true` if the transport is local. @@ -27,33 +116,17 @@ pub struct RpcClient { pub(crate) id: AtomicU64, } -impl RpcClient { - /// Create a new [`ClientBuilder`]. - pub fn builder() -> ClientBuilder { - ClientBuilder { builder: ServiceBuilder::new() } - } -} - -impl RpcClient { +impl RpcClientInner { /// Create a new [`RpcClient`] with the given transport. pub const fn new(t: T, is_local: bool) -> Self { Self { transport: t, is_local, id: AtomicU64::new(0) } } - /// Connect to a transport via a [`TransportConnect`] implementor. - pub async fn connect(connect: C) -> TransportResult - where - T: Transport, - C: TransportConnect, - { - ClientBuilder::default().connect(connect).await - } - /// Build a `JsonRpcRequest` with the given method and params. /// /// This function reserves an ID for the request, however the request - /// is not sent. To send a request, use [`RpcClient::prepare`] and await - /// the returned [`RpcCall`]. + /// is not sent. To send a request, use [`RpcClientInner::prepare`] and + /// await the returned [`RpcCall`]. pub fn make_request( &self, method: &'static str, @@ -91,10 +164,7 @@ impl RpcClient { } } -impl RpcClient -where - T: Transport + Clone, -{ +impl RpcClientInner { /// Prepare an [`RpcCall`]. /// /// This function reserves an ID for the request, however the request @@ -122,9 +192,8 @@ where /// This is for abstracting over `RpcClient` for multiple `T` by /// erasing each type. E.g. if you have `RpcClient` and /// `RpcClient` you can put both into a `Vec>`. - #[inline] - pub fn boxed(self) -> RpcClient { - RpcClient { transport: self.transport.boxed(), is_local: self.is_local, id: self.id } + pub fn boxed(self) -> RpcClientInner { + RpcClientInner { transport: self.transport.boxed(), is_local: self.is_local, id: self.id } } } @@ -132,8 +201,9 @@ where mod pubsub_impl { use super::*; use alloy_pubsub::{PubSubConnect, PubSubFrontend, RawSubscription, Subscription}; + use alloy_transport::TransportResult; - impl RpcClient { + impl RpcClientInner { /// Get a [`RawSubscription`] for the given subscription ID. pub async fn get_raw_subscription(&self, id: alloy_primitives::U256) -> RawSubscription { self.transport.get_subscription(id).await.unwrap() @@ -146,7 +216,9 @@ mod pubsub_impl { ) -> Subscription { Subscription::from(self.get_raw_subscription(id).await) } + } + impl RpcClient { /// Connect to a transport via a [`PubSubConnect`] implementor. pub async fn connect_pubsub(connect: C) -> TransportResult> where @@ -161,26 +233,8 @@ mod pubsub_impl { /// behavior. /// /// [`tokio::sync::broadcast`]: https://docs.rs/tokio/latest/tokio/sync/broadcast/index.html - pub const fn channel_size(&self) -> usize { + pub fn channel_size(&self) -> usize { self.transport.channel_size() } - - /// Set the channel size. This is the number of items to buffer in new - /// subscription channels. Defaults to 16. See - /// [`tokio::sync::broadcast`] for a description of relevant - /// behavior. - /// - /// [`tokio::sync::broadcast`]: https://docs.rs/tokio/latest/tokio/sync/broadcast/index.html - pub fn set_channel_size(&mut self, size: usize) { - self.transport.set_channel_size(size); - } - } -} - -impl RpcClient> { - /// Create a new [`BatchRequest`] builder. - #[inline] - pub fn new_batch(&self) -> BatchRequest<'_, Http> { - BatchRequest::new(self) } } diff --git a/crates/rpc-client/src/lib.rs b/crates/rpc-client/src/lib.rs index c0a148c4e8c..1f3b6ad5f71 100644 --- a/crates/rpc-client/src/lib.rs +++ b/crates/rpc-client/src/lib.rs @@ -28,7 +28,10 @@ mod call; pub use call::RpcCall; mod client; -pub use client::RpcClient; +pub use client::{ClientRef, RpcClient, WeakClient}; + +mod poller; +pub use poller::{PollChannel, PollTask}; #[cfg(feature = "ws")] pub use alloy_transport_ws::WsConnect; diff --git a/crates/rpc-client/src/poller.rs b/crates/rpc-client/src/poller.rs new file mode 100644 index 00000000000..aa7934b68d3 --- /dev/null +++ b/crates/rpc-client/src/poller.rs @@ -0,0 +1,242 @@ +use crate::WeakClient; +use alloy_json_rpc::{RpcError, RpcParam, RpcReturn}; +use alloy_transport::{utils::Spawnable, Transport}; +use serde::Serialize; +use serde_json::value::RawValue; +use std::{ + marker::PhantomData, + ops::{Deref, DerefMut}, + time::Duration, +}; +use tokio::sync::broadcast; +use tokio_stream::wrappers::BroadcastStream; +use tracing::Instrument; + +/// The number of retries for polling a request. +const MAX_RETRIES: usize = 3; + +/// A Poller task. +#[derive(Debug)] +pub struct PollTask +where + Conn: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, +{ + /// The client to poll with. + client: WeakClient, + + /// Request Method + method: &'static str, + params: Params, + + // config options + channel_size: usize, + poll_interval: Duration, + limit: usize, + + _pd: PhantomData Resp>, +} + +impl PollTask +where + Conn: Transport + Clone, + Params: RpcParam + 'static, + Resp: RpcReturn + Clone, +{ + /// Create a new poller task with cloneable params. + pub fn new(client: WeakClient, method: &'static str, params: Params) -> Self { + Self { + client, + method, + params, + channel_size: 16, + poll_interval: Duration::from_secs(10), + limit: usize::MAX, + _pd: PhantomData, + } + } + + /// Returns the channel size for the poller task. + pub const fn channel_size(&self) -> usize { + self.channel_size + } + + /// Sets the channel size for the poller task. + pub fn set_channel_size(&mut self, channel_size: usize) { + self.channel_size = channel_size; + } + + /// Sets the channel size for the poller task. + pub fn with_channel_size(mut self, channel_size: usize) -> Self { + self.set_channel_size(channel_size); + self + } + + /// Retuns the limit on the number of succesful polls. + pub const fn limit(&self) -> usize { + self.limit + } + + /// Sets a limit on the number of succesful polls. + pub fn set_limit(&mut self, limit: Option) { + self.limit = limit.unwrap_or(usize::MAX); + } + + /// Sets a limit on the number of succesful polls. + pub fn with_limit(mut self, limit: Option) -> Self { + self.set_limit(limit); + self + } + + /// Returns the duration between polls. + pub const fn poll_interval(&self) -> Duration { + self.poll_interval + } + + /// Sets the duration between polls. + pub fn set_poll_interval(&mut self, poll_interval: Duration) { + self.poll_interval = poll_interval; + } + + /// Sets the duration between polls. + pub fn with_poll_interval(mut self, poll_interval: Duration) -> Self { + self.set_poll_interval(poll_interval); + self + } + + /// Spawn the poller task, producing a stream of responses. + pub fn spawn(self) -> PollChannel { + let (tx, rx) = broadcast::channel(self.channel_size); + let span = debug_span!("poller", method = self.method); + let fut = async move { + let mut params = ParamsOnce::Typed(self.params); + let mut retries = MAX_RETRIES; + 'outer: for _ in 0..self.limit { + let Some(client) = self.client.upgrade() else { + debug!("client dropped"); + break; + }; + + // Avoid serializing the params more than once. + let params = match params.get() { + Ok(p) => p, + Err(err) => { + error!(%err, "failed to serialize params"); + break; + } + }; + + loop { + trace!("polling"); + match client.prepare(self.method, params).await { + Ok(resp) => { + if tx.send(resp).is_err() { + debug!("channel closed"); + break 'outer; + } + } + Err(RpcError::Transport(err)) if retries > 0 && err.recoverable() => { + debug!(%err, "failed to poll, retrying"); + retries -= 1; + continue; + } + Err(err) => { + error!(%err, "failed to poll"); + break 'outer; + } + } + break; + } + + trace!(duration=?self.poll_interval, "sleeping"); + tokio::time::sleep(self.poll_interval).await; + } + }; + fut.instrument(span).spawn_task(); + rx.into() + } +} + +/// A channel yielding responses from a poller task. +/// +/// This stream is backed by a coroutine, and will continue to produce responses +/// until the poller task is dropped. The poller task is dropped when all +/// [`RpcClient`] instances are dropped, or when all listening `PollChannel` are +/// dropped. +/// +/// The poller task also ignores errors from the server and deserialization +/// errors, and will continue to poll until the client is dropped. +/// +/// [`RpcClient`]: crate::RpcClient +#[derive(Debug)] +pub struct PollChannel { + rx: broadcast::Receiver, +} + +impl From> for PollChannel { + fn from(rx: broadcast::Receiver) -> Self { + Self { rx } + } +} + +impl Deref for PollChannel { + type Target = broadcast::Receiver; + + fn deref(&self) -> &Self::Target { + &self.rx + } +} + +impl DerefMut for PollChannel { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.rx + } +} + +impl PollChannel +where + Resp: RpcReturn + Clone, +{ + /// Resubscribe to the poller task. + pub fn resubscribe(&self) -> Self { + Self { rx: self.rx.resubscribe() } + } + + /// Convert the poll channel into a stream. + pub fn into_stream(self) -> BroadcastStream { + self.rx.into() + } +} + +// Serializes the parameters only once. +enum ParamsOnce

{ + Typed(P), + Serialized(Box), +} + +impl ParamsOnce

{ + #[inline] + fn get(&mut self) -> serde_json::Result<&RawValue> { + match self { + ParamsOnce::Typed(_) => self.init(), + ParamsOnce::Serialized(p) => Ok(p), + } + } + + #[cold] + fn init(&mut self) -> serde_json::Result<&RawValue> { + let Self::Typed(p) = self else { unreachable!() }; + let v = serde_json::value::to_raw_value(p)?; + *self = ParamsOnce::Serialized(v); + let Self::Serialized(v) = self else { unreachable!() }; + Ok(v) + } +} + +#[cfg(test)] +#[allow(clippy::missing_const_for_fn)] +fn _assert_unpin() { + fn _assert() {} + _assert::>(); +} diff --git a/crates/signer-aws/src/signer.rs b/crates/signer-aws/src/signer.rs index d1e4506f03d..9cfc7154600 100644 --- a/crates/signer-aws/src/signer.rs +++ b/crates/signer-aws/src/signer.rs @@ -96,7 +96,7 @@ pub enum AwsSignerError { #[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl Signer for AwsSigner { #[instrument(err)] - #[allow(clippy::blocks_in_conditions)] // tracing::instrument on async fn + #[allow(clippy::blocks_in_conditions)] // instrument on async fn async fn sign_hash(&self, hash: B256) -> Result { self.sign_digest_inner(hash).await.map_err(alloy_signer::Error::other) } diff --git a/crates/transport-ipc/src/lib.rs b/crates/transport-ipc/src/lib.rs index f6912d7d22d..03f2df86602 100644 --- a/crates/transport-ipc/src/lib.rs +++ b/crates/transport-ipc/src/lib.rs @@ -71,8 +71,8 @@ impl IpcBackend { match item { Some(msg) => { let bytes = msg.get(); - if let Err(e) = writer.write_all(bytes.as_bytes()).await { - error!(%e, "Failed to write to IPC socket"); + if let Err(err) = writer.write_all(bytes.as_bytes()).await { + error!(%err, "Failed to write to IPC socket"); break true; } }, @@ -205,8 +205,8 @@ where // can try decoding again *this.drained = false; } - Err(e) => { - error!(%e, "Failed to read from IPC socket, shutting down"); + Err(err) => { + error!(%err, "Failed to read from IPC socket, shutting down"); return Ready(None); } } diff --git a/crates/transport-ws/src/lib.rs b/crates/transport-ws/src/lib.rs index 218b19ecc76..8f6286a92a5 100644 --- a/crates/transport-ws/src/lib.rs +++ b/crates/transport-ws/src/lib.rs @@ -60,8 +60,8 @@ impl WsBackend { return Err(()); } } - Err(e) => { - error!(e = %e, "Failed to deserialize message"); + Err(err) => { + error!(%err, "Failed to deserialize message"); return Err(()); } } diff --git a/crates/transport-ws/src/native.rs b/crates/transport-ws/src/native.rs index 3196c869b87..8a71032245e 100644 --- a/crates/transport-ws/src/native.rs +++ b/crates/transport-ws/src/native.rs @@ -88,7 +88,7 @@ impl WsBackend { /// Spawn a new backend task. pub fn spawn(mut self) { let fut = async move { - let mut err = false; + let mut errored = false; let keepalive = sleep(Duration::from_secs(KEEPALIVE)); tokio::pin!(keepalive); loop { @@ -111,9 +111,9 @@ impl WsBackend { Some(msg) => { // Reset the keepalive timer. keepalive.set(sleep(Duration::from_secs(KEEPALIVE))); - if let Err(e) = self.send(msg).await { - error!(err = %e, "WS connection error"); - err = true; + if let Err(err) = self.send(msg).await { + error!(%err, "WS connection error"); + errored = true; break } }, @@ -128,33 +128,33 @@ impl WsBackend { _ = &mut keepalive => { // Reset the keepalive timer. keepalive.set(sleep(Duration::from_secs(KEEPALIVE))); - if let Err(e) = self.socket.send(Message::Ping(vec![])).await { - error!(err = %e, "WS connection error"); - err = true; + if let Err(err) = self.socket.send(Message::Ping(vec![])).await { + error!(%err, "WS connection error"); + errored = true; break } } resp = self.socket.next() => { match resp { Some(Ok(item)) => { - err = self.handle(item).await.is_err(); - if err { break } + errored = self.handle(item).await.is_err(); + if errored { break } }, - Some(Err(e)) => { - error!(err = %e, "WS connection error"); - err = true; + Some(Err(err)) => { + error!(%err, "WS connection error"); + errored = true; break } None => { error!("WS server has gone away"); - err = true; + errored = true; break }, } } } } - if err { + if errored { self.interface.close_with_error(); } }; diff --git a/crates/transport-ws/src/wasm.rs b/crates/transport-ws/src/wasm.rs index 6fc965187f7..3789bbc7c97 100644 --- a/crates/transport-ws/src/wasm.rs +++ b/crates/transport-ws/src/wasm.rs @@ -53,7 +53,7 @@ impl WsBackend> { /// Spawn this backend on a loop. pub fn spawn(mut self) { let fut = async move { - let mut err = false; + let mut errored = false; loop { // We bias the loop as follows // 1. New dispatch to server. @@ -71,9 +71,9 @@ impl WsBackend> { inst = self.interface.recv_from_frontend() => { match inst { Some(msg) => { - if let Err(e) = self.send(msg).await { - error!(err = %e, "WS connection error"); - err = true; + if let Err(err) = self.send(msg).await { + error!(%err, "WS connection error"); + errored = true; break } }, @@ -86,19 +86,19 @@ impl WsBackend> { resp = self.socket.next() => { match resp { Some(item) => { - err = self.handle(item).await.is_err(); - if err { break } + errored = self.handle(item).await.is_err(); + if errored { break } }, None => { error!("WS server has gone away"); - err = true; + errored = true; break }, } } } } - if err { + if errored { self.interface.close_with_error(); } }; diff --git a/crates/transport/src/error.rs b/crates/transport/src/error.rs index b7ecb329705..0ba76feb3d2 100644 --- a/crates/transport/src/error.rs +++ b/crates/transport/src/error.rs @@ -32,6 +32,12 @@ pub enum TransportErrorKind { } impl TransportErrorKind { + /// Returns `true` if the error is potentially recoverable. + /// This is a naive heuristic and should be used with caution. + pub const fn recoverable(&self) -> bool { + matches!(self, Self::MissingBatchResponse(_)) + } + /// Instantiate a new `TransportError` from a custom error. pub fn custom_str(err: &str) -> TransportError { RpcError::Transport(Self::Custom(err.into())) diff --git a/crates/transport/src/lib.rs b/crates/transport/src/lib.rs index 7a7627f7f15..9db9bed1a62 100644 --- a/crates/transport/src/lib.rs +++ b/crates/transport/src/lib.rs @@ -32,7 +32,7 @@ pub use error::{TransportError, TransportResult}; mod r#trait; pub use r#trait::Transport; -pub use alloy_json_rpc::RpcResult; +pub use alloy_json_rpc::{RpcError, RpcResult}; /// Misc. utilities for building transports. pub mod utils; diff --git a/crates/transport/src/trait.rs b/crates/transport/src/trait.rs index db928a2893e..ca6b20c4ab7 100644 --- a/crates/transport/src/trait.rs +++ b/crates/transport/src/trait.rs @@ -47,10 +47,18 @@ pub trait Transport: /// Convert this transport into a boxed trait object. fn boxed(self) -> BoxTransport where - Self: Sized + Clone + Send + Sync + 'static, + Self: Sized + Clone, { BoxTransport::new(self) } + + /// Make a boxed trait object by cloning this transport. + fn as_boxed(&self) -> BoxTransport + where + Self: Sized + Clone, + { + self.clone().boxed() + } } impl Transport for T where