diff --git a/canister/src/rpc_client/mod.rs b/canister/src/rpc_client/mod.rs index 4deae232..40fa22f5 100644 --- a/canister/src/rpc_client/mod.rs +++ b/canister/src/rpc_client/mod.rs @@ -261,6 +261,7 @@ impl GetRecentPrioritizationFeesRequest { config: GetRecentPrioritizationFeesRpcConfig, params: Params, ) -> Result { + let max_length = config.max_length(); let consensus_strategy = config.response_consensus.unwrap_or_default(); let providers = Providers::new(rpc_sources, consensus_strategy.clone())?; let max_response_bytes = config @@ -272,8 +273,8 @@ impl GetRecentPrioritizationFeesRequest { JsonRpcRequest::new("getRecentPrioritizationFees", params.into()), max_response_bytes, ResponseTransform::GetRecentPrioritizationFees { + max_length: max_length.into(), max_slot_rounding_error: config.max_slot_rounding_error.unwrap_or_default(), - max_length: config.max_length.unwrap_or(100), }, ReductionStrategy::from(consensus_strategy), )) diff --git a/canister/src/rpc_client/sol_rpc/mod.rs b/canister/src/rpc_client/sol_rpc/mod.rs index 5287862b..fa3455b8 100644 --- a/canister/src/rpc_client/sol_rpc/mod.rs +++ b/canister/src/rpc_client/sol_rpc/mod.rs @@ -14,12 +14,14 @@ use sol_rpc_types::{PrioritizationFee, RoundingError}; use solana_clock::Slot; use solana_transaction_status_client_types::TransactionStatus; use std::fmt::Debug; -use strum::EnumIter; +use std::num::NonZeroU8; /// Describes a payload transformation to execute before passing the HTTP response to consensus. /// The purpose of these transformations is to ensure that the response encoding is deterministic /// (the field order is the same). -#[derive(Clone, Debug, Decode, Encode, EnumIter)] +#[derive(Clone, Debug, Decode, Encode)] +#[cfg_attr(test, derive(strum::EnumDiscriminants))] +#[cfg_attr(test, strum_discriminants(derive(strum::EnumIter)))] pub enum ResponseTransform { #[n(0)] GetAccountInfo, @@ -32,7 +34,7 @@ pub enum ResponseTransform { #[cbor(n(0), with = "crate::rpc_client::cbor::rounding_error")] max_slot_rounding_error: RoundingError, #[n(1)] - max_length: u8, + max_length: NonZeroU8, }, #[n(4)] GetSignaturesForAddress, @@ -110,7 +112,7 @@ impl ResponseTransform { // "Currently, a node's prioritization-fee cache stores data from up to 150 blocks." // Manual testing shows that the result seems to always contain 150 elements on mainnet (also for not used addresses) // but not necessarily when using a local validator. - if fees.is_empty() || max_length == &0 { + if fees.is_empty() { return Vec::default(); } // The order of the prioritization fees in the response is not specified in the @@ -128,7 +130,7 @@ impl ResponseTransform { fees.into_iter() .skip_while(|fee| fee.slot > max_rounded_slot) - .take(*max_length as usize) + .take(max_length.get() as usize) .collect::>() .into_iter() .rev() diff --git a/canister/src/rpc_client/sol_rpc/tests.rs b/canister/src/rpc_client/sol_rpc/tests.rs index 9c43733c..818d47b0 100644 --- a/canister/src/rpc_client/sol_rpc/tests.rs +++ b/canister/src/rpc_client/sol_rpc/tests.rs @@ -15,6 +15,8 @@ use strum::IntoEnumIterator; mod normalization_tests { use super::*; + use crate::rpc_client::sol_rpc::ResponseTransformDiscriminants; + use std::num::NonZeroU8; #[test] fn should_normalize_raw_response() { @@ -375,7 +377,7 @@ mod normalization_tests { bytes } - for transform in ResponseTransform::iter() { + for transform in all_response_transforms() { let left = r#"{ "jsonrpc": "2.0", "error": { "code": -32602, "message": "Invalid param: could not find account" }, "id": 1 }"#; let right = r#"{ "error": { "message": "Invalid param: could not find account", "code": -32602 }, "id": 1, "jsonrpc": "2.0" }"#; let normalized_left = normalize_json(&transform, left); @@ -434,6 +436,35 @@ mod normalization_tests { normalize_result(transform, right) ); } + + fn all_response_transforms() -> impl Iterator { + ResponseTransformDiscriminants::iter().map(|variant| match variant { + ResponseTransformDiscriminants::GetAccountInfo => ResponseTransform::GetAccountInfo, + ResponseTransformDiscriminants::GetBalance => ResponseTransform::GetBalance, + ResponseTransformDiscriminants::GetBlock => ResponseTransform::GetBlock, + ResponseTransformDiscriminants::GetRecentPrioritizationFees => { + ResponseTransform::GetRecentPrioritizationFees { + max_slot_rounding_error: RoundingError::default(), + max_length: NonZeroU8::new(100).unwrap(), + } + } + ResponseTransformDiscriminants::GetSignatureStatuses => { + ResponseTransform::GetSignatureStatuses + } + ResponseTransformDiscriminants::GetSignaturesForAddress => { + ResponseTransform::GetSignaturesForAddress + } + ResponseTransformDiscriminants::GetSlot => { + ResponseTransform::GetSlot(RoundingError::default()) + } + ResponseTransformDiscriminants::GetTokenAccountBalance => { + ResponseTransform::GetTokenAccountBalance + } + ResponseTransformDiscriminants::GetTransaction => ResponseTransform::GetTransaction, + ResponseTransformDiscriminants::SendTransaction => ResponseTransform::SendTransaction, + ResponseTransformDiscriminants::Raw => ResponseTransform::Raw, + }) + } } mod get_recent_prioritization_fees { @@ -458,28 +489,21 @@ mod get_recent_prioritization_fees { ( ResponseTransform::GetRecentPrioritizationFees { max_slot_rounding_error: RoundingError::new(2), - max_length: 2, + max_length: 2.try_into().unwrap(), }, prioritization_fees(vec![3, 4]), ), ( ResponseTransform::GetRecentPrioritizationFees { max_slot_rounding_error: RoundingError::new(2), - max_length: 0, - }, - prioritization_fees(vec![]), - ), - ( - ResponseTransform::GetRecentPrioritizationFees { - max_slot_rounding_error: RoundingError::new(2), - max_length: u8::MAX, + max_length: u8::MAX.try_into().unwrap(), }, prioritization_fees(vec![1, 2, 3, 4]), ), ( ResponseTransform::GetRecentPrioritizationFees { max_slot_rounding_error: RoundingError::new(10), - max_length: 2, + max_length: 2.try_into().unwrap(), }, prioritization_fees(vec![]), ), @@ -498,7 +522,7 @@ mod get_recent_prioritization_fees { let raw_response = json_response::(&[]); let transform = ResponseTransform::GetRecentPrioritizationFees { max_slot_rounding_error: RoundingError::new(2), - max_length: 2, + max_length: 2.try_into().unwrap(), }; let original_bytes = serde_json::to_vec(&raw_response).unwrap(); let mut transformed_bytes = original_bytes.clone(); @@ -532,7 +556,7 @@ mod get_recent_prioritization_fees { let transform = ResponseTransform::GetRecentPrioritizationFees { max_slot_rounding_error: RoundingError::new(10), - max_length: 100, + max_length: 100.try_into().unwrap(), }; let mut raw_bytes = serde_json::to_vec(&json_response(&fees)).unwrap(); transform.apply(&mut raw_bytes); @@ -560,7 +584,7 @@ mod get_recent_prioritization_fees { let transform = ResponseTransform::GetRecentPrioritizationFees { max_slot_rounding_error: RoundingError::new(10), - max_length: 100, + max_length: 100.try_into().unwrap(), }; let mut raw_bytes = to_vec(&json_response(&fees)).unwrap(); transform.apply(&mut raw_bytes); @@ -574,7 +598,7 @@ mod get_recent_prioritization_fees { fn should_be_nop_when_failed_to_deserialize(original_bytes in prop::collection::vec(any::(), 0..1000)) { let transform = ResponseTransform::GetRecentPrioritizationFees { max_slot_rounding_error: RoundingError::new(2), - max_length: 2, + max_length: 2.try_into().unwrap(), }; let mut transformed_bytes = original_bytes.clone(); transform.apply(&mut transformed_bytes); @@ -587,7 +611,7 @@ mod get_recent_prioritization_fees { let raw_response = json_response(&fees); let transform = ResponseTransform::GetRecentPrioritizationFees { max_slot_rounding_error: RoundingError::new(20), - max_length: 100, + max_length: 100.try_into().unwrap(), }; let mut raw_bytes = serde_json::to_vec(&raw_response).unwrap(); transform.apply(&mut raw_bytes); @@ -619,7 +643,7 @@ mod get_recent_prioritization_fees { }; let transform = ResponseTransform::GetRecentPrioritizationFees { max_slot_rounding_error: RoundingError::new(20), - max_length: 100, + max_length: 100.try_into().unwrap(), }; let sorted_fees_bytes = { diff --git a/integration_tests/tests/solana_test_validator.rs b/integration_tests/tests/solana_test_validator.rs index b8429b7b..13fd155a 100644 --- a/integration_tests/tests/solana_test_validator.rs +++ b/integration_tests/tests/solana_test_validator.rs @@ -33,6 +33,7 @@ use solana_signature::Signature; use solana_signer::Signer; use solana_transaction::Transaction; use solana_transaction_status_client_types::UiTransactionEncoding; +use std::num::NonZeroU8; use std::{ future::Future, iter::zip, @@ -124,7 +125,7 @@ async fn should_get_recent_prioritization_fees() { |ic| async move { ic.get_recent_prioritization_fees(&[account]) .unwrap() - .with_max_length(150) + .with_max_length(NonZeroU8::new(150).unwrap()) .with_max_slot_rounding_error(1) .send() .await diff --git a/integration_tests/tests/tests.rs b/integration_tests/tests/tests.rs index 80a189bf..3406dbd9 100644 --- a/integration_tests/tests/tests.rs +++ b/integration_tests/tests/tests.rs @@ -476,6 +476,7 @@ mod get_recent_prioritization_fees_tests { use serde_json::json; use sol_rpc_int_tests::{mock::MockOutcallBuilder, Setup, SolRpcTestClient}; use sol_rpc_types::PrioritizationFee; + use std::num::NonZeroU8; #[tokio::test] async fn should_get_fees_with_rounding() { @@ -1107,7 +1108,7 @@ mod get_recent_prioritization_fees_tests { .get_recent_prioritization_fees(&[USDC_PUBLIC_KEY]) .unwrap() .with_max_slot_rounding_error(10) - .with_max_length(5) + .with_max_length(NonZeroU8::new(5).unwrap()) .send() .await .expect_consistent(); diff --git a/libs/client/src/lib.rs b/libs/client/src/lib.rs index 204decb4..1732bb19 100644 --- a/libs/client/src/lib.rs +++ b/libs/client/src/lib.rs @@ -463,10 +463,11 @@ impl SolRpcClient { /// use sol_rpc_client::SolRpcClient; /// use sol_rpc_types::{RpcSources, SolanaCluster}; /// use solana_pubkey::pubkey; - /// + /// # /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { - /// # use sol_rpc_types::{MultiRpcResult, PrioritizationFee, TokenAmount}; + /// use std::num::NonZeroU8; + /// use sol_rpc_types::{MultiRpcResult, PrioritizationFee, TokenAmount}; /// let client = SolRpcClient::builder_for_ic() /// # .with_mocked_response(MultiRpcResult::Consistent(Ok(vec![PrioritizationFee{slot: 338637772, prioritization_fee: 166667}]))) /// .with_rpc_sources(RpcSources::Default(SolanaCluster::Mainnet)) @@ -475,7 +476,7 @@ impl SolRpcClient { /// let fees = client /// .get_recent_prioritization_fees(&[pubkey!("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v")]) /// .unwrap() - /// .with_max_length(1) + /// .with_max_length(NonZeroU8::MIN) /// .send() /// .await /// .expect_consistent(); diff --git a/libs/client/src/request/mod.rs b/libs/client/src/request/mod.rs index b3772ac9..72e218f9 100644 --- a/libs/client/src/request/mod.rs +++ b/libs/client/src/request/mod.rs @@ -11,8 +11,8 @@ use sol_rpc_types::{ GetRecentPrioritizationFeesParams, GetRecentPrioritizationFeesRpcConfig, GetSignatureStatusesParams, GetSignaturesForAddressLimit, GetSignaturesForAddressParams, GetSlotParams, GetSlotRpcConfig, GetTokenAccountBalanceParams, GetTransactionParams, Lamport, - PrioritizationFee, RoundingError, RpcConfig, RpcResult, RpcSources, SendTransactionParams, - Signature, Slot, TokenAmount, TransactionInfo, TransactionStatus, + NonZeroU8, PrioritizationFee, RoundingError, RpcConfig, RpcResult, RpcSources, + SendTransactionParams, Signature, Slot, TokenAmount, TransactionInfo, TransactionStatus, }; use solana_account_decoder_client_types::token::UiTokenAmount; use solana_transaction_status_client_types::EncodedConfirmedTransactionWithStatusMeta; @@ -585,9 +585,9 @@ impl } /// Change the maximum number of entries for a `getRecentPrioritizationFees` response. - pub fn with_max_length(mut self, len: u8) -> Self { + pub fn with_max_length>(mut self, len: T) -> Self { let config = self.request.rpc_config_mut().get_or_insert_default(); - config.max_length = Some(len); + config.set_max_length(len.into()); self } } diff --git a/libs/types/src/lib.rs b/libs/types/src/lib.rs index ee0ae10c..baa5ba76 100644 --- a/libs/types/src/lib.rs +++ b/libs/types/src/lib.rs @@ -18,7 +18,7 @@ pub use lifecycle::{InstallArgs, Mode, NumSubnetNodes}; pub use response::MultiRpcResult; pub use rpc_client::{ ConsensusStrategy, GetRecentPrioritizationFeesRpcConfig, GetSlotRpcConfig, HttpHeader, - HttpOutcallError, JsonRpcError, OverrideProvider, ProviderError, RegexString, + HttpOutcallError, JsonRpcError, NonZeroU8, OverrideProvider, ProviderError, RegexString, RegexSubstitution, RoundingError, RpcAccess, RpcAuth, RpcConfig, RpcEndpoint, RpcError, RpcResult, RpcSource, RpcSources, SolanaCluster, SupportedRpcProvider, SupportedRpcProviderId, }; diff --git a/libs/types/src/rpc_client/mod.rs b/libs/types/src/rpc_client/mod.rs index dc68e13e..21ebb5e8 100644 --- a/libs/types/src/rpc_client/mod.rs +++ b/libs/types/src/rpc_client/mod.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod tests; +use candid::types::{Serializer, Type, TypeInner}; use candid::CandidType; use derive_more::{From, Into}; use ic_cdk::api::call::RejectionCode; @@ -8,6 +9,7 @@ pub use ic_cdk::api::management_canister::http_request::HttpHeader; use regex::Regex; use serde::{Deserialize, Serialize}; use std::fmt::Debug; +use std::num::TryFromIntError; use strum::Display; use thiserror::Error; @@ -206,16 +208,30 @@ pub struct GetRecentPrioritizationFeesRpcConfig { #[serde(rename = "maxSlotRoundingError")] pub max_slot_rounding_error: Option, - /// Limit the number of returned priority fees. + #[serde(rename = "maxLength")] + max_length: Option, +} + +impl GetRecentPrioritizationFeesRpcConfig { + /// Default number of priority fees to return. + pub const DEFAULT_MAX_LENGTH: NonZeroU8 = + NonZeroU8::new(std::num::NonZeroU8::new(100_u8).unwrap()); + + /// Number of priority fees to return. + /// + /// Returns the current value or the default [`Self::DEFAULT_MAX_LENGTH`]. + pub fn max_length(&self) -> NonZeroU8 { + self.max_length.unwrap_or(Self::DEFAULT_MAX_LENGTH) + } + + /// Change the number of priority fees to return. /// /// A Solana validator returns at most 150 entries, so that bigger values are possible but not useful. - /// MUST be non-zero to avoid useless call. - /// Default value is 100. /// Increasing that value can help in estimating the current priority fee /// but will reduce the likelihood of nodes reaching consensus. - #[serde(rename = "maxLength")] - // TODO XC-326: Use a wrapper type to implement Candid on `NonZeroU8` to prohibit the value 0. - pub max_length: Option, + pub fn set_max_length(&mut self, len: NonZeroU8) { + self.max_length = Some(len) + } } impl From for GetRecentPrioritizationFeesRpcConfig { @@ -531,3 +547,62 @@ impl RoundingError { } } } + +/// A wrapper around the primitive [`std::num::NonZeroU8`] to implement [`candid::CandidType`]. +/// +/// From the point of view of Candid, this is like a [`u8`], except that a zero value will fail to be deserialized. +/// +/// # Examples +/// +/// ```rust +/// use candid::{Decode, Encode}; +/// use sol_rpc_types::NonZeroU8; +/// +/// let one = 1_u8; +/// let non_zero_one = NonZeroU8::try_from(one).unwrap(); +/// let encoded_one = Encode!(&one).unwrap(); +/// assert_eq!(encoded_one, Encode!(&non_zero_one).unwrap()); +/// assert_eq!(non_zero_one, Decode!(&encoded_one, NonZeroU8).unwrap()); +/// +/// let encoded_zero = Encode!(&0_u8).unwrap(); +/// assert!(Decode!(&encoded_zero, NonZeroU8).is_err()); +/// ``` +#[derive( + Debug, Clone, Copy, Eq, Ord, PartialEq, PartialOrd, From, Into, Serialize, Deserialize, +)] +#[serde(try_from = "u8", into = "u8")] +pub struct NonZeroU8(std::num::NonZeroU8); + +impl CandidType for NonZeroU8 { + fn _ty() -> Type { + Type(TypeInner::Nat8.into()) + } + + fn idl_serialize(&self, serializer: S) -> Result<(), S::Error> + where + S: Serializer, + { + serializer.serialize_nat8(self.0.get()) + } +} + +impl NonZeroU8 { + /// Construct a new instance of [`NonZeroU8`]. + pub const fn new(value: std::num::NonZeroU8) -> Self { + Self(value) + } +} + +impl From for u8 { + fn from(value: NonZeroU8) -> Self { + value.0.get() + } +} + +impl TryFrom for NonZeroU8 { + type Error = TryFromIntError; + + fn try_from(value: u8) -> Result { + std::num::NonZeroU8::try_from(value).map(Self) + } +} diff --git a/libs/types/src/rpc_client/tests.rs b/libs/types/src/rpc_client/tests.rs index 1f5eddb4..8b2afecb 100644 --- a/libs/types/src/rpc_client/tests.rs +++ b/libs/types/src/rpc_client/tests.rs @@ -1,4 +1,8 @@ use crate::{HttpHeader, RpcEndpoint}; +use candid::{CandidType, Decode, Encode}; +use proptest::prelude::TestCaseError; +use proptest::prop_assert_eq; +use serde::de::DeserializeOwned; #[test] fn should_contain_host_without_sensitive_information() { @@ -28,6 +32,7 @@ fn should_contain_host_without_sensitive_information() { } mod rounding_error_tests { + use crate::rpc_client::tests::encode_decode_roundtrip; use crate::RoundingError; use proptest::proptest; @@ -59,5 +64,55 @@ mod rounding_error_tests { fn should_not_panic (rounding_error: u64, slot: u64) { let _result = RoundingError::new(rounding_error).round(slot); } + + #[test] + fn should_encode_decode (rounding_error: u64) { + encode_decode_roundtrip(RoundingError::new(rounding_error), rounding_error)?; + } + + } +} + +mod non_zero_u8 { + use crate::rpc_client::tests::encode_decode_roundtrip; + use crate::rpc_client::NonZeroU8; + use candid::{Decode, Encode}; + use proptest::proptest; + + proptest! { + #[test] + fn should_encode_decode(v in 1..255_u8) { + encode_decode_roundtrip(NonZeroU8::try_from(v).unwrap(), v)?; + } } + + #[test] + fn should_fail_deserialization_when_zero() { + let encoded_zero = Encode!(&0_u8).unwrap(); + assert!(Decode!(&encoded_zero, NonZeroU8).is_err()); + } +} + +fn encode_decode_roundtrip(wrapped_value: T, inner_value: U) -> Result<(), TestCaseError> +where + T: CandidType + DeserializeOwned + PartialEq + std::fmt::Debug, + U: CandidType, +{ + let encoded_wrapped_value = Encode!(&wrapped_value).unwrap(); + let encoded_inner_value = Encode!(&inner_value).unwrap(); + prop_assert_eq!( + &encoded_wrapped_value, + &encoded_inner_value, + "Encoded value differ for {:?}", + wrapped_value + ); + + let decoded_wrapped_value = Decode!(&encoded_wrapped_value, T).unwrap(); + prop_assert_eq!( + &decoded_wrapped_value, + &wrapped_value, + "Decoded value differ for {:?}", + wrapped_value + ); + Ok(()) }