Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk)!: allow setting CA cert to use when connecting to dapi servers #1924

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion packages/rs-dapi-client/src/dapi_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use backon::{ExponentialBuilder, Retryable};
use dapi_grpc::mock::Mockable;
use dapi_grpc::tonic::async_trait;
use dapi_grpc::tonic::transport::Certificate;
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
use std::time::Duration;
Expand Down Expand Up @@ -94,6 +95,8 @@ pub struct DapiClient {
address_list: Arc<RwLock<AddressList>>,
settings: RequestSettings,
pool: ConnectionPool,
/// Certificate Authority certificate to use for verifying the server's certificate.
pub ca_certificate: Option<Certificate>,
#[cfg(feature = "dump")]
pub(crate) dump_dir: Option<std::path::PathBuf>,
}
Expand All @@ -110,8 +113,23 @@ impl DapiClient {
pool: ConnectionPool::new(address_count),
#[cfg(feature = "dump")]
dump_dir: None,
ca_certificate: None,
}
}

/// Set CA certificate to use when verifying the server's certificate.
///
/// # Arguments
///
/// * `pem_ca_cert` - CA certificate in PEM format.
///
/// # Returns
/// [DapiClient] with CA certificate set.
pub fn with_ca_certificate(mut self, pem_ca_cert: &[u8]) -> Self {
self.ca_certificate = Some(Certificate::from_pem(pem_ca_cert));

self
}
}

