Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse the Client in all cases #327

Merged
merged 7 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,63 @@ use crate::{
InitAndRotationCheck::{NoRotationNeeded, RotationNeeded},
Result,
};
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};

#[cfg(feature = "beta")]
use crate::search::{BlindIndexSearchInitialize, EncryptedBlindIndexSalt};
use tokio::runtime::Runtime;

/// Struct that is used to hold the regular DeviceContext as well as a runtime that will be used
/// when initializing a BlockingIronOxide. This was added to fix a bug where initializing multiple
/// SDK instances with a single device would hang indefinitely (as each initialization call would
/// create its own runtime but share a request client)
#[derive(Clone, Debug)]
pub struct BlockingDeviceContext {
pub device: DeviceContext,
pub(crate) rt: Arc<Runtime>,
}

impl From<DeviceAddResult> for BlockingDeviceContext {
fn from(value: DeviceAddResult) -> Self {
Self {
device: value.into(),
rt: Arc::new(create_runtime()),
}
}
}

impl BlockingDeviceContext {
pub fn new(device: DeviceContext) -> Self {
Self {
device,
rt: Arc::new(create_runtime()),
}
}
/// ID of the device's owner
pub fn account_id(&self) -> &UserId {
&self.device.auth().account_id()
}
/// ID of the segment
pub fn segment_id(&self) -> usize {
self.device.auth().segment_id()
}
/// Private signing key of the device
pub fn signing_private_key(&self) -> &DeviceSigningKeyPair {
&self.device.auth().signing_private_key()
}
/// Private encryption key of the device
pub fn device_private_key(&self) -> &PrivateKey {
&self.device.device_private_key()
}
}

/// Struct that is used to make authenticated requests to the IronCore API. Instantiated with the details
/// of an account's various ids, device, and signing keys. Once instantiated all operations will be
/// performed in the context of the account provided. Identical to IronOxide but also contains a Runtime.
#[derive(Debug)]
pub struct BlockingIronOxide {
pub(crate) ironoxide: IronOxide,
pub(crate) runtime: tokio::runtime::Runtime,
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
}

