Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 92 additions & 103 deletions canister/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,139 +124,128 @@ pub struct Providers {
impl Providers {
// Order of providers matters!
// The threshold consensus strategy will consider the first `total` providers in the order
// they are specified (taking the default ones first, followed by the non default ones if necessary)
// if the providers are not explicitly specified by the caller.
const DEFAULT_MAINNET_SUPPORTED_PROVIDERS: &'static [SupportedRpcProviderId] = &[
// they are specified if the providers are not explicitly specified by the caller.
const MAINNET_PROVIDERS: &'static [SupportedRpcProviderId] = &[
SupportedRpcProviderId::AlchemyMainnet,
SupportedRpcProviderId::HeliusMainnet,
SupportedRpcProviderId::DrpcMainnet,
];
const NON_DEFAULT_MAINNET_SUPPORTED_PROVIDERS: &'static [SupportedRpcProviderId] = &[
SupportedRpcProviderId::AnkrMainnet,
SupportedRpcProviderId::PublicNodeMainnet,
SupportedRpcProviderId::ChainstackMainnet,
];

const DEFAULT_DEVNET_SUPPORTED_PROVIDERS: &'static [SupportedRpcProviderId] = &[
const DEVNET_PROVIDERS: &'static [SupportedRpcProviderId] = &[
SupportedRpcProviderId::AlchemyDevnet,
SupportedRpcProviderId::HeliusDevnet,
SupportedRpcProviderId::DrpcDevnet,
];
const NON_DEFAULT_DEVNET_SUPPORTED_PROVIDERS: &'static [SupportedRpcProviderId] = &[
SupportedRpcProviderId::AnkrDevnet,
SupportedRpcProviderId::ChainstackDevnet,
];

