Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/core/azure_core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- Added `ErrorKind::Connection` for connection errors.
- The `reqwest` HTTP client now classifies connection errors as `ErrorKind::Connection`.
- Added `SecretBytes` to `azure_core::credentials` for securely passing byte secrets without printing them in `Debug` or `Display` output.

### Breaking Changes

Expand Down
128 changes: 125 additions & 3 deletions sdk/core/azure_core/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

//! Azure authentication and authorization.

use crate::Bytes;
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, fmt::Debug};
use std::{borrow::Cow, fmt};
use typespec_client_core::{fmt::SafeDebug, http::ClientMethodOptions, time::OffsetDateTime};

/// Represents a secret.
Expand Down Expand Up @@ -58,12 +59,133 @@ impl From<&'static str> for Secret {
}
}

impl Debug for Secret {
impl fmt::Debug for Secret {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Secret")
}
}

/// Represents secret bytes, e.g., certificate data.
///
/// Neither the [`Debug`](fmt::Debug) nor the [`Display`](fmt::Display) implementation will print the bytes.
#[derive(Clone, Eq)]
pub struct SecretBytes(Vec<u8>);

impl SecretBytes {
/// Create a new `SecretBytes`.
pub fn new(bytes: impl Into<Vec<u8>>) -> Self {
Self(bytes.into())
}

/// Get the secret bytes.
pub fn bytes(&self) -> &[u8] {
&self.0
}
}

// NOTE: this is a constant time compare, however LLVM may (and probably will)
// optimize this in unexpected ways.
impl PartialEq for SecretBytes {
fn eq(&self, other: &Self) -> bool {
let a = self.bytes();
let b = other.bytes();

if a.len() != b.len() {
return false;
}

a.iter().zip(b.iter()).fold(0, |acc, (a, b)| acc | (a ^ b)) == 0
}
}

impl fmt::Debug for SecretBytes {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("SecretBytes")
}
}

impl fmt::Display for SecretBytes {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("SecretBytes")
}
}

impl From<Bytes> for SecretBytes {
fn from(bytes: Bytes) -> Self {
Self(bytes.to_vec())
}
}

impl From<&[u8]> for SecretBytes {
fn from(bytes: &[u8]) -> Self {
Self(bytes.to_vec())
}
}

