diff --git a/sdk/core/azure_core/CHANGELOG.md b/sdk/core/azure_core/CHANGELOG.md index 450502f624b..c007ad9e6a9 100644 --- a/sdk/core/azure_core/CHANGELOG.md +++ b/sdk/core/azure_core/CHANGELOG.md @@ -6,6 +6,7 @@ - Added `ErrorKind::Connection` for connection errors. - The `reqwest` HTTP client now classifies connection errors as `ErrorKind::Connection`. +- Added `SecretBytes` to `azure_core::credentials` for securely passing byte secrets without printing them in `Debug` or `Display` output. ### Breaking Changes diff --git a/sdk/core/azure_core/src/credentials.rs b/sdk/core/azure_core/src/credentials.rs index 3b8009a89c6..9e1276bfed3 100644 --- a/sdk/core/azure_core/src/credentials.rs +++ b/sdk/core/azure_core/src/credentials.rs @@ -3,8 +3,9 @@ //! Azure authentication and authorization. +use crate::Bytes; use serde::{Deserialize, Serialize}; -use std::{borrow::Cow, fmt::Debug}; +use std::{borrow::Cow, fmt}; use typespec_client_core::{fmt::SafeDebug, http::ClientMethodOptions, time::OffsetDateTime}; /// Represents a secret. @@ -58,12 +59,133 @@ impl From<&'static str> for Secret { } } -impl Debug for Secret { +impl fmt::Debug for Secret { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("Secret") } } +/// Represents secret bytes, e.g., certificate data. +/// +/// Neither the [`Debug`](fmt::Debug) nor the [`Display`](fmt::Display) implementation will print the bytes. +#[derive(Clone, Eq)] +pub struct SecretBytes(Vec); + +impl SecretBytes { + /// Create a new `SecretBytes`. + pub fn new(bytes: impl Into>) -> Self { + Self(bytes.into()) + } + + /// Get the secret bytes. + pub fn bytes(&self) -> &[u8] { + &self.0 + } +} + +// NOTE: this is a constant time compare, however LLVM may (and probably will) +// optimize this in unexpected ways. +impl PartialEq for SecretBytes { + fn eq(&self, other: &Self) -> bool { + let a = self.bytes(); + let b = other.bytes(); + + if a.len() != b.len() { + return false; + } + + a.iter().zip(b.iter()).fold(0, |acc, (a, b)| acc | (a ^ b)) == 0 + } +} + +impl fmt::Debug for SecretBytes { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("SecretBytes") + } +} + +impl fmt::Display for SecretBytes { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("SecretBytes") + } +} + +impl From for SecretBytes { + fn from(bytes: Bytes) -> Self { + Self(bytes.to_vec()) + } +} + +impl From<&[u8]> for SecretBytes { + fn from(bytes: &[u8]) -> Self { + Self(bytes.to_vec()) + } +} + +impl From> for SecretBytes { + fn from(bytes: Vec) -> Self { + Self(bytes) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn debug_does_not_print_bytes() { + let secret = SecretBytes::new(b"super-secret".to_vec()); + assert_eq!("SecretBytes", format!("{secret:?}")); + } + + #[test] + fn display_does_not_print_bytes() { + let secret = SecretBytes::new(b"super-secret".to_vec()); + assert_eq!("SecretBytes", format!("{secret}")); + } + + #[test] + fn eq_same_bytes() { + let a = SecretBytes::new(b"hello".to_vec()); + let b = SecretBytes::new(b"hello".to_vec()); + assert_eq!(a, b); + } + + #[test] + fn ne_different_bytes() { + let a = SecretBytes::new(b"hello".to_vec()); + let b = SecretBytes::new(b"world".to_vec()); + assert_ne!(a, b); + } + + #[test] + fn ne_different_lengths() { + let a = SecretBytes::new(b"hello".to_vec()); + let b = SecretBytes::new(b"hello!".to_vec()); + assert_ne!(a, b); + } + + #[test] + fn from_bytes_type() { + let bytes = Bytes::from_static(b"data"); + let secret = SecretBytes::from(bytes); + assert_eq!(b"data", secret.bytes()); + } + + #[test] + fn from_slice() { + let data: &[u8] = b"data"; + let secret = SecretBytes::from(data); + assert_eq!(b"data", secret.bytes()); + } + + #[test] + fn from_vec() { + let secret = SecretBytes::from(b"data".to_vec()); + assert_eq!(b"data", secret.bytes()); + } +} + /// Represents an Azure service bearer access token with expiry information. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AccessToken { @@ -95,7 +217,7 @@ pub struct TokenRequestOptions<'a> { /// Represents a credential capable of providing an OAuth token. #[async_trait::async_trait] -pub trait TokenCredential: Send + Sync + Debug { +pub trait TokenCredential: Send + Sync + fmt::Debug { /// Gets an [`AccessToken`] for the specified scopes async fn get_token( &self, diff --git a/sdk/identity/azure_identity/CHANGELOG.md b/sdk/identity/azure_identity/CHANGELOG.md index 2f1b2414799..ea363ac5b5a 100644 --- a/sdk/identity/azure_identity/CHANGELOG.md +++ b/sdk/identity/azure_identity/CHANGELOG.md @@ -2,11 +2,10 @@ ## 0.33.0 (Unreleased) -### Features Added - ### Breaking Changes - Support for `wasm32-unknown-unknown` has been removed ([#3377](https://github.com/Azure/azure-sdk-for-rust/issues/3377)) +- `ClientCertificateCredential::new()` now takes `SecretBytes` instead of `Secret` for the `certificate` parameter. Pass the raw PKCS12 bytes wrapped in `SecretBytes` instead of a base64-encoded string wrapped in `Secret`. ### Bugs Fixed diff --git a/sdk/identity/azure_identity/src/client_certificate_credential.rs b/sdk/identity/azure_identity/src/client_certificate_credential.rs index c236147b6d5..2c7e7ee90ce 100644 --- a/sdk/identity/azure_identity/src/client_certificate_credential.rs +++ b/sdk/identity/azure_identity/src/client_certificate_credential.rs @@ -7,7 +7,7 @@ use crate::{ }; use azure_core::{ base64, - credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions}, + credentials::{AccessToken, Secret, SecretBytes, TokenCredential, TokenRequestOptions}, error::{Error, ErrorKind, ResultExt}, http::{ headers::{self, content_type}, @@ -64,13 +64,13 @@ impl ClientCertificateCredential { /// # Arguments /// - `tenant_id`: The tenant (directory) ID of the service principal. /// - `client_id`: The client (application) ID of the service principal. - /// - `certificate`: A base64-encoded PKCS12 certificate with its RSA private key. + /// - `certificate`: The PKCS12 certificate bytes with its RSA private key. /// - `options`: Options for configuring the credential. If `None`, the credential uses its default options. /// pub fn new( tenant_id: String, client_id: String, - certificate: Secret, + certificate: SecretBytes, options: Option, ) -> azure_core::Result> { validate_tenant_id(&tenant_id)?; @@ -78,12 +78,8 @@ impl ClientCertificateCredential { let options = options.unwrap_or_default(); - let cert_bytes = base64::decode(certificate.secret()) - .with_context_fn(ErrorKind::Credential, || { - "failed to decode base64 certificate data" - })?; - - let (key, cert, ca_chain) = parse_certificate(&cert_bytes, options.password.as_ref())?; + let (key, cert, ca_chain) = + parse_certificate(certificate.bytes(), options.password.as_ref())?; let thumbprint = cert .digest(MessageDigest::sha1()) .with_context(ErrorKind::Credential, "failed to compute thumbprint")? @@ -212,7 +208,7 @@ impl ClientCertificateCredential { } } -/// Parse a base64-encoded PKCS12 certificate into key, certificate, and optional CA chain. +/// Parse a PKCS12 certificate into key, certificate, and optional CA chain. fn parse_certificate( cert_bytes: &[u8], password: Option<&Secret>, @@ -289,6 +285,7 @@ mod tests { use super::*; use crate::{client_assertion_credential::tests::is_valid_request, tests::*}; use azure_core::{ + credentials::SecretBytes, http::{ headers::Headers, policies::{Policy, PolicyResult}, @@ -302,13 +299,12 @@ mod tests { sync::{Arc, LazyLock}, }; - static TEST_CERT: LazyLock = LazyLock::new(|| { - let pfx = std::fs::read(concat!( + static TEST_CERT: LazyLock> = LazyLock::new(|| { + std::fs::read(concat!( env!("CARGO_MANIFEST_DIR"), "/tests/certificate.pfx" )) - .expect("failed to read test certificate"); - base64::encode(pfx) + .expect("failed to read test certificate") }); #[derive(Debug, Clone)] @@ -319,9 +315,8 @@ mod tests { } impl VerifyAssertionPolicy { - fn new(certificate: String, expect_x5c: bool) -> Self { - let pfx = base64::decode(certificate).expect("base64 encoding"); - let (_, cert, _) = parse_certificate(&pfx, None).expect("valid certificate"); + fn new(certificate: &[u8], expect_x5c: bool) -> Self { + let (_, cert, _) = parse_certificate(certificate, None).expect("valid certificate"); let public_key = cert.public_key().expect("public key"); let cert_der = cert.to_der().expect("valid certificate"); Self { @@ -430,7 +425,7 @@ mod tests { let credential = ClientCertificateCredential::new( FAKE_TENANT_ID.to_string(), FAKE_CLIENT_ID.to_string(), - Secret::new(TEST_CERT.to_string()), + SecretBytes::from(TEST_CERT.as_slice()), Some(ClientCertificateCredentialOptions { client_options: ClientOptions { transport: Some(Transport::new(Arc::new(sts))), @@ -468,7 +463,7 @@ mod tests { let credential = ClientCertificateCredential::new( FAKE_TENANT_ID.to_string(), FAKE_CLIENT_ID.to_string(), - TEST_CERT.as_str().into(), + TEST_CERT.as_slice().into(), Some(ClientCertificateCredentialOptions { client_options: ClientOptions { transport: Some(Transport::new(Arc::new(sts))), @@ -518,14 +513,11 @@ mod tests { let credential = ClientCertificateCredential::new( FAKE_TENANT_ID.to_string(), FAKE_CLIENT_ID.to_string(), - TEST_CERT.as_str().into(), + TEST_CERT.as_slice().into(), Some(ClientCertificateCredentialOptions { client_options: ClientOptions { transport: Some(Transport::new(Arc::new(sts))), - per_try_policies: vec![Arc::new(VerifyAssertionPolicy::new( - TEST_CERT.to_string(), - false, - ))], + per_try_policies: vec![Arc::new(VerifyAssertionPolicy::new(&TEST_CERT, false))], ..Default::default() }, ..Default::default() @@ -559,7 +551,7 @@ mod tests { ClientCertificateCredential::new( FAKE_TENANT_ID.to_string(), FAKE_CLIENT_ID.to_string(), - "not a certificate".into(), + b"not a certificate".as_slice().into(), None, ) .expect_err("invalid certificate"); @@ -570,7 +562,7 @@ mod tests { ClientCertificateCredential::new( "not a valid tenant".to_string(), FAKE_CLIENT_ID.to_string(), - TEST_CERT.as_str().into(), + TEST_CERT.as_slice().into(), None, ) .expect_err("invalid tenant ID"); @@ -581,7 +573,7 @@ mod tests { ClientCertificateCredential::new( FAKE_TENANT_ID.to_string(), FAKE_CLIENT_ID.to_string(), - TEST_CERT.as_str().into(), + TEST_CERT.as_slice().into(), None, ) .expect("valid credential") @@ -602,14 +594,11 @@ mod tests { let credential = ClientCertificateCredential::new( FAKE_TENANT_ID.to_string(), FAKE_CLIENT_ID.to_string(), - TEST_CERT.as_str().into(), + TEST_CERT.as_slice().into(), Some(ClientCertificateCredentialOptions { client_options: ClientOptions { transport: Some(Transport::new(Arc::new(sts))), - per_try_policies: vec![Arc::new(VerifyAssertionPolicy::new( - TEST_CERT.to_string(), - true, - ))], + per_try_policies: vec![Arc::new(VerifyAssertionPolicy::new(&TEST_CERT, true))], ..Default::default() }, env: Some(Env::from(