diff --git a/canister/src/providers/mod.rs b/canister/src/providers/mod.rs index 00506296..d8ec0d82 100644 --- a/canister/src/providers/mod.rs +++ b/canister/src/providers/mod.rs @@ -124,139 +124,128 @@ pub struct Providers { impl Providers { // Order of providers matters! // The threshold consensus strategy will consider the first `total` providers in the order - // they are specified (taking the default ones first, followed by the non default ones if necessary) - // if the providers are not explicitly specified by the caller. - const DEFAULT_MAINNET_SUPPORTED_PROVIDERS: &'static [SupportedRpcProviderId] = &[ + // they are specified if the providers are not explicitly specified by the caller. + const MAINNET_PROVIDERS: &'static [SupportedRpcProviderId] = &[ SupportedRpcProviderId::AlchemyMainnet, SupportedRpcProviderId::HeliusMainnet, SupportedRpcProviderId::DrpcMainnet, - ]; - const NON_DEFAULT_MAINNET_SUPPORTED_PROVIDERS: &'static [SupportedRpcProviderId] = &[ SupportedRpcProviderId::AnkrMainnet, SupportedRpcProviderId::PublicNodeMainnet, SupportedRpcProviderId::ChainstackMainnet, ]; - const DEFAULT_DEVNET_SUPPORTED_PROVIDERS: &'static [SupportedRpcProviderId] = &[ + const DEVNET_PROVIDERS: &'static [SupportedRpcProviderId] = &[ SupportedRpcProviderId::AlchemyDevnet, SupportedRpcProviderId::HeliusDevnet, SupportedRpcProviderId::DrpcDevnet, - ]; - const NON_DEFAULT_DEVNET_SUPPORTED_PROVIDERS: &'static [SupportedRpcProviderId] = &[ SupportedRpcProviderId::AnkrDevnet, SupportedRpcProviderId::ChainstackDevnet, ]; - pub fn new(source: RpcSources, strategy: ConsensusStrategy) -> Result { - fn get_sources(provider_ids: &[SupportedRpcProviderId]) -> Vec { - provider_ids - .iter() - .map(|provider| RpcSource::Supported(*provider)) - .collect() - } + const DEFAULT_NUM_PROVIDERS_FOR_EQUALITY: usize = 3; - let providers: BTreeSet<_> = match source { - RpcSources::Custom(sources) => { - choose_providers(Some(sources), vec![], vec![], strategy)? + pub fn new(source: RpcSources, strategy: ConsensusStrategy) -> Result { + fn supported_providers( + cluster: &SolanaCluster, + ) -> Result<&[SupportedRpcProviderId], ProviderError> { + match cluster { + SolanaCluster::Mainnet => Ok(Providers::MAINNET_PROVIDERS), + SolanaCluster::Devnet => Ok(Providers::DEVNET_PROVIDERS), + SolanaCluster::Testnet => { + Err(ProviderError::UnsupportedCluster(format!("{:?}", cluster))) + } } - RpcSources::Default(cluster) => match cluster { - SolanaCluster::Mainnet => choose_providers( - None, - get_sources(Self::DEFAULT_MAINNET_SUPPORTED_PROVIDERS), - get_sources(Self::NON_DEFAULT_MAINNET_SUPPORTED_PROVIDERS), - strategy, - )?, - SolanaCluster::Devnet => choose_providers( - None, - get_sources(Self::DEFAULT_DEVNET_SUPPORTED_PROVIDERS), - get_sources(Self::NON_DEFAULT_DEVNET_SUPPORTED_PROVIDERS), - strategy, - )?, - cluster => return Err(ProviderError::UnsupportedCluster(format!("{:?}", cluster))), - }, - }; - - if providers.is_empty() { - return Err(ProviderError::InvalidRpcConfig( - "No matching providers found".to_string(), - )); } - Ok(Self { sources: providers }) - } -} - -fn choose_providers( - user_input: Option>, - default_providers: Vec, - non_default_providers: Vec, - strategy: ConsensusStrategy, -) -> Result, ProviderError> { - match strategy { - ConsensusStrategy::Equality => Ok(user_input - .unwrap_or_else(|| default_providers.to_vec()) - .into_iter() - .collect()), - ConsensusStrategy::Threshold { total, min } => { - // Ensure that - // 0 < min <= total <= all_providers.len() - if min == 0 { - return Err(ProviderError::InvalidRpcConfig( - "min must be greater than 0".to_string(), - )); - } - match user_input { - None => { - let all_providers_len = default_providers.len() + non_default_providers.len(); - let total = total.ok_or_else(|| { - ProviderError::InvalidRpcConfig( - "total must be specified when using default providers".to_string(), - ) - })?; - - if min > total { - return Err(ProviderError::InvalidRpcConfig(format!( - "min {} is greater than total {}", - min, total - ))); - } + fn supported_rpc_source(supported_provider: &SupportedRpcProviderId) -> RpcSource { + RpcSource::Supported(*supported_provider) + } - if total > all_providers_len as u8 { - return Err(ProviderError::InvalidRpcConfig(format!( - "total {} is greater than the number of all supported providers {}", - total, all_providers_len - ))); - } - let providers: BTreeSet<_> = default_providers + let providers: BTreeSet<_> = match strategy { + ConsensusStrategy::Equality => match source { + RpcSources::Custom(custom_providers) => Ok(custom_providers.into_iter().collect()), + RpcSources::Default(cluster) => { + let supported_providers = supported_providers(&cluster)?; + assert!( + supported_providers.len() >= Self::DEFAULT_NUM_PROVIDERS_FOR_EQUALITY, + "BUG: need at least 3 providers, but got {supported_providers:?}" + ); + Ok(supported_providers .iter() - .chain(non_default_providers.iter()) - .take(total as usize) - .cloned() - .collect(); - assert_eq!(providers.len(), total as usize, "BUG: duplicate providers"); - Ok(providers) + .take(Self::DEFAULT_NUM_PROVIDERS_FOR_EQUALITY) + .map(supported_rpc_source) + .collect()) } - Some(providers) => { - if min > providers.len() as u8 { - return Err(ProviderError::InvalidRpcConfig(format!( - "min {} is greater than the number of specified providers {}", - min, - providers.len() - ))); + }, + ConsensusStrategy::Threshold { total, min } => { + // Ensure that + // 0 < min <= total <= all_providers.len() + if min == 0 { + return Err(ProviderError::InvalidRpcConfig( + "min must be greater than 0".to_string(), + )); + } + match source { + RpcSources::Custom(custom_providers) => { + if min > custom_providers.len() as u8 { + return Err(ProviderError::InvalidRpcConfig(format!( + "min {} is greater than the number of specified providers {}", + min, + custom_providers.len() + ))); + } + if let Some(total) = total { + if total != custom_providers.len() as u8 { + return Err(ProviderError::InvalidRpcConfig(format!( + "total {} is different than the number of specified providers {}", + total, + custom_providers.len() + ))); + } + }; + Ok(custom_providers.into_iter().collect()) } - if let Some(total) = total { - if total != providers.len() as u8 { + RpcSources::Default(cluster) => { + let supported_providers = supported_providers(&cluster)?; + let all_providers_len = supported_providers.len(); + let total = total.ok_or_else(|| { + ProviderError::InvalidRpcConfig( + "total must be specified when using default providers".to_string(), + ) + })?; + + if min > total { return Err(ProviderError::InvalidRpcConfig(format!( - "total {} is different than the number of specified providers {}", - total, - providers.len() + "min {} is greater than total {}", + min, total ))); } + + if total > all_providers_len as u8 { + return Err(ProviderError::InvalidRpcConfig(format!( + "total {} is greater than the number of all supported providers {}", + total, all_providers_len + ))); + } + let providers: BTreeSet<_> = supported_providers + .iter() + .take(total as usize) + .map(supported_rpc_source) + .collect(); + assert_eq!(providers.len(), total as usize, "BUG: duplicate providers"); + Ok(providers) } - Ok(providers.into_iter().collect()) } } + }?; + + if providers.is_empty() { + return Err(ProviderError::InvalidRpcConfig( + "No matching providers found".to_string(), + )); } + + Ok(Self { sources: providers }) } } diff --git a/canister/src/providers/tests.rs b/canister/src/providers/tests.rs index 19f55e91..a3e8f7c3 100644 --- a/canister/src/providers/tests.rs +++ b/canister/src/providers/tests.rs @@ -1,6 +1,8 @@ -use super::PROVIDERS; +use super::{Providers, PROVIDERS}; use crate::constants::API_KEY_REPLACE_STRING; use sol_rpc_types::{RpcAccess, RpcAuth, SupportedRpcProvider, SupportedRpcProviderId}; +use std::collections::BTreeSet; +use strum::IntoEnumIterator; #[test] fn test_rpc_provider_url_patterns() { @@ -51,6 +53,23 @@ fn should_have_consistent_name_for_cluster() { }) } +#[test] +fn should_partition_providers_between_solana_cluster() { + let mainnet_providers: BTreeSet<_> = Providers::MAINNET_PROVIDERS.iter().collect(); + let devnet_providers: BTreeSet<_> = Providers::DEVNET_PROVIDERS.iter().collect(); + let common_providers: BTreeSet<_> = mainnet_providers.intersection(&devnet_providers).collect(); + assert_eq!(common_providers, BTreeSet::default()); + + let all_providers: BTreeSet<_> = SupportedRpcProviderId::iter().collect(); + let partitioned_providers: BTreeSet<_> = mainnet_providers + .into_iter() + .chain(devnet_providers) + .copied() + .collect(); + + assert_eq!(all_providers, partitioned_providers); +} + mod providers_new { use crate::providers::Providers; use assert_matches::assert_matches; diff --git a/libs/types/src/rpc_client/mod.rs b/libs/types/src/rpc_client/mod.rs index 7e105d7d..e9fb00bc 100644 --- a/libs/types/src/rpc_client/mod.rs +++ b/libs/types/src/rpc_client/mod.rs @@ -11,7 +11,7 @@ pub use ic_cdk::api::management_canister::http_request::HttpHeader; use regex::Regex; use serde::{Deserialize, Serialize}; use std::{fmt::Debug, num::TryFromIntError}; -use strum::Display; +use strum::{Display, EnumIter}; use thiserror::Error; /// An RPC result type. @@ -325,6 +325,7 @@ pub enum SolanaCluster { PartialEq, PartialOrd, CandidType, + EnumIter, Deserialize, Serialize, Display,