impl From<Vec<u8>> for SecretBytes {
fn from(bytes: Vec<u8>) -> Self {
Self(bytes)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn debug_does_not_print_bytes() {
let secret = SecretBytes::new(b"super-secret".to_vec());
assert_eq!("SecretBytes", format!("{secret:?}"));
}

#[test]
fn display_does_not_print_bytes() {
let secret = SecretBytes::new(b"super-secret".to_vec());
assert_eq!("SecretBytes", format!("{secret}"));
}

#[test]
fn eq_same_bytes() {
let a = SecretBytes::new(b"hello".to_vec());
let b = SecretBytes::new(b"hello".to_vec());
assert_eq!(a, b);
}

#[test]
fn ne_different_bytes() {
let a = SecretBytes::new(b"hello".to_vec());
let b = SecretBytes::new(b"world".to_vec());
assert_ne!(a, b);
}

#[test]
fn ne_different_lengths() {
let a = SecretBytes::new(b"hello".to_vec());
let b = SecretBytes::new(b"hello!".to_vec());
assert_ne!(a, b);
}

#[test]
fn from_bytes_type() {
let bytes = Bytes::from_static(b"data");
let secret = SecretBytes::from(bytes);
assert_eq!(b"data", secret.bytes());
}

#[test]
fn from_slice() {
let data: &[u8] = b"data";
let secret = SecretBytes::from(data);
assert_eq!(b"data", secret.bytes());
}

#[test]
fn from_vec() {
let secret = SecretBytes::from(b"data".to_vec());
assert_eq!(b"data", secret.bytes());
}
}

/// Represents an Azure service bearer access token with expiry information.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessToken {
Expand Down Expand Up @@ -95,7 +217,7 @@ pub struct TokenRequestOptions<'a> {

/// Represents a credential capable of providing an OAuth token.
#[async_trait::async_trait]
pub trait TokenCredential: Send + Sync + Debug {
pub trait TokenCredential: Send + Sync + fmt::Debug {
/// Gets an [`AccessToken`] for the specified scopes
async fn get_token(
&self,
Expand Down
3 changes: 1 addition & 2 deletions sdk/identity/azure_identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

## 0.33.0 (Unreleased)

### Features Added

### Breaking Changes

- Support for `wasm32-unknown-unknown` has been removed ([#3377](https://github.com/Azure/azure-sdk-for-rust/issues/3377))
- `ClientCertificateCredential::new()` now takes `SecretBytes` instead of `Secret` for the `certificate` parameter. Pass the raw PKCS12 bytes wrapped in `SecretBytes` instead of a base64-encoded string wrapped in `Secret`.

### Bugs Fixed

Expand Down
53 changes: 21 additions & 32 deletions sdk/identity/azure_identity/src/client_certificate_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
};
use azure_core::{
base64,
credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions},
credentials::{AccessToken, Secret, SecretBytes, TokenCredential, TokenRequestOptions},
error::{Error, ErrorKind, ResultExt},
http::{
headers::{self, content_type},
Expand Down Expand Up @@ -64,26 +64,22 @@ impl ClientCertificateCredential {
/// # Arguments
/// - `tenant_id`: The tenant (directory) ID of the service principal.
/// - `client_id`: The client (application) ID of the service principal.
/// - `certificate`: A base64-encoded PKCS12 certificate with its RSA private key.
/// - `certificate`: The PKCS12 certificate bytes with its RSA private key.
/// - `options`: Options for configuring the credential. If `None`, the credential uses its default options.
///
pub fn new(
tenant_id: String,
client_id: String,
certificate: Secret,
certificate: SecretBytes,
options: Option<ClientCertificateCredentialOptions>,
) -> azure_core::Result<Arc<ClientCertificateCredential>> {
validate_tenant_id(&tenant_id)?;
validate_not_empty(&client_id, "no client ID specified")?;

let options = options.unwrap_or_default();

let cert_bytes = base64::decode(certificate.secret())
.with_context_fn(ErrorKind::Credential, || {
"failed to decode base64 certificate data"
})?;

let (key, cert, ca_chain) = parse_certificate(&cert_bytes, options.password.as_ref())?;
let (key, cert, ca_chain) =
parse_certificate(certificate.bytes(), options.password.as_ref())?;
let thumbprint = cert
.digest(MessageDigest::sha1())
.with_context(ErrorKind::Credential, "failed to compute thumbprint")?
Expand Down Expand Up @@ -212,7 +208,7 @@ impl ClientCertificateCredential {
}
}

/// Parse a base64-encoded PKCS12 certificate into key, certificate, and optional CA chain.
/// Parse a PKCS12 certificate into key, certificate, and optional CA chain.
fn parse_certificate(
cert_bytes: &[u8],
password: Option<&Secret>,
Expand Down Expand Up @@ -289,6 +285,7 @@ mod tests {
use super::*;
use crate::{client_assertion_credential::tests::is_valid_request, tests::*};
use azure_core::{
credentials::SecretBytes,
http::{
headers::Headers,
policies::{Policy, PolicyResult},
Expand All @@ -302,13 +299,12 @@ mod tests {
sync::{Arc, LazyLock},
};

static TEST_CERT: LazyLock<String> = LazyLock::new(|| {
let pfx = std::fs::read(concat!(
static TEST_CERT: LazyLock<Vec<u8>> = LazyLock::new(|| {
std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/certificate.pfx"
))
.expect("failed to read test certificate");
base64::encode(pfx)
.expect("failed to read test certificate")
});

#[derive(Debug, Clone)]
Expand All @@ -319,9 +315,8 @@ mod tests {
}

impl VerifyAssertionPolicy {
fn new(certificate: String, expect_x5c: bool) -> Self {
let pfx = base64::decode(certificate).expect("base64 encoding");
let (_, cert, _) = parse_certificate(&pfx, None).expect("valid certificate");
fn new(certificate: &[u8], expect_x5c: bool) -> Self {
let (_, cert, _) = parse_certificate(certificate, None).expect("valid certificate");
let public_key = cert.public_key().expect("public key");
let cert_der = cert.to_der().expect("valid certificate");
Self {
Expand Down Expand Up @@ -430,7 +425,7 @@ mod tests {
let credential = ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
Secret::new(TEST_CERT.to_string()),
SecretBytes::from(TEST_CERT.as_slice()),
Some(ClientCertificateCredentialOptions {
client_options: ClientOptions {
transport: Some(Transport::new(Arc::new(sts))),
Expand Down Expand Up @@ -468,7 +463,7 @@ mod tests {
let credential = ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
TEST_CERT.as_str().into(),
TEST_CERT.as_slice().into(),
Some(ClientCertificateCredentialOptions {
client_options: ClientOptions {
transport: Some(Transport::new(Arc::new(sts))),
Expand Down Expand Up @@ -518,14 +513,11 @@ mod tests {
let credential = ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
TEST_CERT.as_str().into(),
TEST_CERT.as_slice().into(),
Some(ClientCertificateCredentialOptions {
client_options: ClientOptions {
transport: Some(Transport::new(Arc::new(sts))),
per_try_policies: vec![Arc::new(VerifyAssertionPolicy::new(
TEST_CERT.to_string(),
false,
))],
per_try_policies: vec![Arc::new(VerifyAssertionPolicy::new(&TEST_CERT, false))],
..Default::default()
},
..Default::default()
Expand Down Expand Up @@ -559,7 +551,7 @@ mod tests {
ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
"not a certificate".into(),
b"not a certificate".as_slice().into(),
None,
)
.expect_err("invalid certificate");
Expand All @@ -570,7 +562,7 @@ mod tests {
ClientCertificateCredential::new(
"not a valid tenant".to_string(),
FAKE_CLIENT_ID.to_string(),
TEST_CERT.as_str().into(),
TEST_CERT.as_slice().into(),
None,
)
.expect_err("invalid tenant ID");
Expand All @@ -581,7 +573,7 @@ mod tests {
ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
TEST_CERT.as_str().into(),
TEST_CERT.as_slice().into(),
None,
)
.expect("valid credential")
Expand All @@ -602,14 +594,11 @@ mod tests {
let credential = ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
TEST_CERT.as_str().into(),
TEST_CERT.as_slice().into(),
Some(ClientCertificateCredentialOptions {
client_options: ClientOptions {
transport: Some(Transport::new(Arc::new(sts))),
per_try_policies: vec![Arc::new(VerifyAssertionPolicy::new(
TEST_CERT.to_string(),
true,
))],
per_try_policies: vec![Arc::new(VerifyAssertionPolicy::new(&TEST_CERT, true))],
..Default::default()
},
env: Some(Env::from(
Expand Down
Loading