pub fn new(source: RpcSources, strategy: ConsensusStrategy) -> Result<Self, ProviderError> {
fn get_sources(provider_ids: &[SupportedRpcProviderId]) -> Vec<RpcSource> {
provider_ids
.iter()
.map(|provider| RpcSource::Supported(*provider))
.collect()
}
const DEFAULT_NUM_PROVIDERS_FOR_EQUALITY: usize = 3;

let providers: BTreeSet<_> = match source {
RpcSources::Custom(sources) => {
choose_providers(Some(sources), vec![], vec![], strategy)?
pub fn new(source: RpcSources, strategy: ConsensusStrategy) -> Result<Self, ProviderError> {
fn supported_providers(
cluster: &SolanaCluster,
) -> Result<&[SupportedRpcProviderId], ProviderError> {
match cluster {
SolanaCluster::Mainnet => Ok(Providers::MAINNET_PROVIDERS),
SolanaCluster::Devnet => Ok(Providers::DEVNET_PROVIDERS),
SolanaCluster::Testnet => {
Err(ProviderError::UnsupportedCluster(format!("{:?}", cluster)))
}
}
RpcSources::Default(cluster) => match cluster {
SolanaCluster::Mainnet => choose_providers(
None,
get_sources(Self::DEFAULT_MAINNET_SUPPORTED_PROVIDERS),
get_sources(Self::NON_DEFAULT_MAINNET_SUPPORTED_PROVIDERS),
strategy,
)?,
SolanaCluster::Devnet => choose_providers(
None,
get_sources(Self::DEFAULT_DEVNET_SUPPORTED_PROVIDERS),
get_sources(Self::NON_DEFAULT_DEVNET_SUPPORTED_PROVIDERS),
strategy,
)?,
cluster => return Err(ProviderError::UnsupportedCluster(format!("{:?}", cluster))),
},
};

if providers.is_empty() {
return Err(ProviderError::InvalidRpcConfig(
"No matching providers found".to_string(),
));
}

Ok(Self { sources: providers })
}
}

fn choose_providers(
user_input: Option<Vec<RpcSource>>,
default_providers: Vec<RpcSource>,
non_default_providers: Vec<RpcSource>,
strategy: ConsensusStrategy,
) -> Result<BTreeSet<RpcSource>, ProviderError> {
match strategy {
ConsensusStrategy::Equality => Ok(user_input
.unwrap_or_else(|| default_providers.to_vec())
.into_iter()
.collect()),
ConsensusStrategy::Threshold { total, min } => {
// Ensure that
// 0 < min <= total <= all_providers.len()
if min == 0 {
return Err(ProviderError::InvalidRpcConfig(
"min must be greater than 0".to_string(),
));
}
match user_input {
None => {
let all_providers_len = default_providers.len() + non_default_providers.len();
let total = total.ok_or_else(|| {
ProviderError::InvalidRpcConfig(
"total must be specified when using default providers".to_string(),
)
})?;

if min > total {
return Err(ProviderError::InvalidRpcConfig(format!(
"min {} is greater than total {}",
min, total
)));
}
fn supported_rpc_source(supported_provider: &SupportedRpcProviderId) -> RpcSource {
RpcSource::Supported(*supported_provider)
}

if total > all_providers_len as u8 {
return Err(ProviderError::InvalidRpcConfig(format!(
"total {} is greater than the number of all supported providers {}",
total, all_providers_len
)));
}
let providers: BTreeSet<_> = default_providers
let providers: BTreeSet<_> = match strategy {
ConsensusStrategy::Equality => match source {
RpcSources::Custom(custom_providers) => Ok(custom_providers.into_iter().collect()),
RpcSources::Default(cluster) => {
let supported_providers = supported_providers(&cluster)?;
assert!(
supported_providers.len() >= Self::DEFAULT_NUM_PROVIDERS_FOR_EQUALITY,
"BUG: need at least 3 providers, but got {supported_providers:?}"
);
Ok(supported_providers
.iter()
.chain(non_default_providers.iter())
.take(total as usize)
.cloned()
.collect();
assert_eq!(providers.len(), total as usize, "BUG: duplicate providers");
Ok(providers)
.take(Self::DEFAULT_NUM_PROVIDERS_FOR_EQUALITY)
.map(supported_rpc_source)
.collect())
}
Some(providers) => {
if min > providers.len() as u8 {
return Err(ProviderError::InvalidRpcConfig(format!(
"min {} is greater than the number of specified providers {}",
min,
providers.len()
)));
},
ConsensusStrategy::Threshold { total, min } => {
// Ensure that
// 0 < min <= total <= all_providers.len()
if min == 0 {
return Err(ProviderError::InvalidRpcConfig(
"min must be greater than 0".to_string(),
));
}
match source {
RpcSources::Custom(custom_providers) => {
if min > custom_providers.len() as u8 {
return Err(ProviderError::InvalidRpcConfig(format!(
"min {} is greater than the number of specified providers {}",
min,
custom_providers.len()
)));
}
if let Some(total) = total {
if total != custom_providers.len() as u8 {
return Err(ProviderError::InvalidRpcConfig(format!(
"total {} is different than the number of specified providers {}",
total,
custom_providers.len()
)));
}
};
Ok(custom_providers.into_iter().collect())
}
if let Some(total) = total {
if total != providers.len() as u8 {
RpcSources::Default(cluster) => {
let supported_providers = supported_providers(&cluster)?;
let all_providers_len = supported_providers.len();
let total = total.ok_or_else(|| {
ProviderError::InvalidRpcConfig(
"total must be specified when using default providers".to_string(),
)
})?;

if min > total {
return Err(ProviderError::InvalidRpcConfig(format!(
"total {} is different than the number of specified providers {}",
total,
providers.len()
"min {} is greater than total {}",
min, total
)));
}

if total > all_providers_len as u8 {
return Err(ProviderError::InvalidRpcConfig(format!(
"total {} is greater than the number of all supported providers {}",
total, all_providers_len
)));
}
let providers: BTreeSet<_> = supported_providers
.iter()
.take(total as usize)
.map(supported_rpc_source)
.collect();
assert_eq!(providers.len(), total as usize, "BUG: duplicate providers");
Ok(providers)
}
Ok(providers.into_iter().collect())
}
}
}?;

if providers.is_empty() {
return Err(ProviderError::InvalidRpcConfig(
"No matching providers found".to_string(),
));
}

Ok(Self { sources: providers })
}
}

Expand Down
21 changes: 20 additions & 1 deletion canister/src/providers/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::PROVIDERS;
use super::{Providers, PROVIDERS};
use crate::constants::API_KEY_REPLACE_STRING;
use sol_rpc_types::{RpcAccess, RpcAuth, SupportedRpcProvider, SupportedRpcProviderId};
use std::collections::BTreeSet;
use strum::IntoEnumIterator;

#[test]
fn test_rpc_provider_url_patterns() {
Expand Down Expand Up @@ -51,6 +53,23 @@ fn should_have_consistent_name_for_cluster() {
})
}

#[test]
fn should_partition_providers_between_solana_cluster() {
let mainnet_providers: BTreeSet<_> = Providers::MAINNET_PROVIDERS.iter().collect();
let devnet_providers: BTreeSet<_> = Providers::DEVNET_PROVIDERS.iter().collect();
let common_providers: BTreeSet<_> = mainnet_providers.intersection(&devnet_providers).collect();
assert_eq!(common_providers, BTreeSet::default());

let all_providers: BTreeSet<_> = SupportedRpcProviderId::iter().collect();
let partitioned_providers: BTreeSet<_> = mainnet_providers
.into_iter()
.chain(devnet_providers)
.copied()
.collect();

assert_eq!(all_providers, partitioned_providers);
}

mod providers_new {
use crate::providers::Providers;
use assert_matches::assert_matches;
Expand Down
3 changes: 2 additions & 1 deletion libs/types/src/rpc_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use ic_cdk::api::management_canister::http_request::HttpHeader;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{fmt::Debug, num::TryFromIntError};
use strum::Display;
use strum::{Display, EnumIter};
use thiserror::Error;

/// An RPC result type.
Expand Down Expand Up @@ -325,6 +325,7 @@ pub enum SolanaCluster {
PartialEq,
PartialOrd,
CandidType,
EnumIter,
Deserialize,
Serialize,
Display,
Expand Down
Loading