diff --git a/sdk/identity/Cargo.toml b/sdk/identity/Cargo.toml index cb6246ad395..a1bf4af43ff 100644 --- a/sdk/identity/Cargo.toml +++ b/sdk/identity/Cargo.toml @@ -28,6 +28,7 @@ base64 = "0.13.0" uuid = { version = "1.0", features = ["v4"] } # work around https://github.com/rust-lang/rust/issues/63033 fix-hidden-lifetime-bug = "0.2" +pin-project = "1.0" [dev-dependencies] reqwest = { version = "0.11", features = ["json"], default-features = false } diff --git a/sdk/identity/src/lib.rs b/sdk/identity/src/lib.rs index 1642b2371fd..82427b4b340 100644 --- a/sdk/identity/src/lib.rs +++ b/sdk/identity/src/lib.rs @@ -49,6 +49,7 @@ pub mod development; pub mod device_code_flow; mod oauth2_http_client; pub mod refresh_token; +mod timeout; mod token_credentials; pub use crate::token_credentials::*; diff --git a/sdk/identity/src/timeout.rs b/sdk/identity/src/timeout.rs new file mode 100644 index 00000000000..20ec363301d --- /dev/null +++ b/sdk/identity/src/timeout.rs @@ -0,0 +1,70 @@ +// Copyright (c) 2020 Yoshua Wuyts +// +// based on https://crates.io/crates/futures-time +// Licensed under either of Apache License, Version 2.0 or MIT license at your option. + +use azure_core::sleep::{sleep, Sleep}; +use futures::Future; +use std::time::Duration; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +#[pin_project::pin_project] +#[derive(Debug)] +pub(crate) struct Timeout { + #[pin] + future: F, + #[pin] + deadline: D, + completed: bool, +} + +impl Timeout { + pub(crate) fn new(future: F, deadline: D) -> Self { + Self { + future, + deadline, + completed: false, + } + } +} + +impl Future for Timeout { + type Output = azure_core::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + assert!(!*this.completed, "future polled after completing"); + + match this.future.poll(cx) { + Poll::Ready(v) => { + *this.completed = true; + Poll::Ready(Ok(v)) + } + Poll::Pending => match this.deadline.poll(cx) { + Poll::Ready(_) => { + *this.completed = true; + Poll::Ready(Err(azure_core::error::Error::with_message( + azure_core::error::ErrorKind::Other, + || String::from("operation timed out"), + ))) + } + Poll::Pending => Poll::Pending, + }, + } + } +} + +pub(crate) trait TimeoutExt: Future { + fn timeout(self, duration: Duration) -> Timeout + where + Self: Sized, + { + Timeout::new(self, sleep(duration)) + } +} + +impl TimeoutExt for T where T: Future {} diff --git a/sdk/identity/src/token_credentials/default_credentials.rs b/sdk/identity/src/token_credentials/default_credentials.rs index 1b92e5b21b9..75ad6792637 100644 --- a/sdk/identity/src/token_credentials/default_credentials.rs +++ b/sdk/identity/src/token_credentials/default_credentials.rs @@ -1,6 +1,12 @@ -use super::{AzureCliCredential, ImdsManagedIdentityCredential}; -use azure_core::auth::{TokenCredential, TokenResponse}; -use azure_core::error::{Error, ErrorKind, ResultExt}; +use crate::{ + timeout::TimeoutExt, + {AzureCliCredential, ImdsManagedIdentityCredential}, +}; +use azure_core::{ + auth::{TokenCredential, TokenResponse}, + error::{Error, ErrorKind, ResultExt}, +}; +use std::time::Duration; #[derive(Debug)] /// Provides a mechanism of selectively disabling credentials used for a `DefaultAzureCredential` instance @@ -91,10 +97,19 @@ impl TokenCredential for DefaultAzureCredentialEnum { ) } DefaultAzureCredentialEnum::ManagedIdentity(credential) => { - credential.get_token(resource).await.context( - ErrorKind::Credential, - "error getting managed identity credential", - ) + // IMSD timeout is only limited to 1 second when used in DefaultAzureCredential + credential + .get_token(resource) + .timeout(Duration::from_secs(1)) + .await + .context( + ErrorKind::Credential, + "getting managed identity credential timed out", + )? + .context( + ErrorKind::Credential, + "error getting managed identity credential", + ) } DefaultAzureCredentialEnum::AzureCli(credential) => { credential.get_token(resource).await.context( @@ -128,15 +143,7 @@ impl DefaultAzureCredential { impl Default for DefaultAzureCredential { fn default() -> Self { - DefaultAzureCredential { - sources: vec![ - DefaultAzureCredentialEnum::Environment(super::EnvironmentCredential::default()), - DefaultAzureCredentialEnum::ManagedIdentity( - ImdsManagedIdentityCredential::default(), - ), - DefaultAzureCredentialEnum::AzureCli(AzureCliCredential::new()), - ], - } + DefaultAzureCredentialBuilder::new().build() } }