#[async_trait]
Expand All @@ -128,12 +146,17 @@ impl DapiRequestExecutor for DapiClient {
<R::Client as TransportClient>::Error: Mockable,
{
// Join settings of different sources to get final version of the settings for this execution:
let applied_settings = self
let mut applied_settings = self
.settings
.override_by(R::SETTINGS_OVERRIDES)
.override_by(settings)
.finalize();

// Setup CA certificate
if let Some(ca_certificate) = &self.ca_certificate {
applied_settings = applied_settings.with_ca_certificate(ca_certificate.clone());
}

// Setup retry policy:
let retry_settings = ExponentialBuilder::default()
.with_max_times(applied_settings.retries)
Expand All @@ -151,6 +174,7 @@ impl DapiRequestExecutor for DapiClient {
// Setup DAPI request execution routine future. It's a closure that will be called
// more once to build new future on each retry.
let routine = move || {
let applied_settings = applied_settings.clone();
// Try to get an address to initialize transport on:

let address_list = self
Expand Down
13 changes: 12 additions & 1 deletion packages/rs-dapi-client/src/request_settings.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! DAPI client request settings processing.

use dapi_grpc::tonic::transport::Certificate;
use std::time::Duration;

/// Default low-level client timeout
Expand Down Expand Up @@ -60,12 +61,13 @@ impl RequestSettings {
ban_failed_address: self
.ban_failed_address
.unwrap_or(DEFAULT_BAN_FAILED_ADDRESS),
ca_certificate: None,
}
}
}

/// DAPI settings ready to use.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub struct AppliedRequestSettings {
/// Timeout for establishing a connection.
pub connect_timeout: Option<Duration>,
Expand All @@ -75,4 +77,13 @@ pub struct AppliedRequestSettings {
pub retries: usize,
/// Ban DAPI address if node not responded or responded with error.
pub ban_failed_address: bool,
/// Certificate Authority certificate to use for verifying the server's certificate.
pub ca_certificate: Option<Certificate>,
}
impl AppliedRequestSettings {
/// Use provided CA certificate for verifying the server's certificate.
pub fn with_ca_certificate(mut self, ca_cert: Certificate) -> Self {
self.ca_certificate = Some(ca_cert);
self
}
}
25 changes: 20 additions & 5 deletions packages/rs-dapi-client/src/transport/grpc.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
//! Listing of gRPC requests used in DAPI.

use std::time::Duration;

use super::{CanRetry, TransportClient, TransportRequest};
use crate::connection_pool::{ConnectionPool, PoolPrefix};
use crate::{request_settings::AppliedRequestSettings, RequestSettings};
use dapi_grpc::core::v0::core_client::CoreClient;
use dapi_grpc::core::v0::{self as core_proto};
use dapi_grpc::platform::v0::{self as platform_proto, platform_client::PlatformClient};
use dapi_grpc::tonic::transport::Uri;
use dapi_grpc::tonic::transport::{Certificate, ClientTlsConfig};
use dapi_grpc::tonic::Streaming;
use dapi_grpc::tonic::{transport::Channel, IntoRequest};
use futures::{future::BoxFuture, FutureExt, TryFutureExt};
use std::time::Duration;

/// Platform Client using gRPC transport.
pub type PlatformGrpcClient = PlatformClient<Channel>;
/// Core Client using gRPC transport.
pub type CoreGrpcClient = CoreClient<Channel>;

fn create_channel(uri: Uri, settings: Option<&AppliedRequestSettings>) -> Channel {
let host = uri.host().expect("Failed to get host from URI").to_string();

let mut builder = Channel::builder(uri);

if let Some(settings) = settings {
if let Some(timeout) = settings.connect_timeout {
builder = builder.connect_timeout(timeout);
}
if let Some(pem) = settings.ca_certificate.as_ref() {
let cert = Certificate::from_pem(pem);
let tls_config = ClientTlsConfig::new()
.domain_name(host)
.ca_certificate(cert);
builder = builder
.tls_config(tls_config)
.expect("Failed to set TLS config");
}
}

builder.connect_lazy()
Expand Down Expand Up @@ -186,8 +197,10 @@ impl_transport_request_grpc!(
platform_proto::WaitForStateTransitionResultResponse,
PlatformGrpcClient,
RequestSettings {
timeout: Some(Duration::from_secs(120)),
..RequestSettings::default()
timeout: Some(Duration::from_secs(80)),
retries: Some(0),
ban_failed_address: None,
connect_timeout: None,
},
wait_for_state_transition_result
);
Expand Down Expand Up @@ -382,7 +395,9 @@ impl_transport_request_grpc!(
CoreGrpcClient,
RequestSettings {
timeout: Some(STREAMING_TIMEOUT),
..RequestSettings::default()
ban_failed_address: None,
connect_timeout: None,
retries: None,
},
subscribe_to_transactions_with_proofs
);
33 changes: 32 additions & 1 deletion packages/rs-sdk/src/sdk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,9 @@ pub struct SdkBuilder {

/// Cancellation token; once cancelled, all pending requests should be aborted.
pub(crate) cancel_token: CancellationToken,

/// CA certificate to use for TLS connections.
ca_certificate: Option<Vec<u8>>,
}

impl Default for SdkBuilder {
Expand Down Expand Up @@ -616,6 +619,8 @@ impl Default for SdkBuilder {

version: PlatformVersion::latest(),

ca_certificate: None,

#[cfg(feature = "mocks")]
dump_dir: None,
}
Expand Down Expand Up @@ -665,6 +670,28 @@ impl SdkBuilder {
self
}

/// Configure CA certificate to use when verifying TLS connections.
///
/// Used mainly for testing purposes and local networks.
///
/// If not set, uses standard system CA certificates.
pub fn with_ca_certificate(mut self, pem_certificate: &[u8]) -> Self {
self.ca_certificate = Some(pem_certificate.to_vec());
self
}

/// Load CA certificate from file.
///
/// This is a convenience method that reads the certificate from a file and sets it using
/// [SdkBuilder::with_ca_certificate()].
pub fn with_ca_certificate_file(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better approach would be to have just a single with_ca_certificate and expect an actual certificate value, users can figure out where and how to get it as long as it can be constructed

self,
certificate_file_path: impl AsRef<std::path::Path>,
) -> std::io::Result<Self> {
let pem = std::fs::read(certificate_file_path).expect("failed to read file");
Ok(self.with_ca_certificate(&pem))
}

/// Configure request settings.
///
/// Tune request settings used to connect to the Dash Platform.
Expand Down Expand Up @@ -757,7 +784,11 @@ impl SdkBuilder {
let sdk= match self.addresses {
// non-mock mode
Some(addresses) => {
let dapi = DapiClient::new(addresses, self.settings);
let mut dapi = DapiClient::new(addresses, self.settings);
if let Some(pem) = self.ca_certificate {
dapi = dapi.with_ca_certificate(&pem);
}

#[cfg(feature = "mocks")]
let dapi = dapi.dump_dir(self.dump_dir.clone());

Expand Down
1 change: 1 addition & 0 deletions packages/rs-sdk/tests/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
DASH_SDK_PLATFORM_HOST="127.0.0.1"
DASH_SDK_PLATFORM_PORT=2443
DASH_SDK_PLATFORM_SSL=false
# DASH_SDK_PLATFORM_CA_CERT_PATH=/some/path/to/ca.pem

# ProTxHash of masternode that has at least 1 vote casted for DPNS name `testname`
DASH_SDK_MASTERNODE_OWNER_PRO_REG_TX_HASH="6ac88f64622d9bc0cb79ad0f69657aa9488b213157d20ae0ca371fa5f04fb222"
Expand Down
13 changes: 10 additions & 3 deletions packages/rs-sdk/tests/fetch/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ pub struct Config {
#[serde(default)]
pub platform_ssl: bool,

/// When platform_ssl is true, use the PEM-encoded CA certificate from provided absolute path to verify the server
#[serde(default)]
pub platform_ca_cert_path: Option<PathBuf>,

/// Directory where all generated test vectors will be saved.
///
/// See [SdkBuilder::with_dump_dir()](crate::SdkBuilder::with_dump_dir()) for more details.
Expand Down Expand Up @@ -176,13 +180,17 @@ impl Config {
#[cfg(all(feature = "network-testing", not(feature = "offline-testing")))]
let sdk = {
// Dump all traffic to disk
let builder = dash_sdk::SdkBuilder::new(self.address_list()).with_core(
let mut builder = dash_sdk::SdkBuilder::new(self.address_list()).with_core(
&self.platform_host,
self.core_port,
&self.core_user,
&self.core_password,
);

if let Some(cert_file) = &self.platform_ca_cert_path {
builder = builder
.with_ca_certificate_file(cert_file)
.expect("load CA cert");
}
#[cfg(feature = "generate-test-vectors")]
let builder = {
// When we use namespaces, clean up the namespaced dump dir before starting
Expand Down Expand Up @@ -235,7 +243,6 @@ impl Config {
Encoding::Base58,
)
.unwrap()
.into()
}

fn default_data_contract_id() -> Identifier {
Expand Down
2 changes: 2 additions & 0 deletions packages/rs-sdk/tests/fetch/data_contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use drive_proof_verifier::types::DataContractHistory;
/// Given some dummy data contract ID, when I fetch data contract, I get None because it doesn't exist.
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_data_contract_read_not_found() {
super::common::setup_logs();

pub const DATA_CONTRACT_ID_BYTES: [u8; 32] = [1; 32];
let id = Identifier::from_bytes(&DATA_CONTRACT_ID_BYTES).expect("parse identity id");

Expand Down
Loading