diff --git a/Cargo.lock b/Cargo.lock index 70b83c66043c22..de8ba5b561be32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9900,6 +9900,7 @@ dependencies = [ "solana-version", "solana-vote-interface", "static_assertions", + "test-case", "tokio", ] diff --git a/rpc-client/Cargo.toml b/rpc-client/Cargo.toml index 9db6b57706e833..7cb61012149c07 100644 --- a/rpc-client/Cargo.toml +++ b/rpc-client/Cargo.toml @@ -64,3 +64,4 @@ solana-pubkey = { workspace = true, features = ["rand"] } solana-signer = { workspace = true } solana-system-transaction = { workspace = true } static_assertions = { workspace = true } +test-case = { workspace = true } diff --git a/rpc-client/src/nonblocking/rpc_client.rs b/rpc-client/src/nonblocking/rpc_client.rs index 8da0562052e1ab..8ab3e03b7938ce 100644 --- a/rpc-client/src/nonblocking/rpc_client.rs +++ b/rpc-client/src/nonblocking/rpc_client.rs @@ -4686,7 +4686,8 @@ impl RpcClient { &self, message: &impl SerializableMessage, ) -> ClientResult { - let serialized_encoded = serialize_and_encode(message, UiTransactionEncoding::Base64)?; + let serialized = message.serialize(); + let serialized_encoded = BASE64_STANDARD.encode(serialized); let result = self .send::>>( RpcRequest::GetFeeForMessage, diff --git a/rpc-client/src/rpc_client.rs b/rpc-client/src/rpc_client.rs index bf42e2fb06ad6c..745c68c16f8bee 100644 --- a/rpc-client/src/rpc_client.rs +++ b/rpc-client/src/rpc_client.rs @@ -62,9 +62,19 @@ impl RpcClientConfig { /// Trait used to add support for versioned messages to RPC APIs while /// retaining backwards compatibility -pub trait SerializableMessage: Serialize {} -impl SerializableMessage for LegacyMessage {} -impl SerializableMessage for v0::Message {} +pub trait SerializableMessage { + fn serialize(&self) -> Vec; +} +impl SerializableMessage for LegacyMessage { + fn serialize(&self) -> Vec { + self.serialize() + } +} +impl SerializableMessage for v0::Message { + fn serialize(&self) -> Vec { + self.serialize() + } +} /// Trait used to add support for versioned transactions to RPC APIs while /// retaining backwards compatibility @@ -3797,19 +3807,23 @@ mod tests { super::*, crate::mock_sender::PUBKEY, assert_matches::assert_matches, + base64::{prelude::BASE64_STANDARD, Engine}, crossbeam_channel::unbounded, jsonrpc_core::{futures::prelude::*, Error, IoHandler, Params}, jsonrpc_http_server::{AccessControlAllowOrigin, DomainsValidation, ServerBuilder}, serde_json::{json, Number}, solana_account_decoder::encode_ui_account, solana_account_decoder_client_types::UiAccountEncoding, + solana_hash::Hash, solana_instruction::error::InstructionError, solana_keypair::Keypair, + solana_message::{compiled_instruction::CompiledInstruction, MessageHeader}, solana_rpc_client_api::client_error::ErrorKind, solana_signer::Signer, solana_system_transaction as system_transaction, solana_transaction_error::TransactionError, std::{io, thread}, + test_case::test_case, }; #[test] @@ -4254,4 +4268,85 @@ mod tests { assert_eq!(expected_result, result1); } } + + #[test_case(LegacyMessage { + header: MessageHeader { + num_required_signatures: 1, + num_readonly_signed_accounts: 0, + num_readonly_unsigned_accounts: 1, + }, + account_keys: vec![Pubkey::new_unique()], + recent_blockhash: Hash::new_unique(), + instructions: vec![CompiledInstruction { + program_id_index: 1, + accounts: vec![0], + data: vec![], + }], + }; "legacy message")] + #[test_case(v0::Message { + header: MessageHeader { + num_required_signatures: 1, + num_readonly_signed_accounts: 0, + num_readonly_unsigned_accounts: 0, + }, + account_keys: vec![Pubkey::new_unique()], + recent_blockhash: Hash::new_unique(), + instructions: vec![CompiledInstruction { + program_id_index: 0, + accounts: vec![], + data: vec![], + }], + address_table_lookups: vec![], + }; "v0 message")] + fn test_get_fee_for_message_sends_properly_serialized_v0_transaction(message: M) + where + M: SerializableMessage, + { + let serialized_message = message.serialize(); + let serialized_message_base64 = BASE64_STANDARD.encode(serialized_message); + + let (sender, receiver) = unbounded(); + thread::spawn(move || { + let rpc_addr = "0.0.0.0:0".parse().unwrap(); + let mut io = IoHandler::default(); + // Successful request + io.add_method("getFeeForMessage", move |params: Params| match params { + Params::Array(p) => { + let first_element = p.first().unwrap(); + if let Value::String(actual_serialized_message) = first_element { + assert_eq!(actual_serialized_message, &serialized_message_base64); + return future::ok(json!(Response { + context: RpcResponseContext { + api_version: None, + slot: 1, + }, + value: json!(42), + })); + } + future::err(Error::invalid_params( + "Expected the serialized message to be the first element of the params", + )) + } + _ => { + panic!("Expected an array of params to be forwarded to `getFeeForMessage"); + } + }); + + let server = ServerBuilder::new(io) + .threads(1) + .cors(DomainsValidation::AllowOnly(vec![ + AccessControlAllowOrigin::Any, + ])) + .start_http(&rpc_addr) + .expect("Unable to start RPC server"); + sender.send(*server.address()).unwrap(); + server.wait(); + }); + + let rpc_addr = receiver.recv().unwrap(); + let rpc_client = RpcClient::new_socket(rpc_addr); + + let fee: u64 = rpc_client.get_fee_for_message(&message).unwrap(); + assert_eq!(fee, 42); + } }