impl BlockingIronOxide {
Expand Down Expand Up @@ -293,35 +338,38 @@ fn create_runtime() -> tokio::runtime::Runtime {
/// Initialize the BlockingIronOxide SDK with a device. Verifies that the provided user/segment exists and the provided device
/// keys are valid and exist for the provided account. If successful, returns instance of the BlockingIronOxide SDK.
pub fn initialize(
device_context: &DeviceContext,
device_context: &BlockingDeviceContext,
config: &IronOxideConfig,
) -> Result<BlockingIronOxide> {
let rt = create_runtime();
let maybe_io = rt.block_on(crate::initialize(device_context, config));
let maybe_io = device_context
.rt
.block_on(crate::initialize(&device_context.device, config));
maybe_io.map(|io| BlockingIronOxide {
ironoxide: io,
runtime: rt,
runtime: device_context.rt.clone(),
})
}

/// Initialize the BlockingIronOxide SDK and check to see if the user that owns this `DeviceContext` is
/// marked for private key rotation, or if any of the groups that the user is an admin of are marked
/// for private key rotation.
pub fn initialize_check_rotation(
device_context: &DeviceContext,
device_context: &BlockingDeviceContext,
config: &IronOxideConfig,
) -> Result<InitAndRotationCheck<BlockingIronOxide>> {
let rt = create_runtime();
let maybe_init = rt.block_on(crate::initialize_check_rotation(device_context, config));
let maybe_init = device_context.rt.block_on(crate::initialize_check_rotation(
&device_context.device,
config,
));
maybe_init.map(|init| match init {
NoRotationNeeded(io) => NoRotationNeeded(BlockingIronOxide {
ironoxide: io,
runtime: rt,
runtime: device_context.rt.clone(),
}),
RotationNeeded(io, rot) => RotationNeeded(
BlockingIronOxide {
ironoxide: io,
runtime: rt,
runtime: device_context.rt.clone(),
},
rot,
),
Expand Down
6 changes: 2 additions & 4 deletions src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use crate::internal::{
group_api::GroupId,
rest::{Authorization, IronCoreRequest, SignatureUrlString},
rest::{Authorization, SignatureUrlString},
user_api::UserId,
};
use base64::engine::Engine;
Expand Down Expand Up @@ -34,6 +34,7 @@ pub mod document_api;
pub mod group_api;
mod rest;
pub mod user_api;
pub use rest::IronCoreRequest;

lazy_static! {
pub static ref URL_STRING: String = match std::env::var("IRONCORE_ENV") {
Expand All @@ -45,9 +46,6 @@ lazy_static! {
.to_string(),
_ => "https://api.ironcorelabs.com/api/1/".to_string(),
};
static ref SHARED_CLIENT: reqwest::Client = reqwest::Client::new();
pub static ref OUR_REQUEST: IronCoreRequest =
IronCoreRequest::new(URL_STRING.as_str(), &SHARED_CLIENT);
}

#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
Expand Down
44 changes: 25 additions & 19 deletions src/internal/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::internal::{
auth_v2::AuthV2Builder,
user_api::{Jwt, UserId},
DeviceSigningKeyPair, IronOxideErr, RequestErrorCode, OUR_REQUEST,
DeviceSigningKeyPair, IronOxideErr, RequestErrorCode, URL_STRING,
};
use base64::engine::Engine;
use base64::prelude::BASE64_STANDARD;
Expand Down Expand Up @@ -303,20 +303,20 @@ impl<'a> HeaderIronCoreRequestSig<'a> {
}

///A struct which holds the basic info that will be needed for making requests to an ironcore service. Currently just the base_url.
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct IronCoreRequest {
base_url: &'static str,
#[serde(skip_serializing, skip_deserializing, default = "default_client")]
client: &'static reqwest::Client,
pub(crate) client: reqwest::Client,
}

fn default_client() -> &'static reqwest::Client {
OUR_REQUEST.client
fn default_client() -> reqwest::Client {
Client::new()
}

impl Default for IronCoreRequest {
fn default() -> Self {
*OUR_REQUEST
IronCoreRequest::new(&URL_STRING, default_client())
}
}
impl Hash for IronCoreRequest {
Expand All @@ -332,7 +332,7 @@ impl PartialEq for IronCoreRequest {
impl Eq for IronCoreRequest {}

impl IronCoreRequest {
pub const fn new(base_url: &'static str, client: &'static reqwest::Client) -> IronCoreRequest {
pub const fn new(base_url: &'static str, client: reqwest::Client) -> IronCoreRequest {
IronCoreRequest { base_url, client }
}

Expand Down Expand Up @@ -415,7 +415,7 @@ impl IronCoreRequest {
replace_headers(req.headers_mut(), auth.to_auth_header());
replace_headers(req.headers_mut(), request_sig.to_header());

Self::send_req(req, error_code, move |server_resp| {
self.send_req(req, error_code, move |server_resp| {
IronCoreRequest::deserialize_body(server_resp, error_code)
})
.await
Expand Down Expand Up @@ -551,7 +551,7 @@ impl IronCoreRequest {
Q: Serialize + ?Sized,
F: FnOnce(&Bytes) -> Result<B, IronOxideErr>,
{
let client = Client::new();
let client = self.client.clone();
let mut builder = client.request(
method,
format!("{}{}", self.base_url, relative_url).as_str(),
Expand Down Expand Up @@ -632,7 +632,7 @@ impl IronCoreRequest {
replace_headers(req.headers_mut(), auth.to_auth_header());
replace_headers(req.headers_mut(), request_sig.to_header());

Self::send_req(req, error_code, resp_handler).await
self.send_req(req, error_code, resp_handler).await
} else {
panic!("authorized requests must use version 2 of API authentication")
}
Expand All @@ -653,6 +653,7 @@ impl IronCoreRequest {
}

async fn send_req<B, F>(
&self,
req: Request,
error_code: RequestErrorCode,
resp_handler: F,
Expand All @@ -661,7 +662,7 @@ impl IronCoreRequest {
B: DeserializeOwned,
F: FnOnce(&Bytes) -> Result<B, IronOxideErr>,
{
let client = Client::new();
let client = self.client.clone();
let server_res = client.execute(req).await;
let res = server_res.map_err(|e| (e, error_code))?;
//Parse the body content into bytes
Expand Down Expand Up @@ -1049,12 +1050,11 @@ mod tests {

use recrypt::api::{Ed25519Signature, PublicSigningKey};

lazy_static! {
static ref SHARED_CLIENT: reqwest::Client = reqwest::Client::new();
static ref TEST_REQUEST: IronCoreRequest = IronCoreRequest {
fn create_test_request() -> IronCoreRequest {
IronCoreRequest {
base_url: "https://example.com",
client: &SHARED_CLIENT
};
client: Client::new(),
}
}

#[test]
Expand Down Expand Up @@ -1238,7 +1238,8 @@ mod tests {
public_signing_key: signing_keys.public_key(),
};

let build_url = |relative_url| format!("{}{}", OUR_REQUEST.base_url(), relative_url);
let build_url =
|relative_url| format!("{}{}", IronCoreRequest::default().base_url(), relative_url);
let signing_url_string = SignatureUrlString::new(&build_url("users?id=user-10")).unwrap();

// note that this and the expected value must correspond
Expand Down Expand Up @@ -1378,7 +1379,7 @@ mod tests {
fn query_params_encoded_correctly() {
let mut req = Request::new(
Method::GET,
url::Url::parse(&format!("{}/{}", TEST_REQUEST.base_url(), "users")).unwrap(),
url::Url::parse(&format!("{}/{}", create_test_request().base_url(), "users")).unwrap(),
);
let q = "!\"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";
IronCoreRequest::req_add_query(&mut req, &[("id".to_string(), url_encode(q))]);
Expand All @@ -1391,7 +1392,12 @@ mod tests {
fn empty_query_params_encoded_correctly() {
let mut req = Request::new(
Method::GET,
url::Url::parse(&format!("{}/{}", TEST_REQUEST.base_url(), "policies")).unwrap(),
url::Url::parse(&format!(
"{}/{}",
create_test_request().base_url(),
"policies"
))
.unwrap(),
);
IronCoreRequest::req_add_query(&mut req, &[]);
assert_eq!(req.url().query(), None);
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ pub mod search;
#[cfg(feature = "blocking")]
pub mod blocking;

pub use crate::internal::IronOxideErr;
pub use crate::internal::{IronCoreRequest, IronOxideErr};

use crate::{
common::{DeviceContext, DeviceSigningKeyPair, PublicKey, SdkOperation},
Expand Down
10 changes: 5 additions & 5 deletions src/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ pub use crate::internal::user_api::{
};
use crate::{
common::{PublicKey, SdkOperation},
internal::{add_optional_timeout, user_api, OUR_REQUEST},
IronOxide, Result,
internal::{add_optional_timeout, user_api},
IronCoreRequest, IronOxide, Result,
};
use futures::Future;
use recrypt::api::Recrypt;
Expand Down Expand Up @@ -299,7 +299,7 @@ impl UserOps for IronOxide {
jwt,
password.try_into()?,
user_create_opts.needs_rotation,
*OUR_REQUEST,
IronCoreRequest::default(),
),
timeout,
SdkOperation::UserCreate,
Expand All @@ -324,7 +324,7 @@ impl UserOps for IronOxide {
password.try_into()?,
device_create_options.device_name,
&std::time::SystemTime::now().into(),
&OUR_REQUEST,
&IronCoreRequest::default(),
),
timeout,
SdkOperation::GenerateNewDevice,
Expand All @@ -337,7 +337,7 @@ impl UserOps for IronOxide {
timeout: Option<std::time::Duration>,
) -> Result<Option<UserResult>> {
add_optional_timeout(
user_api::user_verify(jwt, *OUR_REQUEST),
user_api::user_verify(jwt, IronCoreRequest::default()),
timeout,
SdkOperation::UserVerify,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/blocking_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod common;
// Note: The blocking functions need minimal testing as they primarily just call their async counterparts

#[cfg(feature = "blocking")]
mod integration_tests {
mod blocking_integration_tests {
use crate::common::{create_id_all_classes, gen_jwt, USER_PASSWORD};
use galvanic_assert::{matchers::*, *};
use ironoxide::prelude::*;
Expand Down
Loading